D_oracle_simple_test.Rmd
library(stats)
library(MissImp)
knitr::opts_chunk$set(eval = FALSE)
We generate a complete data set \((Y_j)_{1 \leq j \leq 8}\) with \((Y_1, Y_2) \sim \mathcal{N}(\mu_1, \Sigma_1)\), \(Y_3 \sim \text{Bin}(p)\), and \(Y_4\) is a binary variable that depends on the values of \(Y_1\) and \(Y_2\) through a logistic regression.
complete_df_generator <- function(n) {
mu1.X <- c(9, 8)
Sigma1.X <- matrix(c(
16, 14,
14, 25
), nrow = 2)
X.complete.cont1 <- MASS::mvrnorm(n, mu1.X, Sigma1.X) # multivariate normal distribution
X.complete.cat <- stats::rbinom(n, size = 5, prob = 0.4) # binomial
X.complete.cat2 <- (0.5 * X.complete.cont1[, 1] + 0.2 * X.complete.cont1[, 2] > 8) * 1
X.complete <- data.frame(cbind(
X.complete.cont1,
X.complete.cat, X.complete.cat2
))
X.complete[, 3] <- as.factor(X.complete[, 3])
X.complete[, 4] <- as.factor(X.complete[, 4])
colnames(X.complete) <- c("Y1", "Y2", "Y3", "Y4")
return(X.complete)
}
n <- 4000
X.complete <- complete_df_generator(n)
mech <- "MAR1"
miss_prop <- 0.3
col_cat <- c(3, 4)
col_dis <- c()
col_num <- c(1, 2)
rs <- generate_miss(X.complete, miss_prop, mechanism = mech)
df <- rs$X.incomp
df_complete <- X.complete
maxiter_tree <- 10
maxiter_pca <- 200
maxiter_mice <- 10
ncp_pca <- round(ncol(df_complete) / 2)
learn_ncp <- FALSE
num_mi <- 4
n_resample <- 4
n_df <- 5
imp_method <- "PCA"
resample_method <- "bootstrap"
cat_combine_by <- "onehot"
var_cat <- "wilcox_va"
oracle_impute <- function(x) {
mask <- is.na(x)
mu1.X <- c(9, 8)
Sigma1.X <- matrix(c(
16, 14,
14, 25
), nrow = 2)
ro <- Sigma1.X[2] / sqrt(Sigma1.X[1] * Sigma1.X[4])
if (any(mask)) {
if (mask[3]) {
x[3] <- stats::rbinom(1, size = 5, prob = 0.4)
}
if (mask[1] & mask[2]) {
x[c(1, 2)] <- MASS::mvrnorm(1, mu1.X, Sigma1.X)
} else {
if (mask[1]) {
x_2 <- as.numeric(x[2])
mu_new <- mu1.X[1] + sqrt(Sigma1.X[1] / Sigma1.X[4]) * ro * (x_2 - mu1.X[2])
sigma_new <- (1 - ro * ro) * Sigma1.X[1]
x[1] <- rnorm(1, mean = mu_new, sd = sigma_new)
} else {
if (mask[2]) {
x_1 <- as.numeric(x[1])
mu_new <- mu1.X[2] + sqrt(Sigma1.X[4] / Sigma1.X[1]) * ro * (x_1 - mu1.X[1])
sigma_new <- (1 - ro * ro) * Sigma1.X[4]
x[2] <- rnorm(1, mean = mu_new, sd = sigma_new)
}
}
}
if (mask[4]) {
x_1 <- as.numeric(x[1])
x_2 <- as.numeric(x[2])
x[4] <- (0.5 * x_1 + 0.2 * x_2 > 8) * 1
}
return(unlist(x))
} else {
return(unlist(x))
}
}
# df.imp <- data.frame(t(apply(df, 1, oracle_impute)))
# for(i in col_num){
# df.imp[, i] <- as.numeric(df.imp[, i])
# }
# ls_MSE(X.complete, list(df.imp), is.na(df), col_num, resample_method = "none")
# ls_F1(X.complete, list(df.imp), is.na(df), col_cat, col_cat, resample_method = "none")
i <- 1
result <- list()
for (miss_prop in c(0.005, 0.01, 0.03, 0.05, 0.07, 0.1, 0.3, 0.5)) {
print(miss_prop)
miss_prop_real <- c()
ls_mse <- c()
ls_scared_mse <- c()
ls_f1 <- c()
ls_mse_oracle <- c()
ls_scared_mse_oracle <- c()
ls_f1_oracle <- c()
ls_seed <- sample(1:100, n_df, replace = TRUE)
for (j in seq(n_df)) {
set.seed(ls_seed[j])
rs <- generate_miss(df_complete, miss_prop, mechanism = mech, mar2.col.ctrl = 2)
miss_prop_real <- c(miss_prop_real, rs$real_miss_perc)
df_incomp <- rs$X.incomp
res_imp <- MissImp(
df = df_incomp, imp_method = imp_method,
resample_method = resample_method, n_resample = n_resample,
col_cat = col_cat, col_dis = col_dis,
maxiter_tree = maxiter_tree, maxiter_pca = maxiter_pca,
ncp_pca = ncp_pca, learn_ncp = learn_ncp,
cat_combine_by = cat_combine_by, var_cat = var_cat,
df_complete = df_complete, num_mi = num_mi,
maxiter_mice = maxiter_mice
)
ls_mse <- c(ls_mse, res_imp$MSE$Mean_MSE)
ls_scared_mse <- c(ls_scared_mse, res_imp$MSE$Mean_MSE_scale)
ls_f1 <- c(ls_f1, res_imp$F1$Mean_F1)
# Oracle
df.imp <- data.frame(t(apply(df_incomp, 1, oracle_impute)))
for (s in col_num) {
df.imp[, s] <- as.numeric(df.imp[, s])
}
ls.MSE <- ls_MSE(X.complete, list(df.imp), is.na(df), col_num, resample_method = "none")
ls.F1 <- ls_F1(X.complete, list(df.imp), is.na(df), col_cat, col_cat, resample_method = "none")
ls_mse_oracle <- c(ls_mse_oracle, ls.MSE$Mean_MSE)
ls_scared_mse_oracle <- c(ls_scared_mse_oracle, ls.MSE$Mean_MSE_scale)
ls_f1_oracle <- c(ls_f1_oracle, ls.F1$Mean_F1)
}
result[["mechanisme"]][i] <- mech
result[["miss_perc"]][i] <- mean(miss_prop_real)
result[["method"]][i] <- imp_method
result[["scaled_MSE"]][i] <- mean(ls_scared_mse)
result[["Var_scaled_MSE"]][i] <- var(ls_scared_mse)
result[["MSE"]][i] <- mean(ls_mse)
result[["Var_MSE"]][i] <- var(ls_mse)
result[["F1"]][i] <- mean(ls_f1)
result[["Var_F1"]][i] <- var(ls_f1)
result[["scaled_MSE_oracle"]][i] <- mean(ls_scared_mse_oracle)
result[["Var_scaled_MSE_oracle"]][i] <- var(ls_scared_mse_oracle)
result[["MSE_oracle"]][i] <- mean(ls_mse_oracle)
result[["Var_MSE_oracle"]][i] <- var(ls_mse_oracle)
result[["F1_oracle"]][i] <- mean(ls_f1_oracle)
result[["Var_F1_oracle"]][i] <- var(ls_f1_oracle)
i <- i + 1
}
data.frame(result)
# write.csv(result,"test_MAR1_PCA_boot_oracle.csv")
library(ggplot2)
res_EM <- read.csv("test_MAR1_EM_boot_oracle.csv")
predframe <- data.frame(row.names = c(1:8))
predframe[["miss_perc"]] <- c(0.005, 0.01, 0.03, 0.05, 0.07, 0.1, 0.3, 0.5)
predframe[["scaled_MSE_EM"]] <- res_EM$scaled_MSE
predframe[["lwr_EM"]] <- res_EM$scaled_MSE - 1 * sqrt(res_EM$Var_scaled_MSE)
predframe[["upr_EM"]] <- res_EM$scaled_MSE + 1 * sqrt(res_EM$Var_scaled_MSE)
predframe[["scaled_MSE_oracle"]] <- res_EM$scaled_MSE_oracle
predframe[["lwr_oracle"]] <- res_EM$scaled_MSE_oracle - 1 * sqrt(res_EM$Var_scaled_MSE_oracle)
predframe[["upr_oracle"]] <- res_EM$scaled_MSE_oracle + 1 * sqrt(res_EM$Var_scaled_MSE_oracle)
colors <- c("EM" = "red", "Oracle" = "blue")
p1 <- ggplot(predframe, aes(x = miss_perc)) +
scale_colour_manual(values = colors) +
scale_fill_manual(values = colors) +
geom_line(aes(y = scaled_MSE_EM, colour = "EM")) +
geom_point(aes(y = scaled_MSE_EM, colour = "EM")) +
geom_ribbon(data = predframe, aes(ymin = lwr_EM, ymax = upr_EM, fill = "EM"), alpha = 0.3, show.legend = FALSE) +
geom_line(aes(y = scaled_MSE_oracle, colour = "Oracle")) +
geom_point(aes(y = scaled_MSE_oracle, colour = "Oracle")) +
geom_ribbon(data = predframe, aes(ymin = lwr_oracle, ymax = upr_oracle, fill = "Oracle"), alpha = 0.3, show.legend = FALSE) +
xlab("Proportion of missing data") +
ylab("Scaled MSE") +
labs(colour = "Imputation method")
show(p1)