package com.aliasi.classify;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.ScoredObject;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/classify/TfIdfClassifierTrainer.class */
public class TfIdfClassifierTrainer<E> implements ObjectHandler<Classified<E>>, Compilable, Serializable {
    static final long serialVersionUID = -2793388723202924633L;
    final FeatureExtractor<? super E> mFeatureExtractor;
    final Map<Integer, ObjectToDoubleMap<Integer>> mFeatureToCategoryCount;
    final MapSymbolTable mFeatureSymbolTable;
    final MapSymbolTable mCategorySymbolTable;

    /* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/classify/TfIdfClassifierTrainer$Externalizer.class */
    static class Externalizer<F> extends AbstractExternalizable {
        static final long serialVersionUID = 5578122239615646843L;
        final TfIdfClassifierTrainer<F> mTrainer;

        public Externalizer() {
            this(null);
        }

        public Externalizer(TfIdfClassifierTrainer<F> tfIdfClassifierTrainer) {
            this.mTrainer = tfIdfClassifierTrainer;
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            AbstractExternalizable.compileOrSerialize(this.mTrainer.mFeatureExtractor, objectOutput);
            int numSymbols = this.mTrainer.mFeatureSymbolTable.numSymbols();
            objectOutput.writeObject(this.mTrainer.mFeatureSymbolTable);
            int numSymbols2 = this.mTrainer.mCategorySymbolTable.numSymbols();
            double d = numSymbols2;
            objectOutput.writeInt(numSymbols2);
            for (int i = 0; i < numSymbols2; i++) {
                objectOutput.writeUTF(this.mTrainer.mCategorySymbolTable.idToSymbol(i));
            }
            for (int i2 = 0; i2 < this.mTrainer.mFeatureSymbolTable.numSymbols(); i2++) {
                objectOutput.writeFloat((float) TfIdfClassifierTrainer.idf(this.mTrainer.mFeatureToCategoryCount.get(Integer.valueOf(i2)).size(), d));
            }
            int i3 = 0;
            for (int i4 = 0; i4 < numSymbols; i4++) {
                objectOutput.writeInt(i3);
                i3 += this.mTrainer.mFeatureToCategoryCount.get(Integer.valueOf(i4)).size();
            }
            objectOutput.writeInt(i3);
            double[] dArr = new double[numSymbols2];
            Iterator<Map.Entry<Integer, ObjectToDoubleMap<Integer>>> it = this.mTrainer.mFeatureToCategoryCount.entrySet().iterator();
            while (it.hasNext()) {
                ObjectToDoubleMap<Integer> value = it.next().getValue();
                double idf = TfIdfClassifierTrainer.idf(value.size(), d);
                for (Map.Entry<Integer, Double> entry : value.entrySet()) {
                    int intValue = entry.getKey().intValue();
                    double tf = TfIdfClassifierTrainer.tf(entry.getValue().doubleValue()) * idf;
                    dArr[intValue] = dArr[intValue] + (tf * tf);
                }
            }
            for (int i5 = 0; i5 < dArr.length; i5++) {
                dArr[i5] = Math.sqrt(dArr[i5]);
            }
            for (int i6 = 0; i6 < numSymbols; i6++) {
                ObjectToDoubleMap<Integer> objectToDoubleMap = this.mTrainer.mFeatureToCategoryCount.get(Integer.valueOf(i6));
                double idf2 = TfIdfClassifierTrainer.idf(objectToDoubleMap.size(), d);
                for (Map.Entry<Integer, Double> entry2 : objectToDoubleMap.entrySet()) {
                    int intValue2 = entry2.getKey().intValue();
                    float tf2 = (float) ((TfIdfClassifierTrainer.tf(entry2.getValue().doubleValue()) * idf2) / dArr[intValue2]);
                    objectOutput.writeInt(intValue2);
                    objectOutput.writeFloat(tf2);
                }
            }
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws ClassNotFoundException, IOException {
            FeatureExtractor featureExtractor = (FeatureExtractor) objectInput.readObject();
            MapSymbolTable mapSymbolTable = (MapSymbolTable) objectInput.readObject();
            int numSymbols = mapSymbolTable.numSymbols();
            int readInt = objectInput.readInt();
            String[] strArr = new String[readInt];
            for (int i = 0; i < readInt; i++) {
                strArr[i] = objectInput.readUTF();
            }
            float[] fArr = new float[mapSymbolTable.numSymbols()];
            for (int i2 = 0; i2 < fArr.length; i2++) {
                fArr[i2] = objectInput.readFloat();
            }
            int[] iArr = new int[numSymbols + 1];
            for (int i3 = 0; i3 < numSymbols; i3++) {
                iArr[i3] = objectInput.readInt();
            }
            int readInt2 = objectInput.readInt();
            iArr[iArr.length - 1] = readInt2;
            int[] iArr2 = new int[readInt2];
            float[] fArr2 = new float[readInt2];
            for (int i4 = 0; i4 < readInt2; i4++) {
                iArr2[i4] = objectInput.readInt();
                fArr2[i4] = objectInput.readFloat();
            }
            return new TfIdfClassifier(featureExtractor, mapSymbolTable, strArr, fArr, iArr, iArr2, fArr2);
        }
    }

    /* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/classify/TfIdfClassifierTrainer$Serializer.class */
    static class Serializer<F> extends AbstractExternalizable {
        static final long serialVersionUID = -4757808688956812832L;
        final TfIdfClassifierTrainer<F> mTrainer;

        public Serializer() {
            this(null);
        }

        public Serializer(TfIdfClassifierTrainer<F> tfIdfClassifierTrainer) {
            this.mTrainer = tfIdfClassifierTrainer;
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            AbstractExternalizable.serializeOrCompile(this.mTrainer.mFeatureExtractor, objectOutput);
            objectOutput.writeObject(this.mTrainer.mFeatureToCategoryCount);
            objectOutput.writeObject(this.mTrainer.mFeatureSymbolTable);
            objectOutput.writeObject(this.mTrainer.mCategorySymbolTable);
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws ClassNotFoundException, IOException {
            return new TfIdfClassifierTrainer((FeatureExtractor) objectInput.readObject(), (Map) objectInput.readObject(), (MapSymbolTable) objectInput.readObject(), (MapSymbolTable) objectInput.readObject());
        }
    }

    /* loaded from: input_file:lib/lingpipe-4.1.0.jar:com/aliasi/classify/TfIdfClassifierTrainer$TfIdfClassifier.class */
    static class TfIdfClassifier<G> implements ScoredClassifier<G> {
        final FeatureExtractor<? super G> mFeatureExtractor;
        final MapSymbolTable mFeatureSymbolTable;
        final String[] mCategories;
        final float[] mFeatureIdfs;
        final int[] mFeatureOffsets;
        final int[] mCategoryIds;
        final float[] mTfIdfs;

        TfIdfClassifier(FeatureExtractor<? super G> featureExtractor, MapSymbolTable mapSymbolTable, String[] strArr, float[] fArr, int[] iArr, int[] iArr2, float[] fArr2) {
            this.mFeatureExtractor = featureExtractor;
            this.mFeatureSymbolTable = mapSymbolTable;
            this.mCategories = strArr;
            this.mFeatureIdfs = fArr;
            this.mFeatureOffsets = iArr;
            this.mCategoryIds = iArr2;
            this.mTfIdfs = fArr2;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("TfIdfClassifierTrainer.TfIdfClassifier\n");
            sb.append("Feature Symbol Table\n  ");
            sb.append(this.mFeatureSymbolTable.toString());
            sb.append("\n");
            sb.append("Categories\n");
            for (int i = 0; i < this.mCategories.length; i++) {
                sb.append("  " + i + "=" + this.mCategories[i] + "\n");
            }
            sb.append("Index  Feature IDF  offset\n");
            for (int i2 = 0; i2 < this.mFeatureIdfs.length; i2++) {
                sb.append("  " + i2 + "  " + this.mFeatureSymbolTable.idToSymbol(i2) + "   " + this.mFeatureIdfs[i2] + "   " + this.mFeatureOffsets[i2] + "\n");
            }
            sb.append("Index  CategoryID  TF-IDF\n");
            for (int i3 = 0; i3 < this.mCategoryIds.length; i3++) {
                sb.append("  " + i3 + "   " + this.mCategoryIds[i3] + "    " + this.mTfIdfs[i3] + "\n");
            }
            return sb.toString();
        }

        @Override // com.aliasi.classify.ScoredClassifier, com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
        public ScoredClassification classify(G g) {
            Map<String, ? extends Number> features = this.mFeatureExtractor.features(g);
            double[] dArr = new double[this.mCategories.length];
            double d = 0.0d;
            for (Map.Entry<String, ? extends Number> entry : features.entrySet()) {
                int symbolToID = this.mFeatureSymbolTable.symbolToID(entry.getKey());
                if (symbolToID != -1) {
                    double tf = TfIdfClassifierTrainer.tf(entry.getValue().doubleValue()) * this.mFeatureIdfs[symbolToID];
                    d += tf * tf;
                    for (int i = this.mFeatureOffsets[symbolToID]; i < this.mFeatureOffsets[symbolToID + 1]; i++) {
                        int i2 = this.mCategoryIds[i];
                        dArr[i2] = dArr[i2] + (this.mTfIdfs[i] * tf);
                    }
                }
            }
            double sqrt = Math.sqrt(d);
            ArrayList arrayList = new ArrayList(this.mCategories.length);
            for (int i3 = 0; i3 < dArr.length; i3++) {
                arrayList.add(new ScoredObject(this.mCategories[i3], dArr[i3] / sqrt));
            }
            return ScoredClassification.create(arrayList);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
        public /* bridge */ /* synthetic */ RankedClassification classify(Object obj) {
            return classify((TfIdfClassifier<G>) obj);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // com.aliasi.classify.BaseClassifier
        public /* bridge */ /* synthetic */ Classification classify(Object obj) {
            return classify((TfIdfClassifier<G>) obj);
        }
    }

    public TfIdfClassifierTrainer(FeatureExtractor<? super E> featureExtractor) {
        this(featureExtractor, new HashMap(), new MapSymbolTable(), new MapSymbolTable());
    }

    TfIdfClassifierTrainer(FeatureExtractor<? super E> featureExtractor, Map<Integer, ObjectToDoubleMap<Integer>> map, MapSymbolTable mapSymbolTable, MapSymbolTable mapSymbolTable2) {
        this.mFeatureExtractor = featureExtractor;
        this.mFeatureToCategoryCount = map;
        this.mFeatureSymbolTable = mapSymbolTable;
        this.mCategorySymbolTable = mapSymbolTable2;
    }

    public double idf(String str) {
        if (this.mFeatureSymbolTable.symbolToIDInteger(str) == null) {
            return KStarConstants.FLOOR;
        }
        return idf(this.mFeatureToCategoryCount.get(r0).size(), this.mCategorySymbolTable.numSymbols());
    }

    public double tfIdf(String str, String str2) {
        Integer symbolToIDInteger = this.mFeatureSymbolTable.symbolToIDInteger(str);
        if (symbolToIDInteger == null) {
            return KStarConstants.FLOOR;
        }
        ObjectToDoubleMap<Integer> objectToDoubleMap = this.mFeatureToCategoryCount.get(symbolToIDInteger);
        Integer symbolToIDInteger2 = this.mCategorySymbolTable.symbolToIDInteger(str2);
        if (symbolToIDInteger2 == null) {
            return KStarConstants.FLOOR;
        }
        double value = objectToDoubleMap.getValue(symbolToIDInteger2);
        if (value == KStarConstants.FLOOR) {
            return KStarConstants.FLOOR;
        }
        return tf(value) * idf(objectToDoubleMap.size(), this.mCategorySymbolTable.numSymbols());
    }

    public double tf(String str, String str2) {
        Integer symbolToIDInteger = this.mFeatureSymbolTable.symbolToIDInteger(str);
        if (symbolToIDInteger == null) {
            return KStarConstants.FLOOR;
        }
        ObjectToDoubleMap<Integer> objectToDoubleMap = this.mFeatureToCategoryCount.get(symbolToIDInteger);
        Integer symbolToIDInteger2 = this.mCategorySymbolTable.symbolToIDInteger(str2);
        return symbolToIDInteger2 == null ? KStarConstants.FLOOR : tf(objectToDoubleMap.getValue(symbolToIDInteger2));
    }

    public FeatureExtractor<? super E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public Set<String> categories() {
        return this.mCategorySymbolTable.symbolSet();
    }

    void handle(E e, Classification classification) {
        int orAddSymbol = this.mCategorySymbolTable.getOrAddSymbol(classification.bestCategory());
        for (Map.Entry<String, ? extends Number> entry : this.mFeatureExtractor.features(e).entrySet()) {
            String key = entry.getKey();
            double doubleValue = entry.getValue().doubleValue();
            int orAddSymbol2 = this.mFeatureSymbolTable.getOrAddSymbol(key);
            ObjectToDoubleMap<Integer> objectToDoubleMap = this.mFeatureToCategoryCount.get(Integer.valueOf(orAddSymbol2));
            if (objectToDoubleMap == null) {
                objectToDoubleMap = new ObjectToDoubleMap<>();
                this.mFeatureToCategoryCount.put(Integer.valueOf(orAddSymbol2), objectToDoubleMap);
            }
            objectToDoubleMap.increment(Integer.valueOf(orAddSymbol), doubleValue);
        }
    }

    @Override // com.aliasi.corpus.ObjectHandler
    public void handle(Classified<E> classified) {
        handle(classified.getObject(), classified.getClassification());
    }

    @Override // com.aliasi.util.Compilable
    public void compileTo(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeObject(new Externalizer(this));
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    static double idf(double d, double d2) {
        return Math.log(d2 / d);
    }

    static double tf(double d) {
        return Math.sqrt(d);
    }
}
