/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.meta;

import edu.udo.cs.yale.operator.ContainerModel;
import edu.udo.cs.yale.operator.IOContainer;
import edu.udo.cs.yale.operator.IOObject;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.OperatorChain;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.Value;
import edu.udo.cs.yale.operator.condition.InnerOperatorCondition;
import edu.udo.cs.yale.operator.condition.SimpleChainInnerOperatorCondition;
import edu.udo.cs.yale.operator.learner.PredictionModel;
import edu.udo.cs.yale.operator.learner.meta.AdaBoostModel;
import edu.udo.cs.yale.operator.learner.meta.BayBoostModel;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeInt;
import edu.udo.cs.yale.operator.performance.EstimatedPerformance;
import edu.udo.cs.yale.operator.performance.PerformanceVector;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MartinsIterationOperator
extends OperatorChain {
    private int iteration = 0;
    private PerformanceVector performance;

    public MartinsIterationOperator(OperatorDescription description) {
        super(description);
        this.addValue(new Value("performance", "best performance"){

            public double getValue() {
                if (MartinsIterationOperator.this.performance != null) {
                    return MartinsIterationOperator.this.performance.getMainCriterion().getAverage();
                }
                return Double.NaN;
            }
        });
        this.addValue(new Value("iteration", "current iteration"){

            public double getValue() {
                return MartinsIterationOperator.this.iteration;
            }
        });
    }

    @Override
    public IOObject[] apply() throws OperatorException {
        IOContainer input = this.getInput();
        PredictionModel ensemble = this.findEnsembleModel(input.get(Model.class));
        int maxIteration = this.getParameterAsInt("max_iteration");
        PerformanceVector pv = new PerformanceVector();
        String perfName = "";
        this.iteration = 1;
        while (this.iteration <= maxIteration) {
            if (ensemble instanceof BayBoostModel) {
                ((BayBoostModel)ensemble).setMaxModelNumber(this.iteration);
            } else if (ensemble instanceof AdaBoostModel) {
                ((AdaBoostModel)ensemble).setMaxModelNumber(this.iteration);
            } else {
                ensemble.setParameter("iteration", String.valueOf(this.iteration));
            }
            this.setInput(input.copy());
            this.performance = this.getPerformance();
            double auc = this.performance.getCriterion("AUC").getAverage();
            perfName = "AUC_" + this.iteration;
            EstimatedPerformance perfCrit = new EstimatedPerformance(perfName, auc, 1, false);
            pv.addCriterion(perfCrit);
            double acc = this.performance.getCriterion("accuracy").getAverage();
            String perfNameAcc = "ACC_" + this.iteration;
            perfCrit = new EstimatedPerformance(perfNameAcc, acc, 1, false);
            pv.addCriterion(perfCrit);
            double rms = this.performance.getCriterion("root_mean_squared_error").getAverage();
            String perfNameRms = "RMS_" + this.iteration;
            perfCrit = new EstimatedPerformance(perfNameRms, rms, 1, false);
            pv.addCriterion(perfCrit);
            System.out.println(String.valueOf(this.iteration) + ": " + auc + ", " + acc + ", " + rms);
            this.inApplyLoop();
            ++this.iteration;
        }
        pv.setMainCriterionName(perfName);
        return new IOObject[]{pv};
    }

    private PredictionModel findEnsembleModel(Model model) {
        if (model == null) {
            return null;
        }
        if (model instanceof BayBoostModel || model instanceof AdaBoostModel) {
            return (PredictionModel)model;
        }
        if (model instanceof ContainerModel) {
            ContainerModel cm = (ContainerModel)model;
            int i = 0;
            while (i < cm.getNumberOfModels()) {
                PredictionModel res = this.findEnsembleModel(cm.getModel(i));
                if (res != null) {
                    return res;
                }
                ++i;
            }
        }
        return null;
    }

    private PerformanceVector getPerformance() throws OperatorException {
        IOObject[] evalout = super.apply();
        IOContainer evalCont = new IOContainer(evalout);
        return evalCont.remove(PerformanceVector.class);
    }

    @Override
    public Class[] getInputClasses() {
        return new Class[0];
    }

    @Override
    public Class[] getOutputClasses() {
        return new Class[]{PerformanceVector.class};
    }

    @Override
    public InnerOperatorCondition getInnerOperatorCondition() {
        return new SimpleChainInnerOperatorCondition();
    }

    @Override
    public int getMaxNumberOfInnerOperators() {
        return Integer.MAX_VALUE;
    }

    @Override
    public int getMinNumberOfInnerOperators() {
        return 1;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeInt("max_iteration", "The maximum iteration.", 1, Integer.MAX_VALUE, 10));
        return types;
    }
}

