/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.operators;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.continuous.DummyLatentTruncationProvider;
import dr.evomodel.continuous.LatentTruncation;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.RepeatedMeasuresTraitDataModel;
import dr.evomodel.treedatalikelihood.preorder.ConditionalPrecisionAndTransform;
import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.WrappedTipFullConditionalDistributionDelegate;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.MissingOps;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.List;
import org.ejml.alg.dense.mult.MatrixVectorMult;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class NewLatentLiabilityGibbs
extends SimpleMCMCOperator {
    private static final String NEW_LATENT_LIABILITY_GIBBS_OPERATOR = "newlatentLiabilityGibbsOperator";
    private static final String MAX_ATTEMPTS = "numAttempts";
    private static final String MISSING_BY_COLUMN = "missingByColumn";
    private static final String FORCE_ALL_MISSING = "forceAllMissing";
    private final LatentTruncation latentLiability;
    private final CompoundParameter tipTraitParameter;
    private final TreeTrait<List<WrappedNormalSufficientStatistics>> fullConditionalDensity;
    private final ContinuousTraitPartialsProvider extensionProvider;
    private int maxAttempts;
    private final Tree treeModel;
    private final int dim;
    private Parameter mask;
    private final MaskIndicesDelegate maskDelegate;
    private final Boolean missingByColumn;
    private final int[] needSampling;
    private double[] fcdMean;
    private double[][] fcdPrecision;
    private double[][] fcdVaraince;
    private double[] maskedMean;
    private double[][] maskedPrecision;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private static final String MASK = "mask";
        private static final String PARTIALS_PROVIDER = "partialsProvider";
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("weight"), AttributeRule.newBooleanRule("missingByColumn", true), new ElementRule(TreeDataLikelihood.class, "The model for the latent random variables"), new ElementRule(LatentTruncation.class, "The model that links latent and observed variables"), new ElementRule("mask", Parameter.class, "Mask: 1 for latent variables that should be sampled", true), new ElementRule(CompoundParameter.class, "The parameter of tip locations from the tree"), new ElementRule("partialsProvider", RepeatedMeasuresTraitDataModel.class, "Provides information about model extensions", true)};

        @Override
        public String getParserName() {
            return NewLatentLiabilityGibbs.NEW_LATENT_LIABILITY_GIBBS_OPERATOR;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            if (xMLObject.getChildCount() < 3) {
                throw new XMLParseException("Element with id = '" + xMLObject.getName() + "' should contain:\n\t 1 conjugate multivariateTraitLikelihood, 1 latentLiabilityLikelihood and one parameter \n");
            }
            double d = xMLObject.getDoubleAttribute("weight");
            TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood)xMLObject.getChild(TreeDataLikelihood.class);
            LatentTruncation latentTruncation = (LatentTruncation)xMLObject.getChild(LatentTruncation.class);
            CompoundParameter compoundParameter = (CompoundParameter)xMLObject.getChild(CompoundParameter.class);
            int n = xMLObject.getAttribute(NewLatentLiabilityGibbs.MAX_ATTEMPTS, 100000);
            boolean bl = xMLObject.getAttribute(NewLatentLiabilityGibbs.MISSING_BY_COLUMN, true);
            Parameter parameter = null;
            if (xMLObject.hasChildNamed(MASK)) {
                parameter = (Parameter)xMLObject.getElementFirstChild(MASK);
            }
            RepeatedMeasuresTraitDataModel repeatedMeasuresTraitDataModel = null;
            if (xMLObject.hasChildNamed(PARTIALS_PROVIDER)) {
                repeatedMeasuresTraitDataModel = (RepeatedMeasuresTraitDataModel)xMLObject.getElementFirstChild(PARTIALS_PROVIDER);
            }
            if (xMLObject.getAttribute(NewLatentLiabilityGibbs.FORCE_ALL_MISSING, false).booleanValue()) {
                int n2 = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
                parameter = new Parameter.Default(n2);
                for (int i = 0; i < n2; ++i) {
                    parameter.setParameterValue(i, 1.0);
                }
                bl = true;
            }
            return new NewLatentLiabilityGibbs(treeDataLikelihood, latentTruncation, compoundParameter, repeatedMeasuresTraitDataModel, parameter, d, "latent", n, bl);
        }

        @Override
        public String getParserDescription() {
            return "This element returns a gibbs sampler on tip latent trais for latent liability model.";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public NewLatentLiabilityGibbs(TreeDataLikelihood treeDataLikelihood, LatentTruncation latentTruncation, CompoundParameter compoundParameter, ContinuousTraitPartialsProvider continuousTraitPartialsProvider, Parameter parameter, double d, String string, int n, boolean bl) {
        this.latentLiability = latentTruncation;
        this.tipTraitParameter = compoundParameter;
        this.treeModel = treeDataLikelihood.getTree();
        this.extensionProvider = continuousTraitPartialsProvider;
        ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate = (ContinuousDataLikelihoodDelegate)treeDataLikelihood.getDataLikelihoodDelegate();
        this.dim = continuousDataLikelihoodDelegate.getTraitDim();
        String string2 = WrappedTipFullConditionalDistributionDelegate.getName(string);
        if (treeDataLikelihood.getTreeTrait(string2) == null) {
            continuousDataLikelihoodDelegate.addWrappedFullConditionalDensityTrait(string);
        }
        this.fullConditionalDensity = this.castTreeTrait(treeDataLikelihood.getTreeTrait(string2));
        this.missingByColumn = bl;
        this.mask = parameter;
        this.maskDelegate = new MaskIndicesDelegate();
        this.needSampling = this.setupNeedSampling();
        this.fcdMean = new double[this.dim];
        this.fcdVaraince = new double[this.dim][this.dim];
        this.fcdPrecision = new double[this.dim][this.dim];
        this.maxAttempts = n;
        this.setWeight(d);
    }

    public int getStepCount() {
        return 1;
    }

    @Override
    public double doOperation() {
        int n = this.needSampling.length;
        int n2 = MathUtils.nextInt(n);
        NodeRef nodeRef = this.treeModel.getExternalNode(this.needSampling[n2]);
        List<WrappedNormalSufficientStatistics> list = this.fullConditionalDensity.getTrait(this.treeModel, nodeRef);
        WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics = list.get(0);
        double d = this.sampleNode(nodeRef, wrappedNormalSufficientStatistics);
        this.tipTraitParameter.fireParameterChangedEvent();
        return d;
    }

    private double[] getNodeTrait(NodeRef nodeRef) {
        int n = nodeRef.getNumber();
        return this.tipTraitParameter.getParameter(n).getParameterValues();
    }

    private void setNodeTrait(NodeRef nodeRef, double[] dArray) {
        int n = nodeRef.getNumber();
        if (this.mask == null) {
            Parameter parameter = this.tipTraitParameter.getParameter(n);
            for (int i = 0; i < this.dim; ++i) {
                parameter.setParameterValueQuietly(i, dArray[i]);
            }
            parameter.fireParameterChangedEvent(-1, Variable.ChangeType.ALL_VALUES_CHANGED);
        } else {
            int n2 = 0;
            Parameter parameter = this.tipTraitParameter.getParameter(n);
            for (int n3 : this.maskDelegate.getLatentIndices(nodeRef)) {
                parameter.setParameterValueQuietly(n3, dArray[n2]);
                ++n2;
            }
            parameter.fireParameterChangedEvent(-1, Variable.ChangeType.ALL_VALUES_CHANGED);
        }
    }

    private double sampleNode(NodeRef nodeRef, WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics) {
        MultivariateNormalDistribution multivariateNormalDistribution;
        Object object;
        int n;
        int n2 = nodeRef.getNumber();
        int[] nArray = this.maskDelegate.getObservedIndices(nodeRef);
        int n3 = nArray.length;
        if (n3 == this.dim) {
            return 0.0;
        }
        WrappedVector wrappedVector = wrappedNormalSufficientStatistics.getMean();
        WrappedMatrix wrappedMatrix = wrappedNormalSufficientStatistics.getPrecision();
        double d = wrappedNormalSufficientStatistics.getPrecisionScalar();
        for (n = 0; n < wrappedVector.getDim(); ++n) {
            this.fcdMean[n] = wrappedVector.get(n);
        }
        for (n = 0; n < wrappedVector.getDim(); ++n) {
            for (int i = 0; i < wrappedVector.getDim(); ++i) {
                this.fcdPrecision[n][i] = wrappedMatrix.get(n, i) * d;
            }
        }
        if (this.extensionProvider != null) {
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.fcdPrecision);
            double[] dArray = this.extensionProvider.getTipPartial(n2, false);
            int n4 = this.dim;
            DenseMatrix64F denseMatrix64F2 = MissingOps.wrap(dArray, n4, this.dim, this.dim);
            DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(this.dim, 1);
            for (int i = 0; i < this.dim; ++i) {
                denseMatrix64F3.set(i, 0, this.fcdMean[i]);
            }
            object = new DenseMatrix64F(this.dim, 1);
            for (int i = 0; i < this.dim; ++i) {
                ((DenseMatrix64F)object).set(i, 0, dArray[i]);
            }
            DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(this.dim, 1);
            MatrixVectorMult.mult(denseMatrix64F, denseMatrix64F3, denseMatrix64F4);
            MatrixVectorMult.multAdd(denseMatrix64F2, (D1Matrix64F)object, denseMatrix64F4);
            CommonOps.addEquals(denseMatrix64F2, denseMatrix64F);
            DenseMatrix64F denseMatrix64F5 = new DenseMatrix64F(this.dim, this.dim);
            CommonOps.invert(denseMatrix64F2, denseMatrix64F5);
            MatrixVectorMult.mult(denseMatrix64F5, denseMatrix64F4, (D1Matrix64F)object);
            for (int i = 0; i < this.dim; ++i) {
                this.fcdMean[i] = ((D1Matrix64F)object).get(i);
                for (int j = 0; j < this.dim; ++j) {
                    this.fcdPrecision[i][j] = denseMatrix64F2.get(i, j);
                }
            }
        }
        MultivariateNormalDistribution multivariateNormalDistribution2 = new MultivariateNormalDistribution(this.fcdMean, this.fcdPrecision);
        if (this.mask != null && n3 > 0) {
            this.addMaskOnContiuousTraitsPrecisionSpace(n2);
            multivariateNormalDistribution = new MultivariateNormalDistribution(this.maskedMean, this.maskedPrecision);
        } else {
            multivariateNormalDistribution = multivariateNormalDistribution2;
        }
        double[] dArray = this.getNodeTrait(nodeRef);
        int n5 = 0;
        boolean bl = false;
        while (!bl & n5 < this.maxAttempts) {
            this.setNodeTrait(nodeRef, multivariateNormalDistribution.nextMultivariateNormal());
            if (this.latentLiability.validTraitForTip(n2)) {
                bl = true;
            }
            ++n5;
        }
        if (n5 == this.maxAttempts) {
            return Double.NEGATIVE_INFINITY;
        }
        object = this.getNodeTrait(nodeRef);
        if (this.latentLiability instanceof DummyLatentTruncationProvider) {
            return Double.POSITIVE_INFINITY;
        }
        return multivariateNormalDistribution2.logPdf(dArray) - multivariateNormalDistribution2.logPdf((double[])object);
    }

    private void addMaskOnContiuousTraitsPrecisionSpace(int n) {
        double[] dArray = new double[this.dim];
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = this.tipTraitParameter.getParameterValues()[n * this.dim + i];
        }
        ConditionalPrecisionAndTransform conditionalPrecisionAndTransform = new ConditionalPrecisionAndTransform(new Matrix(this.fcdPrecision), this.maskDelegate.getLatentIndices(n), this.maskDelegate.getObservedIndices(n));
        this.maskedPrecision = conditionalPrecisionAndTransform.getConditionalPrecision().toComponents();
        this.maskedMean = conditionalPrecisionAndTransform.getConditionalMean(dArray, 0, this.fcdMean, 0);
    }

    private int[] convertListToArray(List<Integer> list) {
        int[] nArray = new int[list.size()];
        int n = 0;
        for (int n2 : list) {
            nArray[n++] = n2;
        }
        return nArray;
    }

    private int[] setupNeedSampling() {
        int n = this.treeModel.getExternalNodeCount();
        ArrayList<Integer> arrayList = new ArrayList<Integer>();
        for (int i = 0; i < n; ++i) {
            int n2 = this.maskDelegate.getObservedIndices(i).length;
            if (n2 == this.dim) continue;
            arrayList.add(i);
        }
        return this.convertListToArray(arrayList);
    }

    private TreeTrait<List<WrappedNormalSufficientStatistics>> castTreeTrait(TreeTrait treeTrait) {
        return treeTrait;
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override
    public String getOperatorName() {
        return NEW_LATENT_LIABILITY_GIBBS_OPERATOR;
    }

    private class MaskIndicesDelegate {
        int[] latentColumns = null;
        int[] observedColumns = null;

        private MaskIndicesDelegate() {
            if (NewLatentLiabilityGibbs.this.mask != null && NewLatentLiabilityGibbs.this.missingByColumn.booleanValue()) {
                ArrayList<Integer> arrayList = new ArrayList<Integer>();
                ArrayList<Integer> arrayList2 = new ArrayList<Integer>();
                for (int i = 0; i < NewLatentLiabilityGibbs.this.dim; ++i) {
                    if (NewLatentLiabilityGibbs.this.mask.getParameterValue(i) == 1.0) {
                        arrayList.add(i);
                        continue;
                    }
                    arrayList2.add(i);
                }
                this.latentColumns = NewLatentLiabilityGibbs.this.convertListToArray(arrayList);
                this.observedColumns = NewLatentLiabilityGibbs.this.convertListToArray(arrayList2);
            }
        }

        private int[] getLatentIndices(NodeRef nodeRef) {
            return this.getLatentIndices(nodeRef.getNumber());
        }

        private int[] getLatentIndices(int n) {
            if (NewLatentLiabilityGibbs.this.missingByColumn.booleanValue()) {
                return this.latentColumns;
            }
            int n2 = NewLatentLiabilityGibbs.this.dim * n;
            ArrayList<Integer> arrayList = new ArrayList<Integer>();
            for (int i = n2; i < n2 + NewLatentLiabilityGibbs.this.dim; ++i) {
                if (NewLatentLiabilityGibbs.this.mask.getParameterValue(i) != 1.0) continue;
                arrayList.add(i - n2);
            }
            return NewLatentLiabilityGibbs.this.convertListToArray(arrayList);
        }

        private int[] getObservedIndices(int n) {
            if (NewLatentLiabilityGibbs.this.missingByColumn.booleanValue()) {
                return this.observedColumns;
            }
            int n2 = NewLatentLiabilityGibbs.this.dim * n;
            ArrayList<Integer> arrayList = new ArrayList<Integer>();
            for (int i = n2; i < n2 + NewLatentLiabilityGibbs.this.dim; ++i) {
                if (NewLatentLiabilityGibbs.this.mask.getParameterValue(i) != 0.0) continue;
                arrayList.add(i - n2);
            }
            return NewLatentLiabilityGibbs.this.convertListToArray(arrayList);
        }

        private int[] getObservedIndices(NodeRef nodeRef) {
            return this.getObservedIndices(nodeRef.getNumber());
        }
    }

    protected class MaskIndices {
        final int[] discreteIndices;
        final int[] continuousIndex;

        private MaskIndices(int[] nArray, int[] nArray2) {
            this.discreteIndices = nArray;
            this.continuousIndex = nArray2;
        }
    }
}

