/*
 * Decompiled with CFR 0.152.
 */
package de.dfki.madm.paren.operator.learner.functions.neuralnet;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.functions.neuralnet.ActivationFunction;
import com.rapidminer.operator.learner.functions.neuralnet.InnerNode;
import com.rapidminer.operator.learner.functions.neuralnet.InputNode;
import com.rapidminer.operator.learner.functions.neuralnet.LinearFunction;
import com.rapidminer.operator.learner.functions.neuralnet.Node;
import com.rapidminer.operator.learner.functions.neuralnet.OutputNode;
import com.rapidminer.operator.learner.functions.neuralnet.SigmoidFunction;
import com.rapidminer.tools.RandomGenerator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

public class AutoMLPImprovedNeuralNetModel
extends PredictionModel {
    private static final long serialVersionUID = -2206598483097451366L;
    public static final ActivationFunction SIGMOID_FUNCTION = new SigmoidFunction();
    public static final ActivationFunction LINEAR_FUNCTION = new LinearFunction();
    public String[] attributeNames;
    public InputNode[] inputNodes = new InputNode[0];
    public InnerNode[] innerNodes = new InnerNode[0];
    public OutputNode[] outputNodes = new OutputNode[0];
    double error;

    public double getError() {
        return this.error;
    }

    public AutoMLPImprovedNeuralNetModel(ExampleSet trainingExampleSet) {
        super(trainingExampleSet);
        this.attributeNames = Tools.getRegularAttributeNames(trainingExampleSet);
    }

    public void train(ExampleSet exampleSet, List<String[]> hiddenLayers, int maxCycles, double maxError, double learningRate, double momentum, boolean decay, boolean shuffle, boolean normalize, RandomGenerator randomGenerator, boolean is_old_model, AutoMLPImprovedNeuralNetModel old_model) {
        Attribute label = exampleSet.getAttributes().getLabel();
        int numberOfClasses = this.getNumberOfClasses(label);
        if (normalize) {
            exampleSet.recalculateAllAttributeStatistics();
        } else {
            exampleSet.recalculateAttributeStatistics(label);
        }
        this.initInputLayer(exampleSet, normalize);
        double labelMin = exampleSet.getStatistics(label, "minimum");
        double labelMax = exampleSet.getStatistics(label, "maximum");
        this.initOutputLayer(label, numberOfClasses, labelMin, labelMax, randomGenerator);
        if (!is_old_model) {
            this.initHiddenLayers(exampleSet, label, hiddenLayers, randomGenerator);
        } else {
            this.initHiddenLayers(exampleSet, label, hiddenLayers, randomGenerator, old_model);
        }
        Attribute weightAttribute = exampleSet.getAttributes().getWeight();
        double totalWeight = 0.0;
        for (Example example : exampleSet) {
            double weight = 1.0;
            if (weightAttribute != null) {
                weight = example.getValue(weightAttribute);
            }
            totalWeight += weight;
        }
        int[] exampleIndices = null;
        if (shuffle) {
            ArrayList<Integer> indices = new ArrayList<Integer>(exampleSet.size());
            for (int i = 0; i < exampleSet.size(); ++i) {
                indices.add(i);
            }
            Collections.shuffle(indices, randomGenerator);
            exampleIndices = new int[indices.size()];
            int index = 0;
            Iterator i$ = indices.iterator();
            while (i$.hasNext()) {
                int current = (Integer)i$.next();
                exampleIndices[index++] = current;
            }
        }
        for (int cycle = 0; cycle < maxCycles; ++cycle) {
            this.error = 0.0;
            int maxSize = exampleSet.size();
            for (int index = 0; index < maxSize; ++index) {
                int exampleIndex = index;
                if (exampleIndices != null) {
                    exampleIndex = exampleIndices[index];
                }
                Example example = exampleSet.getExample(exampleIndex);
                this.resetNetwork();
                this.calculateValue(example);
                double weight = 1.0;
                if (weightAttribute != null) {
                    weight = example.getValue(weightAttribute);
                }
                double tempRate = learningRate * weight;
                if (decay) {
                    tempRate /= (double)(cycle + 1);
                }
                this.error += this.calculateError(example) / (double)numberOfClasses * weight;
                this.update(example, tempRate, momentum);
            }
            this.error /= totalWeight;
            if (!Double.isInfinite(this.error) && !Double.isNaN(this.error) || !com.rapidminer.tools.Tools.isLessEqual(learningRate, 0.0)) continue;
            throw new RuntimeException("Cannot reset network to a smaller learning rate.");
        }
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        for (Example example : exampleSet) {
            this.resetNetwork();
            if (predictedLabel.isNominal()) {
                int c;
                int numberOfClasses = this.getNumberOfClasses(this.getLabel());
                double[] classProbabilities = new double[numberOfClasses];
                for (int c2 = 0; c2 < numberOfClasses; ++c2) {
                    classProbabilities[c2] = this.outputNodes[c2].calculateValue(true, example);
                }
                double total = 0.0;
                for (int c3 = 0; c3 < numberOfClasses; ++c3) {
                    total += classProbabilities[c3];
                }
                double maxConfidence = Double.NEGATIVE_INFINITY;
                int maxIndex = 0;
                for (c = 0; c < numberOfClasses; ++c) {
                    int n = c;
                    classProbabilities[n] = classProbabilities[n] / total;
                    if (!(classProbabilities[c] > maxConfidence)) continue;
                    maxIndex = c;
                    maxConfidence = classProbabilities[c];
                }
                example.setValue(predictedLabel, predictedLabel.getMapping().mapString(this.getLabel().getMapping().mapIndex(maxIndex)));
                for (c = 0; c < numberOfClasses; ++c) {
                    example.setConfidence(this.getLabel().getMapping().mapIndex(c), classProbabilities[c]);
                }
                continue;
            }
            double value = this.outputNodes[0].calculateValue(true, example);
            example.setValue(predictedLabel, value);
        }
        return exampleSet;
    }

    public String[] getAttributeNames() {
        return this.attributeNames;
    }

    public InputNode[] getInputNodes() {
        return this.inputNodes;
    }

    public OutputNode[] getOutputNodes() {
        return this.outputNodes;
    }

    public InnerNode[] getInnerNodes() {
        return this.innerNodes;
    }

    public int getNumberOfClasses(Attribute label) {
        int numberOfClasses = 1;
        if (label.isNominal()) {
            numberOfClasses = label.getMapping().size();
        }
        return numberOfClasses;
    }

    public void addNode(InnerNode node) {
        InnerNode[] newInnerNodes = new InnerNode[this.innerNodes.length + 1];
        System.arraycopy(this.innerNodes, 0, newInnerNodes, 0, this.innerNodes.length);
        newInnerNodes[newInnerNodes.length - 1] = node;
        this.innerNodes = newInnerNodes;
    }

    public void resetNetwork() {
        for (int i = 0; i < this.outputNodes.length; ++i) {
            this.outputNodes[i].reset();
        }
    }

    public void update(Example example, double learningRate, double momentum) {
        for (int i = 0; i < this.outputNodes.length; ++i) {
            this.outputNodes[i].update(example, learningRate, momentum);
        }
    }

    public void calculateValue(Example example) {
        for (int i = 0; i < this.outputNodes.length; ++i) {
            this.outputNodes[i].calculateValue(true, example);
        }
    }

    public double calculateError(Example example) {
        for (int i = 0; i < this.inputNodes.length; ++i) {
            this.inputNodes[i].calculateError(true, example);
        }
        double totalError = 0.0;
        for (int i = 0; i < this.outputNodes.length; ++i) {
            double error = this.outputNodes[i].calculateError(false, example);
            totalError += error * error;
        }
        return totalError;
    }

    public int getDefaultLayerSize(ExampleSet exampleSet, Attribute label) {
        return (int)Math.round((double)(exampleSet.getAttributes().size() + this.getNumberOfClasses(label)) / 2.0) + 1;
    }

    public void initInputLayer(ExampleSet exampleSet, boolean normalize) {
        this.inputNodes = new InputNode[exampleSet.getAttributes().size()];
        int a = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            this.inputNodes[a] = new InputNode(attribute.getName());
            double range = 1.0;
            double offset = 0.0;
            if (normalize) {
                double min = exampleSet.getStatistics(attribute, "minimum");
                double max = exampleSet.getStatistics(attribute, "maximum");
                range = (max - min) / 2.0;
                offset = (max + min) / 2.0;
            }
            this.inputNodes[a].setAttribute(attribute, range, offset, normalize);
            ++a;
        }
    }

    public void initOutputLayer(Attribute label, int numberOfClasses, double min, double max, RandomGenerator randomGenerator) {
        double range = (max - min) / 2.0;
        double offset = (max + min) / 2.0;
        this.outputNodes = new OutputNode[numberOfClasses];
        for (int o = 0; o < numberOfClasses; ++o) {
            if (!label.isNominal()) {
                this.outputNodes[o] = new OutputNode(label.getName(), label, range, offset);
            } else {
                this.outputNodes[o] = new OutputNode(label.getMapping().mapIndex(o), label, range, offset);
                this.outputNodes[o].setClassIndex(o);
            }
            InnerNode actualOutput = null;
            if (label.isNominal()) {
                String classValue = label.getMapping().mapIndex(o);
                actualOutput = new InnerNode("Class '" + classValue + "'", -2, randomGenerator, SIGMOID_FUNCTION);
            } else {
                actualOutput = new InnerNode("Regression", -2, randomGenerator, LINEAR_FUNCTION);
            }
            this.addNode(actualOutput);
            Node.connect(actualOutput, this.outputNodes[o]);
        }
    }

    public void initHiddenLayers(ExampleSet exampleSet, Attribute label, List<String[]> hiddenLayerList, RandomGenerator randomGenerator) {
        int o;
        String[] layerNames = null;
        int[] layerSizes = null;
        if (hiddenLayerList.size() > 0) {
            layerNames = new String[hiddenLayerList.size()];
            layerSizes = new int[hiddenLayerList.size()];
            int index = 0;
            for (String[] nameSizePair : hiddenLayerList) {
                layerNames[index] = nameSizePair[0];
                int layerSize = Integer.valueOf(nameSizePair[1]);
                if (layerSize <= 0) {
                    layerSize = this.getDefaultLayerSize(exampleSet, label);
                }
                layerSizes[index] = layerSize;
                ++index;
            }
        } else {
            this.log("No hidden layers defined. Using default hidden layer.");
            layerNames = new String[]{"Hidden"};
            layerSizes = new int[]{this.getDefaultLayerSize(exampleSet, label)};
        }
        int lastLayerSize = 0;
        for (int layerIndex = 0; layerIndex < layerNames.length; ++layerIndex) {
            int numberOfNodes = layerSizes[layerIndex];
            for (int nodeIndex = 0; nodeIndex < numberOfNodes; ++nodeIndex) {
                InnerNode innerNode = new InnerNode("Node " + (nodeIndex + 1), layerIndex, randomGenerator, SIGMOID_FUNCTION);
                this.addNode(innerNode);
                if (layerIndex <= 0) continue;
                for (int i = this.innerNodes.length - nodeIndex - 1 - lastLayerSize; i < this.innerNodes.length - nodeIndex - 1; ++i) {
                    Node.connect(this.innerNodes[i], innerNode);
                }
            }
            lastLayerSize = numberOfNodes;
        }
        int firstLayerSize = layerSizes[0];
        int numberOfAttributes = exampleSet.getAttributes().size();
        int numberOfClasses = this.getNumberOfClasses(label);
        if (firstLayerSize == 0) {
            for (int i = 0; i < numberOfAttributes; ++i) {
                for (o = 0; o < numberOfClasses; ++o) {
                    Node.connect(this.inputNodes[i], this.innerNodes[o]);
                }
            }
        } else {
            int i;
            for (i = 0; i < numberOfAttributes; ++i) {
                for (o = numberOfClasses; o < numberOfClasses + firstLayerSize; ++o) {
                    Node.connect(this.inputNodes[i], this.innerNodes[o]);
                }
            }
            for (i = this.innerNodes.length - lastLayerSize; i < this.innerNodes.length; ++i) {
                for (o = 0; o < numberOfClasses; ++o) {
                    Node.connect(this.innerNodes[i], this.innerNodes[o]);
                }
            }
        }
    }

    public void initHiddenLayers(ExampleSet exampleSet, Attribute label, List<String[]> hiddenLayerList, RandomGenerator randomGenerator, AutoMLPImprovedNeuralNetModel old_model) {
        int length;
        double[] new_weights;
        double[] old_weights;
        int new_layerIndex;
        int old_layerIndex;
        InnerNode new_innerNode;
        InnerNode old_innerNode;
        int i;
        this.initHiddenLayers(exampleSet, label, hiddenLayerList, randomGenerator);
        for (i = 0; i < old_model.innerNodes.length && i < this.innerNodes.length; ++i) {
            old_innerNode = old_model.innerNodes[i];
            new_innerNode = this.innerNodes[i];
            old_layerIndex = old_innerNode.getLayerIndex();
            if (old_layerIndex != (new_layerIndex = new_innerNode.getLayerIndex()) || old_layerIndex == -2) continue;
            old_weights = old_innerNode.getWeights();
            new_weights = new_innerNode.getWeights();
            length = old_innerNode.getInputNodes().length;
            for (int j = 0; j <= length; ++j) {
                new_weights[j] = old_weights[j];
            }
            this.innerNodes[i].setWeights(new_weights);
        }
        for (i = 0; i < old_model.innerNodes.length && i < this.innerNodes.length; ++i) {
            old_innerNode = old_model.innerNodes[i];
            new_innerNode = this.innerNodes[i];
            old_layerIndex = old_innerNode.getLayerIndex();
            if (old_layerIndex != (new_layerIndex = new_innerNode.getLayerIndex()) || old_layerIndex != -2) continue;
            old_weights = old_innerNode.getWeights();
            new_weights = new_innerNode.getWeights();
            length = old_innerNode.getInputNodes().length;
            int length1 = new_innerNode.getInputNodes().length;
            for (int j = 0; j <= length && j <= length1; ++j) {
                new_weights[j] = old_weights[j];
            }
            this.innerNodes[i].setWeights(new_weights);
        }
    }

    @Override
    public String toString() {
        int i;
        Node[] inputNodes;
        double[] weights;
        String nodeName;
        String layerName;
        int layerIndex;
        StringBuffer result = new StringBuffer();
        int lastLayerIndex = -99;
        boolean first = true;
        for (InnerNode innerNode : this.innerNodes) {
            int t;
            layerIndex = innerNode.getLayerIndex();
            if (layerIndex == -2) continue;
            if (lastLayerIndex == -99 || lastLayerIndex != layerIndex) {
                if (!first) {
                    result.append(com.rapidminer.tools.Tools.getLineSeparators(2));
                }
                first = false;
                layerName = "Hidden " + (layerIndex + 1);
                result.append(layerName + com.rapidminer.tools.Tools.getLineSeparator());
                for (t = 0; t < layerName.length(); ++t) {
                    result.append("=");
                }
                lastLayerIndex = layerIndex;
                result.append(com.rapidminer.tools.Tools.getLineSeparator());
            }
            nodeName = innerNode.getNodeName() + " (" + innerNode.getActivationFunction().getTypeName() + ")";
            result.append(com.rapidminer.tools.Tools.getLineSeparator() + nodeName + com.rapidminer.tools.Tools.getLineSeparator());
            for (t = 0; t < nodeName.length(); ++t) {
                result.append("-");
            }
            result.append(com.rapidminer.tools.Tools.getLineSeparator());
            weights = innerNode.getWeights();
            inputNodes = innerNode.getInputNodes();
            for (i = 0; i < inputNodes.length; ++i) {
                result.append(inputNodes[i].getNodeName() + ": " + com.rapidminer.tools.Tools.formatNumber(weights[i + 1]) + com.rapidminer.tools.Tools.getLineSeparator());
            }
            result.append("Threshold: " + com.rapidminer.tools.Tools.formatNumber(weights[0]) + com.rapidminer.tools.Tools.getLineSeparator());
        }
        first = true;
        for (InnerNode innerNode : this.innerNodes) {
            layerIndex = innerNode.getLayerIndex();
            if (layerIndex != -2) continue;
            if (first) {
                result.append(com.rapidminer.tools.Tools.getLineSeparators(2));
                layerName = "Output";
                result.append(layerName + com.rapidminer.tools.Tools.getLineSeparator());
                for (int t = 0; t < layerName.length(); ++t) {
                    result.append("=");
                }
                lastLayerIndex = layerIndex;
                result.append(com.rapidminer.tools.Tools.getLineSeparator());
                first = false;
            }
            nodeName = innerNode.getNodeName() + " (" + innerNode.getActivationFunction().getTypeName() + ")";
            result.append(com.rapidminer.tools.Tools.getLineSeparator() + nodeName + com.rapidminer.tools.Tools.getLineSeparator());
            for (int t = 0; t < nodeName.length(); ++t) {
                result.append("-");
            }
            result.append(com.rapidminer.tools.Tools.getLineSeparator());
            weights = innerNode.getWeights();
            inputNodes = innerNode.getInputNodes();
            for (i = 0; i < inputNodes.length; ++i) {
                result.append(inputNodes[i].getNodeName() + ": " + com.rapidminer.tools.Tools.formatNumber(weights[i + 1]) + com.rapidminer.tools.Tools.getLineSeparator());
            }
            result.append("Threshold: " + com.rapidminer.tools.Tools.formatNumber(weights[0]) + com.rapidminer.tools.Tools.getLineSeparator());
        }
        return result.toString();
    }
}

