/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.discrete.DiscreteTraitBranchRateGradient;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.Loggable;
import dr.inference.model.Parameter;
import dr.math.NumericalDerivative;
import dr.xml.Reportable;

public abstract class HyperParameterBranchRateGradient
extends DiscreteTraitBranchRateGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable,
Loggable {
    protected final ArbitraryBranchRates.BranchRateTransform.LocationScaleLogNormal locationScaleTransform;

    protected HyperParameterBranchRateGradient(String string, TreeDataLikelihood treeDataLikelihood, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, Parameter parameter, boolean bl) {
        super(string, treeDataLikelihood, beagleDataLikelihoodDelegate, parameter, bl);
        if (!(this.branchRateModel.getTransform() instanceof ArbitraryBranchRates.BranchRateTransform.LocationScaleLogNormal)) {
            throw new IllegalArgumentException("Must provide a LocationScaleLogNormal transform.");
        }
        this.locationScaleTransform = (ArbitraryBranchRates.BranchRateTransform.LocationScaleLogNormal)this.branchRateModel.getTransform();
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = new double[this.rateParameter.getDimension()];
        double[] dArray2 = super.getGradientLogDensity();
        int n = 0;
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            double[] dArray3 = this.getDifferential(this.tree, nodeRef);
            for (int j = 0; j < dArray.length; ++j) {
                int n2 = j;
                dArray[n2] = dArray[n2] + dArray2[n] * dArray3[j];
            }
            ++n;
        }
        return dArray;
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        return NumericalDerivative.diagonalHessian(this.numeric1, this.rateParameter.getParameterValues());
    }

    abstract double[] getDifferential(Tree var1, NodeRef var2);
}

