package weka.classifiers;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;

/* loaded from: input_file:weka-3-2/weka.jar:weka/classifiers/Stacking.class */
public class Stacking extends Classifier implements OptionHandler {
    protected Instances m_MetaFormat;
    protected Instances m_BaseFormat;
    protected Classifier m_MetaClassifier = new ZeroR();
    protected Classifier[] m_BaseClassifiers = {new ZeroR()};
    protected int m_NumFolds = 10;
    protected int m_Seed = 1;

    @Override // weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector(4);
        vector.addElement(new Option("\tFull class name of base classifiers to include, followed by scheme options\n\t(may be specified multiple times).\n\teg: \"weka.classifiers.NaiveBayes -K\"", "B", 1, "-B <scheme specification>"));
        vector.addElement(new Option("\tFull name of meta classifier, followed by options.", "M", 0, "-M <scheme specification>"));
        vector.addElement(new Option("\tSets the number of cross-validation folds.", "X", 1, "-X <number of folds>"));
        vector.addElement(new Option("\tSets the random number seed.", "S", 1, "-S <random number seed>"));
        return vector.elements();
    }

    @Override // weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('X', strArr);
        if (option.length() != 0) {
            setNumFolds(Integer.parseInt(option));
        } else {
            setNumFolds(10);
        }
        String option2 = Utils.getOption('S', strArr);
        if (option2.length() != 0) {
            setSeed(Integer.parseInt(option2));
        } else {
            setSeed(1);
        }
        FastVector fastVector = new FastVector();
        while (true) {
            String option3 = Utils.getOption('B', strArr);
            if (option3.length() == 0) {
                if (fastVector.size() == 0) {
                    throw new Exception("At least one base classifier must be specified with the -B option.");
                }
                Classifier[] classifierArr = new Classifier[fastVector.size()];
                for (int i = 0; i < classifierArr.length; i++) {
                    classifierArr[i] = (Classifier) fastVector.elementAt(i);
                }
                setBaseClassifiers(classifierArr);
                String[] splitOptions = Utils.splitOptions(Utils.getOption('M', strArr));
                if (splitOptions.length == 0) {
                    throw new Exception("Meta classifier has to be provided.");
                }
                String str = splitOptions[0];
                splitOptions[0] = "";
                setMetaClassifier(Classifier.forName(str, splitOptions));
                return;
            }
            String[] splitOptions2 = Utils.splitOptions(option3);
            if (splitOptions2.length == 0) {
                throw new Exception("Invalid classifier specification string");
            }
            String str2 = splitOptions2[0];
            splitOptions2[0] = "";
            fastVector.addElement(Classifier.forName(str2, splitOptions2));
        }
    }

    @Override // weka.core.OptionHandler
    public String[] getOptions() {
        String[] strArr = new String[6];
        int i = 0;
        if (this.m_BaseClassifiers.length != 0) {
            strArr = new String[(this.m_BaseClassifiers.length * 2) + 6];
            for (int i2 = 0; i2 < this.m_BaseClassifiers.length; i2++) {
                int i3 = i;
                int i4 = i + 1;
                strArr[i3] = "-B";
                i = i4 + 1;
                strArr[i4] = String.valueOf(getBaseClassifierSpec(i2));
            }
        }
        int i5 = i;
        int i6 = i + 1;
        strArr[i5] = "-X";
        int i7 = i6 + 1;
        strArr[i6] = String.valueOf(getNumFolds());
        int i8 = i7 + 1;
        strArr[i7] = "-S";
        int i9 = i8 + 1;
        strArr[i8] = String.valueOf(getSeed());
        if (getMetaClassifier() != null) {
            int i10 = i9 + 1;
            strArr[i9] = "-M";
            i9 = i10 + 1;
            strArr[i10] = getClassifierSpec(getMetaClassifier());
        }
        while (i9 < strArr.length) {
            int i11 = i9;
            i9++;
            strArr[i11] = "";
        }
        return strArr;
    }

    public void setSeed(int i) {
        this.m_Seed = i;
    }

    public int getSeed() {
        return this.m_Seed;
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public void setNumFolds(int i) throws Exception {
        if (i < 0) {
            throw new Exception("Stacking: Number of cross-validation folds must be positive.");
        }
        this.m_NumFolds = i;
    }

    public void setBaseClassifiers(Classifier[] classifierArr) {
        this.m_BaseClassifiers = classifierArr;
    }

    public Classifier[] getBaseClassifiers() {
        return this.m_BaseClassifiers;
    }

    public Classifier getBaseClassifier(int i) {
        return this.m_BaseClassifiers[i];
    }

    public void setMetaClassifier(Classifier classifier) {
        this.m_MetaClassifier = classifier;
    }

    public Classifier getMetaClassifier() {
        return this.m_MetaClassifier;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        if (this.m_BaseClassifiers.length == 0) {
            throw new Exception("No base classifiers have been set");
        }
        if (this.m_MetaClassifier == null) {
            throw new Exception("No meta classifier has been set");
        }
        if (!instances.classAttribute().isNominal() && !instances.classAttribute().isNumeric()) {
            throw new Exception("Class attribute has to be nominal or numeric!");
        }
        Instances instances2 = new Instances(instances);
        this.m_BaseFormat = new Instances(instances, 0);
        instances2.deleteWithMissingClass();
        if (instances2.numInstances() == 0) {
            throw new Exception("No training instances without missing class!");
        }
        instances2.randomize(new Random(this.m_Seed));
        if (instances2.classAttribute().isNominal()) {
            instances2.stratify(this.m_NumFolds);
        }
        int length = this.m_BaseClassifiers.length;
        Instances metaFormat = metaFormat(instances2);
        this.m_MetaFormat = new Instances(metaFormat, 0);
        for (int i = 0; i < this.m_NumFolds; i++) {
            Instances trainCV = instances2.trainCV(this.m_NumFolds, i);
            for (int i2 = 0; i2 < this.m_BaseClassifiers.length; i2++) {
                getBaseClassifier(i2).buildClassifier(trainCV);
            }
            Instances testCV = instances2.testCV(this.m_NumFolds, i);
            for (int i3 = 0; i3 < testCV.numInstances(); i3++) {
                metaFormat.add(metaInstance(testCV.instance(i3)));
            }
        }
        for (int i4 = 0; i4 < length; i4++) {
            getBaseClassifier(i4).buildClassifier(instances2);
        }
        this.m_MetaClassifier.buildClassifier(metaFormat);
    }

    @Override // weka.classifiers.Classifier
    public double classifyInstance(Instance instance) throws Exception {
        return this.m_MetaClassifier.classifyInstance(metaInstance(instance));
    }

    public String toString() {
        if (this.m_BaseClassifiers.length == 0) {
            return "Stacking: No base schemes entered.";
        }
        if (this.m_MetaClassifier == null) {
            return "Stacking: No meta scheme selected.";
        }
        if (this.m_MetaFormat == null) {
            return "Stacking: No model built yet.";
        }
        String str = "Stacking\n\nBase classifiers\n\n";
        for (int i = 0; i < this.m_BaseClassifiers.length; i++) {
            str = new StringBuffer(String.valueOf(str)).append(getBaseClassifier(i).toString()).append("\n\n").toString();
        }
        return new StringBuffer(String.valueOf(new StringBuffer(String.valueOf(str)).append("\n\nMeta classifier\n\n").toString())).append(this.m_MetaClassifier.toString()).toString();
    }

    public static void main(String[] strArr) {
        try {
            System.out.println(Evaluation.evaluateModel(new Stacking(), strArr));
        } catch (Exception e) {
            System.err.println(e.getMessage());
        }
    }

    protected Instances metaFormat(Instances instances) throws Exception {
        FastVector fastVector = new FastVector();
        for (int i = 0; i < this.m_BaseClassifiers.length; i++) {
            Classifier baseClassifier = getBaseClassifier(i);
            String name = baseClassifier.getClass().getName();
            if (this.m_BaseFormat.classAttribute().isNumeric()) {
                fastVector.addElement(new Attribute(name));
            } else if (baseClassifier instanceof DistributionClassifier) {
                for (int i2 = 0; i2 < this.m_BaseFormat.classAttribute().numValues(); i2++) {
                    fastVector.addElement(new Attribute(new StringBuffer(String.valueOf(name)).append(":").append(this.m_BaseFormat.classAttribute().value(i2)).toString()));
                }
            } else {
                FastVector fastVector2 = new FastVector();
                for (int i3 = 0; i3 < this.m_BaseFormat.classAttribute().numValues(); i3++) {
                    fastVector2.addElement(this.m_BaseFormat.classAttribute().value(i3));
                }
                fastVector.addElement(new Attribute(name, fastVector2));
            }
        }
        fastVector.addElement(this.m_BaseFormat.classAttribute());
        Instances instances2 = new Instances("Meta format", fastVector, 0);
        instances2.setClassIndex(instances2.numAttributes() - 1);
        return instances2;
    }

    protected String getBaseClassifierSpec(int i) {
        return this.m_BaseClassifiers.length < i ? "" : getClassifierSpec(getBaseClassifier(i));
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected String getClassifierSpec(Classifier classifier) {
        return classifier instanceof OptionHandler ? new StringBuffer(String.valueOf(classifier.getClass().getName())).append(" ").append(Utils.joinOptions(((OptionHandler) classifier).getOptions())).toString() : classifier.getClass().getName();
    }

    protected Instance metaInstance(Instance instance) throws Exception {
        double[] dArr = new double[this.m_MetaFormat.numAttributes()];
        int i = 0;
        for (int i2 = 0; i2 < this.m_BaseClassifiers.length; i2++) {
            Classifier baseClassifier = getBaseClassifier(i2);
            if (this.m_BaseFormat.classAttribute().isNumeric()) {
                int i3 = i;
                i++;
                dArr[i3] = baseClassifier.classifyInstance(instance);
            } else if (baseClassifier instanceof DistributionClassifier) {
                for (double d : ((DistributionClassifier) baseClassifier).distributionForInstance(instance)) {
                    int i4 = i;
                    i++;
                    dArr[i4] = d;
                }
            } else {
                int i5 = i;
                i++;
                dArr[i5] = baseClassifier.classifyInstance(instance);
            }
        }
        dArr[i] = instance.classValue();
        Instance instance2 = new Instance(1.0d, dArr);
        instance2.setDataset(this.m_MetaFormat);
        return instance2;
    }
}
