/*
 * Decompiled with CFR 0.152.
 */
package de.dfki.madm.paren.operator.learner.functions.neuralnet;

import com.rapidminer.gui.tools.SwingTools;
import com.rapidminer.operator.learner.functions.neuralnet.InnerNode;
import com.rapidminer.operator.learner.functions.neuralnet.InputNode;
import com.rapidminer.operator.learner.functions.neuralnet.Node;
import com.rapidminer.report.Renderable;
import com.rapidminer.tools.Tools;
import de.dfki.madm.paren.operator.learner.functions.neuralnet.AutoMLPImprovedNeuralNetModel;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.Rectangle;
import java.awt.event.MouseEvent;
import java.awt.event.MouseListener;
import java.awt.geom.Ellipse2D;
import java.awt.geom.Rectangle2D;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.swing.JPanel;

public class AutoMLPImprovedNeuralNetVisualizer
extends JPanel
implements MouseListener,
Renderable {
    private static final long serialVersionUID = -26826681541601736L;
    private static final int ROW_HEIGHT = 36;
    private static final int LAYER_WIDTH = 150;
    private static final int MARGIN = 30;
    private static final int NODE_RADIUS = 24;
    private static final Font LABEL_FONT = new Font("SansSerif", 0, 11);
    private AutoMLPImprovedNeuralNetModel neuralNet;
    private int selectedLayerIndex = -1;
    private int selectedRowIndex = -1;
    private double maxAbsoluteWeight = Double.NEGATIVE_INFINITY;
    private String key = null;
    private int keyX = -1;
    private int keyY = -1;
    private String[] attributeNames;
    private Map<Integer, List<Node>> layers = new LinkedHashMap<Integer, List<Node>>();

    public AutoMLPImprovedNeuralNetVisualizer(AutoMLPImprovedNeuralNetModel neuralNet, String[] attributeNames) {
        this.neuralNet = neuralNet;
        this.attributeNames = attributeNames;
        this.addMouseListener(this);
        this.maxAbsoluteWeight = Double.NEGATIVE_INFINITY;
        ArrayList<InputNode> inputNodes = new ArrayList<InputNode>();
        for (InputNode inputNode : this.neuralNet.getInputNodes()) {
            inputNodes.add(inputNode);
        }
        this.layers.put(0, inputNodes);
        for (Node node : this.neuralNet.getInnerNodes()) {
            List<Node> layer;
            double[] weights;
            for (double w : weights = ((InnerNode)node).getWeights()) {
                this.maxAbsoluteWeight = Math.max(this.maxAbsoluteWeight, Math.abs(w));
            }
            int layerIndex = node.getLayerIndex();
            if (layerIndex == -2) continue;
            if ((layer = this.layers.get(++layerIndex)) == null) {
                layer = new ArrayList<Node>();
                this.layers.put(layerIndex, layer);
            }
            layer.add(node);
        }
        int trueLayerIndex = this.layers.size();
        ArrayList<InnerNode> outputNodes = new ArrayList<InnerNode>();
        for (InnerNode innerNode : this.neuralNet.getInnerNodes()) {
            int layerIndex = innerNode.getLayerIndex();
            if (layerIndex != -2) continue;
            outputNodes.add(innerNode);
        }
        this.layers.put(trueLayerIndex, outputNodes);
    }

    @Override
    public Dimension getPreferredSize() {
        int maxRows = -1;
        for (Map.Entry<Integer, List<Node>> entry : this.layers.entrySet()) {
            int layerIndex = entry.getKey();
            int nodes = entry.getValue().size();
            if (layerIndex != -2) {
                ++nodes;
            }
            maxRows = Math.max(maxRows, nodes);
        }
        return new Dimension(this.layers.size() * 150 + 60, maxRows * 36 + 60);
    }

    @Override
    public void paint(Graphics graphics) {
        graphics.clearRect(0, 0, this.getWidth(), this.getHeight());
        graphics.setColor(Color.WHITE);
        graphics.fillRect(0, 0, this.getWidth(), this.getHeight());
        Dimension dim = this.getPreferredSize();
        int height = dim.height;
        Graphics2D g = (Graphics2D)graphics;
        Graphics2D translated = (Graphics2D)g.create();
        translated.translate(30, 30);
        translated.setFont(LABEL_FONT);
        Graphics2D synapsesG = (Graphics2D)translated.create();
        this.paintSynapses(synapsesG, height);
        synapsesG.dispose();
        Graphics2D nodeG = (Graphics2D)translated.create();
        this.paintNodes(nodeG, height);
        nodeG.dispose();
        translated.dispose();
        if (this.key != null) {
            this.key = Tools.transformAllLineSeparators(this.key);
            String[] lines = this.key.split("\n");
            double maxWidth = Double.NEGATIVE_INFINITY;
            double totalHeight = 0.0;
            for (String line : lines) {
                Rectangle2D keyBounds = g.getFontMetrics().getStringBounds(line, g);
                maxWidth = Math.max(maxWidth, keyBounds.getWidth());
                totalHeight += keyBounds.getHeight();
            }
            Rectangle frame = new Rectangle(this.keyX - 4, this.keyY, (int)maxWidth + 8, (int)(totalHeight += (double)((lines.length - 1) * 3)) + 6);
            g.setColor(SwingTools.LIGHTEST_YELLOW);
            g.fill(frame);
            g.setColor(SwingTools.DARK_YELLOW);
            g.draw(frame);
            g.setColor(Color.BLACK);
            int xPos = this.keyX;
            int yPos = this.keyY;
            for (String line : lines) {
                Rectangle2D keyBounds = g.getFontMetrics().getStringBounds(line, g);
                g.drawString(line, xPos, yPos += (int)keyBounds.getHeight());
                yPos += 3;
            }
        }
    }

    private void paintSynapses(Graphics2D g, int height) {
        for (int i = 1; i < this.layers.size(); ++i) {
            int layerIndex = i;
            List<Node> layer = this.layers.get(layerIndex);
            int offset = layerIndex == this.layers.size() - 1 ? 0 : 1;
            int outputY = height / 2 - (layer.size() + offset) * 36 / 2;
            for (Node node : layer) {
                if (node instanceof InnerNode) {
                    Node[] inputNodes = node.getInputNodes();
                    double[] weights = ((InnerNode)node).getWeights();
                    int inputY = height / 2 - (inputNodes.length + 1) * 36 / 2;
                    for (int j = 0; j < inputNodes.length; ++j) {
                        float weight = 1.0f - (float)(Math.abs(weights[j + 1]) / this.maxAbsoluteWeight);
                        Color color = new Color(weight, weight, weight);
                        g.setColor(color);
                        g.drawLine(12, inputY + 12, 162, outputY + 12);
                        inputY += 36;
                    }
                    float weight = 1.0f - (float)(Math.abs(weights[0]) / this.maxAbsoluteWeight);
                    Color color = new Color(weight, weight, weight);
                    g.setColor(color);
                    g.drawLine(12, inputY + 12, 162, outputY + 12);
                }
                outputY += 36;
            }
            g.translate(150, 0);
            ++layerIndex;
        }
    }

    private void paintNodes(Graphics2D g, int height) {
        for (Map.Entry<Integer, List<Node>> entry : this.layers.entrySet()) {
            int layerIndex = entry.getKey();
            List<Node> layer = entry.getValue();
            int nodes = layer.size();
            if (layerIndex < this.layers.size() - 1) {
                ++nodes;
            }
            String layerName = null;
            layerName = layerIndex == 0 ? "Input" : (layerIndex == this.layers.size() - 1 ? "Output" : "Hidden " + layerIndex);
            Rectangle2D stringBounds = LABEL_FONT.getStringBounds(layerName, g.getFontRenderContext());
            g.setColor(Color.BLACK);
            g.drawString(layerName, (int)(-1.0 * stringBounds.getWidth() / 2.0 + 12.0), 0);
            int yPos = height / 2 - nodes * 36 / 2;
            for (int r = 0; r < nodes; ++r) {
                Ellipse2D.Double node = new Ellipse2D.Double(0.0, yPos, 24.0, 24.0);
                if (layerIndex == 0 || layerIndex == this.layers.size() - 1) {
                    if (r < nodes - 1 || layerIndex == this.layers.size() - 1) {
                        g.setPaint(SwingTools.makeYellowPaint(24.0, 24.0));
                    } else {
                        g.setPaint(new Color(233, 233, 233));
                    }
                } else if (r < nodes - 1) {
                    g.setPaint(SwingTools.makeBluePaint(24.0, 24.0));
                } else {
                    g.setPaint(new Color(233, 233, 233));
                }
                g.fill(node);
                if (layerIndex == this.selectedLayerIndex && r == this.selectedRowIndex) {
                    g.setColor(Color.RED);
                } else {
                    g.setColor(Color.BLACK);
                }
                g.draw(node);
                yPos += 36;
            }
            g.translate(150, 0);
            ++layerIndex;
        }
    }

    private void setKey(String key, int keyX, int keyY) {
        this.key = key;
        this.keyX = keyX;
        this.keyY = keyY;
        this.repaint();
    }

    private void setSelectedNode(int layerIndex, int rowIndex, int xPos, int yPos) {
        this.selectedLayerIndex = layerIndex;
        this.selectedRowIndex = rowIndex;
        if (this.selectedLayerIndex < 0 || this.selectedRowIndex < 0) {
            this.setKey(null, -1, -1);
            return;
        }
        if (layerIndex == 0) {
            if (rowIndex >= 0 && rowIndex < this.attributeNames.length) {
                this.setKey(this.attributeNames[rowIndex], xPos, yPos);
            } else if (rowIndex == this.attributeNames.length) {
                this.setKey("Threshold Node", xPos, yPos);
            } else {
                this.setKey(null, -1, -1);
            }
        } else {
            List<Node> currentLayer = this.layers.get(this.selectedLayerIndex);
            if (rowIndex >= 0 && rowIndex < currentLayer.size()) {
                StringBuffer toolTip = new StringBuffer("Weights:" + Tools.getLineSeparator());
                Node node = currentLayer.get(this.selectedRowIndex);
                if (node instanceof InnerNode) {
                    InnerNode innerNode = (InnerNode)node;
                    double[] weights = innerNode.getWeights();
                    for (int w = 1; w < weights.length; ++w) {
                        toolTip.append(Tools.formatNumber(weights[w]) + Tools.getLineSeparator());
                    }
                    toolTip.append(Tools.formatNumber(weights[0]) + " (Threshold)");
                }
                this.setKey(toolTip.toString(), xPos, yPos);
            } else {
                this.setKey("Threshold Node", xPos, yPos);
            }
        }
        this.repaint();
    }

    private void checkMousePos(int xPos, int yPos) {
        boolean layerHit;
        int x = xPos - 30;
        int y = yPos - 30;
        int layerIndex = x / 150;
        int layerMod = x % 150;
        boolean bl = layerHit = layerMod > 0 && layerMod < 24;
        if (layerHit && layerIndex >= 0 && layerIndex < this.layers.size()) {
            int yMargin;
            List<Node> layer = this.layers.get(layerIndex);
            int rows = layer.size();
            if (layerIndex < this.layers.size() - 1) {
                ++rows;
            }
            if (y > (yMargin = this.getPreferredSize().height / 2 - rows * 36 / 2)) {
                for (int i = 0; i < rows; ++i) {
                    if (y > yMargin && y < yMargin + 24) {
                        if (this.selectedLayerIndex == layerIndex && this.selectedRowIndex == i) {
                            this.setSelectedNode(-1, -1, -1, -1);
                        } else {
                            this.setSelectedNode(layerIndex, i, xPos, yPos);
                        }
                        return;
                    }
                    yMargin += 36;
                }
            }
        }
        this.setSelectedNode(-1, -1, -1, -1);
    }

    @Override
    public void mouseClicked(MouseEvent e) {
    }

    @Override
    public void mouseEntered(MouseEvent e) {
    }

    @Override
    public void mouseExited(MouseEvent e) {
    }

    @Override
    public void mousePressed(MouseEvent e) {
    }

    @Override
    public void mouseReleased(MouseEvent e) {
        int xPos = e.getX();
        int yPos = e.getY();
        this.checkMousePos(xPos, yPos);
    }

    @Override
    public void prepareRendering() {
    }

    @Override
    public void finishRendering() {
    }

    @Override
    public int getRenderHeight(int preferredHeight) {
        int height = this.getPreferredSize().height;
        if (height < 1) {
            height = preferredHeight;
        }
        if (preferredHeight > height) {
            height = preferredHeight;
        }
        return height;
    }

    @Override
    public int getRenderWidth(int preferredWidth) {
        int width = this.getPreferredSize().width;
        if (width < 1) {
            width = preferredWidth;
        }
        if (preferredWidth > width) {
            width = preferredWidth;
        }
        return width;
    }

    @Override
    public void render(Graphics graphics, int width, int height) {
        this.setSize(width, height);
        this.paint(graphics);
    }
}

