/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.postprocessing;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.ContainerModel;
import edu.udo.cs.yale.operator.IOObject;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.Operator;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.UserError;
import edu.udo.cs.yale.operator.learner.LearnerCapability;
import edu.udo.cs.yale.operator.learner.PredictionModel;
import edu.udo.cs.yale.operator.postprocessing.PlattParameters;
import edu.udo.cs.yale.operator.postprocessing.PlattScalingModel;
import edu.udo.cs.yale.tools.LogService;
import java.util.Iterator;

public class PlattScaling
extends Operator {
    public PlattScaling(OperatorDescription description) {
        super(description);
    }

    public boolean supportsCapability(LearnerCapability lc) {
        return lc != LearnerCapability.NUMERICAL_CLASS && lc != LearnerCapability.POLYNOMINAL_CLASS;
    }

    public Class[] getInputClasses() {
        return new Class[]{ExampleSet.class, Model.class};
    }

    public Class[] getOutputClasses() {
        return new Class[]{Model.class};
    }

    public IOObject[] apply() throws OperatorException {
        ExampleSet exampleSet = this.getInput(ExampleSet.class);
        Model model = this.getInput(Model.class);
        if (exampleSet.getAttributes().getLabel() == null) {
            throw new UserError((Operator)this, 105, new Object[0]);
        }
        if (exampleSet.getAttributes().size() == 0) {
            throw new UserError((Operator)this, 106, new Object[0]);
        }
        Attribute label = this.extractLabel(model, exampleSet);
        ExampleSet calibrationSet = (ExampleSet)exampleSet.clone();
        model.apply(calibrationSet);
        PlattParameters plattParams = PlattScaling.computeParameters(calibrationSet, label);
        PredictionModel.removePredictedLabel(calibrationSet);
        PlattScalingModel scalingModel = new PlattScalingModel(label, model, plattParams);
        return new IOObject[]{scalingModel};
    }

    private Attribute extractLabel(Model model, ExampleSet exampleSet) {
        if (model instanceof PredictionModel) {
            return ((PredictionModel)model).getLabel();
        }
        if (model instanceof ContainerModel) {
            ContainerModel cm = (ContainerModel)model;
            int i = cm.getNumberOfModels() - 1;
            while (i >= 0) {
                if (cm.getModel(i) instanceof PredictionModel) {
                    return ((PredictionModel)cm.getModel(i)).getLabel();
                }
                --i;
            }
        }
        LogService.logMessage("Could not find label in model for Platt's Scaling, using Label of provided ExampleSet instead.", 4);
        return exampleSet.getAttributes().getLabel();
    }

    public static PlattParameters computeParameters(ExampleSet exampleSet, Attribute label) {
        String posLabelS = label.getMapping().getPositiveString();
        int posLabel = exampleSet.getAttributes().getLabel().getMapping().mapString(posLabelS);
        String negLabelS = label.getMapping().getNegativeString();
        int negLabel = exampleSet.getAttributes().getLabel().getMapping().mapString(negLabelS);
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        double[] priors = new double[2];
        for (Example example : exampleSet) {
            double weight = weightAttr == null ? 1.0 : example.getWeight();
            int n = (int)example.getLabel();
            priors[n] = priors[n] + weight;
        }
        double A = 0.0;
        double B = Math.log((priors[negLabel] + 1.0) / (priors[posLabel] + 1.0));
        double hiTarget = (priors[posLabel] + 1.0) / (priors[posLabel] + 2.0);
        double loTarget = 1.0 / (priors[negLabel] + 2.0);
        double lambda = 0.001;
        double olderr = 1.0E300;
        double[] pp = new double[exampleSet.size()];
        int i = 0;
        while (i < pp.length) {
            pp[i] = (priors[posLabel] + 1.0) / (priors[negLabel] + priors[posLabel] + 2.0);
            ++i;
        }
        int count = 0;
        int it = 1;
        while (it <= 100) {
            double a = 0.0;
            double b = 0.0;
            double c = 0.0;
            double d = 0.0;
            double e = 0.0;
            double t = 0.0;
            Iterator reader = exampleSet.iterator();
            int index = 0;
            while (reader.hasNext()) {
                Example example = (Example)reader.next();
                t = example.getLabel() == (double)posLabel ? hiTarget : loTarget;
                double predicted = PlattScaling.getLogOddsPosConfidence(example.getConfidence(posLabelS));
                double weight = weightAttr == null ? 1.0 : example.getWeight();
                double d1 = weight * (pp[index] - t);
                double d2 = weight * (pp[index] * (1.0 - pp[index]));
                a += predicted * predicted * d2;
                b += d2;
                c += predicted * d2;
                d += predicted * d1;
                e += d1;
                ++index;
            }
            if (Math.abs(d) < 1.0E-9 && Math.abs(e) < 1.0E-9) break;
            double oldA = A;
            double oldB = B;
            double err = 0.0;
            while (true) {
                double det;
                if ((det = (a + lambda) * (b + lambda) - c * c) == 0.0) {
                    lambda *= 10.0;
                    continue;
                }
                A = oldA + ((b + lambda) * d - c * e) / det;
                B = oldB + ((a + lambda) * e - c * d) / det;
                err = 0.0;
                index = 0;
                while (reader.hasNext()) {
                    Example example = (Example)reader.next();
                    double predicted = PlattScaling.getLogOddsPosConfidence(example.getConfidence(posLabelS));
                    double weight = weightAttr == null ? 1.0 : example.getWeight();
                    double oddsVal = Math.min(1.0E30, Math.exp(predicted * A + B));
                    double p = Math.min(1.0, 1.0 / (1.0 + oddsVal));
                    pp[index++] = p;
                    err -= weight * (t * Math.log(p) + (t - 1.0) * Math.log(1.0 - p));
                }
                if (err < olderr * 1.0000001) {
                    lambda *= 0.1;
                    break;
                }
                if ((lambda *= 10.0) >= 1000000.0) break;
            }
            double diff = err - olderr;
            double scale = 0.5 * (err + olderr + 1.0);
            count = diff > -0.001 * scale && diff < 1.0E-7 * scale ? ++count : 0;
            olderr = err;
            if (count == 3) break;
            ++it;
        }
        if (Double.isNaN(A) || Double.isNaN(B)) {
            A = 1.0;
            B = 0.0;
            LogService.logMessage("Discarding invalid result of Platt's scaling, using identity instead.", 4);
        }
        return new PlattParameters(A, B);
    }

    public static double getLogOddsPosConfidence(double originalConfidence) {
        double epsilon = 1.0E-30;
        double confidence = Math.min(Math.max(epsilon, originalConfidence), 1.0 - epsilon);
        if (Double.isNaN(confidence)) {
            confidence = 0.5;
            LogService.logMessage("Found a NaN confidence during Platt's Scaling.", 4);
        }
        double odds = (1.0 - confidence) / confidence;
        return Math.log(odds);
    }
}

