#' segen
#'
#' @param df data.frame of time features (all numeric OR all categorical).
#' @param seq_len integer, forecasting horizon. If NULL, auto-sampled.
#' @param similarity numeric in (0,1), similarity quantile. If NULL, sampled.
#' @param dist_method character. Options:
#'   "euclidean","manhattan","maximum","minkowski","correlation","dtw".
#'   If NULL, sampled from available methods (skips 'dtw' if pkg missing).
#' @param rescale logical, rescale weights before normalization.
#' @param smoother logical, apply loess smoothing for numeric features.
#' @param ci numeric in (0,1), confidence level.
#' @param error_scale "naive" or "deviation".
#' @param error_benchmark "naive" or "average".
#' @param n_windows integer, rolling validation windows.
#' @param n_samp integer, random search samples.
#' @param dates Date vector aligned with rows of df (optional).
#' @param seed integer, RNG seed.
#' @param use_parallel logical, use furrr/future for parallel exploration.
#' @param parallel_workers NULL or integer, number of workers when parallel.
#' @return list with exploration, history, best_model, time_log.
#'
#' @author Giancarlo Vercellino \email{giancarlo.vercellino@gmail.com}
#'
#' @return This function returns a list including:
#' \itemize{
#' \item exploration: list of all not-null models, complete with predictions and error metrics
#' \item history: a table with the sampled models, hyper-parameters, validation errors
#' \item best_model: results for the best selected model according to the weighted average rank, including:
#' \itemize{
#' \item predictions: for continuous variables, min, max, q25, q50, q75, quantiles at selected ci, mean, sd, mode, skewness, kurtosis, IQR to range, risk ratio, upside probability and divergence for each point fo predicted sequences; for factor variables, min, max, q25, q50, q75, quantiles at selected ci, proportions, difformity (deviation of proportions normalized over the maximum possible deviation), entropy, upgrade probability and divergence for each point fo predicted sequences
#' \item testing_errors: testing errors for each time feature for the best selected model (for continuous variables: me, mae, mse, rmsse, mpe, mape, rmae, rrmse, rame, mase, smse, sce, gmrae; for factor variables: czekanowski, tanimoto, cosine, hassebrook, jaccard, dice, canberra, gower, lorentzian, clark)
#' \item plots: standard plots with confidence interval for each time feature
#' }
#' \item time_log
#' }
#'
#' @export
#'
#' @import purrr
#' @import tictoc
#' @importFrom readr parse_number
#' @importFrom lubridate seconds_to_period is.Date as.duration
#' @importFrom stats weighted.mean ecdf na.omit
#' @importFrom imputeTS na_kalman
#' @importFrom fANCOVA loess.as
#' @importFrom modeest mlv1
#' @import ggplot2
#' @importFrom moments skewness kurtosis
#' @importFrom stats quantile sd lm rnorm pnorm fft runif
#' @importFrom scales number
#' @importFrom utils tail head
#' @import greybox
#' @importFrom philentropy distance
#' @importFrom entropy entropy
#' @importFrom Rfast Dist
#' @importFrom narray split
#' @import fastDummies
#' @importFrom future plan multisession sequential availableCores
#' @importFrom digest digest
#' @importFrom furrr future_map furrr_options
#' @importFrom dtw dtw
#'
#'
#'@examples
#'segen(time_features[, 1, drop = FALSE], seq_len = 30, similarity = 0.7, n_windows = 3, n_samp = 1)
#'
#'
#' @export
segen <- function(df,
                  seq_len = NULL,
                  similarity = NULL,
                  dist_method = NULL,
                  rescale = NULL,
                  smoother = FALSE,
                  ci = 0.8,
                  error_scale = "naive",
                  error_benchmark = "naive",
                  n_windows = 10,
                  n_samp = 30,
                  dates = NULL,
                  seed = 42,
                  use_parallel = FALSE,
                  parallel_workers = NULL) {

  tic.clearlog(); tic("time")
  set.seed(seed)

  if (!is.data.frame(df)) stop("`df` must be a data.frame.")
  if (!(is.numeric(ci) && ci > 0 && ci < 1)) stop("`ci` must be a number in (0,1).")
  if (!(is.numeric(n_windows) && n_windows >= 1)) stop("`n_windows` must be >= 1.")
  if (!(is.numeric(n_samp) && n_samp >= 1)) stop("`n_samp` must be >= 1.")

  n_length <- nrow(df)
  validate_dates(dates, n_length)

  class_index <- purrr::map_lgl(df, ~ is.factor(.x) || is.character(.x))
  numeric_index <- purrr::map_lgl(df, ~ is.integer(.x) || is.numeric(.x))
  all_classes <- all(class_index)
  all_numerics <- all(numeric_index)
  if (!(all_classes || all_numerics)) stop("All columns must be either numeric OR categorical (not mixed).")

  if (all_classes) {
    df <- dummy_cols(df, select_columns = NULL,
                     remove_first_dummy = FALSE, remove_most_frequent_dummy = TRUE,
                     ignore_na = FALSE, split = NULL, remove_selected_columns = TRUE)
  }
  binary_class <- rep(all_classes, ncol(df))

  if (anyNA(df) && all_numerics) {
    df <- as.data.frame(na_kalman(df)); message("Kalman imputation applied to numeric time features.")
  }
  if (anyNA(df) && all_classes) {
    df <- floor(as.data.frame(na_kalman(df))); message("Kalman imputation applied to categorical time features (then floored).")
  }
  if (isTRUE(smoother) && all_numerics) {
    n_len <- nrow(df)
    df <- as.data.frame(purrr::map(df, ~ suppressWarnings(loess.as(x = 1:n_len, y = .x)$fitted)))
    message("Loess smoothing (automatic span) applied.")
  }

  n_feats <- ncol(df)
  feat_names <- colnames(df)

  if (length(seq_len) == 1 && length(similarity) == 1 &&
      length(rescale) == 1 && length(dist_method) == 1) n_samp <- 1

  deriv <- purrr::map_dbl(df, ~ best_deriv(.x))
  max_limit <- floor((n_length / (n_windows + 1) - max(deriv)) / 2)
  if (is.na(max_limit) || max_limit < 2) stop("Series too short for given n_windows; reduce `n_windows` or provide longer series.")

  sqln_set <- sampler(seq_len, n_samp, range = c(2, max_limit), integer = TRUE)
  sml_set  <- sampler(similarity, n_samp, range = c(0.01, 0.99), integer = FALSE)

  avail_methods <- c("euclidean","manhattan","maximum","minkowski","correlation")
  if (requireNamespace("dtw", quietly = TRUE)) avail_methods <- c(avail_methods, "dtw")
  dst_set  <- sampler(dist_method, n_samp, range = avail_methods, integer = FALSE, multi = n_feats)

  rscl_set <- as.logical(sampler(rescale, n_samp, range = c(0, 1), integer = TRUE))

  if (any(sqln_set < 2)) { sqln_set[sqln_set < 2] <- 2; message("Setting min seq_len to 2.") }
  if (any(sqln_set > max_limit)) { sqln_set[sqln_set > max_limit] <- max_limit; message(paste0("Setting max seq_len to ", max_limit, ".")) }

  # --- optional parallel plan ---
  map_p <- pmap
  restore_plan <- NULL
  if (isTRUE(use_parallel)) {
    if (requireNamespace("furrr", quietly = TRUE) && requireNamespace("future", quietly = TRUE)) {
      restore_plan <- TRUE
      future::plan(future::multisession, workers = parallel_workers %||% future::availableCores())
      map_p <- function(.l, .f, ...) furrr::future_pmap(.l, .f, ..., .options = furrr::furrr_options(seed = TRUE))
    } else {
      message("Parallel requested but packages 'furrr' and/or 'future' not available; running sequentially.")
    }
  }

  exploration <- mapply(
    function(ftn) {
      map_p(
        list(sqln_set, sml_set, dst_set, rscl_set),
        ~ windower(
          ts           = df[, ftn],
          seq_len      = ..1,
          similarity   = clamp(..2, 0.0001, 0.9999),
          dist_method  = ..3[ftn],
          rescale      = ..4,
          n_windows    = n_windows,
          ci           = ci,
          error_scale  = error_scale,
          error_benchmark = error_benchmark,
          dates        = dates,
          binary_class = binary_class[ftn],
          seed         = seed
        )
      )
    },
    ftn = 1:n_feats,
    SIMPLIFY = FALSE
  )
  if (!is.null(restore_plan)) future::plan(future::sequential)
  exploration <- transpose(exploration)

  models <- map_depth(exploration, 2, ~ .x$quant_pred)
  errors <- map_depth(exploration, 2, ~ .x$errors)

  aggr_errors <- if (n_feats > 1) map(errors, ~ colMeans(Reduce(rbind, .x))) else errors
  aggr_errors <- t(as.data.frame(aggr_errors))

  history <- data.frame(seq_len = unlist(sqln_set))
  history$similarity   <- sml_set
  history$dist_method  <- dst_set
  history$rescale      <- rscl_set
  history <- data.frame(history, round(aggr_errors, 4))
  rownames(history) <- NULL

  if (all_numerics) {
    history <- ranker(history, focus = -c(1, 2, 3, 4), inverse = NULL, absolute = c("me", "mpe", "sce"), reverse = FALSE)
  } else {
    history <- ranker(history, focus = -c(1, 2, 3, 4), inverse = NULL, absolute = NULL, reverse = FALSE)
  }

  best_index <- as.numeric(rownames(history[1, ]))
  predictions <- models[[best_index]]
  testing_errors <- t(as.data.frame(errors[[best_index]]))

  plots <- pmap(list(predictions, feat_names),
                ~ plotter(..1, ci, df[, ..2], dates, ..2))

  names(predictions) <- feat_names
  rownames(testing_errors) <- feat_names
  names(plots) <- feat_names

  best_model <- list(predictions = predictions, testing_errors = testing_errors, plots = plots)

  toc(log = TRUE)
  time_log <- seconds_to_period(round(parse_number(unlist(tic.log())), 0))

  list(exploration = exploration, history = history, best_model = best_model, time_log = time_log)
}


# ---------------------------------------------------------------------------
# Everything below is INTERNAL
# ---------------------------------------------------------------------------


#' @keywords internal
# --- internal cache env (for distance matrices) ---
.segen_cache <- new.env(parent = emptyenv())

# ---------------- utils ----------------
`%||%` <- function(a, b) if (!is.null(a)) a else b

clamp <- function(x, lo, hi) pmax(lo, pmin(hi, x))

safe_scale <- function(x, fallback = 1e-8) {
  if (is.na(x) || !is.finite(x) || x == 0) fallback else x
}

validate_dates <- function(dates, n) {
  if (!is.null(dates)) {
    if (!is.Date(dates)) stop("`dates` must be a Date vector.")
    if (length(dates) != n) stop("`dates` length must match nrow(df).")
  }
}

minmax <- function(x, min_v, max_v) {
  rng <- range(x, na.rm = TRUE)
  if (diff(rng) == 0) return(rep((min_v + max_v) / 2, length(x)))
  span <- (x - rng[1]) / diff(rng)
  span * (max_v - min_v) + min_v
}

# ---- distance computation with memoization ----
.compute_key <- function(X, method, p) {
  if (requireNamespace("digest", quietly = TRUE)) {
    digest::digest(list(dim(X), round(mean(X), 6), round(sd(X), 6), method, p))
  } else NULL
}

compute_dtw_distance <- function(X) {
  if (!requireNamespace("dtw", quietly = TRUE))
    stop("Install 'dtw' to use dist_method = 'dtw'.")
  n <- nrow(X)
  D <- matrix(0, n, n)
  for (i in seq_len(n)) {
    for (j in seq_len(i)) {
      d <- dtw::dtw(X[i, ], X[j, ])$distance
      D[i, j] <- d; D[j, i] <- d
    }
  }
  D
}

compute_distance <- function(X, method, p = 3, use_cache = TRUE) {
  key <- if (use_cache) .compute_key(X, method, p) else NULL
  if (!is.null(key) && exists(key, envir = .segen_cache, inherits = FALSE)) {
    return(get(key, envir = .segen_cache))
  }
  D <- switch(
    method,
    euclidean = as.matrix(Rfast::Dist(X, method = "euclidean")),
    manhattan = as.matrix(Rfast::Dist(X, method = "manhattan")),
    maximum   = as.matrix(Rfast::Dist(X, method = "maximum")),
    minkowski = as.matrix(Rfast::Dist(X, method = "minkowski", p = p)),
    correlation = {
      S <- stats::cor(t(X), use = "pairwise.complete.obs")
      1 - S
    },
    dtw = compute_dtw_distance(X),
    stop("Unsupported distance method: ", method)
  )
  if (!is.null(key)) assign(key, D, envir = .segen_cache)
  D
}

# ---------------- core helpers ----------------
recursive_diff <- function(vector, deriv) {
  vector <- as.numeric(vector)
  head_value <- vector("numeric", deriv)
  tail_value <- vector("numeric", deriv)
  if (deriv == 0) { head_value <- NULL; tail_value <- NULL }
  if (deriv > 0) {
    for (i in 1:deriv) {
      head_value[i] <- head(vector, 1)
      tail_value[i] <- tail(vector, 1)
      vector <- diff(vector)
    }
  }
  list(vector = vector, head_value = head_value, tail_value = tail_value)
}

invdiff <- function(vector, heads, add = FALSE) {
  vector <- as.numeric(vector)
  if (is.null(heads)) return(vector)
  for (d in length(heads):1) vector <- cumsum(c(heads[d], vector))
  if (!add) vector[-seq_along(heads)] else vector
}

prediction_score <- function(integrated_preds, ground_truth) {
  pfuns <- apply(integrated_preds, 2, stats::ecdf)
  pvalues <- purrr::map2_dbl(pfuns, ground_truth, ~ .x(.y))
  1 - 2 * abs(pvalues - 0.5)
}

best_deriv <- function(ts, max_diff = 3, thresh = 0.001) {
  pvalues <- vector(mode = "double", length = as.integer(max_diff) + 1)
  x <- ts
  for (d in 1:(max_diff + 1)) {
    model <- lm(x ~ t, data.frame(x, t = seq_along(x)))
    pvalues[d] <- with(summary(model),
                       pf(fstatistic[1], fstatistic[2], fstatistic[3], lower.tail = FALSE))
    x <- diff(x)
    if (length(x) < 3) break
  }
  best <- tail(cumsum(pvalues < thresh), 1)
  as.integer(best)
}

ranker <- function(df, focus, inverse = NULL, absolute = NULL, reverse = FALSE) {
  rank_set <- df[, focus, drop = FALSE]
  if (!is.null(inverse))  rank_set[, inverse]  <- -rank_set[, inverse]
  if (!is.null(absolute)) rank_set[, absolute] <- abs(rank_set[, absolute])
  idx <- apply(scale(rank_set), 1, mean, na.rm = TRUE)
  if (!reverse) df[order(idx), , drop = FALSE] else df[order(-idx), , drop = FALSE]
}

ts_graph <- function(x_hist, y_hist, x_forcat, y_forcat,
                     lower = NULL, upper = NULL, line_size = 1.3, label_size = 11,
                     forcat_band = "seagreen2", forcat_line = "seagreen4",
                     hist_line = "gray43", label_x = "Horizon", label_y = "Forecasted Var",
                     dbreak = NULL, date_format = "%b-%d-%Y") {

  if (is.character(y_hist))  y_hist   <- as.factor(y_hist)
  if (is.character(y_forcat)) y_forcat <- factor(y_forcat, levels = levels(y_hist))
  if (is.character(lower))   lower    <- factor(lower, levels = levels(y_hist))
  if (is.character(upper))   upper    <- factor(upper, levels = levels(y_hist))

  n_class <- NULL
  if (is.factor(y_hist)) {
    class_levels <- levels(y_hist)
    n_class <- length(class_levels)
  }

  df_all <- data.frame(x_all = c(x_hist, x_forcat),
                       y_all = as.numeric(c(y_hist, y_forcat)))
  df_fc  <- data.frame(x_forcat = x_forcat, y_forcat = as.numeric(y_forcat))
  if (!is.null(lower) & !is.null(upper)) {
    df_fc$lower <- as.numeric(lower); df_fc$upper <- as.numeric(upper)
  }

  p <- ggplot() +
    geom_line(data = df_all, aes_string(x = "x_all", y = "y_all"),
              color = hist_line, size = line_size)

  if (!is.null(lower) & !is.null(upper)) {
    p <- p + geom_ribbon(data = df_fc, aes_string(x = "x_forcat", ymin = "lower", ymax = "upper"),
                         alpha = 0.3, fill = forcat_band)
  }

  p <- p + geom_line(data = df_fc, aes_string(x = "x_forcat", y = "y_forcat"),
                     color = forcat_line, size = line_size)

  if (!is.null(dbreak)) {
    p <- p + scale_x_date(name = paste0("\n", label_x),
                          date_breaks = dbreak, date_labels = date_format)
  } else {
    p <- p + xlab(label_x)
  }

  if (is.null(n_class)) {
    p <- p + scale_y_continuous(name = paste0(label_y, "\n"), labels = scales::number)
  } else {
    p <- p + scale_y_continuous(name = paste0(label_y, "\n"),
                                breaks = 1:n_class, labels = class_levels)
  }

  p + ylab(label_y) + theme_bw() +
    theme(axis.text = element_text(size = label_size),
          axis.title = element_text(size = label_size + 2))
}

sampler <- function(vect, n_samp, range = NULL, integer = FALSE, fun = NULL, multi = NULL) {
  if (is.null(vect) & is.null(fun)) {
    if (!is.character(range)) {
      set <- if (integer) seq.int(min(range), max(range)) else seq(min(range), max(range), length.out = 1000)
    } else {
      set <- range
    }
    if (is.null(multi)) {
      samp <- sample(set, n_samp, replace = TRUE)
    } else {
      samp <- replicate(n_samp, sample(set, multi, replace = TRUE), simplify = FALSE)
    }
  } else if (is.null(vect) & !is.null(fun)) {
    samp <- fun
  } else {
    if (is.null(multi)) {
      samp <- if (length(vect) == 1) rep(vect, n_samp) else sample(vect, n_samp, replace = TRUE)
    } else {
      samp <- if (length(vect) == 1) {
        replicate(n_samp, rep(vect, multi), simplify = FALSE)
      } else {
        replicate(n_samp, sample(vect, multi, replace = TRUE), simplify = FALSE)
      }
    }
  }
  samp
}

plotter <- function(quant_pred, ci, ts, dates = NULL, feat_name) {
  seq_h <- nrow(quant_pred)
  n_ts  <- length(ts)

  if (is.Date(dates)) {
    step <- mean(diff(dates))
    new_dates <- seq.Date(from = tail(dates, 1) + step, by = step, length.out = seq_h)
    x_hist <- dates; x_forc <- new_dates
    rownames(quant_pred) <- as.character(new_dates)
  } else {
    x_hist <- seq_len(n_ts); x_forc <- (n_ts + 1):(n_ts + seq_h)
    rownames(quant_pred) <- paste0("t", seq_len(seq_h))
  }

  lower_b <- paste0((1 - ci) / 2 * 100, "%")
  upper_b <- paste0((ci + (1 - ci) / 2) * 100, "%")
  x_lab <- paste0("Forecasting Horizon for sequence n = ", seq_h)
  y_lab <- paste0("Forecasting Values for ", feat_name)

  ts_graph(x_hist = x_hist, y_hist = ts, x_forcat = x_forc,
           y_forcat = quant_pred[, "50%"],
           lower = quant_pred[, lower_b],
           upper = quant_pred[, upper_b],
           label_x = x_lab, label_y = y_lab)
}

smart_reframer <- function(ts, seq_len, stride)
{
  n_length <- length(ts)
  if(seq_len > n_length | stride > n_length){stop("vector too short for sequence length or stride")}
  if(n_length%%seq_len > 0){ts <- tail(ts, - (n_length%%seq_len))}
  n_length <- length(ts)
  idx <- base::seq(from = 1, to = (n_length - seq_len + 1), by = 1)
  reframed <- t(sapply(idx, function(x) ts[x:(x+seq_len-1)]))
  if(seq_len == 1){reframed <- t(reframed)}
  idx <- rev(base::seq(nrow(reframed), 1, - stride))
  reframed <- reframed[idx,,drop = FALSE]
  colnames(reframed) <- paste0("t", 1:seq_len)
  return(reframed)
}

# --------------- inner workflow (with conformal CI) ----------------
windower <- function(ts, seq_len, similarity, dist_method,
                     rescale = FALSE, n_windows = 10, ci = 0.8,
                     error_scale = "naive", error_benchmark = "naive",
                     dates = NULL, binary_class, seed = 42) {

  n_length <- length(ts)
  idx <- c(rep(1, n_length %% (n_windows + 1)),
           rep(1:(n_windows + 1), each = floor(n_length / (n_windows + 1))))

  holdouts <- map(1:n_windows, ~ head(ts[idx == (.x + 1)], seq_len))

  window_results <- map(
    1:n_windows,
    ~ engine(
      ts        = ts[idx <= .x],
      seq_len   = seq_len,
      similarity = similarity,
      dist_method = dist_method,
      rescale   = rescale,
      ci        = ci,
      holdout   = holdouts[[.x]],
      error_scale = error_scale,
      error_benchmark = error_benchmark,
      dates     = dates,
      binary_class = binary_class,
      seed      = seed
    )
  )

  errors <- colMeans(Reduce(rbind, map(window_results, ~ .x$testing_error)))
  pred_scores <- rowMeans(Reduce(cbind, map(window_results, ~ .x$quant_pred$pred_scores)))

  abs_resids <- unlist(
    map2(window_results, holdouts, ~ abs(as.numeric(.y) - .x$quant_pred[, "50%"]))
  )
  alpha <- 1 - ci
  q_conf <- stats::quantile(abs_resids, probs = 1 - alpha, na.rm = TRUE, names = FALSE)

  model <- engine(ts        = ts,
                  seq_len   = seq_len,
                  similarity = similarity,
                  dist_method = dist_method,
                  rescale   = rescale,
                  ci        = ci,
                  holdout   = NULL,
                  error_scale = error_scale,
                  error_benchmark = error_benchmark,
                  dates     = dates,
                  binary_class = binary_class,
                  seed      = seed)

  quant_pred <- model$quant_pred
  quant_pred <- cbind(quant_pred, pred_scores = pred_scores)

  lower_b <- paste0((1 - ci) / 2 * 100, "%")
  upper_b <- paste0((ci + (1 - ci) / 2) * 100, "%")
  center <- quant_pred[, "50%"]
  lower_c <- center - q_conf
  upper_c <- center + q_conf
  if (isTRUE(binary_class)) {
    lower_c <- pmax(0, lower_c); upper_c <- pmin(1, upper_c)
  }
  quant_pred[, lower_b] <- lower_c
  quant_pred[, upper_b] <- upper_c

  list(quant_pred = quant_pred, errors = errors)
}

# --------------- engine (kernel + Dirichlet uncertainty) ----------
engine <- function(ts, seq_len, similarity, dist_method,
                   rescale = FALSE, ci = 0.8, holdout = NULL,
                   error_scale = "naive", error_benchmark = "naive",
                   dates = NULL, binary_class, seed = 42) {

  diffmodel <- recursive_diff(ts, best_deriv(ts))
  dts <- diffmodel$vector

  reframed <- smart_reframer(dts, seq_len, seq_len)
  n_row <- nrow(reframed)

  D <- compute_distance(reframed, method = dist_method, p = 3, use_cache = is.null(holdout))
  D[upper.tri(D)] <- NA

  dvec <- D[n_row, ]
  bw <- stats::quantile(dvec, probs = 1 - similarity, na.rm = TRUE)
  if (!is.finite(bw) || bw <= 0) bw <- stats::median(dvec[dvec > 0], na.rm = TRUE)
  if (!is.finite(bw) || bw <= 0) bw <- 1.0

  kern <- function(d) exp(-(d * d) / (2 * (bw * bw) + 1e-12))
  w <- kern(dvec); w[is.na(w)] <- 0
  if (isTRUE(rescale)) w <- minmax(w, 0, 1)
  if (sum(w) <= 0) w[n_row] <- 1
  w <- w / sum(w)

  point_fc <- as.numeric(colSums(reframed * w))

  set.seed(seed)
  n_draws <- 128L
  kappa <- 256
  alpha <- pmax(w, 1e-12) * kappa

  draw_one <- function() {
    z <- stats::rgamma(n_row, shape = alpha, rate = 1)
    wi <- z / sum(z)
    colSums(reframed * wi)
  }
  raw_pred <- t(replicate(n_draws, draw_one()))
  rownames(raw_pred) <- NULL

  raw_pred <- t(as.data.frame(
    map(narray::split(raw_pred, along = 1),
        ~ invdiff(.x, diffmodel$tail_value))
  ))
  raw_pred <- rbind(raw_pred, point_fc)

  q <- qpred(raw_pred, holdout, ts, ci, error_scale, error_benchmark, binary_class, dates, seed)
  return(q)
}

# ---------------- qpred & guards ----------------
qpred <- function(raw_pred, holdout_truth = NULL, ts, ci,
                  error_scale = "naive", error_benchmark = "naive",
                  binary_class = FALSE, dates, seed = 42) {

  set.seed(seed)
  raw_pred <- doxa_filter(ts, raw_pred, binary_class)
  quants <- sort(unique(c((1 - ci) / 2, 0.25, 0.5, 0.75, ci + (1 - ci) / 2)))

  if (!binary_class) {
    p_stats <- function(x) {
      nx <- x[is.finite(x)]
      c(
        min       = suppressWarnings(min(nx, na.rm = TRUE)),
        stats::quantile(nx, probs = quants, na.rm = TRUE),
        max       = suppressWarnings(max(nx, na.rm = TRUE)),
        mean      = mean(nx, na.rm = TRUE),
        sd        = sd(nx, na.rm = TRUE),
        mode      = suppressWarnings(mlv1(nx, method = "shorth")),
        kurtosis  = suppressWarnings(kurtosis(nx, na.rm = TRUE)),
        skewness  = suppressWarnings(skewness(nx, na.rm = TRUE))
      )
    }
    quant_pred <- as.data.frame(t(apply(raw_pred, 2, p_stats)))

    p_value <- apply(raw_pred, 2, function(x) {
      rng <- range(raw_pred, na.rm = TRUE)
      stats::ecdf(x)(seq(rng[1], rng[2], length.out = 1000))
    })
    divergence <- c(
      max(p_value[, 1] - seq(0, 1, length.out = 1000)),
      apply(p_value[, -1, drop = FALSE] - p_value[, -ncol(p_value), drop = FALSE],
            2, function(x) abs(max(x, na.rm = TRUE)))
    )

    last_val <- tail(ts, 1)
    ratio_first <- raw_pred[, 1] / last_val
    upside_prob <- c(mean(ratio_first > 1, na.rm = TRUE),
                     apply(raw_pred[, -1, drop = FALSE] / raw_pred[, -ncol(raw_pred), drop = FALSE] > 1,
                           2, mean, na.rm = TRUE))

    iqr_to_range <- (quant_pred[, "75%"] - quant_pred[, "25%"]) /
      pmax(quant_pred[, "max"] - quant_pred[, "min"], 1e-12)
    above_to_below_range <- (quant_pred[, "max"] - quant_pred[, "50%"]) /
      pmax(quant_pred[, "50%"] - quant_pred[, "min"], 1e-12)

    quant_pred <- round(cbind(quant_pred,
                              iqr_to_range = iqr_to_range,
                              above_to_below_range = above_to_below_range,
                              upside_prob = upside_prob,
                              divergence = divergence), 4)
  } else {
    p_stats <- function(x) {
      nx <- x[is.finite(x)]
      c(
        min     = suppressWarnings(min(nx, na.rm = TRUE)),
        stats::quantile(nx, probs = quants, na.rm = TRUE),
        max     = suppressWarnings(max(nx, na.rm = TRUE)),
        prop    = mean(nx, na.rm = TRUE),
        sd      = sd(nx, na.rm = TRUE),
        entropy = entropy(nx)
      )
    }
    quant_pred <- as.data.frame(t(apply(raw_pred, 2, p_stats)))

    p_value <- apply(raw_pred, 2, function(x) stats::ecdf(x)(c(0, 1)))
    divergence <- c(
      max(p_value[, 1] - c(0, 1)),
      apply(p_value[, -1, drop = FALSE] - p_value[, -ncol(p_value), drop = FALSE],
            2, function(x) abs(max(x, na.rm = TRUE)))
    )
    upgrade_prob <- c(
      mean(((raw_pred[, 1] + 1) / (tail(ts, 1) + 1)) > 1, na.rm = TRUE),
      apply(((raw_pred[, -1, drop = FALSE] + 1) / (raw_pred[, -ncol(raw_pred), drop = FALSE] + 1)) > 1,
            2, mean, na.rm = TRUE)
    )
    quant_pred <- round(cbind(quant_pred,
                              upgrade_prob = upgrade_prob,
                              divergence = divergence), 4)
  }

  testing_error <- NULL
  if (!is.null(holdout_truth)) {
    mean_pred <- colMeans(raw_pred)
    testing_error <- custom_metrics(holdout_truth, mean_pred, ts, error_scale, error_benchmark, binary_class)
    pred_scores <- round(prediction_score(raw_pred, holdout_truth), 4)
    quant_pred <- cbind(quant_pred, pred_scores = pred_scores)
  }

  if (is.Date(dates)) {
    step <- mean(diff(dates))
    new_dates <- seq.Date(from = tail(dates, 1) + step, by = step, length.out = nrow(quant_pred))
    rownames(quant_pred) <- as.character(new_dates)
  } else {
    rownames(quant_pred) <- paste0("t", seq_len(nrow(quant_pred)))
  }

  list(quant_pred = quant_pred, testing_error = testing_error)
}

doxa_filter <- function(ts, mat, binary_class = FALSE) {
  discrete_check <- all(ts %% 1 == 0)
  all_positive_check <- all(ts >= 0)
  all_negative_check <- all(ts <= 0)
  monotonic_increase_check <- all(diff(ts) >= 0)
  monotonic_decrease_check <- all(diff(ts) <= 0)

  monotonic_fixer <- function(x, mode) {
    model <- recursive_diff(x, 1)
    vect <- model$vector
    if (mode == 0) { vect[vect < 0] <- 0; vect <- invdiff(vect, model$head_value, add = TRUE) }
    if (mode == 1) { vect[vect > 0] <- 0; vect <- invdiff(vect, model$head_value, add = TRUE) }
    vect
  }

  if (all_positive_check) mat[mat < 0] <- 0
  if (all_negative_check) mat[mat > 0] <- 0
  if (discrete_check)     mat <- round(mat)
  if (monotonic_increase_check) mat <- t(apply(mat, 1, function(x) monotonic_fixer(x, mode = 0)))
  if (monotonic_decrease_check) mat <- t(apply(mat, 1, function(x) monotonic_fixer(x, mode = 1)))

  if (binary_class) {
    mat[mat > 1] <- 1
    mat[mat < 0] <- 0
  }

  mat <- na.omit(mat)
  mat
}

custom_metrics <- function(holdout, forecast, actuals,
                           error_scale = "naive", error_benchmark = "naive",
                           binary_class = FALSE) {

  if (!binary_class) {
    scale_raw <- switch(error_scale,
                        "deviation" = sd(actuals),
                        "naive"     = mean(abs(diff(actuals))))
    scale_val <- safe_scale(scale_raw)

    benchmark <- switch(error_benchmark,
                        "average" = rep(mean(actuals), length(forecast)),
                        "naive"   = rep(tail(actuals, 1), length(forecast)))

    me    <- ME(holdout, forecast, na.rm = TRUE)
    mae   <- MAE(holdout, forecast, na.rm = TRUE)
    mse   <- MSE(holdout, forecast, na.rm = TRUE)
    rmsse <- RMSSE(holdout, forecast, scale_val, na.rm = TRUE)
    mre   <- MRE(holdout, forecast, na.rm = TRUE)
    mpe   <- MPE(holdout, forecast, na.rm = TRUE)
    mape  <- MAPE(holdout, forecast, na.rm = TRUE)
    rmae  <- rMAE(holdout, forecast, benchmark, na.rm = TRUE)
    rrmse <- rRMSE(holdout, forecast, benchmark, na.rm = TRUE)
    rame  <- rAME(holdout, forecast, benchmark, na.rm = TRUE)
    mase  <- MASE(holdout, forecast, scale_val, na.rm = TRUE)
    smse  <- sMSE(holdout, forecast, scale_val, na.rm = TRUE)
    sce   <- sCE(holdout, forecast, scale_val, na.rm = TRUE)
    gmrae <- GMRAE(holdout, forecast, benchmark, na.rm = TRUE)

    out <- round(c(
      me = me, mae = mae, mse = mse, rmsse = rmsse, mpe = mpe, mape = mape,
      rmae = rmae, rrmse = rrmse, rame = rame, mase = mase, smse = smse,
      sce = sce, gmrae = gmrae
    ), 3)
  } else {
    M <- rbind(holdout, forecast)
    dice        <- suppressMessages(distance(M, method = "dice"))
    jaccard     <- suppressMessages(distance(M, method = "jaccard"))
    cosine      <- suppressMessages(distance(M, method = "cosine"))
    canberra    <- suppressMessages(distance(M, method = "canberra"))
    gower       <- suppressMessages(distance(M, method = "gower"))
    tanimoto    <- suppressMessages(distance(M, method = "tanimoto"))
    hassebrook  <- 1 - suppressMessages(distance(M, method = "hassebrook"))
    taneja      <- suppressMessages(distance(M, method = "taneja"))
    lorentzian  <- suppressMessages(distance(M, method = "lorentzian"))
    clark       <- suppressMessages(distance(M, method = "clark"))
    sorensen    <- suppressMessages(distance(M, method = "sorensen"))
    harmonic_m  <- suppressMessages(distance(M, method = "harmonic_mean"))
    avg         <- suppressMessages(distance(M, method = "avg"))

    out <- round(c(dice, jaccard, cosine, canberra, gower, tanimoto, hassebrook,
                   taneja, lorentzian, clark, sorensen, harmonic_m, avg), 4)
  }
  out
}
