package com.aliasi.crf;

import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.features.Features;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Matrices;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.tag.MarginalTagger;
import com.aliasi.tag.NBestTagger;
import com.aliasi.tag.ScoredTagging;
import com.aliasi.tag.TagLattice;
import com.aliasi.tag.Tagger;
import com.aliasi.tag.Tagging;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.BoundedPriorityQueue;
import com.aliasi.util.Iterators;
import com.aliasi.util.Math;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.Scored;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import com.sleepycat.asm.Opcodes;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Formatter;
import java.util.IllegalFormatException;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/crf/ChainCrf.class */
public class ChainCrf<E> implements Tagger<E>, NBestTagger<E>, MarginalTagger<E>, Serializable {
    static final long serialVersionUID = -4868542587460878290L;
    private final List<String> mTagList;
    private final boolean[] mLegalTagStarts;
    private final boolean[] mLegalTagEnds;
    private final boolean[][] mLegalTagTransitions;
    private final Vector[] mCoefficients;
    private final SymbolTable mFeatureSymbolTable;
    private final ChainCrfFeatureExtractor<E> mFeatureExtractor;
    private final boolean mAddInterceptFeature;
    private final int mNumDimensions;
    static final String INTERCEPT_FEATURE_NAME = "*&^INTERCEPT%$^&**";
    static final double[][] EMPTY_DOUBLE_2D_ARRAY = new double[0];
    static final double[][][] EMPTY_DOUBLE_3D_ARRAY = new double[0];

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/crf/ChainCrf$FeatureVectors.class */
    public static class FeatureVectors {
        final Vector[] mNodeFeatureVectors;
        final Vector[][] mEdgeFeatureVectorss;

        FeatureVectors(Vector[] vectorArr, Vector[][] vectorArr2) {
            this.mNodeFeatureVectors = vectorArr;
            this.mEdgeFeatureVectorss = vectorArr2;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/crf/ChainCrf$ForwardPointer.class */
    public static class ForwardPointer {
        final int mK;
        final ForwardPointer mPointer;
        final double mScore;

        ForwardPointer(int i, ForwardPointer forwardPointer, double d) {
            this.mK = i;
            this.mPointer = forwardPointer;
            this.mScore = d;
        }
    }

    /* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/crf/ChainCrf$NBestIterator.class */
    class NBestIterator extends Iterators.Buffered<ScoredTagging<E>> {
        final List<E> mTokens;
        final double mLogZ;
        final double[][][] mTransitionScores;
        final double[][] mViterbiScores;
        final int[][] mBackPointers;
        final BoundedPriorityQueue<NBestState> mPriorityQueue;

        NBestIterator(List<E> list, boolean z, int i) {
            this.mPriorityQueue = new BoundedPriorityQueue<>(ScoredObject.comparator(), i);
            this.mTokens = list;
            int size = list.size();
            int size2 = ChainCrf.this.mTagList.size();
            this.mTransitionScores = new double[size - 1][size2][size2];
            for (double[][] dArr : this.mTransitionScores) {
                for (double[] dArr2 : dArr) {
                    Arrays.fill(dArr2, Double.NEGATIVE_INFINITY);
                }
            }
            this.mViterbiScores = new double[size][size2];
            for (double[] dArr3 : this.mViterbiScores) {
                Arrays.fill(dArr3, Double.NEGATIVE_INFINITY);
            }
            this.mBackPointers = new int[size - 1][size2];
            for (int[] iArr : this.mBackPointers) {
                Arrays.fill(iArr, -1);
            }
            Vector[] vectorArr = new Vector[size2];
            ChainCrfFeatures<E> extract = ChainCrf.this.mFeatureExtractor.extract(list, ChainCrf.this.mTagList);
            for (int i2 = 1; i2 < size; i2++) {
                Vector nodeFeatures = ChainCrf.this.nodeFeatures(i2, extract);
                for (int i3 = 0; i3 < size2; i3++) {
                    if (i2 != 1 || ChainCrf.this.mLegalTagStarts[i3]) {
                        vectorArr[i3] = ChainCrf.this.edgeFeatures(i2, i3, extract);
                    }
                }
                for (int i4 = 0; i4 < size2; i4++) {
                    if (i2 != size - 1 || ChainCrf.this.mLegalTagEnds[i4]) {
                        double dotProduct = nodeFeatures.dotProduct(ChainCrf.this.mCoefficients[i4]);
                        for (int i5 = 0; i5 < size2; i5++) {
                            if (ChainCrf.this.mLegalTagTransitions[i5][i4] && (i2 != 1 || ChainCrf.this.mLegalTagStarts[i5])) {
                                this.mTransitionScores[i2 - 1][i5][i4] = dotProduct + vectorArr[i5].dotProduct(ChainCrf.this.mCoefficients[i4]);
                            }
                        }
                    }
                }
            }
            Vector nodeFeatures2 = ChainCrf.this.nodeFeatures(0, extract);
            for (int i6 = 0; i6 < size2; i6++) {
                if (ChainCrf.this.mLegalTagStarts[i6]) {
                    this.mViterbiScores[0][i6] = nodeFeatures2.dotProduct(ChainCrf.this.mCoefficients[i6]);
                }
            }
            for (int i7 = 1; i7 < size; i7++) {
                for (int i8 = 0; i8 < size2; i8++) {
                    if (i7 != size - 1 || ChainCrf.this.mLegalTagEnds[i8]) {
                        double d = Double.NEGATIVE_INFINITY;
                        int i9 = -1;
                        for (int i10 = 0; i10 < size2; i10++) {
                            if (ChainCrf.this.mLegalTagTransitions[i10][i8]) {
                                double d2 = this.mViterbiScores[i7 - 1][i10] + this.mTransitionScores[i7 - 1][i10][i8];
                                if (d2 > d) {
                                    d = d2;
                                    i9 = i10;
                                }
                            }
                        }
                        this.mViterbiScores[i7][i8] = d;
                        this.mBackPointers[i7 - 1][i8] = i9;
                    }
                }
            }
            this.mLogZ = z ? logZ() : KStarConstants.FLOOR;
            for (int i11 = 0; i11 < size2; i11++) {
                offer(this.mViterbiScores[size - 1][i11], null, size - 1, i11);
            }
        }

        double logZ() {
            double[] dArr = (double[]) this.mViterbiScores[0].clone();
            int length = dArr.length;
            double[] dArr2 = new double[length];
            double[] dArr3 = new double[length];
            for (int i = 0; i < this.mTransitionScores.length; i++) {
                double[] dArr4 = dArr2;
                dArr2 = dArr;
                dArr = dArr4;
                for (int i2 = 0; i2 < length; i2++) {
                    for (int i3 = 0; i3 < length; i3++) {
                        dArr3[i3] = dArr2[i3] + this.mTransitionScores[i][i3][i2];
                    }
                    dArr[i2] = Math.logSumOfExponentials(dArr3);
                }
            }
            return Math.logSumOfExponentials(dArr);
        }

        void offer(double d, ForwardPointer forwardPointer, int i, int i2) {
            if (d == Double.NEGATIVE_INFINITY) {
                return;
            }
            if (forwardPointer == null || forwardPointer.mScore != Double.NEGATIVE_INFINITY) {
                this.mPriorityQueue.offer(new NBestState(d, forwardPointer, i, i2));
            }
        }

        @Override // com.aliasi.util.Iterators.Buffered
        public ScoredTagging<E> bufferNext() {
            NBestState poll = this.mPriorityQueue.poll();
            if (poll == null) {
                return null;
            }
            int i = poll.mK;
            ForwardPointer forwardPointer = poll.mForwardPointer;
            for (int i2 = poll.mN - 1; i2 >= 0; i2--) {
                addAlternatives(i2, i, forwardPointer);
                int i3 = this.mBackPointers[i2][i];
                double d = this.mTransitionScores[i2][i3][i];
                if (forwardPointer != null) {
                    d += forwardPointer.mScore;
                }
                forwardPointer = new ForwardPointer(i, forwardPointer, d);
                i = i3;
            }
            return toScoredTagging(poll);
        }

        void addAlternatives(int i, int i2, ForwardPointer forwardPointer) {
            int size = ChainCrf.this.mTagList.size();
            for (int i3 = 0; i3 < size; i3++) {
                if (i3 != this.mBackPointers[i][i2]) {
                    double d = this.mViterbiScores[i][i3];
                    double d2 = this.mTransitionScores[i][i3][i2];
                    if (forwardPointer != null) {
                        d2 += forwardPointer.mScore;
                    }
                    offer(d, new ForwardPointer(i2, forwardPointer, d2), i, i3);
                }
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        public ScoredTagging<E> toScoredTagging(NBestState nBestState) {
            ArrayList arrayList = new ArrayList(this.mTokens.size());
            int i = nBestState.mK;
            arrayList.add(ChainCrf.this.mTagList.get(i));
            for (int i2 = nBestState.mN; i2 > 0; i2--) {
                i = this.mBackPointers[i2 - 1][i];
                arrayList.add(ChainCrf.this.mTagList.get(i));
            }
            Collections.reverse(arrayList);
            ForwardPointer forwardPointer = nBestState.mForwardPointer;
            while (true) {
                ForwardPointer forwardPointer2 = forwardPointer;
                if (forwardPointer2 == null) {
                    return new ScoredTagging<>(this.mTokens, arrayList, nBestState.score() - this.mLogZ);
                }
                arrayList.add(ChainCrf.this.mTagList.get(forwardPointer2.mK));
                forwardPointer = forwardPointer2.mPointer;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/crf/ChainCrf$NBestState.class */
    public static class NBestState implements Scored {
        final double mScore;
        final ForwardPointer mForwardPointer;
        final int mN;
        final int mK;

        NBestState(double d, ForwardPointer forwardPointer, int i, int i2) {
            this.mScore = d;
            this.mForwardPointer = forwardPointer;
            this.mN = i;
            this.mK = i2;
        }

        @Override // com.aliasi.util.Scored
        public double score() {
            return this.mForwardPointer != null ? this.mScore + this.mForwardPointer.mScore : this.mScore;
        }
    }

    /* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/crf/ChainCrf$Serializer.class */
    static class Serializer<F> extends AbstractExternalizable {
        static final long serialVersionUID = -4140295941325870709L;
        final ChainCrf<F> mCrf;

        public Serializer(ChainCrf<F> chainCrf) {
            this.mCrf = chainCrf;
        }

        public Serializer() {
            this(null);
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            int size = ((ChainCrf) this.mCrf).mTagList.size();
            objectOutput.writeInt(size);
            Iterator<E> it = ((ChainCrf) this.mCrf).mTagList.iterator();
            while (it.hasNext()) {
                objectOutput.writeUTF((String) it.next());
            }
            for (int i = 0; i < size; i++) {
                objectOutput.writeBoolean(((ChainCrf) this.mCrf).mLegalTagStarts[i]);
            }
            for (int i2 = 0; i2 < size; i2++) {
                objectOutput.writeBoolean(((ChainCrf) this.mCrf).mLegalTagEnds[i2]);
            }
            for (int i3 = 0; i3 < size; i3++) {
                for (int i4 = 0; i4 < size; i4++) {
                    objectOutput.writeBoolean(((ChainCrf) this.mCrf).mLegalTagTransitions[i3][i4]);
                }
            }
            for (Vector vector : ((ChainCrf) this.mCrf).mCoefficients) {
                objectOutput.writeObject(vector);
            }
            objectOutput.writeObject(((ChainCrf) this.mCrf).mFeatureSymbolTable);
            objectOutput.writeObject(((ChainCrf) this.mCrf).mFeatureExtractor);
            objectOutput.writeBoolean(((ChainCrf) this.mCrf).mAddInterceptFeature);
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws ClassNotFoundException, IOException {
            int readInt = objectInput.readInt();
            String[] strArr = new String[readInt];
            for (int i = 0; i < strArr.length; i++) {
                strArr[i] = objectInput.readUTF();
            }
            boolean[] zArr = new boolean[readInt];
            for (int i2 = 0; i2 < readInt; i2++) {
                zArr[i2] = objectInput.readBoolean();
            }
            boolean[] zArr2 = new boolean[readInt];
            for (int i3 = 0; i3 < readInt; i3++) {
                zArr2[i3] = objectInput.readBoolean();
            }
            boolean[][] zArr3 = new boolean[readInt][readInt];
            for (int i4 = 0; i4 < readInt; i4++) {
                for (int i5 = 0; i5 < readInt; i5++) {
                    zArr3[i4][i5] = objectInput.readBoolean();
                }
            }
            Vector[] vectorArr = new Vector[readInt];
            for (int i6 = 0; i6 < strArr.length; i6++) {
                vectorArr[i6] = (Vector) objectInput.readObject();
            }
            return new ChainCrf(strArr, zArr, zArr2, zArr3, vectorArr, (SymbolTable) objectInput.readObject(), (ChainCrfFeatureExtractor) objectInput.readObject(), objectInput.readBoolean());
        }
    }

    public ChainCrf(String[] strArr, Vector[] vectorArr, SymbolTable symbolTable, ChainCrfFeatureExtractor<E> chainCrfFeatureExtractor, boolean z) {
        this(strArr, trueArray(strArr.length), trueArray(strArr.length), trueArray(strArr.length, strArr.length), vectorArr, symbolTable, chainCrfFeatureExtractor, z);
    }

    public ChainCrf(String[] strArr, boolean[] zArr, boolean[] zArr2, boolean[][] zArr3, Vector[] vectorArr, SymbolTable symbolTable, ChainCrfFeatureExtractor<E> chainCrfFeatureExtractor, boolean z) {
        if (strArr.length < 1) {
        }
        if (strArr.length != vectorArr.length) {
            throw new IllegalArgumentException("Require tags and coefficients to be same length. Found tags.length=" + strArr.length + " coefficients.length=" + vectorArr.length);
        }
        if (strArr.length != zArr.length) {
            throw new IllegalArgumentException("Tags and starts must be same length. Found tags.length=" + strArr.length + " legalTagStarts.length=" + zArr.length);
        }
        if (strArr.length != zArr2.length) {
            throw new IllegalArgumentException("Tags and starts must be same length. Found tags.length=" + strArr.length + " legalTagStarts.length=" + zArr.length);
        }
        if (strArr.length != zArr3.length) {
            throw new IllegalArgumentException("Tags and transitions must be same length. Found tags.length=" + strArr.length + " legalTagTransitions.length=" + zArr3.length);
        }
        for (int i = 0; i < zArr3.length; i++) {
            if (strArr.length != zArr3[i].length) {
                throw new IllegalArgumentException("Tags and transition rows must be same length. Found tags.length=" + strArr.length + " legalTagTransitions[" + i + "].length=" + zArr3[i].length);
            }
        }
        for (int i2 = 1; i2 < vectorArr.length; i2++) {
            if (vectorArr[0].numDimensions() != vectorArr[i2].numDimensions()) {
                throw new IllegalArgumentException("All coefficients must be same length. Found coefficents[0].numDimensions()=" + vectorArr[0].numDimensions() + " coefficients[" + i2 + "].numDimensions()=" + vectorArr[i2].numDimensions());
            }
        }
        this.mTagList = Arrays.asList(strArr);
        this.mLegalTagStarts = zArr;
        this.mLegalTagEnds = zArr2;
        this.mLegalTagTransitions = zArr3;
        this.mCoefficients = vectorArr;
        this.mNumDimensions = vectorArr[0].numDimensions();
        this.mFeatureSymbolTable = symbolTable;
        this.mFeatureExtractor = chainCrfFeatureExtractor;
        this.mAddInterceptFeature = z;
    }

    public List<String> tags() {
        return Collections.unmodifiableList(this.mTagList);
    }

    public String tag(int i) {
        return this.mTagList.get(i);
    }

    public Vector[] coefficients() {
        Vector[] vectorArr = new Vector[this.mCoefficients.length];
        for (int i = 0; i < vectorArr.length; i++) {
            vectorArr[i] = Matrices.unmodifiableVector(this.mCoefficients[i]);
        }
        return vectorArr;
    }

    public SymbolTable featureSymbolTable() {
        return MapSymbolTable.unmodifiableView(this.mFeatureSymbolTable);
    }

    public ChainCrfFeatureExtractor<E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public boolean addInterceptFeature() {
        return this.mAddInterceptFeature;
    }

    @Override // com.aliasi.tag.Tagger
    public Tagging<E> tag(List<E> list) {
        int size = list.size();
        if (size == 0) {
            return new Tagging<>(list, Collections.emptyList());
        }
        int size2 = this.mTagList.size();
        this.mFeatureSymbolTable.numSymbols();
        double[][] dArr = new double[size][size2];
        int[][] iArr = new int[size - 1][size2];
        ChainCrfFeatures<E> extract = this.mFeatureExtractor.extract(list, this.mTagList);
        Vector nodeFeatures = nodeFeatures(0, extract);
        for (int i = 0; i < size2; i++) {
            dArr[0][i] = this.mLegalTagStarts[i] ? nodeFeatures.dotProduct(this.mCoefficients[i]) : Double.NEGATIVE_INFINITY;
        }
        Vector[] vectorArr = new Vector[size2];
        for (int i2 = 1; i2 < size; i2++) {
            Vector nodeFeatures2 = nodeFeatures(i2, extract);
            for (int i3 = 0; i3 < size2; i3++) {
                vectorArr[i3] = edgeFeatures(i2, i3, extract);
            }
            for (int i4 = 0; i4 < size2; i4++) {
                if (i2 != size - 1 || this.mLegalTagEnds[i4]) {
                    double d = Double.NEGATIVE_INFINITY;
                    int i5 = -1;
                    double dotProduct = nodeFeatures2.dotProduct(this.mCoefficients[i4]);
                    for (int i6 = 0; i6 < size2; i6++) {
                        if (this.mLegalTagTransitions[i6][i4]) {
                            double dotProduct2 = dotProduct + vectorArr[i6].dotProduct(this.mCoefficients[i4]) + dArr[i2 - 1][i6];
                            if (dotProduct2 > d) {
                                d = dotProduct2;
                                i5 = i6;
                            }
                        }
                    }
                    dArr[i2][i4] = d;
                    iArr[i2 - 1][i4] = i5;
                } else {
                    dArr[i2][i4] = Double.NEGATIVE_INFINITY;
                    iArr[i2 - 1][i4] = -1;
                }
            }
        }
        double d2 = Double.NEGATIVE_INFINITY;
        int i7 = -1;
        for (int i8 = 0; i8 < size2; i8++) {
            if (dArr[size - 1][i8] > d2) {
                d2 = dArr[size - 1][i8];
                i7 = i8;
            }
        }
        ArrayList arrayList = new ArrayList(size);
        int i9 = i7;
        arrayList.add(this.mTagList.get(i7));
        int i10 = size - 1;
        while (true) {
            i10--;
            if (i10 < 0) {
                Collections.reverse(arrayList);
                return new Tagging<>(list, arrayList);
            }
            i9 = iArr[i10][i9];
            arrayList.add(this.mTagList.get(i9));
        }
    }

    @Override // com.aliasi.tag.NBestTagger
    public Iterator<ScoredTagging<E>> tagNBest(List<E> list, int i) {
        return list.size() == 0 ? Iterators.singleton(new ScoredTagging(list, Collections.emptyList(), KStarConstants.FLOOR)) : new NBestIterator(list, false, i);
    }

    @Override // com.aliasi.tag.NBestTagger
    public Iterator<ScoredTagging<E>> tagNBestConditional(List<E> list, int i) {
        return list.size() == 0 ? Iterators.singleton(new ScoredTagging(list, Collections.emptyList(), KStarConstants.FLOOR)) : new NBestIterator(list, true, i);
    }

    @Override // com.aliasi.tag.MarginalTagger
    public TagLattice<E> tagMarginal(List<E> list) {
        return list.size() == 0 ? new ForwardBackwardTagLattice(list, this.mTagList, EMPTY_DOUBLE_2D_ARRAY, EMPTY_DOUBLE_2D_ARRAY, EMPTY_DOUBLE_3D_ARRAY, KStarConstants.FLOOR) : forwardBackward(list, features(list));
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Feature Extractor=" + featureExtractor());
        sb.append("\n");
        sb.append("Add intercept=" + addInterceptFeature());
        sb.append("\n");
        List<String> tags = tags();
        sb.append("Tags=" + tags);
        sb.append("\n");
        Vector[] coefficients = coefficients();
        SymbolTable featureSymbolTable = featureSymbolTable();
        sb.append("Coefficients=\n");
        for (int i = 0; i < coefficients.length; i++) {
            sb.append(tags.get(i));
            sb.append("  ");
            int[] nonZeroDimensions = coefficients[i].nonZeroDimensions();
            for (int i2 = 0; i2 < nonZeroDimensions.length; i2++) {
                if (i2 > 0) {
                    sb.append(", ");
                }
                int i3 = nonZeroDimensions[i2];
                sb.append(featureSymbolTable.idToSymbol(i3));
                sb.append("=");
                sb.append(coefficients[i].value(i3));
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Vector nodeFeatures(int i, ChainCrfFeatures<E> chainCrfFeatures) {
        return Features.toVector(chainCrfFeatures.nodeFeatures(i), this.mFeatureSymbolTable, this.mNumDimensions, this.mAddInterceptFeature);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Vector edgeFeatures(int i, int i2, ChainCrfFeatures<E> chainCrfFeatures) {
        return Features.toVector(chainCrfFeatures.edgeFeatures(i, i2), this.mFeatureSymbolTable, this.mNumDimensions, this.mAddInterceptFeature);
    }

    private FeatureVectors features(List<E> list) {
        int size = this.mTagList.size();
        this.mFeatureSymbolTable.numSymbols();
        if (list.size() == 0) {
            return null;
        }
        ChainCrfFeatures<E> extract = this.mFeatureExtractor.extract(list, this.mTagList);
        Vector[] vectorArr = new Vector[list.size()];
        for (int i = 0; i < list.size(); i++) {
            vectorArr[i] = nodeFeatures(i, extract);
        }
        Vector[][] vectorArr2 = new Vector[list.size() - 1][this.mTagList.size()];
        for (int i2 = 1; i2 < list.size(); i2++) {
            for (int i3 = 0; i3 < size; i3++) {
                vectorArr2[i2 - 1][i3] = edgeFeatures(i2, i3, extract);
            }
        }
        return new FeatureVectors(vectorArr, vectorArr2);
    }

    TagLattice<E> forwardBackward(List<E> list, FeatureVectors featureVectors) {
        int size = list.size();
        int size2 = this.mTagList.size();
        double[] dArr = new double[size2];
        for (int i = 0; i < size2; i++) {
            dArr[i] = this.mLegalTagStarts[i] ? featureVectors.mNodeFeatureVectors[0].dotProduct(this.mCoefficients[i]) : Double.NEGATIVE_INFINITY;
        }
        double[][][] dArr2 = new double[size - 1][size2][size2];
        for (double[][] dArr3 : dArr2) {
            for (double[] dArr4 : dArr3) {
                Arrays.fill(dArr4, Double.NEGATIVE_INFINITY);
            }
        }
        for (int i2 = 1; i2 < size; i2++) {
            for (int i3 = 0; i3 < size2; i3++) {
                if (i2 != size - 1 || this.mLegalTagEnds[i3]) {
                    double dotProduct = featureVectors.mNodeFeatureVectors[i2].dotProduct(this.mCoefficients[i3]);
                    for (int i4 = 0; i4 < size2; i4++) {
                        if (this.mLegalTagTransitions[i4][i3]) {
                            dArr2[i2 - 1][i4][i3] = featureVectors.mEdgeFeatureVectorss[i2 - 1][i4].dotProduct(this.mCoefficients[i3]) + dotProduct;
                        }
                    }
                }
            }
        }
        double[] dArr5 = new double[size2];
        double[][] dArr6 = new double[size][size2];
        for (int i5 = 0; i5 < size2; i5++) {
            dArr6[0][i5] = dArr[i5];
        }
        for (int i6 = 1; i6 < size; i6++) {
            for (int i7 = 0; i7 < size2; i7++) {
                for (int i8 = 0; i8 < size2; i8++) {
                    dArr5[i8] = dArr6[i6 - 1][i8] + dArr2[i6 - 1][i8][i7];
                }
                dArr6[i6][i7] = Math.logSumOfExponentials(dArr5);
            }
        }
        double[][] dArr7 = new double[size][size2];
        int i9 = size - 1;
        while (true) {
            i9--;
            if (i9 < 0) {
                return new ForwardBackwardTagLattice(list, this.mTagList, dArr6, dArr7, dArr2, Math.logSumOfExponentials(dArr6[size - 1]), false);
            }
            for (int i10 = 0; i10 < size2; i10++) {
                for (int i11 = 0; i11 < size2; i11++) {
                    dArr5[i11] = dArr7[i9 + 1][i11] + dArr2[i9][i10][i11];
                }
                dArr7[i9][i10] = Math.logSumOfExponentials(dArr5);
            }
        }
    }

    static boolean[] legalStarts(int[][] iArr, int i) {
        boolean[] zArr = new boolean[i];
        for (int[] iArr2 : iArr) {
            if (iArr2.length > 0) {
                zArr[iArr2[0]] = true;
            }
        }
        return zArr;
    }

    static boolean[] legalEnds(int[][] iArr, int i) {
        boolean[] zArr = new boolean[i];
        for (int[] iArr2 : iArr) {
            if (iArr2.length > 0) {
                zArr[iArr2[iArr2.length - 1]] = true;
            }
        }
        return zArr;
    }

    static boolean[][] legalTransitions(int[][] iArr, int i) {
        boolean[][] zArr = new boolean[i][i];
        for (int[] iArr2 : iArr) {
            for (int i2 = 1; i2 < iArr2.length; i2++) {
                zArr[iArr2[i2 - 1]][iArr2[i2]] = true;
            }
        }
        return zArr;
    }

    static boolean[] trueArray(int i) {
        boolean[] zArr = new boolean[i];
        Arrays.fill(zArr, true);
        return zArr;
    }

    static boolean[][] trueArray(int i, int i2) {
        boolean[][] zArr = new boolean[i][i2];
        for (boolean[] zArr2 : zArr) {
            Arrays.fill(zArr2, true);
        }
        return zArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v30, types: [int[], int[][]] */
    public static <F> ChainCrf<F> estimate(Corpus<ObjectHandler<Tagging<F>>> corpus, ChainCrfFeatureExtractor<F> chainCrfFeatureExtractor, boolean z, int i, boolean z2, boolean z3, RegressionPrior regressionPrior, int i2, AnnealingSchedule annealingSchedule, double d, int i3, int i4, Reporter reporter) throws IOException {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        reporter.info("ChainCrf.estimate Parameters");
        reporter.info("featureExtractor=" + chainCrfFeatureExtractor);
        reporter.info("addInterceptFeature=" + z);
        reporter.info("minFeatureCount=" + i);
        reporter.info("cacheFeatureVectors=" + z2);
        reporter.info("allowUnseenTransitions=" + z3);
        reporter.info("prior=" + regressionPrior);
        reporter.info("annealingSchedule=" + annealingSchedule);
        reporter.info("minImprovement=" + d);
        reporter.info("minEpochs=" + i3);
        reporter.info("maxEpochs=" + i4);
        reporter.info("priorBlockSize=" + i2);
        reporter.info("Computing corpus tokens and features");
        List corpusTokens = corpusTokens(corpus);
        String[][] corpusTags = corpusTags(corpus);
        int length = corpusTags.length;
        longestInput(corpusTags);
        long j = 0;
        for (String[] strArr : corpusTags) {
            j += strArr.length;
        }
        ?? r0 = new int[corpusTags.length];
        MapSymbolTable tagSymbolTable = tagSymbolTable(corpusTags, r0);
        MapSymbolTable featureSymbolTable = featureSymbolTable(corpusTags, corpusTokens, z, chainCrfFeatureExtractor, i);
        int numSymbols = tagSymbolTable.numSymbols();
        String[] strArr2 = new String[numSymbols];
        for (int i5 = 0; i5 < numSymbols; i5++) {
            strArr2[i5] = tagSymbolTable.idToSymbol(i5);
        }
        boolean[] trueArray = z3 ? trueArray(numSymbols) : legalStarts(r0, numSymbols);
        boolean[] trueArray2 = z3 ? trueArray(numSymbols) : legalEnds(r0, numSymbols);
        boolean[][] trueArray3 = z3 ? trueArray(numSymbols, numSymbols) : legalTransitions(r0, numSymbols);
        int numSymbols2 = featureSymbolTable.numSymbols();
        DenseVector[] denseVectorArr = new DenseVector[numSymbols];
        for (int i6 = 0; i6 < denseVectorArr.length; i6++) {
            denseVectorArr[i6] = new DenseVector(numSymbols2);
        }
        reporter.info("Corpus Statistics");
        reporter.info("Num Training Instances=" + length);
        reporter.info("Num Training Tokens=" + j);
        reporter.info("Num Dimensions After Pruning=" + numSymbols2);
        reporter.info("Tags=" + tagSymbolTable);
        ChainCrf<F> chainCrf = new ChainCrf<>(strArr2, trueArray, trueArray2, trueArray3, denseVectorArr, featureSymbolTable, chainCrfFeatureExtractor, z);
        FeatureVectors[] featureVectorsArr = z2 ? new FeatureVectors[length] : null;
        if (z2) {
            reporter.info("Caching Feature Vectors");
            for (int i7 = 0; i7 < length; i7++) {
                featureVectorsArr[i7] = chainCrf.features((List) corpusTokens.get(i7));
            }
        }
        double d2 = -8.988465674311579E307d;
        double d3 = 1.0d;
        double d4 = Double.NEGATIVE_INFINITY;
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        long j5 = 0;
        long j6 = 0;
        int i8 = 0;
        while (true) {
            if (i8 >= i4) {
                break;
            }
            int i9 = 0;
            double learningRate = annealingSchedule.learningRate(i8);
            double d5 = learningRate / length;
            for (int i10 = 0; i10 < length; i10++) {
                Object[] objArr = r0[i10];
                List<F> list = (List) corpusTokens.get(i10);
                int size = list.size();
                if (size >= 1) {
                    long currentTimeMillis = System.currentTimeMillis();
                    FeatureVectors features = z2 ? featureVectorsArr[i10] : chainCrf.features(list);
                    long currentTimeMillis2 = System.currentTimeMillis();
                    j2 += currentTimeMillis2 - currentTimeMillis;
                    TagLattice<F> forwardBackward = chainCrf.forwardBackward(list, features);
                    long currentTimeMillis3 = System.currentTimeMillis();
                    j3 += currentTimeMillis3 - currentTimeMillis2;
                    for (int i11 = 0; i11 < size; i11++) {
                        denseVectorArr[objArr[i11]].increment(learningRate, features.mNodeFeatureVectors[i11]);
                    }
                    for (int i12 = 1; i12 < size; i12++) {
                        denseVectorArr[objArr[i12]].increment(learningRate, features.mEdgeFeatureVectorss[i12 - 1][objArr[i12 - 1]]);
                    }
                    for (int i13 = 0; i13 < size; i13++) {
                        for (int i14 = 0; i14 < numSymbols; i14++) {
                            double logProbability = forwardBackward.logProbability(i13, i14);
                            if (logProbability >= -400.0d) {
                                denseVectorArr[i14].increment((-Math.exp(logProbability)) * learningRate, features.mNodeFeatureVectors[i13]);
                            }
                        }
                    }
                    for (int i15 = 1; i15 < size; i15++) {
                        for (int i16 = 0; i16 < numSymbols; i16++) {
                            for (int i17 = 0; i17 < numSymbols; i17++) {
                                double logProbability2 = forwardBackward.logProbability(i15, i16, i17);
                                if (logProbability2 >= -400.0d) {
                                    denseVectorArr[i17].increment((-Math.exp(logProbability2)) * learningRate, features.mEdgeFeatureVectorss[i15 - 1][i16]);
                                }
                            }
                        }
                    }
                    long currentTimeMillis4 = System.currentTimeMillis();
                    j4 += currentTimeMillis4 - currentTimeMillis3;
                    i9++;
                    if (i9 == i2) {
                        adjustWeightsWithPrior(denseVectorArr, regressionPrior, i9 * d5);
                        i9 = 0;
                    }
                    j6 += System.currentTimeMillis() - currentTimeMillis4;
                }
            }
            long currentTimeMillis5 = System.currentTimeMillis();
            adjustWeightsWithPrior(denseVectorArr, regressionPrior, i9 * d5);
            j6 += System.currentTimeMillis() - currentTimeMillis5;
            long currentTimeMillis6 = System.currentTimeMillis();
            double d6 = 0.0d;
            for (int i18 = 0; i18 < length; i18++) {
                if (((List) corpusTokens.get(i18)).size() >= 1) {
                    d6 += chainCrf.forwardBackward((List) corpusTokens.get(i18), z2 ? featureVectorsArr[i18] : chainCrf.features((List) corpusTokens.get(i18))).logProbability(0, r0[i18]);
                }
            }
            double log2Prior = regressionPrior == null ? KStarConstants.FLOOR : regressionPrior.log2Prior(denseVectorArr);
            double d7 = d6 + log2Prior;
            d3 = ((9.0d * d3) + Math.relativeAbsoluteDifference(d2, d7)) / 10.0d;
            d2 = d7;
            if (d7 > d4) {
                d4 = d7;
            }
            j5 += System.currentTimeMillis() - currentTimeMillis6;
            if (reporter.isDebugEnabled()) {
                Formatter formatter = null;
                try {
                    try {
                        formatter = new Formatter(Locale.ENGLISH);
                        formatter.format("epoch=%5d lr=%11.9f ll=%11.4f lp=%11.4f llp=%11.4f llp*=%11.4f", Integer.valueOf(i8), Double.valueOf(learningRate), Double.valueOf(d6), Double.valueOf(log2Prior), Double.valueOf(d7), Double.valueOf(d4));
                        reporter.debug(formatter.toString());
                        if (formatter != null) {
                            formatter.close();
                        }
                    } catch (IllegalFormatException e) {
                        reporter.warn("Illegal format in Logistic Regression");
                        if (formatter != null) {
                            formatter.close();
                        }
                    }
                } catch (Throwable th) {
                    if (formatter != null) {
                        formatter.close();
                    }
                    throw th;
                }
            }
            if (d3 < d) {
                reporter.info("Converged with rollingAverageRelativeDiff=" + d3);
                break;
            }
            i8++;
        }
        reporter.info("Feat Extraction Time=" + Strings.msToString(j2));
        reporter.info("Forward Backward Time=" + Strings.msToString(j3));
        reporter.info("Update Time=" + Strings.msToString(j4));
        reporter.info("Prior Update Time=" + Strings.msToString(j6));
        reporter.info("Loss Time=" + Strings.msToString(j5));
        return chainCrf;
    }

    static void adjustWeightsWithPrior(DenseVector[] denseVectorArr, RegressionPrior regressionPrior, double d) {
        if (regressionPrior.isUniform()) {
            return;
        }
        for (DenseVector denseVector : denseVectorArr) {
            for (int i = 0; i < denseVector.numDimensions(); i++) {
                double value = denseVector.value(i);
                double mode = regressionPrior.mode(i);
                if (value != mode) {
                    double gradient = regressionPrior.gradient(value, i) * d;
                    denseVector.setValue(i, value > mode ? Math.max(mode, value - gradient) : Math.min(mode, value - gradient));
                }
            }
        }
    }

    static MapSymbolTable tagSymbolTable(String[][] strArr, int[][] iArr) {
        MapSymbolTable mapSymbolTable = new MapSymbolTable();
        for (int i = 0; i < strArr.length; i++) {
            iArr[i] = new int[strArr[i].length];
            for (int i2 = 0; i2 < iArr[i].length; i2++) {
                iArr[i][i2] = mapSymbolTable.getOrAddSymbol(strArr[i][i2]);
            }
        }
        return mapSymbolTable;
    }

    static <F> MapSymbolTable featureSymbolTable(String[][] strArr, List<List<F>> list, boolean z, ChainCrfFeatureExtractor<F> chainCrfFeatureExtractor, int i) {
        ObjectToCounterMap objectToCounterMap = new ObjectToCounterMap();
        for (int i2 = 0; i2 < strArr.length; i2++) {
            String[] strArr2 = strArr[i2];
            ChainCrfFeatures<F> extract = chainCrfFeatureExtractor.extract(list.get(i2), Arrays.asList(strArr2));
            for (int i3 = 0; i3 < strArr2.length; i3++) {
                Iterator<String> it = extract.nodeFeatures(i3).keySet().iterator();
                while (it.hasNext()) {
                    objectToCounterMap.increment(it.next());
                }
            }
            for (int i4 = 1; i4 < strArr2.length; i4++) {
                Iterator<String> it2 = extract.edgeFeatures(i4, i4 - 1).keySet().iterator();
                while (it2.hasNext()) {
                    objectToCounterMap.increment(it2.next());
                }
            }
        }
        objectToCounterMap.prune(i);
        MapSymbolTable mapSymbolTable = new MapSymbolTable();
        if (z) {
            mapSymbolTable.getOrAddSymbol("*&^INTERCEPT%$^&**");
        }
        Iterator<E> it3 = objectToCounterMap.keySet().iterator();
        while (it3.hasNext()) {
            mapSymbolTable.getOrAddSymbol((String) it3.next());
        }
        return mapSymbolTable;
    }

    static <F> List<List<F>> corpusTokens(Corpus<ObjectHandler<Tagging<F>>> corpus) throws IOException {
        final ArrayList arrayList = new ArrayList();
        corpus.visitTrain(new ObjectHandler<Tagging<F>>() { // from class: com.aliasi.crf.ChainCrf.1
            @Override // com.aliasi.corpus.ObjectHandler
            public void handle(Tagging<F> tagging) {
                arrayList.add(tagging.tokens());
            }
        });
        return arrayList;
    }

    static <F> String[][] corpusTags(Corpus<ObjectHandler<Tagging<F>>> corpus) throws IOException {
        final ArrayList arrayList = new ArrayList(Opcodes.ACC_ABSTRACT);
        corpus.visitTrain(new ObjectHandler<Tagging<F>>() { // from class: com.aliasi.crf.ChainCrf.2
            @Override // com.aliasi.corpus.ObjectHandler
            public void handle(Tagging<F> tagging) {
                arrayList.add(tagging.tags().toArray(Strings.EMPTY_STRING_ARRAY));
            }
        });
        return (String[][]) arrayList.toArray(Strings.EMPTY_STRING_2D_ARRAY);
    }

    static DenseVector[] copy(DenseVector[] denseVectorArr) {
        DenseVector[] denseVectorArr2 = new DenseVector[denseVectorArr.length];
        for (int i = 0; i < denseVectorArr.length; i++) {
            denseVectorArr2[i] = new DenseVector(denseVectorArr[i]);
        }
        return denseVectorArr2;
    }

    static int longestInput(String[][] strArr) {
        int i = 0;
        for (String[] strArr2 : strArr) {
            if (strArr2.length > i) {
                i = strArr2.length;
            }
        }
        return i;
    }
}
