#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::plugins(cpp11)]]

using namespace Rcpp;
using namespace arma;

//--------------------------------------------------------------
// l1_proj: projects a vector v onto an L1 ball of radius b
// Uses Duchi et al. (2008) algorithm for efficient projection
// NOTE: kept for compatibility; no longer used by the projector.
//--------------------------------------------------------------
vec l1_proj(const vec &v, const double b) {
  if (b <= 0) {
    stop("Radius b must be positive, got %g", b);
  }
  
  // Early exit when already feasible
  const double v_norm1 = norm(v, 1);
  if (v_norm1 <= b + 1e-12) return v;
  
  // Compute sorted absolute values in descending order
  vec u = sort(abs(v), "descend");
  vec sv = cumsum(u);
  
  int rho_idx = -1;
  for (uword i = 0; i < u.n_elem; ++i) {
    const double t = (sv(i) - b) / (i + 1.0);
    if (u(i) >= t - 1e-15) rho_idx = static_cast<int>(i);
  }
  if (rho_idx < 0) rho_idx = 0;
  
  const double theta = std::max(0.0, (sv(rho_idx) - b) / (rho_idx + 1.0));
  return sign(v) % max(abs(v) - theta, zeros<vec>(v.n_elem));
}

//--------------------------------------------------------------
// Internal function to analyze matrix and recommend parameters
//--------------------------------------------------------------
struct OptimizationParams {
  double epsilon;
  double mu;
  double tolerance;
  int    max_iterations;
  bool   needs_projection;
  double mu_decay_rate;
  int    mu_decay_freq;
};

OptimizationParams analyze_matrix(const mat& A,
                                  double epsilon_override = -1.0,
                                  double mu_override = -1.0,
                                  double tol_override = -1.0,
                                  int max_iter_override = -1,
                                  bool verbose = false) {
  OptimizationParams params;
  
  const uword p = A.n_rows;
  
  vec eigval;
  eig_sym(eigval, 0.5 * (A + A.t()));
  
  const double min_eig      = eigval.min();
  const double max_eig      = eigval.max();
  const double matrix_scale = norm(A, "fro");
  const double avg_diag     = mean(A.diag());
  
  // Epsilon: minimum eigenvalue constraint
  if (epsilon_override > 0) {
    params.epsilon = epsilon_override;
  } else {
    // Simple relative floor tied to scale of diagonal
    const double rel = std::max(1e-12, 1e-8 * std::max(1.0, std::abs(avg_diag)));
    params.epsilon = (min_eig > 0.0) ? std::min(1e-4, 0.1 * min_eig) : rel;
    // Adjust for large p (slightly larger floor helps numerics at scale)
    if (p > 500) params.epsilon = std::max(params.epsilon, 5e-8 * std::abs(avg_diag));
  }
  
  // Projection needed?
  params.needs_projection = (min_eig < params.epsilon - 1e-14);
  
  // The rest are retained for backward compatibility with verbose reports
  if (mu_override > 0) {
    params.mu = mu_override;
  } else {
    // Placeholder; projector no longer uses mu
    params.mu = std::max(1.0, avg_diag);
  }
  
  // Note: these are not used by the non-iterative projector
  params.mu_decay_rate = 2.0;
  params.mu_decay_freq = 20;
  
  params.tolerance = (tol_override > 0) ? tol_override : 1e-4;
  params.max_iterations = (max_iter_override > 0) ? max_iter_override :
    (p > 200 ? std::min(2000, 1000 + static_cast<int>(p)) : 1000);
  
  if (verbose) {
    Rcout << "Matrix Analysis:\n";
    Rcout << "  Dimension: " << p << " x " << p << "\n";
    Rcout << "  Min eigenvalue: " << min_eig << "\n";
    Rcout << "  Max eigenvalue: " << max_eig << "\n";
    Rcout << "  Frobenius norm: " << matrix_scale << "\n";
    Rcout << "  Needs projection: " << (params.needs_projection ? "Yes" : "No") << "\n";
    Rcout << "Selected parameters:\n";
    Rcout << "  epsilon: " << params.epsilon << "\n";
    Rcout << "  (mu/tol/max_iter retained for compatibility; projector is non-iterative)\n";
  }
  
  return params;
}

//--------------------------------------------------------------
// Fast, nearest-PSD projection with eigenvalue floor
// - Returns the Frobenius-nearest matrix to A with lambda_min >= eps
// - Early-exits via Cholesky test when already feasible
//--------------------------------------------------------------
static inline arma::mat psd_clip_with_floor(const arma::mat& A_in,
                                            double eps,
                                            bool verbose) {
  const uword p = A_in.n_rows;
  arma::mat A = 0.5 * (A_in + A_in.t());
  
  // Early exit via Cholesky on (A - eps*I)
  if (eps <= 0.0) {
    // no floor required beyond PSD; test plain PD
    arma::mat R;
    if (chol(R, A)) {
      if (verbose) Rcout << "PSD check via Cholesky: already PD; returning input.\n";
      return A;
    }
  } else {
    arma::mat B = A;
    B.diag() -= eps;
    arma::mat R;
    if (chol(R, B)) {
      if (verbose) Rcout << "PSD floor check via Cholesky: already ≥ epsilon; returning input.\n";
      return A;
    }
  }
  
  // One-shot eigendecomposition + clipping
  arma::vec eigval;
  arma::mat eigvec;
  bool ok = eig_sym(eigval, eigvec, A);
  if (!ok) {
    // tiny jitter to handle borderline non-symmetric numerics
    A.diag() += 1e-12;
    eig_sym(eigval, eigvec, A);
  }
  
  // Clip eigenvalues
  eigval = max(eigval, eps * ones<vec>(p));
  
  arma::mat R = eigvec * diagmat(eigval) * eigvec.t();
  R = 0.5 * (R + R.t());
  
  if (verbose) {
    Rcout << "psd_clip: min eigenvalue after clip = " << eigval.min() << "\n";
  }
  return R;
}

//--------------------------------------------------------------
// maxproj_cov: positive definite projection
// [[Rcpp::export]]
arma::mat maxproj_cov(const arma::mat& input_mat,
                      double epsilon = -1.0,   // <0 means auto-select
                      double mu = -1.0,        // ignored (kept for API)
                      int nitr_max = -1,       // ignored (kept for API)
                      double etol = -1.0,      // ignored (kept for API)
                      bool verbose = false) {
  
  mat A = 0.5 * (input_mat + input_mat.t());
  
  if (!A.is_finite()) {
    stop("Input matrix contains NaN or infinite values");
  }
  
  // Compute recommended epsilon cheaply and only as needed
  double eps_use = epsilon;
  if (eps_use < 0.0) {
    const double avg_diag = mean(A.diag());
    // relative floor anchored to scale of A; ensure tiny absolute minimum
    eps_use = std::max(1e-12, 1e-8 * std::max(1.0, std::abs(avg_diag)));
  }
  
  // Single-shot projection (no ADMM)
  arma::mat R = psd_clip_with_floor(A, eps_use, verbose);
  
  if (!R.is_finite()) {
    stop("Result contains NaN or infinite values");
  }
  
  return R;
}

//--------------------------------------------------------------
// getResCov: Compute residual covariance with automatic PD projection
// [[Rcpp::export]]
arma::mat getResCov(const arma::mat& E_input,
                    int n,
                    const arma::mat& rho_mat,
                    double eps = -1.0,      // <0 means auto-select
                    double mu = -1.0,       // kept for API; unused by projector
                    int max_iter = -1,      // kept for API; unused by projector
                    double tol = -1.0,      // kept for API; unused by projector
                    bool use_pairwise = false,
                    bool verbose = false) {
  
  const uword nrows = E_input.n_rows;
  const uword ncols = E_input.n_cols;
  
  if (rho_mat.n_rows != ncols || rho_mat.n_cols != ncols) {
    stop("Dimension mismatch: rho_mat must be %d x %d, got %d x %d",
         ncols, ncols, rho_mat.n_rows, rho_mat.n_cols);
  }
  if (n <= 1) stop("Sample size n must be greater than 1");
  
  arma::mat E = E_input;
  
  // Detect missing if pairwise mode is requested
  bool has_missing = false;
  if (use_pairwise) {
    for (uword i = 0; i < nrows && !has_missing; ++i) {
      for (uword j = 0; j < ncols && !has_missing; ++j) {
        if (!std::isfinite(E(i, j))) has_missing = true;
      }
    }
  }
  
  // Column centering (robust to missing if use_pairwise; otherwise zeros-impute)
  vec  col_means(ncols, fill::zeros);
  uvec valid_counts(ncols, fill::zeros);
  
  for (uword j = 0; j < ncols; ++j) {
    vec col = E.col(j);
    uvec idx = find_finite(col);
    valid_counts(j) = idx.n_elem;
    
    if (idx.n_elem > 0) {
      const double m = mean(col(idx));
      col_means(j) = m;
      
      for (uword i = 0; i < nrows; ++i) {
        if (std::isfinite(E(i, j))) E(i, j) -= m;
        else                        E(i, j) = 0.0; // zero-impute in non-pairwise path
      }
    } else {
      E.col(j).zeros();
      col_means(j) = 0.0;
      warning("Column %d has no valid observations", j + 1);
    }
  }
  
  const uword min_valid = valid_counts.min();
  if (min_valid < 2) {
    warning("Some columns have fewer than 2 valid observations");
  }
  
  arma::mat residual_cov;
  
  if (use_pairwise && has_missing) {
    // Pairwise complete observations
    residual_cov = zeros<arma::mat>(ncols, ncols);
    arma::mat pair_counts = zeros<arma::mat>(ncols, ncols);
    
    for (uword i = 0; i < ncols; ++i) {
      for (uword j = i; j < ncols; ++j) {
        double sum_prod = 0.0;
        uword count = 0;
        
        for (uword k = 0; k < nrows; ++k) {
          const bool fi = std::isfinite(E_input(k, i));
          const bool fj = std::isfinite(E_input(k, j));
          if (fi && fj) {
            sum_prod += (E_input(k, i) - col_means(i)) * (E_input(k, j) - col_means(j));
            ++count;
          }
        }
        
        if (count > 1) {
          const double val = (sum_prod / (count - 1.0)) / rho_mat(i, j);
          residual_cov(i, j) = val;
          residual_cov(j, i) = val;
          pair_counts(i, j) = pair_counts(j, i) = count;
        } else {
          residual_cov(i, j) = (i == j) ? 0.01 : 0.0;
          residual_cov(j, i) = residual_cov(i, j);
        }
      }
    }
    
    if (verbose && pair_counts.min() < 10) {
      Rcout << "Warning: Some variable pairs have fewer than 10 observations\n";
    }
  } else {
    // Standard computation (no missing or not pairwise)
    arma::mat XtX = E.t() * E;
    residual_cov = XtX / static_cast<double>(n - 1);
    
    // element-wise division by rho_mat
    residual_cov %= arma::pow(rho_mat, -1.0);
  }
  
  residual_cov = 0.5 * (residual_cov + residual_cov.t());
  
  if (!residual_cov.is_finite()) {
    stop("Computed covariance matrix contains NaN or infinite values");
  }
  
  // Auto-select epsilon if not provided
  double epsilon_to_use = eps;
  if (epsilon_to_use < 0) {
    const double matrix_scale = norm(residual_cov, "fro");
    const double data_scale   = std::sqrt(matrix_scale / static_cast<double>(ncols));
    if (n < 30)      epsilon_to_use = std::max(1e-4,  0.01  * data_scale);
    else if (n < 100)epsilon_to_use = std::max(1e-5,  0.001 * data_scale);
    else             epsilon_to_use = std::max(1e-6,  0.0001* data_scale);
    if (ncols > 50)  epsilon_to_use = std::max(epsilon_to_use,
        std::sqrt(static_cast<double>(ncols)/50.0) * 1e-6);
  }
  
  if (verbose) Rcout << "Auto-selected epsilon: " << epsilon_to_use << "\n";
  
  // fast PSD projection with eigenvalue floor
  arma::mat proj_cov = maxproj_cov(residual_cov, epsilon_to_use, -1.0, -1, -1.0, verbose);
  return proj_cov;
}

//--------------------------------------------------------------
// Convenience function for simple positive definite projection
// [[Rcpp::export]]
arma::mat make_positive_definite(const arma::mat& A,
                                 double epsilon = -1.0,
                                 bool verbose = false) {
  return maxproj_cov(A, epsilon, -1.0, -1, -1.0, verbose);
}

//--------------------------------------------------------------
// Diagnostic function for matrix properties
// [[Rcpp::export]]
List diagnose_matrix(const arma::mat& A, bool verbose = false) {
  if (A.n_rows != A.n_cols) {
    stop("Matrix must be square");
  }
  
  const uword p = A.n_rows;
  
  mat A_sym = 0.5 * (A + A.t());
  
  vec eigval;
  eig_sym(eigval, A_sym);
  
  const double min_eig = eigval.min();
  const double max_eig = eigval.max();
  const double cond_num = (std::abs(min_eig) > 1e-16)
    ? max_eig / std::abs(min_eig)
      : std::numeric_limits<double>::infinity();
  
  // Recommended parameters
  OptimizationParams params = analyze_matrix(A_sym, -1.0, -1.0, -1.0, -1, verbose);
  
  // Sparsity diagnostics
  uword zero_count = 0, small_count = 0;
  for (uword i = 0; i < p; ++i) {
    for (uword j = 0; j < p; ++j) {
      const double a = std::abs(A(i, j));
      if (a < 1e-10) ++zero_count;
      if (a < 1e-3)  ++small_count;
    }
  }
  
  const double sparsity      = static_cast<double>(zero_count)  / static_cast<double>(p * p);
  const double near_sparsity = static_cast<double>(small_count) / static_cast<double>(p * p);
  
  const double sym_err = norm(A - A.t(), "fro") / std::max(1e-16, norm(A, "fro"));
  const bool   is_sym  = sym_err < 1e-12;
  
  return List::create(
    Named("dimension")               = p,
    Named("min_eigenvalue")          = min_eig,
    Named("max_eigenvalue")          = max_eig,
    Named("condition_number")        = cond_num,
    Named("is_positive_definite")    = (min_eig > 1e-12),
    Named("is_positive_semidefinite")= (min_eig > -1e-12),
    Named("needs_projection")        = params.needs_projection,
    Named("frobenius_norm")          = norm(A, "fro"),
    Named("is_symmetric")            = is_sym,
    Named("symmetry_error")          = sym_err,
    Named("sparsity")                = sparsity,
    Named("near_sparsity")           = near_sparsity,
    Named("recommended_epsilon")     = params.epsilon,
    // Note: mu/tol/max_iter are advisory only (projector is non-iterative)
    Named("recommended_mu")          = params.mu,
    Named("recommended_tolerance")   = params.tolerance,
    Named("recommended_max_iter")    = params.max_iterations
  );
}

//--------------------------------------------------------------
// Batch processing for multiple matrices
// [[Rcpp::export]]
List process_covariance_batch(const List& E_list,
                              const IntegerVector& n_vec,
                              const List& rho_list,
                              double eps = -1.0,
                              bool verbose = false) {
  
  const int n_matrices = E_list.size();
  if (n_vec.size() != n_matrices || rho_list.size() != n_matrices) {
    stop("All input lists must have the same length");
  }
  
  List results(n_matrices);
  for (int i = 0; i < n_matrices; ++i) {
    if (verbose) {
      Rcout << "\nProcessing matrix " << (i + 1) << " of " << n_matrices << "...\n";
    }
    const mat E   = as<mat>(E_list[i]);
    const int n   = n_vec[i];
    const mat rho = as<mat>(rho_list[i]);
    results[i] = getResCov(E, n, rho, eps, -1.0, -1, -1.0, verbose);
  }
  return results;
}
