package com.aliasi.classify;

import com.aliasi.stats.Statistics;
import com.aliasi.util.Math;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/classify/ConfusionMatrix.class */
public class ConfusionMatrix {
    private final String[] mCategories;
    private final int[][] mMatrix;
    private final Map<String, Integer> mCategoryToIndex = new HashMap();

    public ConfusionMatrix(String[] strArr) {
        this.mCategories = (String[]) strArr.clone();
        int length = strArr.length;
        this.mMatrix = new int[length][length];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                this.mMatrix[i][i2] = 0;
            }
        }
        for (int i3 = 0; i3 < length; i3++) {
            this.mCategoryToIndex.put(strArr[i3], Integer.valueOf(i3));
        }
    }

    public ConfusionMatrix(String[] strArr, int[][] iArr) {
        this.mCategories = strArr;
        this.mMatrix = iArr;
        if (strArr.length != iArr.length) {
            throw new IllegalArgumentException("Categories and matrix must be of same length. Found categories length=" + strArr.length + " and matrix length=" + iArr.length);
        }
        for (int i = 0; i < iArr.length; i++) {
            if (strArr.length != iArr[i].length) {
                throw new IllegalArgumentException("Categories and all matrix rows must be of same length. Found categories length=" + strArr.length + " Found row " + i + " length=" + iArr[i].length);
            }
        }
        int length = iArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 < length; i3++) {
                if (iArr[i2][i3] < 0) {
                    throw new IllegalArgumentException("Matrix entries must be non-negative. matrix[" + i2 + "][" + i3 + "]=" + iArr[i2][i3]);
                }
            }
        }
    }

    public String[] categories() {
        return (String[]) this.mCategories.clone();
    }

    public int numCategories() {
        return categories().length;
    }

    public int getIndex(String str) {
        Integer num = this.mCategoryToIndex.get(str);
        if (num == null) {
            return -1;
        }
        return num.intValue();
    }

    public int[][] matrix() {
        return (int[][]) this.mMatrix.clone();
    }

    public void increment(int i, int i2) {
        checkIndex("reference", i);
        checkIndex("response", i2);
        int[] iArr = this.mMatrix[i];
        iArr[i2] = iArr[i2] + 1;
    }

    public void incrementByN(int i, int i2, int i3) {
        checkIndex("reference", i);
        checkIndex("response", i2);
        if (this.mMatrix[i][i2] + i3 < 0) {
            throw new IllegalArgumentException("Cannot decrement to less than 0 value. referenceCategoryIndex=" + i + " responseCategoryIndex=" + i2 + " matrix[referenceCategoryIndex][responseCategoryIndex]=" + this.mMatrix[i][i] + " increment=" + i3);
        }
        int[] iArr = this.mMatrix[i];
        iArr[i2] = iArr[i2] + i3;
    }

    public void increment(String str, String str2) {
        increment(getIndex(str), getIndex(str2));
    }

    public int count(int i, int i2) {
        checkIndex("reference", i);
        checkIndex("response", i2);
        return this.mMatrix[i][i2];
    }

    public int totalCount() {
        int i = 0;
        int numCategories = numCategories();
        for (int i2 = 0; i2 < numCategories; i2++) {
            for (int i3 = 0; i3 < numCategories; i3++) {
                i += this.mMatrix[i2][i3];
            }
        }
        return i;
    }

    public int totalCorrect() {
        int i = 0;
        int numCategories = numCategories();
        for (int i2 = 0; i2 < numCategories; i2++) {
            i += this.mMatrix[i2][i2];
        }
        return i;
    }

    public double totalAccuracy() {
        return totalCorrect() / totalCount();
    }

    public double confidence95() {
        return confidence(1.96d);
    }

    public double confidence99() {
        return confidence(2.58d);
    }

    public double confidence(double d) {
        double d2 = totalAccuracy();
        return d * Math.sqrt((d2 * (1.0d - d2)) / totalCount());
    }

    public double referenceEntropy() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            double referenceLikelihood = oneVsAll(i).referenceLikelihood();
            d += referenceLikelihood * Math.log2(referenceLikelihood);
        }
        return -d;
    }

    public double responseEntropy() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            double responseLikelihood = oneVsAll(i).responseLikelihood();
            d += responseLikelihood * Math.log2(responseLikelihood);
        }
        return -d;
    }

    public double crossEntropy() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            PrecisionRecallEvaluation oneVsAll = oneVsAll(i);
            d += oneVsAll.referenceLikelihood() * Math.log2(oneVsAll.responseLikelihood());
        }
        return -d;
    }

    public double jointEntropy() {
        double d = totalCount();
        double d2 = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            for (int i2 = 0; i2 < numCategories(); i2++) {
                double count = count(i, i2) / d;
                if (count > KStarConstants.FLOOR) {
                    d2 += count * Math.log2(count);
                }
            }
        }
        return -d2;
    }

    public double conditionalEntropy(int i) {
        double d = 0.0d;
        long positiveReference = oneVsAll(i).positiveReference();
        for (int i2 = 0; i2 < numCategories(); i2++) {
            double count = count(i, i2) / positiveReference;
            if (count > KStarConstants.FLOOR) {
                d += count * Math.log2(count);
            }
        }
        return -d;
    }

    public double conditionalEntropy() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            d += oneVsAll(i).referenceLikelihood() * conditionalEntropy(i);
        }
        return d;
    }

    public double kappa() {
        return kappa(randomAccuracy());
    }

    public double kappaUnbiased() {
        return kappa(randomAccuracyUnbiased());
    }

    public double kappaNoPrevalence() {
        return (2.0d * totalAccuracy()) - 1.0d;
    }

    private double kappa(double d) {
        return (totalAccuracy() - d) / (1.0d - d);
    }

    public double randomAccuracy() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            PrecisionRecallEvaluation oneVsAll = oneVsAll(i);
            d += oneVsAll.referenceLikelihood() * oneVsAll.responseLikelihood();
        }
        return d;
    }

    public double randomAccuracyUnbiased() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            PrecisionRecallEvaluation oneVsAll = oneVsAll(i);
            double referenceLikelihood = (oneVsAll.referenceLikelihood() + oneVsAll.responseLikelihood()) / 2.0d;
            d += referenceLikelihood * referenceLikelihood;
        }
        return d;
    }

    public int chiSquaredDegreesOfFreedom() {
        int numCategories = numCategories() - 1;
        return numCategories * numCategories;
    }

    public double chiSquared() {
        int numCategories = numCategories();
        double[][] dArr = new double[numCategories][numCategories];
        for (int i = 0; i < numCategories; i++) {
            for (int i2 = 0; i2 < numCategories; i2++) {
                dArr[i][i2] = count(i, i2);
            }
        }
        return Statistics.chiSquaredIndependence(dArr);
    }

    public double phiSquared() {
        return chiSquared() / totalCount();
    }

    public double cramersV() {
        return Math.sqrt(phiSquared() / (numCategories() - 1));
    }

    public PrecisionRecallEvaluation oneVsAll(int i) {
        PrecisionRecallEvaluation precisionRecallEvaluation = new PrecisionRecallEvaluation();
        int i2 = 0;
        while (i2 < numCategories()) {
            int i3 = 0;
            while (i3 < numCategories()) {
                precisionRecallEvaluation.addCase(i2 == i, i3 == i, this.mMatrix[i2][i3]);
                i3++;
            }
            i2++;
        }
        return precisionRecallEvaluation;
    }

    public PrecisionRecallEvaluation microAverage() {
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        for (int i = 0; i < numCategories(); i++) {
            PrecisionRecallEvaluation oneVsAll = oneVsAll(i);
            j += oneVsAll.truePositive();
            j2 += oneVsAll.falsePositive();
            j4 += oneVsAll.trueNegative();
            j3 += oneVsAll.falseNegative();
        }
        return new PrecisionRecallEvaluation(j, j3, j2, j4);
    }

    public double macroAvgPrecision() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            d += oneVsAll(i).precision();
        }
        return d / numCategories();
    }

    public double macroAvgRecall() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            d += oneVsAll(i).recall();
        }
        return d / numCategories();
    }

    public double macroAvgFMeasure() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            d += oneVsAll(i).fMeasure();
        }
        return d / numCategories();
    }

    public double lambdaA() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            double positiveReference = oneVsAll(i).positiveReference();
            if (positiveReference > d) {
                d = positiveReference;
            }
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < numCategories(); i2++) {
            int i3 = 0;
            for (int i4 = 0; i4 < numCategories(); i4++) {
                int count = count(i4, i2);
                if (count > i3) {
                    i3 = count;
                }
            }
            d2 += i3;
        }
        return (d2 - d) / (totalCount() - d);
    }

    public double lambdaB() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            double positiveResponse = oneVsAll(i).positiveResponse();
            if (positiveResponse > d) {
                d = positiveResponse;
            }
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < numCategories(); i2++) {
            int i3 = 0;
            for (int i4 = 0; i4 < numCategories(); i4++) {
                int count = count(i2, i4);
                if (count > i3) {
                    i3 = count;
                }
            }
            d2 += i3;
        }
        return (d2 - d) / (totalCount() - d);
    }

    public double mutualInformation() {
        double d = totalCount();
        double d2 = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            double referenceLikelihood = oneVsAll(i).referenceLikelihood();
            if (referenceLikelihood > KStarConstants.FLOOR) {
                for (int i2 = 0; i2 < numCategories(); i2++) {
                    double responseLikelihood = oneVsAll(i2).responseLikelihood();
                    if (responseLikelihood > KStarConstants.FLOOR) {
                        double count = count(i, i2) / d;
                        if (count > KStarConstants.FLOOR) {
                            d2 += count * Math.log2(count / (referenceLikelihood * responseLikelihood));
                        }
                    }
                }
            }
        }
        return d2;
    }

    public double klDivergence() {
        double d = 0.0d;
        for (int i = 0; i < numCategories(); i++) {
            PrecisionRecallEvaluation oneVsAll = oneVsAll(i);
            double referenceLikelihood = oneVsAll.referenceLikelihood();
            d += referenceLikelihood * Math.log2(referenceLikelihood / oneVsAll.responseLikelihood());
        }
        return d;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("GLOBAL CONFUSION MATRIX STATISTICS\n");
        toStringGlobal(sb);
        for (int i = 0; i < numCategories(); i++) {
            sb.append("CATEGORY " + i + "=" + categories()[i] + " VS. ALL\n");
            sb.append("  Conditional Entropy=" + conditionalEntropy(i));
            sb.append('\n');
            sb.append(oneVsAll(i).toString());
            sb.append('\n');
        }
        return sb.toString();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void toStringGlobal(StringBuilder sb) {
        sb.append("Categories=" + Arrays.asList(categories()));
        sb.append('\n');
        sb.append("Total Count=" + totalCount());
        sb.append('\n');
        sb.append("Total Correct=" + totalCorrect());
        sb.append('\n');
        sb.append("Total Accuracy=" + totalAccuracy());
        sb.append('\n');
        sb.append("95% Confidence Interval=" + totalAccuracy() + " +/- " + confidence95());
        sb.append('\n');
        sb.append("Confusion Matrix\n");
        sb.append("reference \\ response\n");
        sb.append(matrixToCSV());
        sb.append('\n');
        sb.append("Macro-averaged Precision=" + macroAvgPrecision());
        sb.append('\n');
        sb.append("Macro-averaged Recall=" + macroAvgRecall());
        sb.append('\n');
        sb.append("Macro-averaged F=" + macroAvgFMeasure());
        sb.append('\n');
        sb.append("Micro-averaged Results\n");
        sb.append("         the following symmetries are expected:\n");
        sb.append("           TP=TN, FN=FP\n");
        sb.append("           PosRef=PosResp=NegRef=NegResp\n");
        sb.append("           Acc=Prec=Rec=F\n");
        sb.append(microAverage().toString());
        sb.append('\n');
        sb.append("Random Accuracy=" + randomAccuracy());
        sb.append('\n');
        sb.append("Random Accuracy Unbiased=" + randomAccuracyUnbiased());
        sb.append('\n');
        sb.append("kappa=" + kappa());
        sb.append('\n');
        sb.append("kappa Unbiased=" + kappaUnbiased());
        sb.append('\n');
        sb.append("kappa No Prevalence =" + kappaNoPrevalence());
        sb.append('\n');
        sb.append("Reference Entropy=" + referenceEntropy());
        sb.append('\n');
        sb.append("Response Entropy=" + responseEntropy());
        sb.append('\n');
        sb.append("Cross Entropy=" + crossEntropy());
        sb.append('\n');
        sb.append("Joint Entropy=" + jointEntropy());
        sb.append('\n');
        sb.append("Conditional Entropy=" + conditionalEntropy());
        sb.append('\n');
        sb.append("Mutual Information=" + mutualInformation());
        sb.append('\n');
        sb.append("Kullback-Liebler Divergence=" + klDivergence());
        sb.append('\n');
        sb.append("chi Squared=" + chiSquared());
        sb.append('\n');
        sb.append("chi-Squared Degrees of Freedom=" + chiSquaredDegreesOfFreedom());
        sb.append('\n');
        sb.append("phi Squared=" + phiSquared());
        sb.append('\n');
        sb.append("Cramer's V=" + cramersV());
        sb.append('\n');
        sb.append("lambda A=" + lambdaA());
        sb.append('\n');
        sb.append("lambda B=" + lambdaB());
        sb.append('\n');
    }

    String matrixToCSV() {
        StringBuilder sb = new StringBuilder();
        sb.append("  ");
        for (int i = 0; i < numCategories(); i++) {
            sb.append(',');
            sb.append(categories()[i]);
        }
        for (int i2 = 0; i2 < numCategories(); i2++) {
            sb.append("\n  ");
            sb.append(categories()[i2]);
            for (int i3 = 0; i3 < numCategories(); i3++) {
                sb.append(',');
                sb.append(count(i2, i3));
            }
        }
        return sb.toString();
    }

    String matrixToHTML() {
        StringBuilder sb = new StringBuilder();
        sb.append("<html>\n");
        sb.append("<table border='1' cellpadding='5'>");
        sb.append('\n');
        sb.append("<tr>\n  <td colspan='2' rowspan='2'>&nbsp;</td>");
        sb.append("\n  <td colspan='" + numCategories() + "' align='center' bgcolor='darkgray'><b>Response</b></td></tr>");
        sb.append("<tr>");
        for (int i = 0; i < numCategories(); i++) {
            sb.append("\n  <td align='right' bgcolor='lightgray'><i>" + categories()[i] + "</i></td>");
        }
        sb.append("</tr>\n");
        for (int i2 = 0; i2 < numCategories(); i2++) {
            sb.append("<tr>");
            if (i2 == 0) {
                sb.append("\n  <td rowspan='" + numCategories() + "' bgcolor='darkgray'><b>Ref-<br>erence</b></td>");
            }
            sb.append("\n  <td align='right' bgcolor='lightgray'><i>" + categories()[i2] + "</i></td>");
            for (int i3 = 0; i3 < numCategories(); i3++) {
                if (i2 == i3) {
                    sb.append("\n  <td align='right' bgcolor='lightgreen'>");
                } else if (count(i2, i3) == 0) {
                    sb.append("\n  <td align='right' bgcolor='yellow'>");
                } else {
                    sb.append("\n  <td align='right' bgcolor='red'>");
                }
                sb.append(count(i2, i3));
                sb.append("</td>");
            }
            sb.append("</tr>\n");
        }
        sb.append("</table>\n");
        sb.append("</html>\n");
        return sb.toString();
    }

    private void checkIndex(String str, int i) {
        if (i < 0) {
            throw new IllegalArgumentException("Index for " + str + " must be > 0. Found index=" + i);
        }
        if (i >= numCategories()) {
            throw new IllegalArgumentException("Index for " + str + " must be < numCategories()=" + numCategories());
        }
    }
}
