/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.example;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.example.PartitionBuilder;
import edu.udo.cs.yale.tools.LogService;
import edu.udo.cs.yale.tools.RandomGenerator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;

public class StratifiedPartitionBuilder
implements PartitionBuilder {
    private ExampleSet exampleSet;
    private Random random;

    public StratifiedPartitionBuilder(ExampleSet exampleSet, int seed) {
        this.exampleSet = exampleSet;
        this.random = RandomGenerator.getRandomGenerator(seed);
    }

    public int[] createPartition(double[] ratio, int size) {
        Attribute label = this.exampleSet.getAttributes().getLabel();
        if (size != this.exampleSet.size()) {
            throw new RuntimeException("Cannot create stratified Partition: given size and size of the example set must be equal!");
        }
        if (label == null) {
            throw new RuntimeException("Cannot create stratified Partition: example set must have a label!");
        }
        if (!label.isNominal()) {
            throw new RuntimeException("Cannot create stratified Partition: label of example set must be nominal!");
        }
        double firstValue = ratio[0];
        int i = 1;
        while (i < ratio.length) {
            if (ratio[i] != firstValue) {
                LogService.logMessage("Not all ratio values are equal: using non-equal stratified sampling.", 2);
                return this.createNonEqualPartition(ratio, size, label);
            }
            ++i;
        }
        LogService.logMessage("All ratio values are equal: using stratified sampling.", 2);
        return this.createEqualPartition(ratio, size, label);
    }

    private int[] createEqualPartition(double[] ratio, int size, Attribute label) {
        ArrayList<ExampleIndex> examples = new ArrayList<ExampleIndex>(size);
        Iterator reader = this.exampleSet.iterator();
        int index = 0;
        while (reader.hasNext()) {
            Example example = (Example)reader.next();
            examples.add(new ExampleIndex(index++, example.getValue(label)));
        }
        Collections.shuffle(examples, this.random);
        Collections.sort(examples);
        ArrayList<ExampleIndex> newExamples = new ArrayList<ExampleIndex>(size);
        int start = 0;
        int numberOfPartitions = ratio.length;
        while (newExamples.size() < size) {
            int i = start;
            while (i < examples.size()) {
                newExamples.add((ExampleIndex)examples.get(i));
                i += numberOfPartitions;
            }
            ++start;
        }
        int[] startNewP = new int[ratio.length + 1];
        startNewP[0] = 0;
        double ratioSum = 0.0;
        int i = 1;
        while (i < startNewP.length) {
            startNewP[i] = (int)Math.round((double)newExamples.size() * (ratioSum += ratio[i - 1]));
            ++i;
        }
        int[] part = new int[newExamples.size()];
        int p = 0;
        int counter = 0;
        Iterator n = newExamples.iterator();
        while (n.hasNext()) {
            if (counter >= startNewP[p + 1]) {
                // empty if block
            }
            ExampleIndex exampleIndex = (ExampleIndex)n.next();
            part[exampleIndex.exampleIndex] = ++p;
            ++counter;
        }
        return part;
    }

    private int[] createNonEqualPartition(double[] ratio, int size, Attribute label) {
        HashMap<String, LinkedList<Integer>> classLists = new HashMap<String, LinkedList<Integer>>();
        Iterator reader = this.exampleSet.iterator();
        int index = 0;
        while (reader.hasNext()) {
            Example example = (Example)reader.next();
            String value = label.getMapping().mapIndex((int)example.getValue(label));
            List<Integer> classList = (List)classLists.get(value);
            if (classList == null) {
                classList = new LinkedList<Integer>();
                classList.add(index++);
                classLists.put(value, (LinkedList<Integer>)classList);
                continue;
            }
            classList.add(index++);
        }
        int[] part = new int[this.exampleSet.size()];
        for (List<Integer> classList : classLists.values()) {
            Collections.shuffle(classList, this.random);
            int[] startNewP = new int[ratio.length + 1];
            startNewP[0] = 0;
            double ratioSum = 0.0;
            int i = 1;
            while (i < startNewP.length) {
                startNewP[i] = (int)Math.round((double)classList.size() * (ratioSum += ratio[i - 1]));
                ++i;
            }
            int p = 0;
            int counter = 0;
            Iterator n = classList.iterator();
            while (n.hasNext()) {
                if (counter >= startNewP[p + 1]) {
                    // empty if block
                }
                Integer exampleIndex = (Integer)n.next();
                part[exampleIndex.intValue()] = ++p;
                ++counter;
            }
        }
        return part;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static class ExampleIndex
    implements Comparable<ExampleIndex> {
        int exampleIndex;
        double classIndex;

        public ExampleIndex(int exampleIndex, double classIndex) {
            this.exampleIndex = exampleIndex;
            this.classIndex = classIndex;
        }

        @Override
        public int compareTo(ExampleIndex e) {
            return Double.compare(this.classIndex, e.classIndex);
        }

        public boolean equals(Object o) {
            if (!(o instanceof ExampleIndex)) {
                return false;
            }
            ExampleIndex other = (ExampleIndex)o;
            return this.exampleIndex == other.exampleIndex;
        }

        public int hashCode() {
            return Integer.valueOf(this.exampleIndex).hashCode();
        }

        public String toString() {
            return String.valueOf(this.exampleIndex) + "(" + this.classIndex + ")";
        }
    }
}

