package weka.classifiers;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.kstar.KStarConstants;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: input_file:weka-3-2/weka.jar:weka/classifiers/AdaBoostM1.class */
public class AdaBoostM1 extends DistributionClassifier implements OptionHandler, WeightedInstancesHandler, Sourcable {
    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    protected Classifier[] m_Classifiers;
    protected double[] m_Betas;
    protected int m_NumIterations;
    protected boolean m_Debug;
    protected boolean m_UseResampling;
    protected int m_NumClasses;
    protected Classifier m_Classifier = new ZeroR();
    protected int m_MaxIterations = 10;
    protected int m_WeightThreshold = 100;
    protected int m_Seed = 1;

    protected Instances selectWeightQuantile(Instances instances, double d) {
        int numInstances = instances.numInstances();
        Instances instances2 = new Instances(instances, numInstances);
        double[] dArr = new double[numInstances];
        double d2 = 0.0d;
        for (int i = 0; i < numInstances; i++) {
            dArr[i] = instances.instance(i).weight();
            d2 += dArr[i];
        }
        double d3 = d2 * d;
        int[] sort = Utils.sort(dArr);
        double d4 = 0.0d;
        for (int i2 = numInstances - 1; i2 >= 0; i2--) {
            instances2.add((Instance) instances.instance(sort[i2]).copy());
            d4 += dArr[sort[i2]];
            if (d4 > d3 && i2 > 0 && dArr[sort[i2]] != dArr[sort[i2 - 1]]) {
                break;
            }
        }
        if (this.m_Debug) {
            System.err.println(new StringBuffer("Selected ").append(instances2.numInstances()).append(" out of ").append(numInstances).toString());
        }
        return instances2;
    }

    @Override // weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector(6);
        vector.addElement(new Option("\tTurn on debugging output.", "D", 0, "-D"));
        vector.addElement(new Option("\tMaximum number of boost iterations.\n\t(default 10)", "I", 1, "-I <num>"));
        vector.addElement(new Option("\tPercentage of weight mass to base training on.\n\t(default 100, reduce to around 90 speed up)", "P", 1, "-P <num>"));
        vector.addElement(new Option("\tFull name of classifier to boost.\n\teg: weka.classifiers.NaiveBayes", "W", 1, "-W <class name>"));
        vector.addElement(new Option("\tUse resampling for boosting.", "Q", 0, "-Q"));
        vector.addElement(new Option("\tSeed for resampling. (Default 1)", "S", 1, "-S <num>"));
        if (this.m_Classifier != null && (this.m_Classifier instanceof OptionHandler)) {
            vector.addElement(new Option("", "", 0, new StringBuffer("\nOptions specific to classifier ").append(this.m_Classifier.getClass().getName()).append(":").toString()));
            Enumeration listOptions = ((OptionHandler) this.m_Classifier).listOptions();
            while (listOptions.hasMoreElements()) {
                vector.addElement(listOptions.nextElement());
            }
        }
        return vector.elements();
    }

    @Override // weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        setDebug(Utils.getFlag('D', strArr));
        String option = Utils.getOption('I', strArr);
        if (option.length() != 0) {
            setMaxIterations(Integer.parseInt(option));
        } else {
            setMaxIterations(10);
        }
        String option2 = Utils.getOption('P', strArr);
        if (option2.length() != 0) {
            setWeightThreshold(Integer.parseInt(option2));
        } else {
            setWeightThreshold(100);
        }
        setUseResampling(Utils.getFlag('Q', strArr));
        if (this.m_UseResampling && option2.length() != 0) {
            throw new Exception("Weight pruning with resamplingnot allowed.");
        }
        String option3 = Utils.getOption('S', strArr);
        if (option3.length() != 0) {
            setSeed(Integer.parseInt(option3));
        } else {
            setSeed(1);
        }
        String option4 = Utils.getOption('W', strArr);
        if (option4.length() == 0) {
            throw new Exception("A classifier must be specified with the -W option.");
        }
        setClassifier(Classifier.forName(option4, Utils.partitionOptions(strArr)));
    }

    @Override // weka.core.OptionHandler
    public String[] getOptions() {
        int i;
        String[] strArr = new String[0];
        if (this.m_Classifier != null && (this.m_Classifier instanceof OptionHandler)) {
            strArr = ((OptionHandler) this.m_Classifier).getOptions();
        }
        String[] strArr2 = new String[strArr.length + 10];
        int i2 = 0;
        if (getDebug()) {
            i2 = 0 + 1;
            strArr2[0] = "-D";
        }
        if (getUseResampling()) {
            int i3 = i2;
            i = i2 + 1;
            strArr2[i3] = "-Q";
        } else {
            int i4 = i2;
            int i5 = i2 + 1;
            strArr2[i4] = "-P";
            i = i5 + 1;
            strArr2[i5] = String.valueOf(getWeightThreshold());
        }
        int i6 = i;
        int i7 = i + 1;
        strArr2[i6] = "-I";
        int i8 = i7 + 1;
        strArr2[i7] = String.valueOf(getMaxIterations());
        int i9 = i8 + 1;
        strArr2[i8] = "-S";
        int i10 = i9 + 1;
        strArr2[i9] = String.valueOf(getSeed());
        if (getClassifier() != null) {
            int i11 = i10 + 1;
            strArr2[i10] = "-W";
            i10 = i11 + 1;
            strArr2[i11] = getClassifier().getClass().getName();
        }
        int i12 = i10;
        int i13 = i10 + 1;
        strArr2[i12] = "--";
        System.arraycopy(strArr, 0, strArr2, i13, strArr.length);
        int length = i13 + strArr.length;
        while (length < strArr2.length) {
            int i14 = length;
            length++;
            strArr2[i14] = "";
        }
        return strArr2;
    }

    public void setClassifier(Classifier classifier) {
        this.m_Classifier = classifier;
    }

    public Classifier getClassifier() {
        return this.m_Classifier;
    }

    public void setMaxIterations(int i) {
        this.m_MaxIterations = i;
    }

    public int getMaxIterations() {
        return this.m_MaxIterations;
    }

    public void setWeightThreshold(int i) {
        this.m_WeightThreshold = i;
    }

    public int getWeightThreshold() {
        return this.m_WeightThreshold;
    }

    public void setSeed(int i) {
        this.m_Seed = i;
    }

    public int getSeed() {
        return this.m_Seed;
    }

    public void setDebug(boolean z) {
        this.m_Debug = z;
    }

    public boolean getDebug() {
        return this.m_Debug;
    }

    public void setUseResampling(boolean z) {
        this.m_UseResampling = z;
    }

    public boolean getUseResampling() {
        return this.m_UseResampling;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        if (instances.checkForStringAttributes()) {
            throw new Exception("Can't handle string attributes!");
        }
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        if (instances2.numInstances() == 0) {
            throw new Exception("No train instances without class missing!");
        }
        if (instances2.classAttribute().isNumeric()) {
            throw new Exception("AdaBoostM1 can't handle a numeric class!");
        }
        if (this.m_Classifier == null) {
            throw new Exception("A base classifier has not been specified!");
        }
        this.m_NumClasses = instances2.numClasses();
        this.m_Classifiers = Classifier.makeCopies(this.m_Classifier, getMaxIterations());
        if (this.m_UseResampling || !(this.m_Classifier instanceof WeightedInstancesHandler)) {
            buildClassifierUsingResampling(instances2);
        } else {
            buildClassifierWithWeights(instances2);
        }
    }

    protected void buildClassifierUsingResampling(Instances instances) throws Exception {
        double errorRate;
        int numInstances = instances.numInstances();
        Random random = new Random(this.m_Seed);
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterations = 0;
        Instances instances2 = new Instances(instances, 0, numInstances);
        double sumOfWeights = instances2.sumOfWeights();
        for (int i = 0; i < instances2.numInstances(); i++) {
            instances2.instance(i).setWeight(instances2.instance(i).weight() / sumOfWeights);
        }
        this.m_NumIterations = 0;
        while (this.m_NumIterations < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println(new StringBuffer("Training classifier ").append(this.m_NumIterations + 1).toString());
            }
            Instances selectWeightQuantile = this.m_WeightThreshold < 100 ? selectWeightQuantile(instances2, this.m_WeightThreshold / 100.0d) : new Instances(instances2);
            int i2 = 0;
            double[] dArr = new double[selectWeightQuantile.numInstances()];
            for (int i3 = 0; i3 < dArr.length; i3++) {
                dArr[i3] = selectWeightQuantile.instance(i3).weight();
            }
            do {
                this.m_Classifiers[this.m_NumIterations].buildClassifier(selectWeightQuantile.resampleWithWeights(random, dArr));
                Evaluation evaluation = new Evaluation(instances);
                evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterations], instances2);
                errorRate = evaluation.errorRate();
                i2++;
                if (!Utils.eq(errorRate, KStarConstants.FLOOR)) {
                    break;
                }
            } while (i2 < MAX_NUM_RESAMPLING_ITERATIONS);
            if (Utils.grOrEq(errorRate, 0.5d) || Utils.eq(errorRate, KStarConstants.FLOOR)) {
                if (this.m_NumIterations == 0) {
                    this.m_NumIterations = 1;
                    return;
                }
                return;
            }
            this.m_Betas[this.m_NumIterations] = Math.log((1.0d - errorRate) / errorRate);
            double d = (1.0d - errorRate) / errorRate;
            if (this.m_Debug) {
                System.err.println(new StringBuffer("\terror rate = ").append(errorRate).append("  beta = ").append(this.m_Betas[this.m_NumIterations]).toString());
            }
            double sumOfWeights2 = instances2.sumOfWeights();
            Enumeration enumerateInstances = instances2.enumerateInstances();
            while (enumerateInstances.hasMoreElements()) {
                Instance instance = (Instance) enumerateInstances.nextElement();
                if (!Utils.eq(this.m_Classifiers[this.m_NumIterations].classifyInstance(instance), instance.classValue())) {
                    instance.setWeight(instance.weight() * d);
                }
            }
            double sumOfWeights3 = instances2.sumOfWeights();
            Enumeration enumerateInstances2 = instances2.enumerateInstances();
            while (enumerateInstances2.hasMoreElements()) {
                Instance instance2 = (Instance) enumerateInstances2.nextElement();
                instance2.setWeight((instance2.weight() * sumOfWeights2) / sumOfWeights3);
            }
            this.m_NumIterations++;
        }
    }

    protected void buildClassifierWithWeights(Instances instances) throws Exception {
        int numInstances = instances.numInstances();
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterations = 0;
        Instances instances2 = new Instances(instances, 0, numInstances);
        this.m_NumIterations = 0;
        while (this.m_NumIterations < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println(new StringBuffer("Training classifier ").append(this.m_NumIterations + 1).toString());
            }
            this.m_Classifiers[this.m_NumIterations].buildClassifier(this.m_WeightThreshold < 100 ? selectWeightQuantile(instances2, this.m_WeightThreshold / 100.0d) : new Instances(instances2, 0, numInstances));
            Evaluation evaluation = new Evaluation(instances);
            evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterations], instances2);
            double errorRate = evaluation.errorRate();
            if (Utils.grOrEq(errorRate, 0.5d) || Utils.eq(errorRate, KStarConstants.FLOOR)) {
                if (this.m_NumIterations == 0) {
                    this.m_NumIterations = 1;
                    return;
                }
                return;
            }
            this.m_Betas[this.m_NumIterations] = Math.log((1.0d - errorRate) / errorRate);
            double d = (1.0d - errorRate) / errorRate;
            if (this.m_Debug) {
                System.err.println(new StringBuffer("\terror rate = ").append(errorRate).append("  beta = ").append(this.m_Betas[this.m_NumIterations]).toString());
            }
            double sumOfWeights = instances2.sumOfWeights();
            Enumeration enumerateInstances = instances2.enumerateInstances();
            while (enumerateInstances.hasMoreElements()) {
                Instance instance = (Instance) enumerateInstances.nextElement();
                if (!Utils.eq(this.m_Classifiers[this.m_NumIterations].classifyInstance(instance), instance.classValue())) {
                    instance.setWeight(instance.weight() * d);
                }
            }
            double sumOfWeights2 = instances2.sumOfWeights();
            Enumeration enumerateInstances2 = instances2.enumerateInstances();
            while (enumerateInstances2.hasMoreElements()) {
                Instance instance2 = (Instance) enumerateInstances2.nextElement();
                instance2.setWeight((instance2.weight() * sumOfWeights) / sumOfWeights2);
            }
            this.m_NumIterations++;
        }
    }

    @Override // weka.classifiers.DistributionClassifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_NumIterations == 0) {
            throw new Exception("No model built");
        }
        double[] dArr = new double[instance.numClasses()];
        if (this.m_NumIterations != 1) {
            for (int i = 0; i < this.m_NumIterations; i++) {
                int classifyInstance = (int) this.m_Classifiers[i].classifyInstance(instance);
                dArr[classifyInstance] = dArr[classifyInstance] + this.m_Betas[i];
            }
        } else {
            if (this.m_Classifiers[0] instanceof DistributionClassifier) {
                return ((DistributionClassifier) this.m_Classifiers[0]).distributionForInstance(instance);
            }
            int classifyInstance2 = (int) this.m_Classifiers[0].classifyInstance(instance);
            dArr[classifyInstance2] = dArr[classifyInstance2] + 1.0d;
        }
        Utils.normalize(dArr);
        return dArr;
    }

    @Override // weka.classifiers.Sourcable
    public String toSource(String str) throws Exception {
        if (this.m_NumIterations == 0) {
            throw new Exception("No model built yet");
        }
        if (!(this.m_Classifiers[0] instanceof Sourcable)) {
            throw new Exception(new StringBuffer("Base learner ").append(this.m_Classifier.getClass().getName()).append(" is not Sourcable").toString());
        }
        StringBuffer stringBuffer = new StringBuffer("class ");
        stringBuffer.append(str).append(" {\n\n");
        stringBuffer.append("  public static double classify(Object [] i) {\n");
        if (this.m_NumIterations == 1) {
            stringBuffer.append(new StringBuffer("    return ").append(str).append("_0.classify(i);\n").toString());
        } else {
            stringBuffer.append(new StringBuffer("    double [] sums = new double [").append(this.m_NumClasses).append("];\n").toString());
            for (int i = 0; i < this.m_NumIterations; i++) {
                stringBuffer.append(new StringBuffer("    sums[(int) ").append(str).append('_').append(i).append(".classify(i)] += ").append(this.m_Betas[i]).append(";\n").toString());
            }
            stringBuffer.append(new StringBuffer("    double maxV = sums[0];\n    int maxI = 0;\n    for (int j = 1; j < ").append(this.m_NumClasses).append("; j++) {\n").append("      if (sums[j] > maxV) { maxV = sums[j]; maxI = j; }\n").append("    }\n    return (double) maxI;\n").toString());
        }
        stringBuffer.append("  }\n}\n");
        for (int i2 = 0; i2 < this.m_Classifiers.length; i2++) {
            stringBuffer.append(((Sourcable) this.m_Classifiers[i2]).toSource(new StringBuffer(String.valueOf(str)).append('_').append(i2).toString()));
        }
        return stringBuffer.toString();
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_NumIterations == 0) {
            stringBuffer.append("AdaBoostM1: No model built yet.\n");
        } else if (this.m_NumIterations == 1) {
            stringBuffer.append("AdaBoostM1: No boosting possible, one classifier used!\n");
            stringBuffer.append(new StringBuffer(String.valueOf(this.m_Classifiers[0].toString())).append("\n").toString());
        } else {
            stringBuffer.append("AdaBoostM1: Base classifiers and their weights: \n\n");
            for (int i = 0; i < this.m_NumIterations; i++) {
                stringBuffer.append(new StringBuffer(String.valueOf(this.m_Classifiers[i].toString())).append("\n\n").toString());
                stringBuffer.append(new StringBuffer("Weight: ").append(Utils.roundDouble(this.m_Betas[i], 2)).append("\n\n").toString());
            }
            stringBuffer.append(new StringBuffer("Number of performed Iterations: ").append(this.m_NumIterations).append("\n").toString());
        }
        return stringBuffer.toString();
    }

    public static void main(String[] strArr) {
        try {
            System.out.println(Evaluation.evaluateModel(new AdaBoostM1(), strArr));
        } catch (Exception e) {
            System.err.println(e.getMessage());
        }
    }
}
