/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.example;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.PartitionBuilder;
import edu.udo.cs.yale.tools.LogService;
import edu.udo.cs.yale.tools.RandomGenerator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class StratifiedPartitionBuilder
implements PartitionBuilder {
    private ExampleSet exampleSet;

    public StratifiedPartitionBuilder(ExampleSet exampleSet) {
        this.exampleSet = exampleSet;
    }

    public int[] createPartition(double[] ratio, int size) {
        Attribute label = this.exampleSet.getLabel();
        if (size != this.exampleSet.getSize()) {
            throw new RuntimeException("Cannot create stratified Partition: given size and size of the example set must be equal!");
        }
        if (label == null) {
            throw new RuntimeException("Cannot create stratified Partition: example set must have a label!");
        }
        if (!label.isNominal()) {
            throw new RuntimeException("Cannot create stratified Partition: label of example set must be nominal!");
        }
        double firstValue = ratio[0];
        for (int i = 1; i < ratio.length; ++i) {
            if (ratio[i] == firstValue) continue;
            LogService.logMessage("Not all ratio values are equal: using non-equal stratified sampling.", 2);
            return this.createNonEqualPartition(ratio, size, label);
        }
        LogService.logMessage("All ratio values are equal: using stratified sampling.", 2);
        return this.createEqualPartition(ratio, size, label);
    }

    private int[] createEqualPartition(double[] ratio, int size, Attribute label) {
        ArrayList<ExampleIndex> examples = new ArrayList<ExampleIndex>(size);
        ExampleReader reader = this.exampleSet.getExampleReader();
        int index = 0;
        while (reader.hasNext()) {
            Example example = reader.next();
            examples.add(new ExampleIndex(index++, example.getValue(label)));
        }
        Collections.shuffle(examples, RandomGenerator.getGlobalRandomGenerator());
        Collections.sort(examples);
        ArrayList newExamples = new ArrayList(size);
        int start = 0;
        int numberOfPartitions = ratio.length;
        while (newExamples.size() < size) {
            for (int i = start; i < examples.size(); i += numberOfPartitions) {
                newExamples.add(examples.get(i));
            }
            ++start;
        }
        int[] startNewP = new int[ratio.length + 1];
        startNewP[0] = 0;
        double ratioSum = 0.0;
        for (int i = 1; i < startNewP.length; ++i) {
            startNewP[i] = (int)Math.round((double)newExamples.size() * (ratioSum += ratio[i - 1]));
        }
        int[] part = new int[newExamples.size()];
        int p = 0;
        int counter = 0;
        Iterator n = newExamples.iterator();
        while (n.hasNext()) {
            if (counter >= startNewP[p + 1]) {
                // empty if block
            }
            ExampleIndex exampleIndex = (ExampleIndex)n.next();
            part[exampleIndex.exampleIndex] = ++p;
            ++counter;
        }
        return part;
    }

    private int[] createNonEqualPartition(double[] ratio, int size, Attribute label) {
        HashMap<String, LinkedList<Integer>> classLists = new HashMap<String, LinkedList<Integer>>();
        ExampleReader reader = this.exampleSet.getExampleReader();
        int index = 0;
        while (reader.hasNext()) {
            Example example = reader.next();
            String value = label.mapIndex((int)example.getValue(label));
            LinkedList<Integer> classList = (LinkedList<Integer>)classLists.get(value);
            if (classList == null) {
                classList = new LinkedList<Integer>();
                classList.add(new Integer(index++));
                classLists.put(value, classList);
                continue;
            }
            classList.add(new Integer(index++));
        }
        int[] part = new int[this.exampleSet.getSize()];
        Iterator c = classLists.keySet().iterator();
        while (c.hasNext()) {
            String value = (String)c.next();
            List classList = (List)classLists.get(value);
            Collections.shuffle(classList, RandomGenerator.getGlobalRandomGenerator());
            int[] startNewP = new int[ratio.length + 1];
            startNewP[0] = 0;
            double ratioSum = 0.0;
            for (int i = 1; i < startNewP.length; ++i) {
                startNewP[i] = (int)Math.round((double)classList.size() * (ratioSum += ratio[i - 1]));
            }
            int p = 0;
            int counter = 0;
            Iterator n = classList.iterator();
            while (n.hasNext()) {
                if (counter >= startNewP[p + 1]) {
                    // empty if block
                }
                Integer exampleIndex = (Integer)n.next();
                part[exampleIndex.intValue()] = ++p;
                ++counter;
            }
        }
        return part;
    }

    private static class ExampleIndex
    implements Comparable {
        int exampleIndex;
        double classIndex;

        public ExampleIndex(int exampleIndex, double classIndex) {
            this.exampleIndex = exampleIndex;
            this.classIndex = classIndex;
        }

        public int compareTo(Object o) {
            ExampleIndex e = (ExampleIndex)o;
            return Double.compare(this.classIndex, e.classIndex);
        }

        public String toString() {
            return this.exampleIndex + "(" + this.classIndex + ")";
        }
    }
}

