/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.neuralnet;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.functions.neuralnet.ImprovedNeuralNetModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeList;
import com.rapidminer.parameter.ParameterTypeString;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;

public class ImprovedNeuralNetLearner
extends AbstractLearner {
    public static final String PARAMETER_HIDDEN_LAYERS = "hidden_layers";
    public static final String PARAMETER_TRAINING_CYCLES = "training_cycles";
    public static final String PARAMETER_ERROR_EPSILON = "error_epsilon";
    public static final String PARAMETER_LEARNING_RATE = "learning_rate";
    public static final String PARAMETER_MOMENTUM = "momentum";
    public static final String PARAMETER_DECAY = "decay";
    public static final String PARAMETER_SHUFFLE = "shuffle";
    public static final String PARAMETER_NORMALIZE = "normalize";

    public ImprovedNeuralNetLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        ImprovedNeuralNetModel model = new ImprovedNeuralNetModel(exampleSet);
        List<String[]> hiddenLayers = this.getParameterList(PARAMETER_HIDDEN_LAYERS);
        int maxCycles = this.getParameterAsInt(PARAMETER_TRAINING_CYCLES);
        double maxError = this.getParameterAsDouble(PARAMETER_ERROR_EPSILON);
        double learningRate = this.getParameterAsDouble(PARAMETER_LEARNING_RATE);
        double momentum = this.getParameterAsDouble(PARAMETER_MOMENTUM);
        boolean decay = this.getParameterAsBoolean(PARAMETER_DECAY);
        boolean shuffle = this.getParameterAsBoolean(PARAMETER_SHUFFLE);
        boolean normalize = this.getParameterAsBoolean(PARAMETER_NORMALIZE);
        RandomGenerator randomGenerator = RandomGenerator.getRandomGenerator(this);
        model.train(exampleSet, hiddenLayers, maxCycles, maxError, learningRate, momentum, decay, shuffle, normalize, randomGenerator);
        return model;
    }

    @Override
    public Class<? extends PredictionModel> getModelClass() {
        return ImprovedNeuralNetModel.class;
    }

    @Override
    public boolean supportsCapability(OperatorCapability lc) {
        if (lc == OperatorCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (lc == OperatorCapability.POLYNOMINAL_LABEL) {
            return true;
        }
        if (lc == OperatorCapability.BINOMINAL_LABEL) {
            return true;
        }
        if (lc == OperatorCapability.NUMERICAL_LABEL) {
            return true;
        }
        return lc == OperatorCapability.WEIGHTED_EXAMPLES;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterType type = new ParameterTypeList(PARAMETER_HIDDEN_LAYERS, "Describes the name and the size of all hidden layers.", (ParameterType)new ParameterTypeString("hidden_layer_name", "The name of the hidden layer."), (ParameterType)new ParameterTypeInt("hidden_layer_sizes", "The size of the hidden layers. A size of < 0 leads to a layer size of (number_of_attributes + number of classes) / 2 + 1.", -1, Integer.MAX_VALUE, -1));
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeInt(PARAMETER_TRAINING_CYCLES, "The number of training cycles used for the neural network training.", 1, Integer.MAX_VALUE, 500);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_LEARNING_RATE, "The learning rate determines by how much we change the weights at each step. May not be 0.", Double.MIN_VALUE, 1.0, 0.3);
        type.setExpert(false);
        types.add(type);
        types.add(new ParameterTypeDouble(PARAMETER_MOMENTUM, "The momentum simply adds a fraction of the previous weight update to the current one (prevent local maxima and smoothes optimization directions).", 0.0, 1.0, 0.2));
        types.add(new ParameterTypeBoolean(PARAMETER_DECAY, "Indicates if the learning rate should be decreased during learningh", false));
        types.add(new ParameterTypeBoolean(PARAMETER_SHUFFLE, "Indicates if the input data should be shuffled before learning (increases memory usage but is recommended if data is sorted before)", true));
        types.add(new ParameterTypeBoolean(PARAMETER_NORMALIZE, "Indicates if the input data should be normalized between -1 and +1 before learning (increases runtime but is in most cases necessary)", true));
        types.add(new ParameterTypeDouble(PARAMETER_ERROR_EPSILON, "The optimization is stopped if the training error gets below this epsilon value.", 0.0, Double.POSITIVE_INFINITY, 1.0E-5));
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return types;
    }
}

