/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.Model;
import edu.udo.cs.yale.operator.OperatorDescription;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.AbstractLearner;
import edu.udo.cs.yale.operator.learner.ConjunctiveRuleModel;
import edu.udo.cs.yale.operator.learner.LearnerCapability;
import edu.udo.cs.yale.operator.parameter.ParameterType;
import edu.udo.cs.yale.operator.parameter.ParameterTypeBoolean;
import edu.udo.cs.yale.operator.parameter.ParameterTypeCategory;
import edu.udo.cs.yale.operator.parameter.ParameterTypeInt;
import edu.udo.cs.yale.operator.parameter.UndefinedParameterError;
import edu.udo.cs.yale.tools.LogService;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class BestRuleInduction
extends AbstractLearner {
    private static final String DEPTH_BOUND = "max_depth";
    private static final String UTILITY_FUNCTION = "utility_function";
    private static final String MAX_CACHE = "max_cache";
    private static final String REL_TO_PRED = "relative_to_predictions";
    private static final String WRACC = "weighted relative accuracy";
    private static final String BINOMIAL = "binomial test function";
    private static final String[] UTILITY_FUNCTION_LIST = new String[]{"weighted relative accuracy", "binomial test function"};
    private double globalP;
    private double globalN;
    protected ConjunctiveRuleModel bestRule;
    private double bestScore;
    private final Vector<RuleWithScoreUpperBound> openNodes = new Vector();
    private final Vector<ConjunctiveRuleModel> prunedNodes = new Vector();

    public BestRuleInduction(OperatorDescription description) {
        super(description);
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.POLYNOMINAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.BINOMINAL_ATTRIBUTES) {
            return true;
        }
        return lc == LearnerCapability.BINOMINAL_CLASS;
    }

    protected void initHighscore() {
        this.bestRule = null;
        this.bestScore = Double.NEGATIVE_INFINITY;
    }

    protected boolean communicateToHighscore(ConjunctiveRuleModel rule, double[] counts) throws UndefinedParameterError {
        double optimisticScore = this.getOptimisticScore(counts);
        if (optimisticScore <= this.getPruningScore()) {
            return true;
        }
        double posScore = this.getScore(counts, true);
        double negScore = this.getScore(counts, false);
        if (posScore > this.bestScore) {
            this.bestRule = rule;
            this.bestScore = posScore;
        }
        if (negScore > this.bestScore) {
            ConjunctiveRuleModel negRule;
            this.bestRule = negRule = new ConjunctiveRuleModel(rule, rule.getLabel().getMapping().getNegativeIndex());
            this.bestScore = negScore;
        }
        return false;
    }

    protected ConjunctiveRuleModel getBestRule() {
        return this.bestRule;
    }

    protected double getPruningScore() {
        return this.bestScore;
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        this.initHighscore();
        int positiveLabel = exampleSet.getAttributes().getLabel().getMapping().getPositiveIndex();
        ConjunctiveRuleModel defaultRule = new ConjunctiveRuleModel(exampleSet.getAttributes().getLabel(), positiveLabel);
        double[] globalCounts = this.getCounts(defaultRule, exampleSet);
        this.globalP = globalCounts[0];
        this.globalN = globalCounts[1];
        this.communicateToHighscore(defaultRule, globalCounts);
        double optimisticScore = this.getOptimisticScore(globalCounts);
        this.openNodes.clear();
        this.prunedNodes.clear();
        this.addRulesToOpenNodes(defaultRule.getAllRefinedRules(exampleSet), optimisticScore);
        int length = 1;
        while (!this.openNodes.isEmpty() && length <= this.getParameterAsInt(DEPTH_BOUND)) {
            int ignored = 0;
            LogService.logMessage("Evaluating " + this.openNodes.size() + " rules of length " + length, 2);
            if (this.openNodes.size() > this.getParameterAsInt(MAX_CACHE)) {
                LogService.logMessage("Ignoring all but the " + this.getParameterAsInt(MAX_CACHE) + " rules with highest support.", 2);
            }
            Object[] ruleArray = new RuleWithScoreUpperBound[this.openNodes.size()];
            this.openNodes.toArray(ruleArray);
            Arrays.sort(ruleArray);
            int stopAtIndex = Math.max(0, ruleArray.length - this.getParameterAsInt(MAX_CACHE));
            this.openNodes.clear();
            int i = ruleArray.length - 1;
            while (i >= stopAtIndex) {
                Object rulePlusScore = ruleArray[i];
                ConjunctiveRuleModel rule = ((RuleWithScoreUpperBound)rulePlusScore).getRule();
                if (this.isRefinementOfPrunedRule(rule)) {
                    ++ignored;
                } else if (((RuleWithScoreUpperBound)rulePlusScore).getScoreBound() <= this.getPruningScore()) {
                    ++ignored;
                    this.prunedNodes.add(((RuleWithScoreUpperBound)rulePlusScore).getRule());
                } else {
                    this.expandNode(rule, exampleSet);
                }
                this.checkForStop();
                --i;
            }
            LogService.logMessage("Could ignore " + ignored + " rules as refinements of pruned rules or by optimistic estimates.", 2);
            LogService.logMessage("Number of pruned rules in cache: " + this.prunedNodes.size(), 2);
            LogService.logMessage("Best rule is " + this.getBestRule().toString(), 2);
            LogService.logMessage("Score is " + this.getPruningScore(), 2);
            ++length;
        }
        this.openNodes.clear();
        this.prunedNodes.clear();
        return this.getBestRule();
    }

    private void addRulesToOpenNodes(Collection rules, double scoreUpperBound) {
        if (scoreUpperBound <= this.getPruningScore()) {
            return;
        }
        Iterator it = rules.iterator();
        while (it.hasNext()) {
            this.openNodes.add(new RuleWithScoreUpperBound((ConjunctiveRuleModel)it.next(), scoreUpperBound));
        }
    }

    private void expandNode(ConjunctiveRuleModel rule, ExampleSet exampleSet) throws OperatorException {
        double[] counts = this.getCounts(rule, exampleSet);
        boolean pruning = this.communicateToHighscore(rule, counts);
        if (pruning) {
            this.prunedNodes.add(rule);
        } else if (rule.getRuleLength() < this.getParameterAsInt(DEPTH_BOUND)) {
            this.addRulesToOpenNodes(rule.getAllRefinedRules(exampleSet), this.getOptimisticScore(counts));
        }
    }

    public boolean isRefinementOfPrunedRule(ConjunctiveRuleModel rule) {
        for (ConjunctiveRuleModel prunedRule : this.prunedNodes) {
            if (!rule.isRefinementOf(prunedRule)) continue;
            return true;
        }
        return false;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    protected double getScore(double[] counts, boolean predictPositives) throws UndefinedParameterError {
        double pnEst;
        double p = counts[0];
        double n = counts[1];
        double cov = (p + n) / (this.globalP + this.globalN);
        double pnRel = predictPositives ? p : n;
        String function = UTILITY_FUNCTION_LIST[this.getParameterAsInt(UTILITY_FUNCTION)];
        UndefinedParameterError upe = new UndefinedParameterError("Missing parameter 'utility_function'!");
        if (!this.getParameterAsBoolean(REL_TO_PRED) || counts.length != 4) {
            double pnAbs;
            double d = pnAbs = predictPositives ? this.globalP : this.globalN;
            if (function.equals(WRACC)) {
                return cov * (pnRel / (p + n) - pnAbs / (this.globalP + this.globalN));
            }
            if (!function.equals(BINOMIAL)) throw upe;
            return Math.sqrt(cov) * (pnRel / (p + n) - pnAbs / (this.globalP + this.globalN));
        }
        double estP = counts[2];
        double estN = counts[3];
        double d = pnEst = predictPositives ? estP : estN;
        if (function.equals(WRACC)) {
            return cov * (pnRel / (p + n) - pnEst / (estP + estN));
        }
        if (!function.equals(BINOMIAL)) throw upe;
        return Math.sqrt(cov) * (pnRel / (p + n) - pnEst / (estP + estN));
    }

    protected double getOptimisticScore(double[] counts) throws UndefinedParameterError {
        double p = counts[0];
        double n = counts[1];
        if (!this.getParameterAsBoolean(REL_TO_PRED) || counts.length != 4) {
            return Math.max(this.getScore(new double[]{p, 0.0}, true), this.getScore(new double[]{0.0, n}, false));
        }
        double estP = counts[2];
        double estN = counts[3];
        return Math.max(this.getScore(new double[]{p, 0.0, 0.0, estN}, true), this.getScore(new double[]{0.0, n, 0.0, estP}, false));
    }

    protected double[] getCounts(ConjunctiveRuleModel rule, ExampleSet exampleSet) throws OperatorException {
        double[] dArray;
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        boolean relToPred = predictedLabel != null && this.getParameterAsBoolean(REL_TO_PRED);
        int coveredPart = rule.getConclusion();
        int posIndex = exampleSet.getAttributes().getLabel().getMapping().getPositiveIndex();
        int negIndex = exampleSet.getAttributes().getLabel().getMapping().getNegativeIndex();
        String posS = null;
        String negS = null;
        if (relToPred) {
            posS = predictedLabel.getMapping().mapIndex(posIndex);
            negS = predictedLabel.getMapping().mapIndex(negIndex);
        }
        double p = 0.0;
        double n = 0.0;
        double estP = 0.0;
        double estN = 0.0;
        for (Example example : exampleSet) {
            double weight;
            double d = weight = weightAttr == null ? 1.0 : example.getValue(weightAttr);
            if (rule.predict(example) != (double)coveredPart) continue;
            if (example.getValue(example.getAttributes().getLabel()) == (double)posIndex) {
                p += weight;
            } else {
                n += weight;
            }
            if (!relToPred) continue;
            double sum = example.getConfidence(posS) + example.getConfidence(negS);
            estP += weight * example.getConfidence(posS) / sum;
            estN += weight * example.getConfidence(negS) / sum;
        }
        if (relToPred) {
            double[] dArray2 = new double[4];
            dArray2[0] = p;
            dArray2[1] = n;
            dArray2[2] = estP;
            dArray = dArray2;
            dArray2[3] = estN;
        } else {
            double[] dArray3 = new double[2];
            dArray3[0] = p;
            dArray = dArray3;
            dArray3[1] = n;
        }
        return dArray;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeInt(DEPTH_BOUND, "An upper bound for the number of literals.", 1, Integer.MAX_VALUE, 2));
        types.add(new ParameterTypeCategory(UTILITY_FUNCTION, "The function to be optimized by the rule.", UTILITY_FUNCTION_LIST, 0));
        types.add(new ParameterTypeInt(MAX_CACHE, "Bounds the number of rules considered per depth to avoid high memory consumption, but leads to incomplete search.", 1, Integer.MAX_VALUE, 10000));
        types.add(new ParameterTypeBoolean(REL_TO_PRED, "Searches for rules with a maximum difference to the predited label.", false));
        return types;
    }

    public static class RuleWithScoreUpperBound
    implements Comparable {
        private final ConjunctiveRuleModel rule;
        private final double scoreUpperBound;

        public RuleWithScoreUpperBound(ConjunctiveRuleModel rule, double scoreUpperBound) {
            this.rule = rule;
            this.scoreUpperBound = scoreUpperBound;
        }

        public ConjunctiveRuleModel getRule() {
            return this.rule;
        }

        public double getScoreBound() {
            return this.scoreUpperBound;
        }

        public int compareTo(Object obj) {
            if (obj instanceof RuleWithScoreUpperBound) {
                double otherScore = ((RuleWithScoreUpperBound)obj).getScoreBound();
                if (this.getScoreBound() < otherScore) {
                    return -1;
                }
                if (this.getScoreBound() > otherScore) {
                    return 1;
                }
                return 0;
            }
            return this.getClass().getName().compareTo(obj.getClass().getName());
        }

        public boolean equals(Object o) {
            if (!(o instanceof RuleWithScoreUpperBound)) {
                return false;
            }
            return this.rule.equals(((RuleWithScoreUpperBound)o).rule);
        }

        public int hashCode() {
            return this.rule.hashCode();
        }
    }
}

