/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.tree;

import dr.evolution.MetagenomeData;
import dr.evolution.alignment.Alignment;
import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.AminoAcids;
import dr.evolution.datatype.Codons;
import dr.evolution.datatype.DataType;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.util.Taxon;
import dr.evolution.util.TaxonList;
import dr.evomodel.tipstatesmodel.TipStatesModel;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.math.MathUtils;
import dr.oldevomodel.substmodel.SubstitutionModel;
import dr.oldevomodel.treelikelihood.GeneralLikelihoodCore;
import dr.oldevomodel.treelikelihood.LikelihoodCore;
import dr.oldevomodel.treelikelihood.NativeAminoAcidLikelihoodCore;
import dr.oldevomodel.treelikelihood.NativeNucleotideLikelihoodCore;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

public class HiddenLinkageModel
extends TipStatesModel
implements PatternList {
    int linkageGroupCount = 0;
    ArrayList<HashSet<Taxon>> groups = null;
    MetagenomeData data = null;
    ArrayList<Taxon> alignmentTaxa;
    double[][] tipPartials;
    double[][] storedTipPartials;
    boolean[] dirtyTipPartials;
    LikelihoodCore core;
    double blen = 0.001;
    SubstitutionModel substitutionModel;
    double[] tipMatrix;
    double[] internalMatrix;
    ArrayList<Move> movesMade = new ArrayList();
    int[] nodeIdToMyTaxaMap;

    public HiddenLinkageModel(int n, MetagenomeData metagenomeData) {
        super("HiddenLinkageModel", metagenomeData.getReferenceTaxa(), metagenomeData.getReadsTaxa());
        int n2;
        int n3;
        this.linkageGroupCount = n;
        this.data = metagenomeData;
        this.groups = new ArrayList(n);
        for (int i = 0; i < n; ++i) {
            this.groups.add(new HashSet());
        }
        TaxonList taxonList = metagenomeData.getReadsTaxa();
        for (n3 = 0; n3 < taxonList.getTaxonCount(); ++n3) {
            n2 = MathUtils.nextInt(n);
            this.groups.get(n2).add(taxonList.getTaxon(n3));
        }
        this.alignmentTaxa = new ArrayList<Taxon>(metagenomeData.getReferenceTaxa().asList());
        for (n3 = 0; n3 < n; ++n3) {
            this.alignmentTaxa.add(new Taxon("LinkageGroup_" + n3));
        }
        n3 = metagenomeData.getAlignment().getSiteCount() * metagenomeData.getAlignment().getStateCount();
        this.tipPartials = new double[this.alignmentTaxa.size()][n3];
        this.storedTipPartials = new double[this.alignmentTaxa.size()][n3];
        this.dirtyTipPartials = new boolean[this.alignmentTaxa.size()];
        this.initCore();
        this.setupMatrices();
        for (n2 = 0; n2 < this.tipPartials.length; ++n2) {
            this.computeTipPartials(n2);
        }
    }

    @Override
    public boolean areUnique() {
        return false;
    }

    @Override
    public boolean areUncertain() {
        return false;
    }

    private void initCore() {
        int n;
        if (this.data.getAlignment().getDataType() instanceof Nucleotides) {
            this.core = new NativeNucleotideLikelihoodCore();
        }
        if (this.data.getAlignment().getDataType() instanceof AminoAcids) {
            this.core = new NativeAminoAcidLikelihoodCore();
        }
        if (this.data.getAlignment().getDataType() instanceof Codons) {
            this.core = new GeneralLikelihoodCore(this.data.getAlignment().getStateCount());
        }
        this.core.initialize(this.data.getReadsTaxa().getTaxonCount() * 2, this.data.getAlignment().getSiteCount(), 1, false);
        for (n = 0; n < this.data.getReadsTaxa().getTaxonCount(); ++n) {
            int n2 = this.data.getAlignment().getTaxonIndex(this.data.getReadsTaxa().getTaxon(n));
            int[] nArray = new int[this.data.getAlignment().getSiteCount()];
            for (int i = 0; i < nArray.length; ++i) {
                nArray[i] = this.data.getAlignment().getState(n2, i);
            }
            this.core.setNodeStates(n, nArray);
        }
        for (n = 0; n < this.data.getReadsTaxa().getTaxonCount(); ++n) {
            this.core.createNodePartials(n + this.data.getReadsTaxa().getTaxonCount());
        }
    }

    private void setupMatrices() {
        int n;
        this.tipMatrix = new double[this.data.getAlignment().getStateCount() * this.data.getAlignment().getStateCount()];
        this.internalMatrix = new double[this.data.getAlignment().getStateCount() * this.data.getAlignment().getStateCount()];
        double d = 1.0 - this.blen;
        double d2 = this.blen / (double)(this.data.getAlignment().getStateCount() - 1);
        double d3 = 0.99999999999999;
        double d4 = (1.0 - d3) / (double)(this.data.getAlignment().getStateCount() - 1);
        for (n = 0; n < this.tipMatrix.length; ++n) {
            this.tipMatrix[n] = d2;
            this.internalMatrix[n] = d4;
        }
        for (n = 0; n < this.data.getAlignment().getStateCount(); ++n) {
            this.tipMatrix[n * this.data.getAlignment().getStateCount() + n] = d;
            this.internalMatrix[n * this.data.getAlignment().getStateCount() + n] = d3;
        }
        for (n = 0; n < this.data.getReadsTaxa().getTaxonCount(); ++n) {
            this.core.setNodeMatrix(n, 0, this.tipMatrix);
        }
        for (n = 0; n < this.data.getReadsTaxa().getTaxonCount(); ++n) {
            this.core.setNodeMatrix(n + this.data.getReadsTaxa().getTaxonCount(), 0, this.internalMatrix);
        }
    }

    public int getLinkageGroupCount() {
        return this.linkageGroupCount;
    }

    public MetagenomeData getData() {
        return this.data;
    }

    public int getLinkageGroupId(Taxon taxon) {
        int n = 0;
        for (HashSet<Taxon> hashSet : this.groups) {
            if (hashSet.contains(taxon)) break;
            ++n;
        }
        return n;
    }

    @Override
    protected void acceptState() {
        this.movesMade.clear();
        for (int i = 0; i < this.dirtyTipPartials.length; ++i) {
            this.dirtyTipPartials[i] = false;
        }
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
    }

    @Override
    protected void restoreState() {
        int n;
        for (n = this.movesMade.size(); n > 0; --n) {
            Move move = this.movesMade.get(n - 1);
            this.groups.get(move.toGroup).remove(move.read);
            this.groups.get(move.fromGroup).add(move.read);
        }
        this.movesMade.clear();
        for (n = 0; n < this.dirtyTipPartials.length; ++n) {
            if (!this.dirtyTipPartials[n]) continue;
            this.swapTipPartials(n);
            this.dirtyTipPartials[n] = false;
        }
    }

    @Override
    protected void storeState() {
        this.movesMade.clear();
        for (int i = 0; i < this.dirtyTipPartials.length; ++i) {
            this.dirtyTipPartials[i] = false;
        }
    }

    public Set<Taxon> getGroup(int n) {
        return this.groups.get(n);
    }

    public void moveReadGroup(Taxon taxon, int n, int n2) {
        boolean bl = this.groups.get(n).remove(taxon);
        if (!bl) {
            throw new RuntimeException("Error, could not find read " + taxon + " in linkage group " + n);
        }
        this.groups.get(n2).add(taxon);
        this.movesMade.add(new Move(taxon, n, n2));
        this.computeTipPartials(this.data.getReferenceTaxa().getTaxonCount() + n);
        this.computeTipPartials(this.data.getReferenceTaxa().getTaxonCount() + n2);
        this.fireModelChanged(this.alignmentTaxa.get(this.alignmentTaxa.size() - this.groups.size() + n));
        this.fireModelChanged(this.alignmentTaxa.get(this.alignmentTaxa.size() - this.groups.size() + n2));
    }

    private void swapTipPartials(int n) {
        double[] dArray = this.storedTipPartials[n];
        this.storedTipPartials[n] = this.tipPartials[n];
        this.tipPartials[n] = dArray;
    }

    private void computeTipPartials(int n) {
        int n2;
        if (!this.dirtyTipPartials[n]) {
            this.swapTipPartials(n);
            this.dirtyTipPartials[n] = true;
        }
        double[] dArray = this.tipPartials[n];
        Alignment alignment = this.data.getAlignment();
        int n3 = alignment.getStateCount();
        for (n2 = 0; n2 < dArray.length; ++n2) {
            dArray[n2] = 0.0;
        }
        if (n < this.data.getReferenceTaxa().getTaxonCount()) {
            n2 = 0;
            for (int i = 0; i < alignment.getSiteCount(); ++i) {
                int n4 = alignment.getState(n, i);
                if (n4 >= n3) {
                    for (int j = 0; j < n3; ++j) {
                        dArray[n2 + j] = 1.0;
                    }
                } else {
                    dArray[n2 + n4] = 1.0;
                }
                n2 += n3;
            }
        } else {
            n2 = n - this.data.getReferenceTaxa().getTaxonCount();
            HashSet<Taxon> hashSet = this.groups.get(n2);
            int n5 = this.data.getReadsTaxa().getTaxonCount();
            Taxon taxon = null;
            boolean bl = false;
            for (Taxon taxon2 : hashSet) {
                if (taxon == null) {
                    taxon = taxon2;
                    continue;
                }
                int n6 = this.data.getReadsTaxa().getTaxonIndex(taxon2);
                if (!bl) {
                    int n7 = this.data.getReadsTaxa().getTaxonIndex(taxon);
                    this.core.setNodePartialsForUpdate(n5);
                    this.core.calculatePartials(n7, n6, n5);
                } else {
                    this.core.setNodePartialsForUpdate(n5);
                    this.core.calculatePartials(n5 - 1, n6, n5);
                }
                ++n5;
                bl = true;
            }
            if (hashSet.size() == 0) {
                for (int i = 0; i < dArray.length; ++i) {
                    dArray[i] = 1.0;
                }
            } else if (!bl) {
                this.getPartialsForGroupSizeOne(taxon, dArray);
            } else {
                this.core.getPartials(n5 - 1, dArray);
            }
        }
    }

    private void getPartialsForGroupSizeOne(Taxon taxon, double[] dArray) {
        Alignment alignment = this.data.getAlignment();
        int n = alignment.getStateCount();
        int n2 = alignment.getTaxonIndex(taxon);
        int n3 = 0;
        for (int i = 0; i < alignment.getSiteCount(); ++i) {
            int n4 = alignment.getState(n2, i);
            if (n4 >= n) {
                for (int j = 0; j < n; ++j) {
                    dArray[n3 + j] = 1.0;
                }
            } else {
                System.arraycopy(this.internalMatrix, n4 * n, dArray, n3, n);
            }
            n3 += n;
        }
    }

    public int newGroup() {
        throw new RuntimeException("Not implemented!");
    }

    public void deleteGroup() {
        throw new RuntimeException("Not implemented!");
    }

    @Override
    public TipStatesModel.Type getModelType() {
        return TipStatesModel.Type.PARTIALS;
    }

    @Override
    public void getTipStates(int n, int[] nArray) {
        throw new IllegalArgumentException("This model emits only tip partials");
    }

    @Override
    public void getTipPartials(int n, double[] dArray) {
        int n2 = this.nodeIdToMyTaxaMap[this.tree.getNode(n).getNumber()];
        System.arraycopy(this.tipPartials[n2], 0, dArray, 0, dArray.length);
    }

    @Override
    protected void taxaChanged() {
        this.nodeIdToMyTaxaMap = new int[this.tree.getNodeCount()];
        block0: for (int i = 0; i < this.nodeIdToMyTaxaMap.length; ++i) {
            for (int j = 0; j < this.alignmentTaxa.size(); ++j) {
                if (this.tree.getTaxon(i) == null) continue;
                if (this.tree.getTaxon(i) == null || this.alignmentTaxa.get(j) == null) {
                    System.err.print("asdgasdg\n");
                } else if (this.tree.getTaxon(i).getId() == null || this.alignmentTaxa.get(j).getId() == null) {
                    System.err.print("asdgasdg\n");
                }
                if (!this.tree.getTaxon(i).getId().equalsIgnoreCase(this.alignmentTaxa.get(j).getId())) continue;
                this.nodeIdToMyTaxaMap[this.tree.getExternalNode((int)i).getNumber()] = j;
                continue block0;
            }
        }
    }

    @Override
    public DataType getDataType() {
        return this.data.getAlignment().getDataType();
    }

    @Override
    public int[] getPattern(int n) {
        return this.data.getAlignment().getPattern(n);
    }

    @Override
    public double[][] getUncertainPattern(int n) {
        return new double[0][];
    }

    @Override
    public int getPatternCount() {
        return this.data.getAlignment().getPatternCount();
    }

    @Override
    public int getPatternLength() {
        return this.data.getAlignment().getPatternLength();
    }

    @Override
    public int getPatternState(int n, int n2) {
        if (n < this.data.getReferenceTaxa().getTaxonCount()) {
            return this.data.getAlignment().getPatternState(n, n2);
        }
        return 0;
    }

    @Override
    public double[] getUncertainPatternState(int n, int n2) {
        return new double[0];
    }

    @Override
    public double getPatternWeight(int n) {
        return this.data.getAlignment().getPatternWeight(n);
    }

    @Override
    public double[] getPatternWeights() {
        return this.data.getAlignment().getPatternWeights();
    }

    @Override
    public int getPatternIndex(int n) {
        return -1;
    }

    @Override
    public int getStateCount() {
        return this.data.getAlignment().getStateCount();
    }

    @Override
    public double[] getStateFrequencies() {
        return this.data.getAlignment().getStateFrequencies();
    }

    @Override
    public List<Taxon> asList() {
        return this.alignmentTaxa;
    }

    @Override
    public Taxon getTaxon(int n) {
        return this.alignmentTaxa.get(n);
    }

    @Override
    public Object getTaxonAttribute(int n, String string) {
        return this.alignmentTaxa.get(n).getAttribute(string);
    }

    @Override
    public int getTaxonCount() {
        return this.alignmentTaxa.size();
    }

    @Override
    public String getTaxonId(int n) {
        return this.alignmentTaxa.get(n).getId();
    }

    @Override
    public int getTaxonIndex(String string) {
        for (int i = 0; i < this.alignmentTaxa.size(); ++i) {
            if (!this.alignmentTaxa.get(i).getId().equals(string)) continue;
            return i;
        }
        return -1;
    }

    @Override
    public int getTaxonIndex(Taxon taxon) {
        for (int i = 0; i < this.alignmentTaxa.size(); ++i) {
            if (this.alignmentTaxa.get(i).compareTo(taxon) != 0) continue;
            return i;
        }
        return -1;
    }

    @Override
    public Iterator<Taxon> iterator() {
        return this.alignmentTaxa.iterator();
    }

    private class Move {
        Taxon read;
        int fromGroup;
        int toGroup;

        public Move(Taxon taxon, int n, int n2) {
            this.read = taxon;
            this.fromGroup = n;
            this.toGroup = n2;
        }
    }
}

