package com.aliasi.hmm;

import com.aliasi.symbol.SymbolTable;
import com.aliasi.tag.TagLattice;
import com.aliasi.util.Math;
import com.aliasi.util.ScoredObject;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/hmm/TagWordLattice.class */
class TagWordLattice extends TagLattice<String> {
    final double[][][] mTransitions;
    final double[][] mForwards;
    final double[] mForwardExps;
    final double[][] mBacks;
    final double[] mBackExps;
    final double[] mStarts;
    final double[] mEnds;
    final String[] mTokens;
    final SymbolTable mTagSymbolTable;
    double mTotal = Double.NaN;
    double mLog2Total = Double.NaN;

    public TagWordLattice(String[] strArr, SymbolTable symbolTable, double[] dArr, double[] dArr2, double[][][] dArr3) {
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] < KStarConstants.FLOOR || dArr[i] > 1.0d) {
                throw new IllegalArgumentException("startProbs[" + i + "]=" + dArr[i]);
            }
        }
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            if (dArr2[i2] < KStarConstants.FLOOR || dArr2[i2] > 1.0d) {
                throw new IllegalArgumentException("endProbs[" + i2 + "]=" + dArr2[i2]);
            }
        }
        for (int i3 = 1; i3 < dArr3.length; i3++) {
            for (int i4 = 0; i4 < dArr3[i3].length; i4++) {
                for (int i5 = 0; i5 < dArr3[i3][i4].length; i5++) {
                    if (dArr3[i3][i4][i5] < KStarConstants.FLOOR || dArr3[i3][i4][i5] > 1.0d) {
                        throw new IllegalArgumentException("transitProbs[" + i3 + "][" + i4 + "][" + i5 + "]=" + dArr3[i3][i4][i5]);
                    }
                }
            }
        }
        int numSymbols = symbolTable.numSymbols();
        int length = strArr.length;
        this.mStarts = dArr;
        this.mEnds = dArr2;
        this.mTransitions = dArr3;
        this.mTokens = strArr;
        this.mTagSymbolTable = symbolTable;
        this.mForwards = new double[length][numSymbols];
        this.mForwardExps = new double[length];
        this.mBacks = new double[length][numSymbols];
        this.mBackExps = new double[length];
        computeAll();
    }

    public String[] tokens() {
        return this.mTokens;
    }

    @Override // com.aliasi.tag.TagLattice
    public SymbolTable tagSymbolTable() {
        return this.mTagSymbolTable;
    }

    public List<ScoredObject<String>> log2ConditionalTagList(int i) {
        double log2Total = log2Total();
        SymbolTable symbolTable = this.mTagSymbolTable;
        int numSymbols = symbolTable.numSymbols();
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < numSymbols; i2++) {
            String idToSymbol = symbolTable.idToSymbol(i2);
            double log2ForwardBackward = log2ForwardBackward(i, i2) - log2Total;
            if (log2ForwardBackward > KStarConstants.FLOOR) {
                log2ForwardBackward = 0.0d;
            } else if (Double.isNaN(log2ForwardBackward) || Double.isInfinite(log2ForwardBackward)) {
                log2ForwardBackward = Math.log2(Double.MIN_VALUE);
            }
            arrayList.add(new ScoredObject(idToSymbol, log2ForwardBackward));
        }
        Collections.sort(arrayList, ScoredObject.reverseComparator());
        return arrayList;
    }

    public ScoredObject<String>[] log2ConditionalTags(int i) {
        double log2Total = log2Total();
        SymbolTable symbolTable = this.mTagSymbolTable;
        int numSymbols = symbolTable.numSymbols();
        ScoredObject<String>[] scoredObjectArr = new ScoredObject[numSymbols];
        for (int i2 = 0; i2 < numSymbols; i2++) {
            String idToSymbol = symbolTable.idToSymbol(i2);
            double log2ForwardBackward = log2ForwardBackward(i, i2) - log2Total;
            if (log2ForwardBackward > KStarConstants.FLOOR) {
                log2ForwardBackward = 0.0d;
            } else if (Double.isNaN(log2ForwardBackward) || Double.isInfinite(log2ForwardBackward)) {
                log2ForwardBackward = Math.log2(Double.MIN_VALUE);
            }
            scoredObjectArr[i2] = new ScoredObject<>(idToSymbol, log2ForwardBackward);
        }
        Arrays.sort(scoredObjectArr, ScoredObject.reverseComparator());
        return scoredObjectArr;
    }

    public String[] bestForwardBackward() {
        String[] strArr = new String[this.mTokens.length];
        int numSymbols = this.mTagSymbolTable.numSymbols();
        for (int i = 0; i < strArr.length; i++) {
            int i2 = 0;
            double forwardBackward = forwardBackward(i, 0);
            for (int i3 = 1; i3 < numSymbols; i3++) {
                double forwardBackward2 = forwardBackward(i, i3);
                if (forwardBackward2 > forwardBackward) {
                    forwardBackward = forwardBackward2;
                    i2 = i3;
                }
            }
            strArr[i] = this.mTagSymbolTable.idToSymbol(i2);
        }
        return strArr;
    }

    public double start(int i) {
        return this.mStarts[i];
    }

    public double log2Start(int i) {
        return Math.log2(start(i));
    }

    public double end(int i) {
        return this.mEnds[i];
    }

    public double log2End(int i) {
        return Math.log2(end(i));
    }

    public double transition(int i, int i2, int i3) {
        if (i == 0) {
            throw new IndexOutOfBoundsException("Token index must be > 0.");
        }
        return this.mTransitions[i][i2][i3];
    }

    public double log2Transitions(int i, int i2, int i3) {
        return Math.log2(transition(i, i2, i3));
    }

    public double forward(int i, int i2) {
        return this.mForwards[i][i2] * Math.pow(2.0d, this.mForwardExps[i]);
    }

    public double log2Forward(int i, int i2) {
        return Math.log2(this.mForwards[i][i2]) + this.mForwardExps[i];
    }

    public double backward(int i, int i2) {
        return this.mBacks[i][i2] * Math.pow(2.0d, this.mBackExps[i]);
    }

    public double log2Backward(int i, int i2) {
        return Math.log2(this.mBacks[i][i2]) + this.mBackExps[i];
    }

    public double forwardBackward(int i, int i2) {
        return forward(i, i2) * backward(i, i2);
    }

    public double log2ForwardBackward(int i, int i2) {
        return log2Forward(i, i2) + log2Backward(i, i2);
    }

    public double total() {
        return this.mTotal;
    }

    public double log2Total() {
        return this.mLog2Total;
    }

    @Override // com.aliasi.tag.TagLattice
    public double logForward(int i, int i2) {
        return Math.logBase2ToNaturalLog(log2Forward(i, i2));
    }

    @Override // com.aliasi.tag.TagLattice
    public double logBackward(int i, int i2) {
        return Math.logBase2ToNaturalLog(log2Backward(i, i2));
    }

    @Override // com.aliasi.tag.TagLattice
    public double logZ() {
        return Math.logBase2ToNaturalLog(log2Total());
    }

    @Override // com.aliasi.tag.TagLattice
    public double logTransition(int i, int i2, int i3) {
        return Math.logBase2ToNaturalLog(log2Transitions(i + 1, i2, i3));
    }

    @Override // com.aliasi.tag.TagLattice
    public double logProbability(int i, int i2) {
        return Math.logBase2ToNaturalLog(log2ForwardBackward(i, i2));
    }

    @Override // com.aliasi.tag.TagLattice
    public double logProbability(int i, int i2, int i3) {
        return logProbability(i - 1, new int[]{i2, i3});
    }

    @Override // com.aliasi.tag.TagLattice
    public double logProbability(int i, int[] iArr) {
        double logForward = (logForward(i, iArr[0]) + logBackward((i + iArr.length) - 1, iArr[iArr.length - 1])) - logZ();
        for (int i2 = 1; i2 < iArr.length; i2++) {
            logForward += logTransition((i + i2) - 1, iArr[i2 - 1], iArr[i2]);
        }
        return logForward;
    }

    @Override // com.aliasi.tag.TagLattice
    public int numTokens() {
        return this.mTokens.length;
    }

    @Override // com.aliasi.tag.TagLattice
    public List<String> tokenList() {
        return Arrays.asList(this.mTokens);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.aliasi.tag.TagLattice
    public String token(int i) {
        return this.mTokens[i];
    }

    @Override // com.aliasi.tag.TagLattice
    public int numTags() {
        return this.mTagSymbolTable.numSymbols();
    }

    @Override // com.aliasi.tag.TagLattice
    public String tag(int i) {
        return this.mTagSymbolTable.idToSymbol(i);
    }

    @Override // com.aliasi.tag.TagLattice
    public List<String> tagList() {
        ArrayList arrayList = new ArrayList(numTags());
        for (int i = 0; i < numTags(); i++) {
            arrayList.add(tag(i));
        }
        return arrayList;
    }

    final void computeAll() {
        computeForward();
        computeBackward();
        computeTotal();
    }

    private void computeTotal() {
        if (this.mForwards.length == 0) {
            this.mTotal = 1.0d;
            this.mLog2Total = KStarConstants.FLOOR;
            return;
        }
        double d = 0.0d;
        int numSymbols = tagSymbolTable().numSymbols();
        for (int i = 0; i < numSymbols; i++) {
            d += this.mForwards[0][i] * this.mBacks[0][i];
        }
        double d2 = this.mForwardExps[0] + this.mBackExps[0];
        this.mLog2Total = Math.log2(d) + d2;
        this.mTotal = d * Math.pow(2.0d, d2);
    }

    private void computeForward() {
        if (this.mForwards.length == 0) {
            return;
        }
        int numSymbols = tagSymbolTable().numSymbols();
        double[] dArr = this.mForwards[0];
        for (int i = 0; i < numSymbols; i++) {
            if (this.mStarts[i] < KStarConstants.FLOOR) {
                this.mStarts[i] = 0.0d;
            }
            dArr[i] = this.mStarts[i];
        }
        this.mForwardExps[0] = log2ScaleExp(dArr);
        int length = this.mTokens.length;
        for (int i2 = 1; i2 < length; i2++) {
            double[] dArr2 = this.mForwards[i2 - 1];
            double[][] dArr3 = this.mTransitions[i2];
            for (int i3 = 0; i3 < numSymbols; i3++) {
                double d = 0.0d;
                for (int i4 = 0; i4 < numSymbols; i4++) {
                    d += dArr2[i4] * dArr3[i4][i3];
                }
                this.mForwards[i2][i3] = d;
            }
            this.mForwardExps[i2] = log2ScaleExp(this.mForwards[i2]) + this.mForwardExps[i2 - 1];
        }
    }

    private void computeBackward() {
        if (this.mBacks.length == 0) {
            return;
        }
        int numSymbols = tagSymbolTable().numSymbols();
        int length = this.mTokens.length - 1;
        double[] dArr = this.mBacks[length];
        for (int i = 0; i < numSymbols; i++) {
            dArr[i] = this.mEnds[i];
        }
        this.mBackExps[length] = log2ScaleExp(dArr);
        int i2 = length;
        while (true) {
            i2--;
            if (i2 < 0) {
                return;
            }
            double[] dArr2 = this.mBacks[i2 + 1];
            double[][] dArr3 = this.mTransitions[i2 + 1];
            for (int i3 = 0; i3 < numSymbols; i3++) {
                double d = 0.0d;
                for (int i4 = 0; i4 < numSymbols; i4++) {
                    d += dArr2[i4] * dArr3[i3][i4];
                }
                this.mBacks[i2][i3] = d;
            }
            this.mBackExps[i2] = log2ScaleExp(this.mBacks[i2]) + this.mBackExps[i2 + 1];
        }
    }

    static double log2ScaleExp(double[] dArr) {
        if (dArr.length == 0) {
            return KStarConstants.FLOOR;
        }
        double d = dArr[0];
        for (int i = 1; i < dArr.length; i++) {
            if (d < dArr[i]) {
                d = dArr[i];
            }
        }
        if (d < KStarConstants.FLOOR || d > 1.0d) {
            throw new IllegalArgumentException("Max must be >= 0 and <= 1. Found max=" + d);
        }
        if (d == KStarConstants.FLOOR) {
            return KStarConstants.FLOOR;
        }
        double d2 = 0.0d;
        double d3 = 1.0d;
        while (d < 0.5d) {
            d2 -= 1.0d;
            d3 *= 2.0d;
            d *= 2.0d;
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = dArr[i2] * d3;
        }
        if (d2 > KStarConstants.FLOOR) {
            throw new IllegalArgumentException("Exponent must be <= 0. Found exp=" + d2);
        }
        return d2;
    }
}
