/*
 * Decompiled with CFR 0.152.
 */
package edu.udo.cs.yale.operator.learner.neuralnet;

import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.IOContainer;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.PredictionModel;
import edu.udo.cs.yale.operator.learner.neuralnet.NeuralNetVisualizer;
import edu.udo.cs.yale.tools.Tools;
import java.awt.BorderLayout;
import java.awt.Component;
import java.awt.FlowLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.Iterator;
import java.util.Vector;
import javax.swing.ButtonGroup;
import javax.swing.JPanel;
import javax.swing.JRadioButton;
import javax.swing.JScrollPane;
import org.joone.engine.InputPatternListener;
import org.joone.engine.Layer;
import org.joone.engine.Matrix;
import org.joone.engine.OutputPatternListener;
import org.joone.engine.Synapse;
import org.joone.io.MemoryInputSynapse;
import org.joone.io.MemoryOutputSynapse;
import org.joone.net.NeuralNet;

public class NeuralNetModel
extends PredictionModel {
    private NeuralNet neuralNet;
    private int numberOfInputAttributes;
    private double minLabel;
    private double maxLabel;

    public NeuralNetModel() {
    }

    public NeuralNetModel(Attribute label, NeuralNet neuralNet, int numberOfInputAttributes, double minLabel, double maxLabel) {
        super(label);
        this.neuralNet = neuralNet;
        this.numberOfInputAttributes = numberOfInputAttributes;
        this.minLabel = minLabel;
        this.maxLabel = maxLabel;
    }

    public void performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        Layer input = this.neuralNet.getInputLayer();
        input.removeAllInputs();
        MemoryInputSynapse memInp = new MemoryInputSynapse();
        memInp.setFirstRow(1);
        memInp.setAdvancedColumnSelector("1-" + this.numberOfInputAttributes);
        input.addInputSynapse((InputPatternListener)memInp);
        memInp.setInputArray(this.createInputData(exampleSet));
        Layer output = this.neuralNet.getOutputLayer();
        output.removeAllOutputs();
        this.neuralNet.getMonitor().setTotCicles(1);
        this.neuralNet.getMonitor().setTrainingPatterns(exampleSet.size());
        this.neuralNet.getMonitor().setLearning(false);
        double[] predictions = this.recall(this.neuralNet);
        Iterator i = exampleSet.iterator();
        int counter = 0;
        while (i.hasNext()) {
            Example example = (Example)i.next();
            double prediction = predictions[counter];
            if (predictedLabel.isNominal()) {
                double scaled = (prediction - 0.5) * 2.0;
                int index = scaled > 0.0 ? predictedLabel.getMapping().getPositiveIndex() : predictedLabel.getMapping().getNegativeIndex();
                example.setValue(predictedLabel, index);
                example.setConfidence(predictedLabel.getMapping().getPositiveString(), 1.0 / (1.0 + Math.exp(-scaled)));
                example.setConfidence(predictedLabel.getMapping().getNegativeString(), 1.0 / (1.0 + Math.exp(scaled)));
            } else {
                example.setValue(predictedLabel, prediction * (this.maxLabel - this.minLabel) + this.minLabel);
            }
            ++counter;
        }
    }

    private double[] recall(NeuralNet net) {
        MemoryOutputSynapse output = new MemoryOutputSynapse();
        this.neuralNet.addOutputSynapse((OutputPatternListener)output);
        this.neuralNet.start();
        this.neuralNet.getMonitor().Go();
        this.neuralNet.join();
        int cc = this.neuralNet.getMonitor().getTrainingPatterns();
        double[] result = new double[cc];
        int i = 0;
        while (i < cc) {
            double[] pattern = output.getNextPattern();
            result[i] = pattern[0];
            ++i;
        }
        this.neuralNet.stop();
        return result;
    }

    private double[][] createInputData(ExampleSet exampleSet) {
        double[][] result = new double[exampleSet.size()][exampleSet.getAttributes().size()];
        int counter = 0;
        for (Example example : exampleSet) {
            int a = 0;
            for (Attribute attribute : exampleSet.getAttributes()) {
                result[counter][a++] = example.getValue(attribute);
            }
            ++counter;
        }
        return result;
    }

    public Component getVisualizationComponent(IOContainer ioContainer) {
        final JPanel mainPanel = new JPanel();
        mainPanel.setLayout(new BorderLayout());
        final JScrollPane graphView = new JScrollPane(new NeuralNetVisualizer(this.neuralNet));
        graphView.getVerticalScrollBar().setUnitIncrement(10);
        final JRadioButton graphViewButton = new JRadioButton("graph view", true);
        graphViewButton.setToolTipText("Changes to a graphical view of this model.");
        graphViewButton.addActionListener(new ActionListener(){

            public void actionPerformed(ActionEvent e) {
                if (graphViewButton.isSelected()) {
                    mainPanel.remove(1);
                    mainPanel.add((Component)graphView, "Center");
                    mainPanel.repaint();
                }
            }
        });
        final Component textView = super.getVisualizationComponent(ioContainer);
        final JRadioButton textViewButton = new JRadioButton("text view", true);
        textViewButton.setToolTipText("Changes to a textual view of this model.");
        textViewButton.addActionListener(new ActionListener(){

            public void actionPerformed(ActionEvent e) {
                if (textViewButton.isSelected()) {
                    mainPanel.remove(1);
                    mainPanel.add(textView, "Center");
                    mainPanel.repaint();
                }
            }
        });
        ButtonGroup group = new ButtonGroup();
        group.add(graphViewButton);
        group.add(textViewButton);
        JPanel togglePanel = new JPanel(new FlowLayout(0));
        togglePanel.add(graphViewButton);
        togglePanel.add(textViewButton);
        mainPanel.add((Component)togglePanel, "North");
        mainPanel.add((Component)graphView, "Center");
        return mainPanel;
    }

    public String toString() {
        StringBuffer result = new StringBuffer("NeuralNet" + Tools.getLineSeparator());
        Vector layers = this.neuralNet.getLayers();
        Iterator i = layers.iterator();
        int layerIndex = 0;
        while (i.hasNext()) {
            Layer layer = (Layer)i.next();
            String nodeString = layer.getRows() == 1 ? "1 node" : String.valueOf(layer.getRows()) + " nodes";
            result.append("Layer '" + layer.getLayerName() + "' (" + nodeString + ")" + Tools.getLineSeparator());
            if (layerIndex > 0) {
                result.append("Input Weights:" + Tools.getLineSeparator());
                Vector inputs = layer.getAllInputs();
                for (Object object : inputs) {
                    Synapse synapse;
                    Matrix weights;
                    if (!(object instanceof Synapse) || (weights = (synapse = (Synapse)object).getWeights()) == null) continue;
                    int inputRows = weights.getM_rows();
                    int outputRows = weights.getM_cols();
                    int y = 0;
                    while (y < outputRows) {
                        result.append("Node " + (y + 1) + Tools.getLineSeparator());
                        int x = 0;
                        while (x < inputRows) {
                            result.append(String.valueOf(weights.value[x][y]) + Tools.getLineSeparator());
                            ++x;
                        }
                        ++y;
                    }
                }
            }
            ++layerIndex;
        }
        return result.toString();
    }
}

