#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::plugins(cpp11)]]

using namespace Rcpp;
using namespace arma;

//--------------------------------------------------------------
// Fast cache-friendly computation of evaluation error
// Computes: ||Y - X*Beta||_F^2 / n using sufficient statistics
// [[Rcpp::export]]
double getEvalErr(const arma::mat& yty,
                  const arma::mat& ytx,
                  const arma::mat& xty,
                  const arma::mat& xtx,
                  const arma::mat& Beta,
                  const int n) {
  
  (void)ytx; // ytx not needed here; keep param for API compatibility
  
  if (!yty.is_finite() || !xty.is_finite() || !xtx.is_finite() || !Beta.is_finite()) {
    stop("Input matrices contain NaN or infinite values");
  }
  
  double err = trace(yty);                 // ||Y||_F^2
  // Compute -2 * trace(Beta.t() * xty) efficiently
  // trace(Beta.t() * xty) = trace(xty.t() * Beta) = accu(xty % Beta)
  err -= 2.0 * accu(xty % Beta);
  
  // cache-efficient for large matrices
  // trace(B^T * xtx * B) = sum_ij B_ij * (xtx*B)_ij
  const mat temp = xtx * Beta;             // p x q
  err += accu(Beta % temp);
  
  double result = err / static_cast<double>(n);
  
  // clamp tiny negative due to FP error
  if (!std::isfinite(result) || result < 0) {
    if (result < 0 && result > -1e-12) result = 0.0;
    else {
      warning("Computed evaluation error is negative or non-finite: %g", result);
      result = std::max(0.0, result);
    }
  }
  return result;
}

//--------------------------------------------------------------
// Q function for Gaussian graphical models
// Q(S, Theta) = trace(S * Theta) - log|Theta|
// [[Rcpp::export]]
double Q_func(const arma::mat& S, const arma::mat& Theta) {
  
  if (!S.is_finite() || !Theta.is_finite()) {
    stop("Input matrices contain NaN or infinite values");
  }
  
  // symmetrize for stability
  const mat S_sym     = 0.5 * (S + S.t());
  const mat Theta_sym = 0.5 * (Theta + Theta.t());
  
  const double trace_ST = accu(S_sym % Theta_sym);
  
  // Prefer Cholesky. If fails, use eigvals.
  double log_det_val = 0.0;
  
  {
    mat R;
    if (chol(R, Theta_sym)) {
      // log|Theta| = 2 * sum(log(diag(R)))
      const vec d = R.diag();
      // d must be positive for PD
      log_det_val = 2.0 * accu(log(d));
    } else {
      // Fallback: eigenvalues with relative tolerance
      vec evals;
      if (!eig_sym(evals, Theta_sym)) {
        stop("Failed to compute eigenvalues of Theta");
      }
      const double max_abs_eig = std::max(1.0, std::abs(evals.max()));
      const double tol = std::max(1e-12, 1e-12 * max_abs_eig); // relative tol
      
      if (evals.min() <= tol) {
        warning("Theta is not positive definite (min eigenvalue: %g, tol: %g)", evals.min(), tol);
        return datum::inf;
      }
      log_det_val = accu(log(evals));
    }
  }
  
  const double result = trace_ST - log_det_val;
  if (!std::isfinite(result)) {
    warning("Q function result is not finite");
    return datum::inf;
  }
  return result;
}

//--------------------------------------------------------------
// Goodness-of-fit function
// [[Rcpp::export]]
double GoF_func(const std::string& GoF,
                const double gamma,
                const int n,
                const int p,
                const int q,
                const arma::mat& Theta,
                const arma::mat& S,
                const arma::mat& Beta,
                const double sparsity_tol = 1e-8) {
  
  if (gamma < 0 || gamma > 1) {
    stop("gamma must be in [0, 1], got %g", gamma);
  }
  if (!Theta.is_finite() || !S.is_finite() || !Beta.is_finite()) {
    stop("Input matrices contain NaN or infinite values");
  }
  
  // Log-likelihood constants
  const double log_2pi  = std::log(2.0 * M_PI);
  const double const1   = -0.5 * static_cast<double>(n * q) * log_2pi;
  const double const2   =  0.5 * static_cast<double>(n);
  
  const double Q_val = Q_func(S, Theta);
  if (!std::isfinite(Q_val)) {
    warning("Q function returned non-finite value");
    return datum::inf;
  }
  const double log_likelihood = const1 - const2 * Q_val;
  
  // scale-aware sparsity threshold
  const double maxB  = (Beta.n_elem > 0)  ? abs(Beta).max()  : 0.0;
  const mat Theta_off = Theta - diagmat(Theta.diag());
  const double maxTh = (Theta_off.n_elem > 0) ? abs(Theta_off).max() : 0.0;
  const double scale = std::max(1.0, std::max(maxB, maxTh));
  const double tol = std::max(sparsity_tol, 1e-12 * scale);
  
  // Degrees of freedom
  const uword dfB = accu(abs(Beta) > tol);
  
  uword dfTheta = 0;
  for (uword j = 0; j + 1 < static_cast<uword>(q); ++j) {
    dfTheta += accu(abs(Theta(span(j + 1, q - 1), j)) > tol);  // lower triangle
  }
  const uword total_df = dfB + dfTheta;
  
  // Information criteria
  const double log_n = std::log(static_cast<double>(n));
  double penalty = 0.0;
  
  if (GoF == "AIC") {
    penalty = 2.0 * static_cast<double>(total_df);
  } else if (GoF == "BIC") {
    penalty = log_n * static_cast<double>(total_df);
  } else if (GoF == "eBIC") {
    const double log_p = std::log(static_cast<double>(p));
    const double log_q = std::log(static_cast<double>(q));
    const double term1 = log_n * static_cast<double>(total_df);
    const double term2 = 2.0 * gamma * (2.0 * log_q * static_cast<double>(dfTheta) +
                                        log_p * static_cast<double>(dfB));
    if (!std::isfinite(term1) || !std::isfinite(term2)) {
      warning("eBIC penalty computation resulted in overflow");
      return datum::inf;
    }
    penalty = term1 + term2;
  } else {
    stop("Unknown GoF criterion '%s'. Use 'AIC', 'BIC', or 'eBIC'", GoF.c_str());
  }
  
  const double result = -2.0 * log_likelihood + penalty;
  if (!std::isfinite(result)) {
    warning("GoF function result is not finite");
    return datum::inf;
  }
  return result;
}

//--------------------------------------------------------------
// Elementwise soft-thresholding (matrix lambda)
// [[Rcpp::export]]
arma::mat soft_threshold_mat(const arma::mat& X, const arma::mat& lambda) {
  
  if (!X.is_finite() || !lambda.is_finite()) {
    stop("Input matrices contain NaN or infinite values");
  }
  if (any(vectorise(lambda) < 0)) {
    stop("lambda must be non-negative");
  }
  
  // if effectively zero penalty, return X
  if (lambda.max() <= 1e-16) return X;
  
  return sign(X) % max(abs(X) - lambda, zeros<mat>(size(X)));;
}

//--------------------------------------------------------------
// Elementwise soft-thresholding (scalar lambda)
// [[Rcpp::export]]
arma::mat soft_threshold_scalar(const arma::mat& X, const double lambda) {
  
  if (!X.is_finite()) {
    stop("Input matrix X contains NaN or infinite values");
  }
  if (!std::isfinite(lambda) || lambda < 0) {
    stop("lambda must be a non-negative finite value, got %g", lambda);
  }
  
  if (lambda <= 1e-16) {                 // effectively zero threshold
    return X;
  }
  
  // if lambda >= max|X|, return zero without touching each element
  const double m = abs(X).max();
  if (lambda >= m) return arma::zeros<mat>(size(X));
  
  return sign(X) % max(abs(X) - lambda, zeros<mat>(size(X)));
}

//--------------------------------------------------------------
// Matrix norm of A-B with several types
// [[Rcpp::export]]
double matrix_norm_diff(const arma::mat& A, const arma::mat& B,
                        const std::string& type = "fro") {
  
  if (A.n_rows != B.n_rows || A.n_cols != B.n_cols) {
    stop("Matrices must have same dimensions: A is %d x %d, B is %d x %d",
         A.n_rows, A.n_cols, B.n_rows, B.n_cols);
  }
  if (!A.is_finite() || !B.is_finite()) {
    stop("Input matrices contain NaN or infinite values");
  }
  
  const mat D = A - B;
  double result = 0.0;
  
  if (type == "fro" || type == "F") {
    result = norm(D, "fro");
  } else if (type == "inf") {
    result = norm(D, "inf");
  } else if (type == "1") {
    result = norm(D, 1);
  } else if (type == "2" || type == "spectral") {
    vec s = svd(D);
    result = s.n_elem > 0 ? s(0) : 0.0;
  } else {
    stop("Unknown norm type '%s'. Use 'fro', 'inf', '1', '2', or 'spectral'", type.c_str());
  }
  
  if (!std::isfinite(result) || result < 0) {
    warning("Computed norm is negative or non-finite: %g", result);
    return 0.0;
  }
  return result;
}

//--------------------------------------------------------------
// Symmetric eigendecomposition with optional eigenvalue floor
// [[Rcpp::export]]
List eigen_decomp_sym(const arma::mat& A, const double min_eigval = 0.0) {
  
  if (A.n_rows != A.n_cols) {
    stop("Matrix must be square, got %d x %d", A.n_rows, A.n_cols);
  }
  if (!A.is_finite()) {
    stop("Input matrix contains NaN or infinite values");
  }
  if (!std::isfinite(min_eigval)) {
    stop("min_eigval must be finite, got %g", min_eigval);
  }
  
  mat A_sym = 0.5 * (A + A.t());
  
  vec eigval;
  mat eigvec;
  if (!eig_sym(eigval, eigvec, A_sym)) {
    stop("Eigendecomposition failed");
  }
  if (!eigval.is_finite()) {
    stop("Computed eigenvalues contain NaN or infinite values");
  }
  
  bool projection_applied = false;
  if (min_eigval > eigval.min()) {
    eigval = max(eigval, min_eigval * ones<vec>(eigval.n_elem));
    projection_applied = true;
  }
  
  // Orthogonality check
  const mat Iq = eye<mat>(eigvec.n_cols, eigvec.n_cols);
  const double orth_err = norm(eigvec.t() * eigvec - Iq, "fro");
  if (orth_err > 1e-10) {
    warning("Eigenvectors are not perfectly orthogonal (error: %g)", orth_err);
  }
  
  return List::create(
    Named("values")             = eigval,
    Named("vectors")            = eigvec,
    Named("projection_applied") = projection_applied,
    Named("min_eigenvalue")     = eigval.min(),
    Named("max_eigenvalue")     = eigval.max(),
    Named("condition_number")   = eigval.max() / std::max(eigval.min(), 1e-16)
  );
}

//--------------------------------------------------------------
// Make a matrix symmetric
// [[Rcpp::export]]
arma::mat make_symmetric(const arma::mat& A) {
  
  if (A.n_rows <= 1) return A;
  
  return 0.5 * (A + A.t());
}

//--------------------------------------------------------------
// PD test with Cholesky first, then eigenvalues
// [[Rcpp::export]]
List is_positive_definite(const arma::mat& A, const double tol = 1e-12) {
  
  if (A.n_rows != A.n_cols) {
    stop("Matrix must be square, got %d x %d", A.n_rows, A.n_cols);
  }
  if (!A.is_finite()) {
    stop("Input matrix contains NaN or infinite values");
  }
  if (tol <= 0) {
    stop("Tolerance must be positive, got %g", tol);
  }
  
  // relative symmetry error
  const double rel_asym = norm(A - A.t(), "fro") / std::max(1e-16, norm(A, "fro"));
  const bool is_sym = rel_asym < 1e-12;
  
  if (!is_sym) {
    return List::create(
      Named("is_positive_definite") = false,
      Named("is_symmetric")         = false,
      Named("asymmetry")            = rel_asym,
      Named("min_eigenvalue")       = R_NaN,
      Named("method")               = "symmetry_check"
    );
  }
  
  // Cholesky is fastest for PD
  mat R;
  if (chol(R, A)) {
    // chol success implies PD; report diag min for diagnostics
    return List::create(
      Named("is_positive_definite") = true,
      Named("is_symmetric")         = true,
      Named("asymmetry")            = rel_asym,
      Named("min_diagonal_chol")    = R.diag().min(),
      Named("method")               = "cholesky"
    );
  }
  
  // Fallback: eigenvalues
  vec evals;
  if (!eig_sym(evals, A)) {
    warning("Eigendecomposition failed");
    return List::create(
      Named("is_positive_definite") = false,
      Named("is_symmetric")         = true,
      Named("asymmetry")            = rel_asym,
      Named("method")               = "eigendecomposition_failed"
    );
  }
  
  const double min_eig = evals.min();
  const bool is_pd = min_eig > tol;
  
  return List::create(
    Named("is_positive_definite") = is_pd,
    Named("is_symmetric")         = true,
    Named("asymmetry")            = rel_asym,
    Named("min_eigenvalue")       = min_eig,
    Named("max_eigenvalue")       = evals.max(),
    Named("condition_number")     = evals.max() / std::max(min_eig, 1e-16),
    Named("method")               = "eigenvalues"
  );
}

//--------------------------------------------------------------
// Simple fast trace
// [[Rcpp::export]]
double fast_trace(const arma::mat& A) {
  if (A.n_rows != A.n_cols) {
    stop("Matrix must be square for trace computation");
  }
  if (!A.is_finite()) {
    stop("Input matrix contains NaN or infinite values");
  }
  return trace(A);
}
