package weka.classifiers.trees.lmt;

import java.util.Collections;
import java.util.Vector;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.SimpleLinearRegression;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.trees.j48.ClassifierSplitModel;
import weka.classifiers.trees.j48.ModelSelection;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.gui.GenericObjectEditorHistory;

/* loaded from: input_file:lib/weka-3.7.9.jar:weka/classifiers/trees/lmt/LMTNode.class */
public class LMTNode extends LogisticBase {
    static final long serialVersionUID = 1862737145870398755L;
    protected double m_totalInstanceWeight;
    protected int m_id;
    protected int m_leafModelNum;
    public double m_alpha;
    public double m_numIncorrectModel;
    public double m_numIncorrectTree;
    protected int m_minNumInstances;
    protected ModelSelection m_modelSelection;
    protected NominalToBinary m_nominalToBinary;
    protected SimpleLinearRegression[][] m_higherRegressions;
    protected int m_numHigherRegressions = 0;
    protected static int m_numFoldsPruning = 5;
    protected boolean m_fastRegression;
    protected int m_numInstances;
    protected ClassifierSplitModel m_localModel;
    protected LMTNode[] m_sons;
    protected boolean m_isLeaf;

    public LMTNode(ModelSelection modelSelection, int i, boolean z, boolean z2, int i2, double d, boolean z3) {
        this.m_modelSelection = modelSelection;
        this.m_fixedNumIterations = i;
        this.m_fastRegression = z;
        this.m_errorOnProbabilities = z2;
        this.m_minNumInstances = i2;
        this.m_maxIterations = GenericObjectEditorHistory.MAX_HISTORY_LENGTH;
        setWeightTrimBeta(d);
        setUseAIC(z3);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // weka.classifiers.trees.lmt.LogisticBase, weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        if (this.m_fastRegression && this.m_fixedNumIterations < 0) {
            this.m_fixedNumIterations = tryLogistic(instances);
        }
        Instances instances2 = new Instances(instances);
        instances2.stratify(m_numFoldsPruning);
        double[] dArr = new double[m_numFoldsPruning];
        double[] dArr2 = new double[m_numFoldsPruning];
        for (int i = 0; i < m_numFoldsPruning; i++) {
            Instances trainCV = instances2.trainCV(m_numFoldsPruning, i);
            Instances testCV = instances2.testCV(m_numFoldsPruning, i);
            buildTree(trainCV, (SimpleLinearRegression[][]) null, trainCV.numInstances(), KStarConstants.FLOOR);
            int numInnerNodes = getNumInnerNodes();
            dArr[i] = new double[numInnerNodes + 2];
            dArr2[i] = new double[numInnerNodes + 2];
            prune(dArr[i], dArr2[i], testCV);
        }
        buildTree(instances, (SimpleLinearRegression[][]) null, instances.numInstances(), KStarConstants.FLOOR);
        int numInnerNodes2 = getNumInnerNodes();
        double[] dArr3 = new double[numInnerNodes2 + 2];
        int prune = prune(dArr3, null, null);
        double[] dArr4 = new double[numInnerNodes2 + 2];
        for (int i2 = 0; i2 <= prune; i2++) {
            double sqrt = Math.sqrt(dArr3[i2] * dArr3[i2 + 1]);
            double d = 0.0d;
            for (int i3 = 0; i3 < m_numFoldsPruning; i3++) {
                int i4 = 0;
                while (dArr[i3][i4] <= sqrt) {
                    i4++;
                }
                d += dArr2[i3][i4 - 1];
            }
            dArr4[i2] = d;
        }
        int i5 = -1;
        double d2 = Double.MAX_VALUE;
        for (int i6 = prune; i6 >= 0; i6--) {
            if (dArr4[i6] < d2) {
                d2 = dArr4[i6];
                i5 = i6;
            }
        }
        double sqrt2 = Math.sqrt(dArr3[i5] * dArr3[i5 + 1]);
        unprune();
        prune(sqrt2);
        cleanup();
    }

    public void buildTree(Instances instances, SimpleLinearRegression[][] simpleLinearRegressionArr, double d, double d2) throws Exception {
        boolean z;
        this.m_totalInstanceWeight = d;
        this.m_train = new Instances(instances);
        this.m_isLeaf = true;
        this.m_sons = null;
        this.m_numInstances = this.m_train.numInstances();
        this.m_numClasses = this.m_train.numClasses();
        this.m_numericData = getNumericData(this.m_train);
        this.m_numericDataHeader = new Instances(this.m_numericData, 0);
        this.m_regressions = initRegressions();
        this.m_numRegressions = 0;
        if (simpleLinearRegressionArr != null) {
            this.m_higherRegressions = simpleLinearRegressionArr;
        } else {
            this.m_higherRegressions = new SimpleLinearRegression[this.m_numClasses][0];
        }
        this.m_numHigherRegressions = this.m_higherRegressions[0].length;
        this.m_numParameters = d2;
        if (this.m_numInstances >= m_numFoldsBoosting) {
            if (this.m_fixedNumIterations > 0) {
                performBoosting(this.m_fixedNumIterations);
            } else if (getUseAIC()) {
                performBoostingInfCriterion();
            } else {
                performBoostingCV();
            }
        }
        this.m_numParameters += this.m_numRegressions;
        this.m_regressions = selectRegressions(this.m_regressions);
        if (this.m_numInstances > this.m_minNumInstances) {
            if (this.m_modelSelection instanceof ResidualModelSelection) {
                double[][] probs = getProbs(getFs(this.m_numericData));
                double[][] ys = getYs(this.m_train);
                this.m_localModel = ((ResidualModelSelection) this.m_modelSelection).selectModel(this.m_train, getZs(probs, ys), getWs(probs, ys));
            } else {
                this.m_localModel = this.m_modelSelection.selectModel(this.m_train);
            }
            z = this.m_localModel.numSubsets() > 1;
        } else {
            z = false;
        }
        if (z) {
            this.m_isLeaf = false;
            Instances[] split = this.m_localModel.split(this.m_train);
            this.m_sons = new LMTNode[this.m_localModel.numSubsets()];
            for (int i = 0; i < this.m_sons.length; i++) {
                this.m_sons[i] = new LMTNode(this.m_modelSelection, this.m_fixedNumIterations, this.m_fastRegression, this.m_errorOnProbabilities, this.m_minNumInstances, getWeightTrimBeta(), getUseAIC());
                this.m_sons[i].buildTree(split[i], mergeArrays(this.m_regressions, this.m_higherRegressions), this.m_totalInstanceWeight, this.m_numParameters);
                split[i] = null;
            }
        }
    }

    public void prune(double d) throws Exception {
        CompareNode compareNode = new CompareNode();
        modelErrors();
        treeErrors();
        calculateAlphas();
        Vector nodes = getNodes();
        boolean z = nodes.size() > 0;
        while (z) {
            LMTNode lMTNode = (LMTNode) Collections.min(nodes, compareNode);
            if (lMTNode.m_alpha > d) {
                return;
            }
            lMTNode.m_isLeaf = true;
            lMTNode.m_sons = null;
            treeErrors();
            calculateAlphas();
            nodes = getNodes();
            z = nodes.size() > 0;
        }
    }

    public int prune(double[] dArr, double[] dArr2, Instances instances) throws Exception {
        CompareNode compareNode = new CompareNode();
        modelErrors();
        treeErrors();
        calculateAlphas();
        Vector nodes = getNodes();
        boolean z = nodes.size() > 0;
        dArr[0] = 0.0d;
        if (dArr2 != null) {
            Evaluation evaluation2 = new Evaluation(instances);
            evaluation2.evaluateModel(this, instances, new Object[0]);
            dArr2[0] = evaluation2.errorRate();
        }
        int i = 0;
        while (z) {
            i++;
            LMTNode lMTNode = (LMTNode) Collections.min(nodes, compareNode);
            lMTNode.m_isLeaf = true;
            dArr[i] = lMTNode.m_alpha;
            if (dArr2 != null) {
                Evaluation evaluation3 = new Evaluation(instances);
                evaluation3.evaluateModel(this, instances, new Object[0]);
                dArr2[i] = evaluation3.errorRate();
            }
            treeErrors();
            calculateAlphas();
            nodes = getNodes();
            z = nodes.size() > 0;
        }
        dArr[i + 1] = 1.0d;
        return i;
    }

    protected void unprune() {
        if (this.m_sons != null) {
            this.m_isLeaf = false;
            for (int i = 0; i < this.m_sons.length; i++) {
                this.m_sons[i].unprune();
            }
        }
    }

    protected int tryLogistic(Instances instances) throws Exception {
        Instances instances2 = new Instances(instances);
        NominalToBinary nominalToBinary = new NominalToBinary();
        nominalToBinary.setInputFormat(instances2);
        Instances useFilter = Filter.useFilter(instances2, nominalToBinary);
        LogisticBase logisticBase = new LogisticBase(0, true, this.m_errorOnProbabilities);
        logisticBase.setMaxIterations(GenericObjectEditorHistory.MAX_HISTORY_LENGTH);
        logisticBase.setWeightTrimBeta(getWeightTrimBeta());
        logisticBase.setUseAIC(getUseAIC());
        logisticBase.buildClassifier(useFilter);
        return logisticBase.getNumRegressions();
    }

    public int getNumInnerNodes() {
        if (this.m_isLeaf) {
            return 0;
        }
        int i = 1;
        for (int i2 = 0; i2 < this.m_sons.length; i2++) {
            i += this.m_sons[i2].getNumInnerNodes();
        }
        return i;
    }

    public int getNumLeaves() {
        int i;
        if (this.m_isLeaf) {
            i = 1;
        } else {
            i = 0;
            int i2 = 0;
            for (int i3 = 0; i3 < this.m_sons.length; i3++) {
                i += this.m_sons[i3].getNumLeaves();
                if (this.m_sons[i3].m_isLeaf && !this.m_sons[i3].hasModels()) {
                    i2++;
                }
            }
            if (i2 > 1) {
                i -= i2 - 1;
            }
        }
        return i;
    }

    public void modelErrors() throws Exception {
        Evaluation evaluation2 = new Evaluation(this.m_train);
        if (this.m_isLeaf) {
            evaluation2.evaluateModel(this, this.m_train, new Object[0]);
            this.m_numIncorrectModel = evaluation2.incorrect();
            return;
        }
        this.m_isLeaf = true;
        evaluation2.evaluateModel(this, this.m_train, new Object[0]);
        this.m_isLeaf = false;
        this.m_numIncorrectModel = evaluation2.incorrect();
        for (int i = 0; i < this.m_sons.length; i++) {
            this.m_sons[i].modelErrors();
        }
    }

    public void treeErrors() {
        if (this.m_isLeaf) {
            this.m_numIncorrectTree = this.m_numIncorrectModel;
            return;
        }
        this.m_numIncorrectTree = KStarConstants.FLOOR;
        for (int i = 0; i < this.m_sons.length; i++) {
            this.m_sons[i].treeErrors();
            this.m_numIncorrectTree += this.m_sons[i].m_numIncorrectTree;
        }
    }

    public void calculateAlphas() throws Exception {
        if (this.m_isLeaf) {
            this.m_alpha = Double.MAX_VALUE;
            return;
        }
        double d = this.m_numIncorrectModel - this.m_numIncorrectTree;
        if (d <= KStarConstants.FLOOR) {
            this.m_isLeaf = true;
            this.m_sons = null;
            this.m_alpha = Double.MAX_VALUE;
        } else {
            this.m_alpha = (d / this.m_totalInstanceWeight) / (getNumLeaves() - 1);
            for (int i = 0; i < this.m_sons.length; i++) {
                this.m_sons[i].calculateAlphas();
            }
        }
    }

    protected SimpleLinearRegression[][] mergeArrays(SimpleLinearRegression[][] simpleLinearRegressionArr, SimpleLinearRegression[][] simpleLinearRegressionArr2) {
        int length = simpleLinearRegressionArr[0].length;
        int length2 = simpleLinearRegressionArr2[0].length;
        SimpleLinearRegression[][] simpleLinearRegressionArr3 = new SimpleLinearRegression[this.m_numClasses][length + length2];
        for (int i = 0; i < this.m_numClasses; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                simpleLinearRegressionArr3[i][i2] = simpleLinearRegressionArr[i][i2];
            }
        }
        for (int i3 = 0; i3 < this.m_numClasses; i3++) {
            for (int i4 = 0; i4 < length2; i4++) {
                simpleLinearRegressionArr3[i3][i4 + length] = simpleLinearRegressionArr2[i3][i4];
            }
        }
        return simpleLinearRegressionArr3;
    }

    public Vector getNodes() {
        Vector vector = new Vector();
        getNodes(vector);
        return vector;
    }

    public void getNodes(Vector vector) {
        if (this.m_isLeaf) {
            return;
        }
        vector.add(this);
        for (int i = 0; i < this.m_sons.length; i++) {
            this.m_sons[i].getNodes(vector);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // weka.classifiers.trees.lmt.LogisticBase
    public Instances getNumericData(Instances instances) throws Exception {
        Instances instances2 = new Instances(instances);
        this.m_nominalToBinary = new NominalToBinary();
        this.m_nominalToBinary.setInputFormat(instances2);
        return super.getNumericData(Filter.useFilter(instances2, this.m_nominalToBinary));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // weka.classifiers.trees.lmt.LogisticBase
    public double[] getFs(Instance instance) throws Exception {
        double[] dArr = new double[this.m_numClasses];
        double[] fs = super.getFs(instance);
        for (int i = 0; i < this.m_numHigherRegressions; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.m_numClasses; i2++) {
                dArr[i2] = this.m_higherRegressions[i2][i].classifyInstance(instance);
                d += dArr[i2];
            }
            double d2 = d / this.m_numClasses;
            for (int i3 = 0; i3 < this.m_numClasses; i3++) {
                int i4 = i3;
                fs[i4] = fs[i4] + (((dArr[i3] - d2) * (this.m_numClasses - 1)) / this.m_numClasses);
            }
        }
        return fs;
    }

    public boolean hasModels() {
        return this.m_numRegressions > 0;
    }

    public double[] modelDistributionForInstance(Instance instance) throws Exception {
        this.m_nominalToBinary.input((Instance) instance.copy());
        Instance output = this.m_nominalToBinary.output();
        output.setDataset(this.m_numericDataHeader);
        return probs(getFs(output));
    }

    @Override // weka.classifiers.trees.lmt.LogisticBase, weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] distributionForInstance;
        if (this.m_isLeaf) {
            distributionForInstance = modelDistributionForInstance(instance);
        } else {
            distributionForInstance = this.m_sons[this.m_localModel.whichSubset(instance)].distributionForInstance(instance);
        }
        return distributionForInstance;
    }

    public int numLeaves() {
        if (this.m_isLeaf) {
            return 1;
        }
        int i = 0;
        for (int i2 = 0; i2 < this.m_sons.length; i2++) {
            i += this.m_sons[i2].numLeaves();
        }
        return i;
    }

    public int numNodes() {
        if (this.m_isLeaf) {
            return 1;
        }
        int i = 1;
        for (int i2 = 0; i2 < this.m_sons.length; i2++) {
            i += this.m_sons[i2].numNodes();
        }
        return i;
    }

    @Override // weka.classifiers.trees.lmt.LogisticBase
    public String toString() {
        assignLeafModelNumbers(0);
        try {
            StringBuffer stringBuffer = new StringBuffer();
            if (this.m_isLeaf) {
                stringBuffer.append(": ");
                stringBuffer.append("LM_" + this.m_leafModelNum + ":" + getModelParameters());
            } else {
                dumpTree(0, stringBuffer);
            }
            stringBuffer.append("\n\nNumber of Leaves  : \t" + numLeaves() + "\n");
            stringBuffer.append("\nSize of the Tree : \t" + numNodes() + "\n");
            stringBuffer.append(modelsToString());
            return stringBuffer.toString();
        } catch (Exception e) {
            return "Can't print logistic model tree";
        }
    }

    public String getModelParameters() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append(this.m_numRegressions + "/" + (this.m_numRegressions + this.m_numHigherRegressions) + " (" + this.m_numInstances + ")");
        return stringBuffer.toString();
    }

    protected void dumpTree(int i, StringBuffer stringBuffer) throws Exception {
        for (int i2 = 0; i2 < this.m_sons.length; i2++) {
            stringBuffer.append("\n");
            for (int i3 = 0; i3 < i; i3++) {
                stringBuffer.append("|   ");
            }
            stringBuffer.append(this.m_localModel.leftSide(this.m_train));
            stringBuffer.append(this.m_localModel.rightSide(i2, this.m_train));
            if (this.m_sons[i2].m_isLeaf) {
                stringBuffer.append(": ");
                stringBuffer.append("LM_" + this.m_sons[i2].m_leafModelNum + ":" + this.m_sons[i2].getModelParameters());
            } else {
                this.m_sons[i2].dumpTree(i + 1, stringBuffer);
            }
        }
    }

    public int assignIDs(int i) {
        int i2 = i + 1;
        this.m_id = i2;
        if (this.m_sons != null) {
            for (int i3 = 0; i3 < this.m_sons.length; i3++) {
                i2 = this.m_sons[i3].assignIDs(i2);
            }
        }
        return i2;
    }

    public int assignLeafModelNumbers(int i) {
        if (this.m_isLeaf) {
            i++;
            this.m_leafModelNum = i;
        } else {
            this.m_leafModelNum = 0;
            for (int i2 = 0; i2 < this.m_sons.length; i2++) {
                i = this.m_sons[i2].assignLeafModelNumbers(i);
            }
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // weka.classifiers.trees.lmt.LogisticBase
    public double[][] getCoefficients() {
        double[][] coefficients = super.getCoefficients();
        double d = (this.m_numClasses - 1) / this.m_numClasses;
        for (int i = 0; i < this.m_numClasses; i++) {
            for (int i2 = 0; i2 < this.m_numHigherRegressions; i2++) {
                double slope = this.m_higherRegressions[i][i2].getSlope();
                double intercept = this.m_higherRegressions[i][i2].getIntercept();
                int attributeIndex = this.m_higherRegressions[i][i2].getAttributeIndex();
                double[] dArr = coefficients[i];
                dArr[0] = dArr[0] + (d * intercept);
                double[] dArr2 = coefficients[i];
                int i3 = attributeIndex + 1;
                dArr2[i3] = dArr2[i3] + (d * slope);
            }
        }
        return coefficients;
    }

    public String modelsToString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_isLeaf) {
            stringBuffer.append("LM_" + this.m_leafModelNum + ":" + super.toString());
        } else {
            for (int i = 0; i < this.m_sons.length; i++) {
                stringBuffer.append("\n" + this.m_sons[i].modelsToString());
            }
        }
        return stringBuffer.toString();
    }

    public String graph() throws Exception {
        StringBuffer stringBuffer = new StringBuffer();
        assignIDs(-1);
        assignLeafModelNumbers(0);
        stringBuffer.append("digraph LMTree {\n");
        if (this.m_isLeaf) {
            stringBuffer.append("N" + this.m_id + " [label=\"LM_" + this.m_leafModelNum + ":" + getModelParameters() + "\" shape=box style=filled");
            stringBuffer.append("]\n");
        } else {
            stringBuffer.append("N" + this.m_id + " [label=\"" + this.m_localModel.leftSide(this.m_train) + "\" ");
            stringBuffer.append("]\n");
            graphTree(stringBuffer);
        }
        return stringBuffer.toString() + "}\n";
    }

    private void graphTree(StringBuffer stringBuffer) throws Exception {
        for (int i = 0; i < this.m_sons.length; i++) {
            stringBuffer.append("N" + this.m_id + "->N" + this.m_sons[i].m_id + " [label=\"" + this.m_localModel.rightSide(i, this.m_train).trim() + "\"]\n");
            if (this.m_sons[i].m_isLeaf) {
                stringBuffer.append("N" + this.m_sons[i].m_id + " [label=\"LM_" + this.m_sons[i].m_leafModelNum + ":" + this.m_sons[i].getModelParameters() + "\" shape=box style=filled");
                stringBuffer.append("]\n");
            } else {
                stringBuffer.append("N" + this.m_sons[i].m_id + " [label=\"" + this.m_sons[i].m_localModel.leftSide(this.m_train) + "\" ");
                stringBuffer.append("]\n");
                this.m_sons[i].graphTree(stringBuffer);
            }
        }
    }

    @Override // weka.classifiers.trees.lmt.LogisticBase
    public void cleanup() {
        super.cleanup();
        if (this.m_isLeaf) {
            return;
        }
        for (int i = 0; i < this.m_sons.length; i++) {
            this.m_sons[i].cleanup();
        }
    }

    @Override // weka.classifiers.trees.lmt.LogisticBase, weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }
}
