/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.hmc;

import dr.inference.hmc.DerivativeWrtParameterProvider;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.hmc.ParallelGradientExecutor;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.DerivativeOrder;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Future;

public class JointGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
DerivativeWrtParameterProvider,
Reportable {
    private final int dimension;
    private final Likelihood likelihood;
    private final Parameter parameter;
    private final ParallelGradientExecutor parallelExecutor;
    final List<GradientWrtParameterProvider> derivativeList;
    final List<DerivativeWrtParameterProvider> newDerivativeList;
    private final DerivativeOrder highestOrder;
    private static final boolean DEBUG = false;
    private static final boolean DEBUG_KILL = false;

    public JointGradient(List<GradientWrtParameterProvider> derivativeList) {
        this(derivativeList, 0);
    }

    public JointGradient(List<GradientWrtParameterProvider> derivativeList, int threadCount) {
        this.derivativeList = derivativeList;
        GradientWrtParameterProvider first = derivativeList.get(0);
        this.dimension = first.getDimension();
        this.parameter = first.getParameter();
        if (derivativeList.size() == 1) {
            this.likelihood = first.getLikelihood();
        } else {
            ArrayList<Likelihood> likelihoodList = new ArrayList<Likelihood>();
            for (GradientWrtParameterProvider grad : derivativeList) {
                if (grad.getDimension() != this.dimension) {
                    throw new RuntimeException("Unequal parameter dimensions");
                }
                if (!Arrays.equals(grad.getParameter().getParameterValues(), this.parameter.getParameterValues())) {
                    throw new RuntimeException("Unequal parameter values");
                }
                for (Likelihood likelihood : grad.getLikelihood().getLikelihoodSet()) {
                    if (likelihoodList.contains(likelihood)) continue;
                    likelihoodList.add(likelihood);
                }
            }
            this.likelihood = new CompoundLikelihood(likelihoodList);
        }
        this.newDerivativeList = new ArrayList<DerivativeWrtParameterProvider>();
        for (GradientWrtParameterProvider p : derivativeList) {
            if (!(p instanceof DerivativeWrtParameterProvider)) continue;
            DerivativeWrtParameterProvider provider = (DerivativeWrtParameterProvider)p;
            this.newDerivativeList.add(provider);
        }
        this.highestOrder = DerivativeWrtParameterProvider.getHighestOrder(this.newDerivativeList);
        this.parallelExecutor = threadCount > 1 || threadCount < 0 ? new ParallelGradientExecutor(threadCount, derivativeList) : null;
    }

    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    public Parameter getParameter() {
        return this.parameter;
    }

    public int getDimension(DerivativeOrder order) {
        return order.getDerivativeDimension(this.dimension);
    }

    public int getDimension() {
        return this.dimension;
    }

    public double[] getDerivativeLogDensity(DerivativeOrder type) {
        assert (this.highestOrder.getValue() >= type.getValue());
        int size = this.newDerivativeList.size();
        double[] derivative = this.newDerivativeList.get(0).getDerivativeLogDensity(type);
        int i = 1;
        while (i < size) {
            double[] temp = this.newDerivativeList.get(i).getDerivativeLogDensity(type);
            int j = 0;
            while (j < temp.length) {
                int n = j;
                derivative[n] = derivative[n] + temp[j];
                ++j;
            }
            ++i;
        }
        return derivative;
    }

    public DerivativeOrder getHighestOrder() {
        return this.highestOrder;
    }

    public double[] getDiagonalHessianLogDensity() {
        return this.getDerivativeLogDensity(DerivativeType.DIAGONAL_HESSIAN);
    }

    public double[][] getHessianLogDensity() {
        assert (this.derivativeList.get(0) instanceof HessianWrtParameterProvider);
        int size = this.derivativeList.size();
        double[][] hessian = ((HessianWrtParameterProvider)this.derivativeList.get(0)).getHessianLogDensity();
        int i = 1;
        while (i < size) {
            assert (this.derivativeList.get(i) instanceof HessianWrtParameterProvider);
            double[][] temp = ((HessianWrtParameterProvider)this.derivativeList.get(i)).getHessianLogDensity();
            int j = 0;
            while (j < temp[0].length) {
                int k = 0;
                while (k < temp[0].length) {
                    double[] dArray = hessian[j];
                    int n = k;
                    dArray[n] = dArray[n] + temp[j][k];
                    ++k;
                }
                ++j;
            }
            ++i;
        }
        return hessian;
    }

    double[] getDerivativeLogDensity(DerivativeType derivativeType) {
        if (this.parallelExecutor != null) {
            return this.getDerivativeLogDensityParallelImpl(derivativeType);
        }
        return this.getDerivativeLogDensitySerialImpl(derivativeType);
    }

    private double[] getDerivativeLogDensityParallelImpl(DerivativeType derivativeType) {
        return this.parallelExecutor.getDerivativeLogDensityInParallel(derivativeType, (gradients, length) -> {
            double[] reduction = new double[length];
            for (Future result : gradients) {
                double[] tmp = (double[])result.get();
                int j = 0;
                while (j < length) {
                    int n = j;
                    reduction[n] = reduction[n] + tmp[j];
                    ++j;
                }
            }
            return reduction;
        }, this.dimension);
    }

    private double[] getDerivativeLogDensitySerialImpl(DerivativeType derivativeType) {
        int size = this.derivativeList.size();
        double[] derivative = derivativeType.getDerivativeLogDensity(this.derivativeList.get(0));
        int i = 1;
        while (i < size) {
            double[] temp = derivativeType.getDerivativeLogDensity(this.derivativeList.get(i));
            int j = 0;
            while (j < temp.length) {
                int n = j;
                derivative[n] = derivative[n] + temp[j];
                ++j;
            }
            ++i;
        }
        return derivative;
    }

    public double[] getGradientLogDensity() {
        return this.getDerivativeLogDensity(DerivativeType.GRADIENT);
    }

    public String getReport() {
        return "jointGradient." + this.parameter.getParameterName() + "\n" + GradientWrtParameterProvider.getReportAndCheckForError((GradientWrtParameterProvider)this, (double)Double.NEGATIVE_INFINITY, (double)Double.POSITIVE_INFINITY, (Double)GradientWrtParameterProvider.TOLERANCE);
    }

    static enum DerivativeType {
        GRADIENT("gradient"){

            @Override
            public double[] getDerivativeLogDensity(GradientWrtParameterProvider gradientWrtParameterProvider) {
                return gradientWrtParameterProvider.getGradientLogDensity();
            }
        }
        ,
        DIAGONAL_HESSIAN("diagonalHessian"){

            @Override
            public double[] getDerivativeLogDensity(GradientWrtParameterProvider gradientWrtParameterProvider) {
                return ((HessianWrtParameterProvider)gradientWrtParameterProvider).getDiagonalHessianLogDensity();
            }
        };

        private final String type;

        private DerivativeType(String type) {
            this.type = type;
        }

        public abstract double[] getDerivativeLogDensity(GradientWrtParameterProvider var1);
    }
}

