## ============================================================
## SVEM (wAIC / wAICc / wSSE, no debias) vs glmnet_with_cv (defaults)
## Focus: small training sizes; compare test RMSE head-to-head
## ============================================================

if (!requireNamespace("SVEMnet", quietly = TRUE)) install.packages("SVEMnet")
if (!requireNamespace("ggplot2", quietly = TRUE)) install.packages("ggplot2")
if (!requireNamespace("dplyr", quietly = TRUE)) install.packages("dplyr")
if (!requireNamespace("tidyr", quietly = TRUE)) install.packages("tidyr")
if (!requireNamespace("patchwork", quietly = TRUE)) install.packages("patchwork")

library(SVEMnet)
library(ggplot2)
library(dplyr)
library(tidyr)
library(patchwork)

# ## ---------- Controls ----------
# set.seed(425)
# REPS         <- 5
# NBOOT        <- 200
# R2_GRID      <- c(0.30, 0.50, 0.70, 0.90)
# RHO_GRID     <- c(0, -0.5, 0.5, -0.9, 0.9)   # correlation among factors
# P_GRID       <- c(3, 5, 7)
# MODELS       <- c("main_plus_int", "full_quadratic")
# NTEST        <- 1000
# DENSITY_GRID <- c(.05, .10, .20, .30, .50)   # fraction of non-intercept terms active


# ## ---------- Controls ----------
# set.seed(425)
REPS         <- 5
NBOOT        <- 200
R2_GRID      <- c(0.30)
RHO_GRID     <- c( -0.5)   # correlation among factors
P_GRID       <- c(5)
MODELS       <- c("full_quadratic")
NTEST        <- 1000
DENSITY_GRID <- c(.20)   # fraction of non-intercept terms active

## ---------- Helpers ----------
# CHANGED: NA-robust RMSE so rare NA preds don’t contaminate a whole row
rmse <- function(obs, pred) {
  ok <- is.finite(obs) & is.finite(pred)
  if (!any(ok)) return(NA_real_)
  sqrt(mean((obs[ok] - pred[ok])^2))
}

safe_predict_svem <- function(fit, newdata, agg = "mean") {
  out <- try(predict(fit, newdata = newdata, debias = FALSE, agg = agg), silent = TRUE)
  if (inherits(out, "try-error")) return(rep(NA_real_, nrow(newdata)))
  out <- as.numeric(out)
  out[!is.finite(out)] <- NA_real_  # CHANGED: sanitize inf
  out
}
safe_predict_cv <- function(fitcv, newdata) {
  out <- try(predict_cv(fitcv, newdata = newdata, debias = FALSE), silent = TRUE)
  if (inherits(out, "try-error")) return(rep(NA_real_, nrow(newdata)))
  out <- as.numeric(out)
  out[!is.finite(out)] <- NA_real_  # CHANGED: sanitize inf
  out
}

## Count p_full (columns in model matrix – intercept)
count_p_full <- function(fm, p) {
  rhs <- stats::reformulate(attr(stats::delete.response(stats::terms(fm)), "term.labels"))
  tmp <- as.data.frame(matrix(rnorm(4 * p), nrow = 4))
  names(tmp) <- paste0("X", seq_len(p))
  mm <- model.matrix(rhs, data = tmp)
  sum(colnames(mm) != "(Intercept)")
}

build_formulas <- function(p) {
  vars <- paste0("X", seq_len(p))
  main_terms <- paste(vars, collapse = " + ")
  int_terms  <- paste0("(", main_terms, ")^2")
  sq_terms   <- paste(sprintf("I(%s^2)", vars), collapse = " + ")
  list(
    main_effects   = as.formula(paste0("y ~ ", main_terms)),
    main_plus_int  = as.formula(paste0("y ~ ", int_terms)),
    full_quadratic = as.formula(paste0("y ~ ", int_terms, " + ", sq_terms))
  )
}

gen_X <- function(n, p, rho = 0.5) {
  Z <- matrix(rnorm(n * p), n, p)
  if (p >= 2 && abs(rho) > 0) {
    for (j in 2:p) {
      Z[, j] <- rho * scale(Z[, j - 1], TRUE, TRUE) +
        sqrt(1 - rho^2) * scale(Z[, j], TRUE, TRUE)
    }
  }
  X <- Z
  colnames(X) <- paste0("X", seq_len(p))
  as.data.frame(X)
}

## density-driven truth generator using the selected formula fm
make_sparse_data_R2 <- function(n, p, rho, target_R2, fm, density = 0.2, seed = NULL) {
  if (!is.null(seed)) set.seed(seed)
  Xdf <- gen_X(n, p, rho); names(Xdf) <- paste0("X", seq_len(p))
  terms_obj <- stats::delete.response(stats::terms(fm, data = transform(Xdf, y = 0)))
  mf  <- model.frame(terms_obj, data = Xdf)
  MM  <- model.matrix(terms_obj, mf)   # n x (M+1) with intercept at column 1

  M <- ncol(MM) - 1L
  if (M < 1L) stop("Formula produces no predictors beyond intercept.")
  n_active <- max(1L, floor(density * M))
  active_idx <- sample.int(M, n_active, replace = FALSE) + 1L  # skip intercept

  beta <- numeric(M + 1L)
  beta[active_idx] <- rexp(n_active) - rexp(n_active)
  y_signal <- drop(MM %*% beta)

  sd_sig <- sqrt(max(var(y_signal), .Machine$double.eps))
  sd_eps <- sd_sig * sqrt((1 - target_R2) / target_R2)
  y <- y_signal + rnorm(n, 0, sd_eps)

  out <- cbind.data.frame(y = y, Xdf)
  out$y_signal <- y_signal
  rownames(out) <- sprintf("row%06d", seq_len(n))
  out
}

## convenience wrapper to fit SVEM for a given objective (robust)
fit_and_predict_svem <- function(fm, train_used, test_df, objective, nBoot = NBOOT, agg = "mean") {
  t0 <- proc.time()[3]
  mod <- try(SVEMnet(
    formula = fm, data = train_used,
    nBoot = nBoot,
    glmnet_alpha = c(1),
    weight_scheme = "SVEM",
    objective = objective,
    standardize = TRUE
  ), silent = TRUE)
  elapsed <- round(proc.time()[3] - t0, 2)
  if (inherits(mod, "try-error")) {
    preds <- rep(NA_real_, nrow(test_df))
  } else {
    preds <- safe_predict_svem(mod, test_df, agg = agg)  # CHANGED: plumb agg if you want to try "median"
  }
  list(pred = preds, time = elapsed)
}

## ---------- Run sims ----------
rows <- list(); rid <- 1L

for (p in P_GRID) {
  fms <- build_formulas(p)

  for (model in MODELS) {
    fm <- fms[[model]]
    p_full <- count_p_full(fm, p)

    # choose n around the p_full boundary; clamp to [12, 40]
    n_grid <- sort(unique(pmin(pmax(p_full + c(-12,-8,-4,0,4,8,12), 12), 40)))

    for (rho in RHO_GRID) {
      for (R2_tgt in R2_GRID) {
        for (dens in DENSITY_GRID) {
          for (rep in seq_len(REPS)) {

            n_tr <- sample(n_grid, 1L)
            n_te <- NTEST

            # CHANGED: include model/density in the seed so scenarios don’t collide
            seed_i <- 3000 +
              19 * rep + 100 * p +
              1000 * match(model, MODELS) +
              round(1000 * R2_tgt) +
              round(1000 * rho) +
              round(1000 * dens)

            df <- make_sparse_data_R2(
              n = n_tr + n_te, p = p, rho = rho, target_R2 = R2_tgt,
              fm = fm, density = dens, seed = seed_i
            )

            idx <- sample(seq_len(nrow(df)), size = n_tr)
            train_df <- df[idx, ]
            test_df  <- df[-idx, ]

            keep <- complete.cases(model.frame(fm, data = train_df, na.action = stats::na.pass))
            if (sum(keep) < 2) next
            train_used <- train_df[keep, , drop = FALSE]
            n_used <- nrow(train_used)

            r2_true <- function(d) var(d$y_signal) / var(d$y)
            R2_train_true <- r2_true(train_used)
            R2_test_true  <- r2_true(test_df)

            ## ---- SVEMnet variants (alpha=1; no debias) ----
            sv_waic  <- fit_and_predict_svem(fm, train_used, test_df, "wAIC",  nBoot = NBOOT)
            sv_waicc <- fit_and_predict_svem(fm, train_used, test_df, "wAICc", nBoot = NBOOT)
            sv_wsse  <- fit_and_predict_svem(fm, train_used, test_df, "wSSE",  nBoot = NBOOT)

            ## ---- glmnet_with_cv (alpha=1; no debias) ----
            t1 <- proc.time()[3]
            fit_cv <- glmnet_with_cv(
              formula = fm, data = train_used,
              glmnet_alpha = c(1),
              standardize = TRUE
            )
            pr_te_cv <- safe_predict_cv(fit_cv, test_df)
            time_cv  <- round(proc.time()[3] - t1, 2)

            rows[[rid]] <- data.frame(
              p = p, model = model, rho = rho, R2_target = R2_tgt, density = dens,
              p_full = p_full, n_train = n_tr, n_train_used = n_used, n_test = n_te,
              n_train_minus_p_full = n_tr - p_full,                         # CHANGED: convenience column
              ratio_n_over_p = n_used / p_full,
              above_boundary = as.integer(n_used > p_full),
              R2_true_train = R2_train_true, R2_true_test = R2_test_true,
              rmse_test_svem_waic  = rmse(test_df$y, sv_waic$pred),
              rmse_test_svem_waicc = rmse(test_df$y, sv_waicc$pred),
              rmse_test_svem_wsse  = rmse(test_df$y, sv_wsse$pred),
              rmse_test_glmnet_cv  = rmse(test_df$y, pr_te_cv),
              time_sec_svem_waic   = sv_waic$time,
              time_sec_svem_waicc  = sv_waicc$time,
              time_sec_svem_wsse   = sv_wsse$time,
              time_sec_cv          = time_cv,
              stringsAsFactors = FALSE
            )
            rid <- rid + 1L
          }
        }
      }
    }
  }
}

stopifnot(length(rows) > 0)
res <- dplyr::bind_rows(rows)

## ---------- Summary & win-rates (robust, single block) ----------
stopifnot(all(c("rmse_test_svem_waic","rmse_test_svem_waicc","rmse_test_svem_wsse","rmse_test_glmnet_cv") %in% names(res)))

# Long format for plotting and summaries
res_long <- res %>%
  mutate(
    Row           = dplyr::row_number(),
    lmrse_waic    = log(pmax(rmse_test_svem_waic,  .Machine$double.eps)),
    lmrse_waicc   = log(pmax(rmse_test_svem_waicc, .Machine$double.eps)),
    lmrse_wsse    = log(pmax(rmse_test_svem_wsse,  .Machine$double.eps)),
    lmrse_cv      = log(pmax(rmse_test_glmnet_cv,  .Machine$double.eps))
  ) %>%
  tidyr::pivot_longer(
    cols = c(lmrse_waic, lmrse_waicc, lmrse_wsse, lmrse_cv),
    names_to = "metric", values_to = "lmrse"
  ) %>%
  mutate(
    settings = dplyr::recode(metric,
                             lmrse_waic  = "SVEM_wAIC",
                             lmrse_waicc = "SVEM_wAICc",
                             lmrse_wsse  = "SVEM_wSSE",
                             lmrse_cv    = "glmnet_cv"),
    settings = factor(settings, levels = c("SVEM_wAIC","SVEM_wAICc","SVEM_wSSE","glmnet_cv"))
  )

# Head-to-head win-rates vs glmnet_cv (paired by row)
win_rates_vs_cv <- res %>%
  summarise(
    SVEM_wAIC  = mean(rmse_test_svem_waic  < rmse_test_glmnet_cv, na.rm = TRUE),
    SVEM_wAICc = mean(rmse_test_svem_waicc < rmse_test_glmnet_cv, na.rm = TRUE),
    SVEM_wSSE  = mean(rmse_test_svem_wsse  < rmse_test_glmnet_cv, na.rm = TRUE)
  ) %>%
  tidyr::pivot_longer(everything(), names_to = "settings", values_to = "winrate_vs_cv")

# Overall summary (log RMSE)
summ <- res_long %>%
  group_by(settings) %>%
  summarise(
    mean_lmrse = mean(lmrse, na.rm = TRUE),
    sd_lmrse   = sd(lmrse,   na.rm = TRUE),
    n          = dplyr::n(),
    se         = sd_lmrse / sqrt(pmax(n, 1)),
    ci_lo      = mean_lmrse - 1.96 * se,
    ci_hi      = mean_lmrse + 1.96 * se,
    .groups    = "drop"
  ) %>%
  left_join(win_rates_vs_cv, by = "settings")

cat("\n================ SUMMARY (log RMSE) ================\n")
print(summ, row.names = FALSE, digits = 4)

## ---------- By-boundary summaries (attach boundary to long) ----------
if ("above_boundary" %in% names(res)) {
  bound_key <- res %>%
    mutate(Row = dplyr::row_number()) %>%
    select(Row, above_boundary)
} else if (all(c("n_train_used","p_full") %in% names(res))) {
  bound_key <- res %>%
    mutate(Row = dplyr::row_number(),
           above_boundary = as.integer(n_train_used > p_full)) %>%
    select(Row, above_boundary)
} else if ("ratio_n_over_p" %in% names(res)) {
  bound_key <- res %>%
    mutate(Row = dplyr::row_number(),
           above_boundary = as.integer(ratio_n_over_p > 1)) %>%
    select(Row, above_boundary)
} else {
  bound_key <- res %>%
    mutate(Row = dplyr::row_number(), above_boundary = NA_integer_) %>%
    select(Row, above_boundary)
}

res_long <- res_long %>% left_join(bound_key, by = "Row")

# Unify any .x/.y artifacts (defensive)
if ("above_boundary.x" %in% names(res_long) || "above_boundary.y" %in% names(res_long)) {
  res_long <- res_long %>%
    dplyr::mutate(
      above_boundary = dplyr::coalesce(
        if ("above_boundary.y" %in% names(.)) .data$above_boundary.y else NA_integer_,
        if ("above_boundary.x" %in% names(.)) .data$above_boundary.x else NA_integer_
      )
    ) %>%
    dplyr::select(-dplyr::any_of(c("above_boundary.x","above_boundary.y")))
}

# Means of log RMSE by boundary & setting
summ_by_boundary <- res_long %>%
  group_by(above_boundary, settings) %>%
  summarise(
    mean_lmrse = mean(lmrse, na.rm = TRUE),
    sd_lmrse   = sd(lmrse,   na.rm = TRUE),
    n          = dplyr::n(),
    .groups    = "drop"
  )

cat("\n---- Means by boundary (n_train_used > p_full) ----\n")
print(summ_by_boundary, row.names = FALSE, digits = 4)

# Boundary-specific win-rate for each SVEM variant vs glmnet
win_by_boundary <- res %>%
  mutate(
    above_boundary = if ("above_boundary" %in% names(.)) above_boundary
    else if (all(c("n_train_used","p_full") %in% names(.))) as.integer(n_train_used > p_full)
    else if ("ratio_n_over_p" %in% names(.)) as.integer(ratio_n_over_p > 1)
    else NA_integer_,
    win_waic  = rmse_test_svem_waic  < rmse_test_glmnet_cv,
    win_waicc = rmse_test_svem_waicc < rmse_test_glmnet_cv,
    win_wsse  = rmse_test_svem_wsse  < rmse_test_glmnet_cv
  ) %>%
  group_by(above_boundary) %>%
  summarise(
    winrate_waic  = mean(win_waic,  na.rm = TRUE),
    winrate_waicc = mean(win_waicc, na.rm = TRUE),
    winrate_wsse  = mean(win_wsse,  na.rm = TRUE),
    n             = sum(!is.na(win_waic) | !is.na(win_waicc) | !is.na(win_wsse)),
    .groups       = "drop"
  )

cat("\n---- Win-rate vs glmnet (by boundary) ----\n")
print(win_by_boundary, row.names = FALSE, digits = 4)

## ========== ANOM-style, row-blocked plot ==========
res_long_anom <- res_long %>%
  dplyr::mutate(Row = as.factor(Row)) %>%
  dplyr::filter(is.finite(lmrse))

fit_anom <- lm(lmrse ~ settings + Row, data = res_long_anom)
aov_tbl  <- anova(fit_anom)

MSE    <- aov_tbl[["Mean Sq"]][nrow(aov_tbl)]
df_res <- aov_tbl[["Df"]][nrow(aov_tbl)]

t_methods <- nlevels(res_long_anom$settings)
b_blocks  <- nlevels(res_long_anom$Row)
grand_mu  <- mean(res_long_anom$lmrse, na.rm = TRUE)

# Var(ȳ_i. − ȳ..) = σ^2 * (t − 1) / (t * b)  under RBD
se_group <- sqrt(MSE * (t_methods - 1) / (t_methods * b_blocks))

alpha <- 0.05
crit  <- qt(1 - alpha / (2 * t_methods), df = df_res)
UCL   <- grand_mu + crit * se_group
LCL   <- grand_mu - crit * se_group

means_df <- res_long_anom %>%
  dplyr::group_by(settings) %>%
  dplyr::summarise(mean_lmrse = mean(lmrse, na.rm = TRUE), .groups = "drop") %>%
  dplyr::mutate(flag = dplyr::case_when(
    mean_lmrse > UCL ~ "Above UCL",
    mean_lmrse < LCL ~ "Below LCL",
    TRUE             ~ "Within Limits"
  ))

p_anom <- ggplot(means_df, aes(x = settings, y = mean_lmrse)) +
  geom_hline(yintercept = grand_mu, linetype = 2) +
  geom_hline(yintercept = UCL,      linetype = 3) +
  geom_hline(yintercept = LCL,      linetype = 3) +
  geom_segment(aes(xend = settings, y = grand_mu, yend = mean_lmrse), linewidth = 1) +
  geom_point(aes(color = flag), size = 3) +
  scale_color_manual(values = c("Within Limits" = "black",
                                "Above UCL"     = "red",
                                "Below LCL"     = "red")) +
  labs(
    title = "Blocked ANOM-style plot for lmrse (Row = block)",
    subtitle = sprintf("Grand mean = %.3f | Limits = [%.3f, %.3f] | df = %d",
                       grand_mu, LCL, UCL, df_res),
    x = NULL, y = "Mean lmrse", color = NULL
  ) +
  theme_bw() +
  theme(plot.title = element_text(face = "bold"))

p_pairs <- ggplot(res_long_anom, aes(x = settings, y = lmrse, group = Row)) +
  geom_line(alpha = 0.25) +
  geom_point(alpha = 0.6, position = position_dodge(width = 0.05)) +
  stat_summary(fun = mean, geom = "point", size = 3, color = "black") +
  labs(title = "Paired runs by Row", x = NULL, y = "lmrse") +
  theme_bw()

(p_anom / p_pairs)

## ---------- Pairwise GM ratios vs glmnet ----------
eps <- .Machine$double.eps

gm_for <- function(x, y) {
  ld <- log(pmax(x, eps)) - log(pmax(y, eps))
  ld <- ld[is.finite(ld)]         # CHANGED: drop NA/inf before bootstrap
  if (!length(ld)) return(c(gm = NA_real_, lo = NA_real_, hi = NA_real_, win = NA_real_))
  gm  <- exp(mean(ld))
  set.seed(123)
  B <- 2000
  n <- length(ld)
  gm_boot <- replicate(B, {
    ii <- sample.int(n, replace = TRUE)
    exp(mean(ld[ii]))
  })
  ci <- quantile(gm_boot, c(0.025, 0.975), na.rm = TRUE)
  win <- mean(exp(ld) < 1)        # fraction of pairs where method beats cv
  c(gm = gm, lo = ci[1], hi = ci[2], win = win)
}

cat("\nPaired geometric-mean RMSE ratio (SVEM/CV):\n")
stats_waic  <- gm_for(res$rmse_test_svem_waic,  res$rmse_test_glmnet_cv)
stats_waicc <- gm_for(res$rmse_test_svem_waicc, res$rmse_test_glmnet_cv)
stats_wsse  <- gm_for(res$rmse_test_svem_wsse,  res$rmse_test_glmnet_cv)

cat(sprintf("  wAIC : GM=%.3f  (95%% CI %.3f–%.3f) | win=%.1f%%\n",
            stats_waic["gm"], stats_waic["lo"], stats_waic["hi"], 100*stats_waic["win"]))
cat(sprintf("  wAICc: GM=%.3f  (95%% CI %.3f–%.3f) | win=%.1f%%\n",
            stats_waicc["gm"], stats_waicc["lo"], stats_waicc["hi"], 100*stats_waicc["win"]))
cat(sprintf("  wSSE : GM=%.3f  (95%% CI %.3f–%.3f) | win=%.1f%%\n",
            stats_wsse["gm"], stats_wsse["lo"], stats_wsse["hi"], 100*stats_wsse["win"]))

# CHANGED: keep your trimmed ratio idea handy if you want it
trim <- function(x, p=0.05) x[x >= quantile(x, p, na.rm=TRUE) & x <= quantile(x, 1-p, na.rm=TRUE)]
