/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.antigenic;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.GammaFunction;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.StringAttributeRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;

public class NPAntigenicLikelihood
extends AbstractModelLikelihood {
    public static final String NP_ANTIGENIC_LIKELIHOOD = "NPAntigenicLikelihood";
    Set<NodeRef> allTips;
    CompoundParameter traitParameter;
    Parameter alpha;
    Parameter clusterPrec;
    Parameter priorPrec;
    Parameter priorMean;
    Parameter assignments;
    Parameter links;
    Parameter means2;
    Parameter means1;
    Parameter locationDrift;
    Parameter offsets;
    boolean hasDrift;
    private boolean depMatrixKnown = false;
    private boolean[] dataMatrixKnown;
    private boolean logLikelihoodKnown = false;
    private double logLikelihood = 0.0;
    private boolean[] logLikelihoodsVectorKnown;
    boolean proposedChangeDepMatrix = false;
    boolean proposedChangeDataMatrix = false;
    TreeModel treeModel;
    String traitName;
    double[][] depMatrix;
    double[][] logDepMatrix;
    double[] logLikelihoodsVector;
    double[] storedLogLikelihoodsVector;
    int numdata;
    Parameter transformFactor;
    double k0;
    double v0;
    double[][] T0Inv;
    double[] m;
    double logDetT0;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        public static final String CLUSTER_PREC = "clusterPrec";
        public static final String PRIOR_PREC = "priorPrec";
        public static final String PRIOR_MEAN = "priorMean";
        public static final String ASSIGNMENTS = "assignments";
        public static final String LINKS = "links";
        public static final String MEANS_1 = "clusterMeans1";
        public static final String MEANS_2 = "clusterMeans2";
        public static final String TRANSFORM_FACTOR = "transformFactor";
        public static final String CHI = "chi";
        public static final String OFFSETS = "offsets";
        public static final String LOCATION_DRIFT = "locationDrift";
        boolean integrate = false;
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new StringAttributeRule("traitName", "The name of the trait for which a likelihood should be calculated"), new ElementRule("traitParameter", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("priorPrec", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("clusterPrec", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("priorMean", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("assignments", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("links", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("transformFactor", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("clusterMeans1", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("clusterMeans2", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("chi", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("offsets", Parameter.class), new ElementRule("locationDrift", Parameter.class), new ElementRule(TreeModel.class)};

        @Override
        public String getParserName() {
            return NPAntigenicLikelihood.NP_ANTIGENIC_LIKELIHOOD;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            TreeModel treeModel = (TreeModel)xMLObject.getChild(TreeModel.class);
            XMLObject xMLObject2 = xMLObject.getChild(CLUSTER_PREC);
            Parameter parameter = (Parameter)xMLObject2.getChild(Parameter.class);
            xMLObject2 = xMLObject.getChild(PRIOR_PREC);
            Parameter parameter2 = (Parameter)xMLObject2.getChild(Parameter.class);
            xMLObject2 = xMLObject.getChild(PRIOR_MEAN);
            Parameter parameter3 = (Parameter)xMLObject2.getChild(Parameter.class);
            xMLObject2 = xMLObject.getChild(ASSIGNMENTS);
            Parameter parameter4 = (Parameter)xMLObject2.getChild(Parameter.class);
            xMLObject2 = xMLObject.getChild(LINKS);
            Parameter parameter5 = (Parameter)xMLObject2.getChild(Parameter.class);
            xMLObject2 = xMLObject.getChild(MEANS_2);
            Parameter parameter6 = (Parameter)xMLObject2.getChild(Parameter.class);
            xMLObject2 = xMLObject.getChild(MEANS_1);
            Parameter parameter7 = (Parameter)xMLObject2.getChild(Parameter.class);
            xMLObject2 = xMLObject.getChild(CHI);
            Parameter parameter8 = (Parameter)xMLObject2.getChild(Parameter.class);
            xMLObject2 = xMLObject.getChild(TRANSFORM_FACTOR);
            Parameter parameter9 = (Parameter)xMLObject2.getChild(Parameter.class);
            xMLObject2 = xMLObject.getChild(LOCATION_DRIFT);
            Parameter parameter10 = (Parameter)xMLObject2.getChild(Parameter.class);
            xMLObject2 = xMLObject.getChild(OFFSETS);
            Parameter parameter11 = (Parameter)xMLObject2.getChild(Parameter.class);
            boolean bl = false;
            if (parameter11.getDimension() > 1) {
                bl = true;
            }
            TreeTraitParserUtilities treeTraitParserUtilities = new TreeTraitParserUtilities();
            TreeTraitParserUtilities.TraitsAndMissingIndices traitsAndMissingIndices = treeTraitParserUtilities.parseTraitsFromTaxonAttributes(xMLObject, (Tree)treeModel, this.integrate);
            CompoundParameter compoundParameter = traitsAndMissingIndices.traitParameter;
            return new NPAntigenicLikelihood(treeModel, compoundParameter, parameter4, parameter5, parameter8, parameter, parameter3, parameter2, parameter9, parameter7, parameter6, parameter10, parameter11, bl);
        }

        @Override
        public String getParserDescription() {
            return "conditional likelihood ddCRP";
        }

        @Override
        public Class getReturnType() {
            return NPAntigenicLikelihood.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public NPAntigenicLikelihood(TreeModel treeModel, CompoundParameter compoundParameter, Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, Parameter parameter6, Parameter parameter7, Parameter parameter8, Parameter parameter9, Parameter parameter10, Parameter parameter11, Boolean bl) {
        super(NP_ANTIGENIC_LIKELIHOOD);
        this.assignments = parameter;
        this.links = parameter2;
        this.clusterPrec = parameter4;
        this.priorPrec = parameter6;
        this.priorMean = parameter5;
        this.treeModel = treeModel;
        this.traitParameter = compoundParameter;
        this.transformFactor = parameter7;
        this.means1 = parameter8;
        this.means2 = parameter9;
        this.alpha = parameter3;
        this.locationDrift = parameter10;
        this.offsets = parameter11;
        this.hasDrift = false;
        this.addVariable(compoundParameter);
        this.addVariable(parameter);
        this.addVariable(parameter2);
        this.addModel(treeModel);
        this.addVariable(parameter3);
        this.addVariable(parameter7);
        this.addVariable(this.alpha);
        this.addVariable(parameter11);
        this.numdata = compoundParameter.getParameterCount();
        this.allTips = TreeUtils.getExternalNodes(treeModel, treeModel.getRoot());
        this.setDepMatrix();
        for (int i = 0; i < this.numdata; ++i) {
            parameter.setParameterValue(i, i);
            parameter2.setParameterValue(i, i);
        }
        this.logLikelihoodsVector = new double[parameter2.getDimension() + 1];
        this.logLikelihoodsVectorKnown = new boolean[parameter2.getDimension() + 1];
        this.storedLogLikelihoodsVector = new double[parameter2.getDimension() + 1];
        this.m = new double[2];
        this.m[0] = parameter5.getParameterValue(0);
        this.m[1] = parameter5.getParameterValue(1);
        this.v0 = 2.0;
        this.k0 = parameter6.getParameterValue(0) / parameter4.getParameterValue(0);
        this.T0Inv = new double[2][2];
        this.T0Inv[0][0] = this.v0 / parameter4.getParameterValue(0);
        this.T0Inv[1][1] = this.v0 / parameter4.getParameterValue(0);
        this.T0Inv[1][0] = 0.0;
        this.T0Inv[0][1] = 0.0;
        this.logDetT0 = -Math.log(this.T0Inv[0][0] * this.T0Inv[1][1]);
    }

    private void setDepMatrix() {
        this.depMatrixKnown = true;
        this.depMatrix = new double[this.numdata][this.numdata];
        ArrayList arrayList = new ArrayList();
        this.recursion(this.treeModel.getRoot(), arrayList);
        this.logCorrectMatrix(this.transformFactor.getParameterValue(0));
        this.logDepMatrix = new double[this.numdata][this.numdata];
        for (int i = 0; i < this.numdata; ++i) {
            for (int j = 0; j < i; ++j) {
                this.logDepMatrix[i][j] = Math.log(this.depMatrix[i][j]);
                this.logDepMatrix[j][i] = this.logDepMatrix[j][i];
            }
        }
    }

    public double getLogLikGroup(int n) {
        double d = 0.0;
        int n2 = 0;
        for (int i = 0; i < this.assignments.getDimension(); ++i) {
            if ((int)this.assignments.getParameterValue(i) != n) continue;
            ++n2;
        }
        if (n2 != 0) {
            double[][] dArray = new double[n2][2];
            double[] dArray2 = new double[2];
            int n3 = 0;
            for (int i = 0; i < this.assignments.getDimension(); ++i) {
                if ((int)this.assignments.getParameterValue(i) != n) continue;
                dArray[n3][0] = this.getData(i, 0);
                dArray[n3][1] = this.getData(i, 0);
                dArray2[0] = dArray2[0] + dArray[n3][0];
                dArray2[1] = dArray2[1] + dArray[n3][1];
                ++n3;
            }
            dArray2[0] = dArray2[0] / (double)n2;
            dArray2[1] = dArray2[1] / (double)n2;
            double d2 = this.k0 + (double)n2;
            double d3 = this.v0 + (double)n2;
            double[][] dArray3 = new double[2][2];
            for (int i = 0; i < n2; ++i) {
                double[] dArray4 = dArray3[0];
                dArray4[0] = dArray4[0] + (dArray[i][0] - dArray2[0]) * (dArray[i][0] - dArray2[0]);
                double[] dArray5 = dArray3[0];
                dArray5[1] = dArray5[1] + (dArray[i][0] - dArray2[0]) * (dArray[i][1] - dArray2[1]);
                double[] dArray6 = dArray3[1];
                dArray6[0] = dArray6[0] + (dArray[i][0] - dArray2[0]) * (dArray[i][1] - dArray2[1]);
                double[] dArray7 = dArray3[1];
                dArray7[1] = dArray7[1] + (dArray[i][1] - dArray2[1]) * (dArray[i][1] - dArray2[1]);
            }
            double[][] dArray8 = new double[2][2];
            dArray8[0][0] = this.T0Inv[0][0] + (double)n2 * (this.k0 / d2) * (dArray2[0] - this.m[0]) * (dArray2[0] - this.m[0]) + dArray3[0][0];
            dArray8[0][1] = this.T0Inv[0][1] + (double)n2 * (this.k0 / d2) * (dArray2[1] - this.m[1]) * (dArray2[0] - this.m[0]) + dArray3[0][1];
            dArray8[1][0] = this.T0Inv[1][0] + (double)n2 * (this.k0 / d2) * (dArray2[0] - this.m[0]) * (dArray2[1] - this.m[1]) + dArray3[1][0];
            dArray8[1][1] = this.T0Inv[1][1] + (double)n2 * (this.k0 / d2) * (dArray2[1] - this.m[1]) * (dArray2[1] - this.m[1]) + dArray3[1][1];
            double d4 = -Math.log(dArray8[0][0] * dArray8[1][1] - dArray8[0][1] * dArray8[1][0]);
            d += (double)(-n2) * Math.log(Math.PI);
            d += Math.log(this.k0) - Math.log(d2);
            d += d3 / 2.0 * d4 - this.v0 / 2.0 * this.logDetT0;
            d += GammaFunction.lnGamma(d3 / 2.0) + GammaFunction.lnGamma(d3 / 2.0 - 0.5);
            d += -GammaFunction.lnGamma(this.v0 / 2.0) - GammaFunction.lnGamma(this.v0 / 2.0 - 0.5);
        }
        this.logLikelihoodsVectorKnown[n] = true;
        return d;
    }

    @Override
    public Model getModel() {
        return this;
    }

    public double[] getLogLikelihoodsVector() {
        return this.logLikelihoodsVector;
    }

    public Parameter getLinks() {
        return this.links;
    }

    public Parameter getAssignments() {
        return this.assignments;
    }

    public double getData(int n, int n2) {
        return this.traitParameter.getParameter(n).getParameterValue(n2);
    }

    public double[][] getDepMatrix() {
        return this.depMatrix;
    }

    public double[][] getLogDepMatrix() {
        return this.logDepMatrix;
    }

    public Parameter getPriorMean() {
        return this.priorMean;
    }

    public Parameter getPriorPrec() {
        return this.priorPrec;
    }

    public Parameter getClusterPrec() {
        return this.clusterPrec;
    }

    public void setLogLikelihoodsVector(int n, double d) {
        this.logLikelihoodsVector[n] = d;
    }

    public void setAssingments(int n, double d) {
        this.assignments.setParameterValue(n, d);
    }

    public void setLinks(int n, double d) {
        this.links.setParameterValue(n, d);
    }

    public void setMeans(int n, double[] dArray) {
        this.means1.setParameterValue(n, dArray[0]);
        this.means2.setParameterValue(n, dArray[1]);
    }

    @Override
    public double getLogLikelihood() {
        if (!this.logLikelihoodKnown) {
            this.logLikelihood = this.computeLogLikelihood();
        }
        return this.logLikelihood;
    }

    public double computeLogLikelihood() {
        int n;
        if (!this.depMatrixKnown) {
            this.setDepMatrix();
        }
        double d = 0.0;
        for (n = 0; n < this.logLikelihoodsVector.length; ++n) {
            if (!this.logLikelihoodsVectorKnown[n]) {
                this.logLikelihoodsVector[n] = this.getLogLikGroup(n);
            }
            d += this.logLikelihoodsVector[n];
        }
        for (n = 0; n < this.links.getDimension(); ++n) {
            d = this.links.getParameterValue(n) == (double)n ? (d += Math.log(this.alpha.getParameterValue(0))) : (d += Math.log(this.depMatrix[n][(int)this.links.getParameterValue(n)]));
            double d2 = 0.0;
            for (int i = 0; i < this.numdata; ++i) {
                if (i == n) continue;
                d2 += this.depMatrix[i][n];
            }
            d -= Math.log(this.alpha.getParameterValue(0) + d2);
        }
        this.logLikelihoodKnown = true;
        return d;
    }

    void recursion(NodeRef nodeRef, List list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        if (!this.treeModel.isExternal(nodeRef)) {
            this.recursion(this.treeModel.getChild(nodeRef, 0), arrayList);
            this.recursion(this.treeModel.getChild(nodeRef, 1), arrayList2);
            double d = this.treeModel.getBranchLength(this.treeModel.getChild(nodeRef, 0));
            double d2 = this.treeModel.getBranchLength(this.treeModel.getChild(nodeRef, 1));
            HashSet<NodeRef> hashSet = new HashSet<NodeRef>();
            hashSet.addAll(this.allTips);
            for (Iterator iterator : arrayList) {
                hashSet.remove(iterator);
            }
            HashSet hashSet2 = new HashSet();
            hashSet2.addAll(this.allTips);
            for (NodeRef nodeRef2 : arrayList2) {
                hashSet2.remove(nodeRef2);
            }
            for (NodeRef nodeRef2 : arrayList) {
                for (NodeRef nodeRef3 : hashSet) {
                    double[] dArray = this.depMatrix[nodeRef3.getNumber()];
                    int n = nodeRef2.getNumber();
                    dArray[n] = dArray[n] + d;
                    double[] dArray2 = this.depMatrix[nodeRef2.getNumber()];
                    int n2 = nodeRef3.getNumber();
                    dArray2[n2] = dArray2[n2] + d;
                }
            }
            for (NodeRef nodeRef2 : arrayList2) {
                Iterator iterator = hashSet2.iterator();
                while (iterator.hasNext()) {
                    NodeRef nodeRef3;
                    nodeRef3 = (NodeRef)iterator.next();
                    double[] dArray = this.depMatrix[nodeRef3.getNumber()];
                    int n = nodeRef2.getNumber();
                    dArray[n] = dArray[n] + d2;
                    double[] dArray3 = this.depMatrix[nodeRef2.getNumber()];
                    int n3 = nodeRef3.getNumber();
                    dArray3[n3] = dArray3[n3] + d2;
                }
            }
            list.addAll(arrayList);
            list.addAll(arrayList2);
        } else {
            list.add(nodeRef);
        }
    }

    void logCorrectMatrix(double d) {
        for (int i = 0; i < this.numdata; ++i) {
            for (int j = 0; j < i; ++j) {
                this.depMatrix[i][j] = 1.0 / Math.pow(this.depMatrix[i][j], d);
                this.depMatrix[j][i] = this.depMatrix[i][j];
            }
        }
    }

    public double getTreeDist(int n, int n2) {
        double d = 0.0;
        NodeRef nodeRef = this.findMRCA(n, n2);
        NodeRef nodeRef2 = this.treeModel.getExternalNode(n);
        while (nodeRef2 != nodeRef) {
            d += this.treeModel.getBranchLength(nodeRef2);
            nodeRef2 = this.treeModel.getParent(nodeRef2);
        }
        nodeRef2 = this.treeModel.getExternalNode(n2);
        while (nodeRef2 != nodeRef) {
            d += this.treeModel.getBranchLength(nodeRef2);
            nodeRef2 = this.treeModel.getParent(nodeRef2);
        }
        return d;
    }

    private NodeRef findMRCA(int n, int n2) {
        HashSet<String> hashSet = new HashSet<String>();
        hashSet.add(this.treeModel.getTaxonId(n));
        hashSet.add(this.treeModel.getTaxonId(n2));
        return TreeUtils.getCommonAncestorNode(this.treeModel, hashSet);
    }

    public void printInformtion(double[][] dArray) {
        StringBuffer stringBuffer = new StringBuffer("matrix \n");
        for (int i = 0; i < this.numdata; ++i) {
            stringBuffer.append(" \n");
            for (int j = 0; j < this.numdata; ++j) {
                stringBuffer.append(dArray[i][j] + " \t");
            }
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printInformation(Parameter parameter) {
        StringBuffer stringBuffer = new StringBuffer("Vector \n");
        for (int i = 0; i < this.numdata; ++i) {
            stringBuffer.append(parameter.getParameterValue(i) + " \t");
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printInformation(int[] nArray) {
        StringBuffer stringBuffer = new StringBuffer("Vector \n");
        for (int i = 0; i < this.numdata; ++i) {
            stringBuffer.append(nArray[i] + " \t");
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printOrder() {
        StringBuffer stringBuffer = new StringBuffer("taxa \n");
        for (int i = 0; i < this.numdata; ++i) {
            stringBuffer.append(" \n");
            stringBuffer.append(this.treeModel.getTaxonId(i));
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printInformation(double d) {
        StringBuffer stringBuffer = new StringBuffer("Info \n");
        stringBuffer.append(d);
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printInformation(String string) {
        StringBuffer stringBuffer = new StringBuffer("Info \n");
        stringBuffer.append(string);
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printInformation(String string, String string2) {
        StringBuffer stringBuffer = new StringBuffer("Info \n");
        stringBuffer.append(string + " and " + string2);
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    @Override
    protected void storeState() {
        System.arraycopy(this.logLikelihoodsVector, 0, this.storedLogLikelihoodsVector, 0, this.logLikelihoodsVector.length);
    }

    @Override
    protected void restoreState() {
        double[] dArray = this.logLikelihoodsVector;
        this.logLikelihoodsVector = this.storedLogLikelihoodsVector;
        this.storedLogLikelihoodsVector = dArray;
        this.depMatrixKnown = !this.proposedChangeDepMatrix;
        this.proposedChangeDepMatrix = false;
        this.logLikelihoodKnown = false;
    }

    @Override
    public void makeDirty() {
    }

    @Override
    public void acceptState() {
        this.proposedChangeDepMatrix = false;
        this.proposedChangeDataMatrix = false;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.treeModel) {
            this.depMatrixKnown = false;
        }
        this.logLikelihoodKnown = false;
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.logLikelihoodKnown = false;
        if (variable == this.transformFactor) {
            this.depMatrixKnown = false;
            this.proposedChangeDepMatrix = true;
        }
        if (variable == this.traitParameter) {
            int n2 = n / 2;
            int n3 = (int)this.assignments.getParameterValue(n2);
            this.logLikelihoodsVectorKnown[n3] = false;
        }
    }
}

