/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.preorder;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evomodel.continuous.MultivariateDiffusionModel;
import dr.evomodel.treedatalikelihood.continuous.ConjugateRootTraitPrior;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousRateTransformation;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.evomodel.treedatalikelihood.preorder.NormalSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.TipFullConditionalDistributionDelegate;
import dr.inference.model.MatrixParameterInterface;
import dr.math.matrixAlgebra.missingData.MissingOps;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class TipGradientViaFullConditionalDelegate
extends TipFullConditionalDistributionDelegate {
    private final int offset;
    private final int dimGradient;
    private final int[] subInds;
    private final boolean doSubset;

    public TipGradientViaFullConditionalDelegate(String string, MutableTreeModel mutableTreeModel, MultivariateDiffusionModel multivariateDiffusionModel, ContinuousTraitPartialsProvider continuousTraitPartialsProvider, ConjugateRootTraitPrior conjugateRootTraitPrior, ContinuousRateTransformation continuousRateTransformation, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, int n, int n2) {
        super(string, mutableTreeModel, multivariateDiffusionModel, continuousTraitPartialsProvider, conjugateRootTraitPrior, continuousRateTransformation, continuousDataLikelihoodDelegate);
        this.offset = n;
        this.dimGradient = n2;
        this.subInds = new int[n2];
        for (int i = 0; i < n2; ++i) {
            this.subInds[i] = i + n;
        }
        this.doSubset = n != 0 || n2 != this.dimTrait;
    }

    public static String getName(String string) {
        return "grad." + string;
    }

    @Override
    public String getTraitName(String string) {
        return TipGradientViaFullConditionalDelegate.getName(string);
    }

    @Override
    protected double[] getTraitForNode(NodeRef nodeRef) {
        if (this.likelihoodDelegate.getPrecisionType() == PrecisionType.SCALAR) {
            return this.getTraitForNodeScalar(nodeRef);
        }
        if (this.likelihoodDelegate.getPrecisionType() == PrecisionType.FULL) {
            return this.getTraitForNodeFull(nodeRef);
        }
        throw new RuntimeException("Tip gradients are not implemented for '" + this.likelihoodDelegate.getPrecisionType().toString() + "' likelihoods");
    }

    private double[] getTraitForNodeScalar(NodeRef nodeRef) {
        double[] dArray = super.getTraitForNode(nodeRef);
        double[] dArray2 = new double[this.dimPartial * this.numTraits];
        this.cdi.getPostOrderPartial(this.likelihoodDelegate.getActiveNodeIndex(nodeRef.getNumber()), dArray2);
        MatrixParameterInterface matrixParameterInterface = this.diffusionModel.getPrecisionParameter();
        double[] dArray3 = new double[this.dimTrait * this.numTraits];
        if (this.numTraits > 1) {
            throw new RuntimeException("Not yet implemented");
        }
        double d = dArray[this.dimTrait];
        for (int i = 0; i < this.dimTrait; ++i) {
            double d2 = 0.0;
            for (int j = 0; j < this.dimTrait; ++j) {
                d2 += (dArray[j] - dArray2[j]) * d * matrixParameterInterface.getParameterValue(i, j);
            }
            dArray3[i] = d2;
        }
        return dArray3;
    }

    protected double[] getTraitForNodeFull(NodeRef nodeRef) {
        DenseMatrix64F denseMatrix64F;
        if (this.numTraits > 1) {
            throw new RuntimeException("Not yet implemented");
        }
        double[] dArray = super.getTraitForNode(nodeRef);
        NormalSufficientStatistics normalSufficientStatistics = new NormalSufficientStatistics(dArray, 0, this.dimTrait, this.Pd, this.likelihoodDelegate.getPrecisionType());
        DenseMatrix64F denseMatrix64F2 = normalSufficientStatistics.getRawPrecisionCopy();
        DenseMatrix64F denseMatrix64F3 = normalSufficientStatistics.getRawMeanCopy();
        double[] dArray2 = new double[this.dimPartial * this.numTraits];
        int n = this.likelihoodDelegate.getActiveNodeIndex(nodeRef.getNumber());
        this.cdi.getPostOrderPartial(n, dArray2);
        DenseMatrix64F denseMatrix64F4 = MissingOps.wrap(dArray2, 0, this.dimTrait, 1);
        if (this.doSubset) {
            denseMatrix64F2 = MissingOps.gatherRowsAndColumns(denseMatrix64F2, this.subInds, this.subInds);
            denseMatrix64F3 = MissingOps.gatherRowsAndColumns(denseMatrix64F3, this.subInds, new int[]{0});
            denseMatrix64F4 = MissingOps.gatherRowsAndColumns(denseMatrix64F4, this.subInds, new int[]{0});
        }
        if (this.dataModel.needToUpdateTipDataGradient(this.offset, this.dimGradient)) {
            denseMatrix64F = normalSufficientStatistics.getRawVarianceCopy();
            if (this.doSubset) {
                denseMatrix64F = MissingOps.gatherRowsAndColumns(denseMatrix64F2, this.subInds, this.subInds);
            }
            this.dataModel.updateTipDataGradient(denseMatrix64F2, denseMatrix64F, nodeRef, this.offset, this.dimGradient);
        }
        denseMatrix64F = new DenseMatrix64F(this.dimGradient, this.numTraits);
        CommonOps.addEquals(denseMatrix64F4, -1.0, denseMatrix64F3);
        CommonOps.changeSign(denseMatrix64F4);
        CommonOps.mult(denseMatrix64F2, denseMatrix64F4, denseMatrix64F);
        return denseMatrix64F.getData();
    }
}

