#ifndef lossFunction_h
#define lossFunction_h

#include <memory>
#include <string>
#include <vector>
#include <map>
#include <cmath>
#include "paramType.h"
#include "matrice.h"
#include "types.h"

class LossFunctionMRP {
protected:
    std::uint64_t calls;
    std::shared_ptr<std::vector<std::shared_ptr<Matrice<double>>>> mrps;
    std::shared_ptr<std::map<std::uint64_t, double>> weights;
public:
    LossFunctionMRP(std::shared_ptr<std::vector<std::shared_ptr<Matrice<double>>>> p, std::shared_ptr<std::map<std::uint64_t, double>> w=nullptr) : mrps(p), weights(w) {
        if (weights == nullptr) {
            weights = std::make_shared<std::map<std::uint64_t, double>>();
        }
    }
    virtual ~LossFunctionMRP() {}
    
    virtual std::string to_string() const {
        std::string result = "";
        result += "Calls: " + std::to_string(this->calls);
        return result;
    }

    virtual double operator()(Matrice<double>&) = 0;
};

//************************************
//************************************
//************************************

class LBMRP : public LossFunctionMRP {
public:
    LBMRP(std::shared_ptr<std::vector<std::shared_ptr<Matrice<double>>>> p, std::shared_ptr<std::map<std::uint64_t, double>> w=nullptr) : LossFunctionMRP(p, w) {
        
    }
    ~LBMRP() {}
    double operator()(Matrice<double>& x) {
        double risultato = 0.0;
        
        for (auto ym : *this->mrps) {
            Matrice<double>& mc = *(ym);
            double denominatore = 0.0;
            double numeratore = 0.0;
            
            for (auto const& [riga, peso_r] : *weights) {
                for (auto const& [colonna, peso_c] : *weights) {
            //for (std::uint64_t riga = 0; riga < x.Rows(); ++riga) {
                //for (std::uint64_t colonna = 0; colonna < x.Rows(); ++colonna) {
                    //if (weights->find(riga) != weights->end() && weights->find(colonna) != weights->end()) {
                        //auto peso_r = weights->at(riga);
                        //auto peso_c = weights->at(colonna);
                        auto v_x = x.at(riga, colonna);
                        auto v_mc = mc.at(riga, colonna);
                        numeratore += peso_r * std::fabs(v_x - v_mc) * peso_c;
                        denominatore += peso_r * v_mc * peso_c;
                    //}
                }
            }
            risultato += numeratore / denominatore;
        }
        //risultato = risultato / this->mrps->size();
        
        return risultato;
    }
};

//************************************
//************************************
//************************************

class L1LossFunctionMRP : public LossFunctionMRP {
public:
    L1LossFunctionMRP(std::shared_ptr<std::vector<std::shared_ptr<Matrice<double>>>> p, std::shared_ptr<std::map<std::uint64_t, double>> w=nullptr) : LossFunctionMRP(p, w) {
        
    }
    ~L1LossFunctionMRP() {}
    double operator()(Matrice<double>& x) {
        double risultato = 0.0;
        double denominatore = (x.Rows() * (x.Rows() + 1)) / 2.0;
        for (auto ym : *this->mrps) {
            Matrice<double>& mc = *(ym);
            double numeratore = 0.0;
            for (auto const& [riga, peso_r] : *weights) {
                for (auto const& [colonna, peso_c] : *weights) {
            //for (std::uint64_t riga = 0; riga < x.Rows(); ++riga) {
            //    for (std::uint64_t colonna = riga + 1; colonna < x.Rows(); ++colonna) {
            //        if (weights->find(riga) != weights->end() && weights->find(colonna) != weights->end()) {
                    if (riga < colonna) {
                        numeratore += weights->at(riga) * std::fabs(x.at(riga, colonna) - mc.at(riga, colonna)) * weights->at(colonna);
                    }
                }
            }
            risultato += (2.0 * numeratore) / denominatore;
            
        }
        risultato = risultato / this->mrps->size();
        
        return risultato;
    }
};

//************************************
//************************************
//************************************

class JS0LossFunctionMRP : public LossFunctionMRP {
protected:
    std::uint64_t base;
public:
    JS0LossFunctionMRP(std::shared_ptr<std::vector<std::shared_ptr<Matrice<double>>>> p, std::shared_ptr<std::map<std::uint64_t, double>> w=nullptr, std::uint64_t b=2) : LossFunctionMRP(p, w), base(b) {
    }
    ~JS0LossFunctionMRP() {}
    double operator()(Matrice<double>& x) {
        auto n = x.Rows();
        auto k = this->mrps->size();
        double somma = 0.0;
        for (std::uint64_t j = 0; j < k; ++j) {
            double somma_js = 0.0;
            for (std::uint64_t r = 0; r < n; ++r) {
                for (std::uint64_t s = 0; s < r; ++s) {
                    double mrp_j_r_s = (*this->mrps->at(j))(r, s);
                    std::vector<double> B_j(2);
                    B_j[0] = mrp_j_r_s;
                    B_j[1] = 1 - mrp_j_r_s;
                    std::vector<double> B_x(2);
                    B_x[0] = x(r, s);
                    B_x[1] = 1 - x(r, s);
                    somma_js += JSDivergence(B_j, B_x, base);
                }
            }
            somma += somma_js;
        }
        double risultato = (2.0 / (1.0 * k * n * (n - 1))) * somma;
        return risultato;
    }

    static double Entropy(std::vector<double>& p, std::vector<double>& w, std::uint64_t base=2) {
        double numeratore = 0.0;
        double denominatore = 0.0;
        for (std::uint64_t k = 0; k < p.size(); ++k) {
            auto p_at = p.at(k);
            auto w_at = w.at(k);
            denominatore += w_at;
            if (p_at <= 0) {
                numeratore += w_at;
            } else {
                numeratore += (std::log(p_at) * w_at / std::log(base));
            }
        }
        return -(numeratore / denominatore);
    }
    
    static double KLDivergence(std::vector<double>& p, std::vector<double>& q, std::uint64_t base=2) {
        auto e1 = Entropy(q, p, base);
        auto e2 = Entropy(p, p, base);
        return e1 - e2;
    }
    
    static double JSDivergence(std::vector<double>& p, std::vector<double>& q, std::uint64_t base=2) {
        std::vector<double> m(p.size());
        for (std::uint64_t k = 0; k < p.size(); ++k) {
            m[k] = (p[k] + q[k]) / 2.0;
        }
       return (KLDivergence(p, m, base) + KLDivergence(q, m, base)) / 2.0;
    }
    
};

//************************************
//************************************
//************************************

class JS1LossFunctionMRP : public LossFunctionMRP {
protected:
    std::uint64_t base;
public:
    JS1LossFunctionMRP(std::shared_ptr<std::vector<std::shared_ptr<Matrice<double>>>> p, std::shared_ptr<std::map<std::uint64_t, double>> w=nullptr, std::uint64_t b=2) : LossFunctionMRP(p, w), base(2) {
    }
    ~JS1LossFunctionMRP() {}
    double operator()(Matrice<double>& x) {
        auto n = x.Rows();
        auto k = this->mrps->size();
        double somma_numeratore = 0.0;
        double somma_denominatore = 0.0;
        for (std::uint64_t j = 0; j < k; ++j) {
            double somma_js = 0.0;
            double somma_uno_meno_entropia = 0.0;
            for (std::uint64_t r = 0; r < n; ++r) {
                for (std::uint64_t s = 0; s < r; ++s) {
                    double mrp_j_r_s = (*this->mrps->at(j))(r, s);
                    std::vector<double> B_j(2);
                    B_j[0] = mrp_j_r_s;
                    B_j[1] = 1.0 - mrp_j_r_s;
                    std::vector<double> B_x(2);
                    B_x[0] = x(r, s);
                    B_x[1] = 1 - x(r, s);
                    double uno_meno_h = (1 - JS0LossFunctionMRP::Entropy(B_j, B_j, base));
                    somma_js += JS0LossFunctionMRP::JSDivergence(B_j, B_x, base) * uno_meno_h;
                    somma_uno_meno_entropia += uno_meno_h;
                }
            }
            somma_numeratore += somma_js;
            somma_denominatore += somma_uno_meno_entropia;

        }
        double risultato = somma_numeratore / somma_denominatore;
        return risultato;
    }

};

//************************************
//************************************
//************************************

class JS2LossFunctionMRP : public LossFunctionMRP {
protected:
    std::uint64_t base;
public:
    JS2LossFunctionMRP(std::shared_ptr<std::vector<std::shared_ptr<Matrice<double>>>> p, std::shared_ptr<std::map<std::uint64_t, double>> w=nullptr, std::uint64_t b=2) : LossFunctionMRP(p, w), base(2) {
    }
    ~JS2LossFunctionMRP() {}
    double operator()(Matrice<double>& x) {
        auto n = x.Rows();
        auto k = this->mrps->size();
        double somma_numeratore = 0.0;
        double somma_denominatore = 0.0;
        for (std::uint64_t j = 0; j < k; ++j) {
            double somma_h = 0.0;
            double somma_js = 0.0;
            for (std::uint64_t r = 0; r < n; ++r) {
                for (std::uint64_t s = 0; s < r; ++s) {
                    double mrp_j_r_s = (*this->mrps->at(j))(r, s);
                    std::vector<double> B_j(2);
                    B_j[0] = mrp_j_r_s;
                    B_j[1] = 1 - mrp_j_r_s;
                    std::vector<double> B_x(2);
                    B_x[0] = x(r, s);
                    B_x[1] = 1 - x(r, s);
                    somma_js += JS0LossFunctionMRP::JSDivergence(B_j, B_x, base);
                    somma_h += JS0LossFunctionMRP::Entropy(B_j, B_j, base);
                }
            }
            double h_medio = (2.0 / (n * (n - 1) * 1.0)) * somma_h;
            somma_numeratore += (1 - h_medio) * somma_js;
            somma_denominatore += (1 - h_medio);
        }
        double risultato = (somma_numeratore / somma_denominatore) * 2.0 / (1.0 * n * (n - 1));
        return risultato;
    }

};

//************************************
//************************************
//************************************

class LossFunctionMRPV2 : public LossFunctionMRP  {
public:
    LossFunctionMRPV2(std::shared_ptr<std::vector<std::shared_ptr<Matrice<double>>>> p,
                      std::shared_ptr<std::map<std::uint64_t, double>> w=nullptr) : LossFunctionMRP(p, w) {
    }
    virtual ~LossFunctionMRPV2() {}
    
    virtual std::string to_string() const {
        std::string result = "";
        result += "Calls: " + std::to_string(this->calls);
        return result;
    }
    
    virtual double operator()(Matrice<double>&) = 0;
    virtual double operator()(Matrice<double>&,
                              std::list<std::shared_ptr<std::vector<std::uint_fast64_t>>>&,
                              std::list<std::shared_ptr<std::vector<std::uint_fast64_t>>>&) = 0;
    
    virtual void operator()(Matrice<double>&,
                    std::list<std::shared_ptr<std::vector<std::uint_fast64_t>>>&,
                    std::list<std::shared_ptr<std::vector<std::uint_fast64_t>>>&,
                    std::list<std::pair<std::uint_fast64_t, double>>&) = 0;
};

//************************************
//************************************
//************************************

class LBMRP2 : public LossFunctionMRPV2 {
public:
    LBMRP2(std::shared_ptr<std::vector<std::shared_ptr<Matrice<double>>>> p,
           std::shared_ptr<std::map<std::uint64_t, double>> w=nullptr) : LossFunctionMRPV2(p, w) {
        
    }
    ~LBMRP2() {}
    double operator()(Matrice<double>& x) {
        double risultato = 0.0;
        
        for (auto ym : *this->mrps) {
            Matrice<double>& mc = *(ym);
            double denominatore = 0.0;
            double numeratore = 0.0;
            
            for (auto const& [riga, peso_r] : *weights) {
                for (auto const& [colonna, peso_c] : *weights) {
                    auto v_x = x.at(riga, colonna);
                    auto v_mc = mc.at(riga, colonna);
                    numeratore += peso_r * std::fabs(v_x - v_mc) * peso_c;
                    denominatore += peso_r * v_mc * peso_c;
                }
            }
            risultato += numeratore / denominatore;
        }
        
        return risultato;
    }
    
    double operator()(Matrice<double>& x,
                      std::list<std::shared_ptr<std::vector<std::uint_fast64_t>>>& rows_function,
                      std::list<std::shared_ptr<std::vector<std::uint_fast64_t>>>& cols_function) {
        double risultato = 0.0;
        //std::uint_fast64_t conta = 0;
        for (auto ym : *this->mrps) {
            Matrice<double>& mc = *(ym);
            double denominatore = 0.0;
            double numeratore = 0.0;
            
            for (auto const& [r, peso_r] : *weights) {
                auto riga = r;
                for (auto& l : rows_function) {
                    riga = l->at(riga);
                }
                for (auto const& [c, peso_c] : *weights) {
                    //++conta;
                    auto colonna = c;
                    for (auto& l : cols_function) {
                        colonna = l->at(colonna);
                    }
                    
                    auto v_x = x.at(riga, colonna);
                    auto v_mc = mc.at(r, c);
                    numeratore += peso_r * std::fabs(v_x - v_mc) * peso_c;
                    denominatore += peso_r * v_mc * peso_c;
                }
            }
            risultato += numeratore / denominatore;
        }
        //std::cout << conta << std::endl;
        return risultato;
    }
    
    
    void operator()(Matrice<double>& x,
                    std::list<std::shared_ptr<std::vector<std::uint_fast64_t>>>& rows_function,
                    std::list<std::shared_ptr<std::vector<std::uint_fast64_t>>>& cols_function,
                    std::list<std::pair<std::uint_fast64_t, double>>& result) {
        result.clear();
        //std::uint_fast64_t conta = 0;
        for (auto ym : *this->mrps) {
            Matrice<double>& mc = *(ym);
            for (auto const& [c, peso_c] : *weights) {
                double denominatore = 0.0;
                double numeratore = 0.0;
                for (auto const& [r, peso_r] : *weights) {
                    //++conta;
                    auto riga = r;
                    for (auto& l : rows_function) {
                        riga = l->at(riga);
                    }
                    auto colonna = c;
                    for (auto& l : cols_function) {
                        colonna = l->at(colonna);
                    }
                    
                    auto v_x = x.at(riga, colonna);
                    auto v_mc = mc.at(r, c);
                    numeratore += peso_r * std::fabs(v_x - v_mc);
                    denominatore += peso_r * v_mc;
                }
                result.push_back({c, (denominatore == 0 ? 0 : numeratore / denominatore)});
            }
        }
        return;
    }
    
};



#endif /* lossFunction_hpp */
