/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.neuralnet;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Vector;
import org.joone.engine.InputPatternListener;
import org.joone.engine.Layer;
import org.joone.engine.Matrix;
import org.joone.engine.OutputPatternListener;
import org.joone.engine.Synapse;
import org.joone.io.MemoryInputSynapse;
import org.joone.io.MemoryOutputSynapse;
import org.joone.net.NeuralNet;

public class NeuralNetModel
extends PredictionModel {
    private static final long serialVersionUID = 776221623930869372L;
    private NeuralNet neuralNet;
    private String[] attributeNames;
    private int numberOfInputAttributes;
    private double minLabel;
    private double maxLabel;

    public NeuralNetModel(ExampleSet exampleSet, NeuralNet neuralNet, int numberOfInputAttributes, double minLabel, double maxLabel) {
        super(exampleSet);
        this.attributeNames = Tools.getRegularAttributeNames(exampleSet);
        this.neuralNet = neuralNet;
        this.numberOfInputAttributes = numberOfInputAttributes;
        this.minLabel = minLabel;
        this.maxLabel = maxLabel;
    }

    public NeuralNet getNeuralNet() {
        return this.neuralNet;
    }

    public String[] getAttributeNames() {
        return this.attributeNames;
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        Layer input = this.neuralNet.getInputLayer();
        input.removeAllInputs();
        MemoryInputSynapse memInp = new MemoryInputSynapse();
        memInp.setFirstRow(1);
        memInp.setAdvancedColumnSelector("1-" + this.numberOfInputAttributes);
        input.addInputSynapse((InputPatternListener)memInp);
        memInp.setInputArray(this.createInputData(exampleSet));
        Layer output = this.neuralNet.getOutputLayer();
        output.removeAllOutputs();
        this.neuralNet.getMonitor().setTotCicles(1);
        this.neuralNet.getMonitor().setTrainingPatterns(exampleSet.size());
        this.neuralNet.getMonitor().setLearning(false);
        double[] predictions = this.recall(this.neuralNet);
        Iterator i = exampleSet.iterator();
        int counter = 0;
        while (i.hasNext()) {
            Example example = (Example)i.next();
            double prediction = predictions[counter];
            if (predictedLabel.isNominal()) {
                double scaled = (prediction - 0.5) * 2.0;
                int index = scaled > 0.0 ? predictedLabel.getMapping().getPositiveIndex() : predictedLabel.getMapping().getNegativeIndex();
                example.setValue(predictedLabel, index);
                example.setConfidence(predictedLabel.getMapping().getPositiveString(), 1.0 / (1.0 + Math.exp(-scaled)));
                example.setConfidence(predictedLabel.getMapping().getNegativeString(), 1.0 / (1.0 + Math.exp(scaled)));
            } else {
                example.setValue(predictedLabel, prediction * (this.maxLabel - this.minLabel) + this.minLabel);
            }
            ++counter;
        }
        return exampleSet;
    }

    private double[] recall(NeuralNet net) {
        MemoryOutputSynapse output = new MemoryOutputSynapse();
        this.neuralNet.addOutputSynapse((OutputPatternListener)output);
        this.neuralNet.start();
        this.neuralNet.getMonitor().Go();
        this.neuralNet.join();
        int cc = this.neuralNet.getMonitor().getTrainingPatterns();
        double[] result = new double[cc];
        for (int i = 0; i < cc; ++i) {
            double[] pattern = output.getNextPattern();
            result[i] = pattern[0];
        }
        this.neuralNet.stop();
        return result;
    }

    private double[][] createInputData(ExampleSet exampleSet) {
        double[][] result = new double[exampleSet.size()][exampleSet.getAttributes().size()];
        int counter = 0;
        for (Example example : exampleSet) {
            int a = 0;
            for (Attribute attribute : exampleSet.getAttributes()) {
                result[counter][a++] = example.getValue(attribute);
            }
            ++counter;
        }
        return result;
    }

    @Override
    public String toString() {
        StringBuffer result = new StringBuffer();
        Vector layers = this.neuralNet.getLayers();
        Iterator i = layers.iterator();
        int layerIndex = 0;
        while (i.hasNext()) {
            Layer layer = (Layer)i.next();
            String nodeString = layer.getRows() == 1 ? "1 node" : layer.getRows() + " nodes";
            String titleString = "Layer '" + layer.getLayerName() + "' (" + nodeString + ")";
            result.append(titleString + com.rapidminer.tools.Tools.getLineSeparator());
            for (int t = 0; t < titleString.length(); ++t) {
                result.append("-");
            }
            result.append(com.rapidminer.tools.Tools.getLineSeparator());
            if (layerIndex == 0) {
                result.append(Arrays.asList(this.attributeNames).toString() + com.rapidminer.tools.Tools.getLineSeparator());
            } else {
                result.append("Input Weights:" + com.rapidminer.tools.Tools.getLineSeparator());
                Vector inputs = layer.getAllInputs();
                for (Object object : inputs) {
                    Synapse synapse;
                    Matrix weights;
                    if (!(object instanceof Synapse) || (weights = (synapse = (Synapse)object).getWeights()) == null) continue;
                    int inputRows = weights.getM_rows();
                    int outputRows = weights.getM_cols();
                    for (int y = 0; y < outputRows; ++y) {
                        result.append("Node " + (y + 1) + com.rapidminer.tools.Tools.getLineSeparator());
                        for (int x = 0; x < inputRows; ++x) {
                            result.append(weights.value[x][y] + com.rapidminer.tools.Tools.getLineSeparator());
                        }
                    }
                }
            }
            result.append(com.rapidminer.tools.Tools.getLineSeparator());
            ++layerIndex;
        }
        return result.toString();
    }
}

