/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.neuralnet;

import com.rapidminer.gui.tools.SwingTools;
import com.rapidminer.operator.learner.functions.neuralnet.NeuralNetModel;
import com.rapidminer.report.Renderable;
import com.rapidminer.tools.Tools;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.Rectangle;
import java.awt.event.MouseEvent;
import java.awt.event.MouseListener;
import java.awt.geom.Ellipse2D;
import java.awt.geom.Rectangle2D;
import java.util.Iterator;
import java.util.Vector;
import javax.swing.JPanel;
import org.joone.engine.Layer;
import org.joone.engine.Matrix;
import org.joone.engine.Synapse;
import org.joone.net.NeuralNet;

public class NeuralNetVisualizer
extends JPanel
implements MouseListener,
Renderable {
    private static final long serialVersionUID = 1511167115976161350L;
    private static final int ROW_HEIGHT = 36;
    private static final int LAYER_WIDTH = 150;
    private static final int MARGIN = 30;
    private static final int NODE_RADIUS = 24;
    private static final Font LABEL_FONT = new Font("SansSerif", 0, 11);
    private NeuralNet neuralNet;
    private int selectedLayerIndex = -1;
    private int selectedRowIndex = -1;
    private double maxAbsoluteWeight = Double.NEGATIVE_INFINITY;
    private String key = null;
    private int keyX = -1;
    private int keyY = -1;
    private String[] attributeNames;

    public NeuralNetVisualizer(NeuralNetModel neuralNetModel) {
        this(neuralNetModel.getNeuralNet(), neuralNetModel.getAttributeNames());
    }

    public NeuralNetVisualizer(NeuralNet neuralNet, String[] attributeNames) {
        this.neuralNet = neuralNet;
        this.attributeNames = attributeNames;
        this.addMouseListener(this);
        this.maxAbsoluteWeight = Double.NEGATIVE_INFINITY;
        Vector layers = this.neuralNet.getLayers();
        Iterator i = layers.iterator();
        while (i.hasNext()) {
            Layer layer = (Layer)i.next();
            if (!i.hasNext()) continue;
            Vector outputs = layer.getAllOutputs();
            for (Synapse synapse : outputs) {
                Matrix weights = synapse.getWeights();
                int inputRows = weights.getM_rows();
                int outputRows = weights.getM_cols();
                for (int x = 0; x < inputRows; ++x) {
                    for (int y = 0; y < outputRows; ++y) {
                        this.maxAbsoluteWeight = Math.max(this.maxAbsoluteWeight, Math.abs(weights.value[x][y]));
                    }
                }
            }
        }
    }

    @Override
    public Dimension getPreferredSize() {
        Vector layers = this.neuralNet.getLayers();
        Iterator i = layers.iterator();
        int maxRows = -1;
        while (i.hasNext()) {
            Layer layer = (Layer)i.next();
            int rows = layer.getRows();
            maxRows = Math.max(maxRows, rows);
        }
        return new Dimension(layers.size() * 150 + 60, maxRows * 36 + 60);
    }

    @Override
    public void paint(Graphics graphics) {
        graphics.clearRect(0, 0, this.getWidth(), this.getHeight());
        graphics.setColor(Color.WHITE);
        graphics.fillRect(0, 0, this.getWidth(), this.getHeight());
        Dimension dim = this.getPreferredSize();
        int height = dim.height;
        Graphics2D g = (Graphics2D)graphics;
        Graphics2D translated = (Graphics2D)g.create();
        translated.translate(30, 30);
        translated.setFont(LABEL_FONT);
        Graphics2D synapsesG = (Graphics2D)translated.create();
        this.paintSynapses(synapsesG, height);
        synapsesG.dispose();
        Graphics2D nodeG = (Graphics2D)translated.create();
        this.paintNodes(nodeG, height);
        nodeG.dispose();
        translated.dispose();
        if (this.key != null) {
            this.key = Tools.transformAllLineSeparators(this.key);
            String[] lines = this.key.split("\n");
            double maxWidth = Double.NEGATIVE_INFINITY;
            double totalHeight = 0.0;
            for (String line : lines) {
                Rectangle2D keyBounds = g.getFontMetrics().getStringBounds(line, g);
                maxWidth = Math.max(maxWidth, keyBounds.getWidth());
                totalHeight += keyBounds.getHeight();
            }
            Rectangle frame = new Rectangle(this.keyX - 4, this.keyY, (int)maxWidth + 8, (int)(totalHeight += (double)((lines.length - 1) * 3)) + 6);
            g.setColor(SwingTools.LIGHTEST_YELLOW);
            g.fill(frame);
            g.setColor(SwingTools.DARK_YELLOW);
            g.draw(frame);
            g.setColor(Color.BLACK);
            int xPos = this.keyX;
            int yPos = this.keyY;
            for (String line : lines) {
                Rectangle2D keyBounds = g.getFontMetrics().getStringBounds(line, g);
                g.drawString(line, xPos, yPos += (int)keyBounds.getHeight());
                yPos += 3;
            }
        }
    }

    private void paintSynapses(Graphics2D g, int height) {
        Vector layers = this.neuralNet.getLayers();
        Iterator i = layers.iterator();
        while (i.hasNext()) {
            Layer layer = (Layer)i.next();
            if (i.hasNext()) {
                Vector outputs = layer.getAllOutputs();
                for (Synapse synapse : outputs) {
                    Matrix weights = synapse.getWeights();
                    int inputRows = weights.getM_rows();
                    int outputRows = weights.getM_cols();
                    int inputY = height / 2 - inputRows * 36 / 2;
                    for (int x = 0; x < inputRows; ++x) {
                        int outputY = height / 2 - outputRows * 36 / 2;
                        for (int y = 0; y < outputRows; ++y) {
                            float weight = 1.0f - (float)(Math.abs(weights.value[x][y]) / this.maxAbsoluteWeight);
                            Color color = new Color(weight, weight, weight);
                            g.setColor(color);
                            g.drawLine(12, inputY + 12, 162, outputY + 12);
                            outputY += 36;
                        }
                        inputY += 36;
                    }
                }
            }
            g.translate(150, 0);
        }
    }

    private void paintNodes(Graphics2D g, int height) {
        Vector layers = this.neuralNet.getLayers();
        Iterator i = layers.iterator();
        int layerIndex = 0;
        while (i.hasNext()) {
            Layer layer = (Layer)i.next();
            int rows = layer.getRows();
            Rectangle2D stringBounds = LABEL_FONT.getStringBounds(layer.getLayerName(), g.getFontRenderContext());
            g.setColor(Color.BLACK);
            g.drawString(layer.getLayerName(), (int)(-1.0 * stringBounds.getWidth() / 2.0 + 12.0), 0);
            int yPos = height / 2 - rows * 36 / 2;
            for (int r = 0; r < rows; ++r) {
                Ellipse2D.Double node = new Ellipse2D.Double(0.0, yPos, 24.0, 24.0);
                if (layer.getLayerName().toLowerCase().indexOf("input") >= 0 || layer.getLayerName().toLowerCase().indexOf("output") >= 0) {
                    g.setPaint(SwingTools.makeYellowPaint(24.0, 24.0));
                } else {
                    g.setPaint(SwingTools.makeBluePaint(24.0, 24.0));
                }
                g.fill(node);
                if (layerIndex == this.selectedLayerIndex && r == this.selectedRowIndex) {
                    g.setColor(Color.RED);
                } else {
                    g.setColor(Color.BLACK);
                }
                g.draw(node);
                yPos += 36;
            }
            g.translate(150, 0);
            ++layerIndex;
        }
    }

    private void setKey(String key, int keyX, int keyY) {
        this.key = key;
        this.keyX = keyX;
        this.keyY = keyY;
        this.repaint();
    }

    private void setSelectedNode(int layerIndex, int rowIndex, int xPos, int yPos) {
        this.selectedLayerIndex = layerIndex;
        this.selectedRowIndex = rowIndex;
        if (layerIndex >= 1) {
            Layer layer = (Layer)this.neuralNet.getLayers().get(this.selectedLayerIndex);
            Vector inputs = layer.getAllInputs();
            if (inputs.size() > 0) {
                Synapse synapse = (Synapse)inputs.get(0);
                Matrix weights = synapse.getWeights();
                int inputRows = weights.getM_rows();
                StringBuffer toolTip = new StringBuffer("Weights:" + Tools.getLineSeparator());
                for (int x = 0; x < inputRows; ++x) {
                    toolTip.append(Tools.formatNumber(weights.value[x][this.selectedRowIndex]) + Tools.getLineSeparator());
                }
                this.setKey(toolTip.toString(), xPos, yPos);
            } else {
                this.setKey(null, -1, -1);
            }
        } else if (rowIndex >= 0 && rowIndex < this.attributeNames.length) {
            this.setKey(this.attributeNames[rowIndex], xPos, yPos);
        } else {
            this.setKey(null, -1, -1);
        }
        this.repaint();
    }

    private void checkMousePos(int xPos, int yPos) {
        Layer layer;
        int rows;
        int yMargin;
        boolean layerHit;
        int x = xPos - 30;
        int y = yPos - 30;
        int layerIndex = x / 150;
        int layerMod = x % 150;
        boolean bl = layerHit = layerMod > 0 && layerMod < 24;
        if (layerHit && layerIndex >= 0 && layerIndex < this.neuralNet.getLayers().size() && y > (yMargin = this.getPreferredSize().height / 2 - (rows = (layer = (Layer)this.neuralNet.getLayers().get(layerIndex)).getRows()) * 36 / 2)) {
            for (int i = 0; i < rows; ++i) {
                if (y > yMargin && y < yMargin + 24) {
                    if (this.selectedLayerIndex == layerIndex && this.selectedRowIndex == i) {
                        this.setSelectedNode(-1, -1, -1, -1);
                    } else {
                        this.setSelectedNode(layerIndex, i, xPos, yPos);
                    }
                    return;
                }
                yMargin += 36;
            }
        }
        this.setSelectedNode(-1, -1, -1, -1);
    }

    @Override
    public void mouseClicked(MouseEvent e) {
    }

    @Override
    public void mouseEntered(MouseEvent e) {
    }

    @Override
    public void mouseExited(MouseEvent e) {
    }

    @Override
    public void mousePressed(MouseEvent e) {
    }

    @Override
    public void mouseReleased(MouseEvent e) {
        int xPos = e.getX();
        int yPos = e.getY();
        this.checkMousePos(xPos, yPos);
    }

    @Override
    public void prepareRendering() {
    }

    @Override
    public void finishRendering() {
    }

    @Override
    public int getRenderHeight(int preferredHeight) {
        int height = this.getPreferredSize().height;
        if (height < 1) {
            height = preferredHeight;
        }
        if (preferredHeight > height) {
            height = preferredHeight;
        }
        return height;
    }

    @Override
    public int getRenderWidth(int preferredWidth) {
        int width = this.getPreferredSize().width;
        if (width < 1) {
            width = preferredWidth;
        }
        if (preferredWidth > width) {
            width = preferredWidth;
        }
        return width;
    }

    @Override
    public void render(Graphics graphics, int width, int height) {
        this.setSize(width, height);
        this.paint(graphics);
    }
}

