library(stats)
library(MissImp)
knitr::opts_chunk$set(eval = FALSE)

Generate data

We generate a complete data set \((Y_j)_{1 \leq j \leq 8}\) with \((Y_1, Y_2) \sim \mathcal{N}(\mu_1, \Sigma_1)\), \(Y_3 \sim \text{Bin}(p)\), and \(Y_4\) is a binary variable that depends on the values of \(Y_1\) and \(Y_2\) through a logistic regression.

complete_df_generator <- function(n) {
  mu1.X <- c(9, 8)
  Sigma1.X <- matrix(c(
    16, 14,
    14, 25
  ), nrow = 2)
  X.complete.cont1 <- MASS::mvrnorm(n, mu1.X, Sigma1.X) # multivariate normal distribution

  X.complete.cat <- stats::rbinom(n, size = 5, prob = 0.4) # binomial

  X.complete.cat2 <- (0.5 * X.complete.cont1[, 1] + 0.2 * X.complete.cont1[, 2] > 8) * 1

  X.complete <- data.frame(cbind(
    X.complete.cont1,
    X.complete.cat, X.complete.cat2
  ))
  X.complete[, 3] <- as.factor(X.complete[, 3])
  X.complete[, 4] <- as.factor(X.complete[, 4])
  colnames(X.complete) <- c("Y1", "Y2", "Y3", "Y4")
  return(X.complete)
}

n <- 4000
X.complete <- complete_df_generator(n)
mech <- "MAR1"
miss_prop <- 0.3
col_cat <- c(3, 4)
col_dis <- c()
col_num <- c(1, 2)
rs <- generate_miss(X.complete, miss_prop, mechanism = mech)
df <- rs$X.incomp
df_complete <- X.complete
maxiter_tree <- 10
maxiter_pca <- 200
maxiter_mice <- 10
ncp_pca <- round(ncol(df_complete) / 2)
learn_ncp <- FALSE

num_mi <- 4
n_resample <- 4
n_df <- 5

imp_method <- "PCA"
resample_method <- "bootstrap"
cat_combine_by <- "onehot"
var_cat <- "wilcox_va"
oracle_impute <- function(x) {
  mask <- is.na(x)
  mu1.X <- c(9, 8)
  Sigma1.X <- matrix(c(
    16, 14,
    14, 25
  ), nrow = 2)
  ro <- Sigma1.X[2] / sqrt(Sigma1.X[1] * Sigma1.X[4])
  if (any(mask)) {
    if (mask[3]) {
      x[3] <- stats::rbinom(1, size = 5, prob = 0.4)
    }

    if (mask[1] & mask[2]) {
      x[c(1, 2)] <- MASS::mvrnorm(1, mu1.X, Sigma1.X)
    } else {
      if (mask[1]) {
        x_2 <- as.numeric(x[2])
        mu_new <- mu1.X[1] + sqrt(Sigma1.X[1] / Sigma1.X[4]) * ro * (x_2 - mu1.X[2])
        sigma_new <- (1 - ro * ro) * Sigma1.X[1]
        x[1] <- rnorm(1, mean = mu_new, sd = sigma_new)
      } else {
        if (mask[2]) {
          x_1 <- as.numeric(x[1])
          mu_new <- mu1.X[2] + sqrt(Sigma1.X[4] / Sigma1.X[1]) * ro * (x_1 - mu1.X[1])
          sigma_new <- (1 - ro * ro) * Sigma1.X[4]
          x[2] <- rnorm(1, mean = mu_new, sd = sigma_new)
        }
      }
    }

    if (mask[4]) {
      x_1 <- as.numeric(x[1])
      x_2 <- as.numeric(x[2])
      x[4] <- (0.5 * x_1 + 0.2 * x_2 > 8) * 1
    }
    return(unlist(x))
  } else {
    return(unlist(x))
  }
}

# df.imp <- data.frame(t(apply(df, 1, oracle_impute)))
# for(i in col_num){
#   df.imp[, i] <- as.numeric(df.imp[, i])
# }
# ls_MSE(X.complete, list(df.imp), is.na(df), col_num, resample_method = "none")
# ls_F1(X.complete, list(df.imp), is.na(df), col_cat, col_cat, resample_method = "none")
i <- 1
result <- list()
for (miss_prop in c(0.005, 0.01, 0.03, 0.05, 0.07, 0.1, 0.3, 0.5)) {
  print(miss_prop)
  miss_prop_real <- c()
  ls_mse <- c()
  ls_scared_mse <- c()
  ls_f1 <- c()
  ls_mse_oracle <- c()
  ls_scared_mse_oracle <- c()
  ls_f1_oracle <- c()
  ls_seed <- sample(1:100, n_df, replace = TRUE)
  for (j in seq(n_df)) {
    set.seed(ls_seed[j])
    rs <- generate_miss(df_complete, miss_prop, mechanism = mech, mar2.col.ctrl = 2)
    miss_prop_real <- c(miss_prop_real, rs$real_miss_perc)
    df_incomp <- rs$X.incomp
    res_imp <- MissImp(
      df = df_incomp, imp_method = imp_method,
      resample_method = resample_method, n_resample = n_resample,
      col_cat = col_cat, col_dis = col_dis,
      maxiter_tree = maxiter_tree, maxiter_pca = maxiter_pca,
      ncp_pca = ncp_pca, learn_ncp = learn_ncp,
      cat_combine_by = cat_combine_by, var_cat = var_cat,
      df_complete = df_complete, num_mi = num_mi,
      maxiter_mice = maxiter_mice
    )
    ls_mse <- c(ls_mse, res_imp$MSE$Mean_MSE)
    ls_scared_mse <- c(ls_scared_mse, res_imp$MSE$Mean_MSE_scale)
    ls_f1 <- c(ls_f1, res_imp$F1$Mean_F1)

    # Oracle
    df.imp <- data.frame(t(apply(df_incomp, 1, oracle_impute)))
    for (s in col_num) {
      df.imp[, s] <- as.numeric(df.imp[, s])
    }
    ls.MSE <- ls_MSE(X.complete, list(df.imp), is.na(df), col_num, resample_method = "none")
    ls.F1 <- ls_F1(X.complete, list(df.imp), is.na(df), col_cat, col_cat, resample_method = "none")
    ls_mse_oracle <- c(ls_mse_oracle, ls.MSE$Mean_MSE)
    ls_scared_mse_oracle <- c(ls_scared_mse_oracle, ls.MSE$Mean_MSE_scale)
    ls_f1_oracle <- c(ls_f1_oracle, ls.F1$Mean_F1)
  }

  result[["mechanisme"]][i] <- mech
  result[["miss_perc"]][i] <- mean(miss_prop_real)
  result[["method"]][i] <- imp_method
  result[["scaled_MSE"]][i] <- mean(ls_scared_mse)
  result[["Var_scaled_MSE"]][i] <- var(ls_scared_mse)
  result[["MSE"]][i] <- mean(ls_mse)
  result[["Var_MSE"]][i] <- var(ls_mse)
  result[["F1"]][i] <- mean(ls_f1)
  result[["Var_F1"]][i] <- var(ls_f1)
  result[["scaled_MSE_oracle"]][i] <- mean(ls_scared_mse_oracle)
  result[["Var_scaled_MSE_oracle"]][i] <- var(ls_scared_mse_oracle)
  result[["MSE_oracle"]][i] <- mean(ls_mse_oracle)
  result[["Var_MSE_oracle"]][i] <- var(ls_mse_oracle)
  result[["F1_oracle"]][i] <- mean(ls_f1_oracle)
  result[["Var_F1_oracle"]][i] <- var(ls_f1_oracle)
  i <- i + 1
}
data.frame(result)
# write.csv(result,"test_MAR1_PCA_boot_oracle.csv")
library(ggplot2)
res_EM <- read.csv("test_MAR1_EM_boot_oracle.csv")
predframe <- data.frame(row.names = c(1:8))
predframe[["miss_perc"]] <- c(0.005, 0.01, 0.03, 0.05, 0.07, 0.1, 0.3, 0.5)
predframe[["scaled_MSE_EM"]] <- res_EM$scaled_MSE
predframe[["lwr_EM"]] <- res_EM$scaled_MSE - 1 * sqrt(res_EM$Var_scaled_MSE)
predframe[["upr_EM"]] <- res_EM$scaled_MSE + 1 * sqrt(res_EM$Var_scaled_MSE)
predframe[["scaled_MSE_oracle"]] <- res_EM$scaled_MSE_oracle
predframe[["lwr_oracle"]] <- res_EM$scaled_MSE_oracle - 1 * sqrt(res_EM$Var_scaled_MSE_oracle)
predframe[["upr_oracle"]] <- res_EM$scaled_MSE_oracle + 1 * sqrt(res_EM$Var_scaled_MSE_oracle)

colors <- c("EM" = "red", "Oracle" = "blue")

p1 <- ggplot(predframe, aes(x = miss_perc)) +
  scale_colour_manual(values = colors) +
  scale_fill_manual(values = colors) +
  geom_line(aes(y = scaled_MSE_EM, colour = "EM")) +
  geom_point(aes(y = scaled_MSE_EM, colour = "EM")) +
  geom_ribbon(data = predframe, aes(ymin = lwr_EM, ymax = upr_EM, fill = "EM"), alpha = 0.3, show.legend = FALSE) +
  geom_line(aes(y = scaled_MSE_oracle, colour = "Oracle")) +
  geom_point(aes(y = scaled_MSE_oracle, colour = "Oracle")) +
  geom_ribbon(data = predframe, aes(ymin = lwr_oracle, ymax = upr_oracle, fill = "Oracle"), alpha = 0.3, show.legend = FALSE) +
  xlab("Proportion of missing data") +
  ylab("Scaled MSE") +
  labs(colour = "Imputation method")

show(p1)