/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.preprocessing.sampling;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.example.set.Partition;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.annotation.ResourceConsumptionEstimator;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.WeightedPerformanceMeasures;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MDInteger;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.PredictionModelMetaData;
import com.rapidminer.operator.preprocessing.sampling.AbstractSamplingOperator;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.tools.OperatorResourceConsumptionHandler;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;

public class ModelBasedSampling
extends AbstractSamplingOperator {
    private InputPort modelInput = this.getInputPorts().createPort("model", PredictionModel.class);

    public ModelBasedSampling(OperatorDescription description) {
        super(description);
    }

    @Override
    protected MetaData modifyMetaData(ExampleSetMetaData metaData) {
        List<AttributeMetaData> predictionAttributes;
        MetaData modelMetaData = this.modelInput.getMetaData();
        if (modelMetaData instanceof PredictionModelMetaData && (predictionAttributes = ((PredictionModelMetaData)modelMetaData).getPredictionAttributeMetaData()) != null) {
            metaData.addAllAttributes(predictionAttributes);
            metaData.mergeSetRelation(((PredictionModelMetaData)modelMetaData).getPredictionAttributeSetRelation());
        }
        metaData.addAttribute(Tools.createWeightAttributeMetaData(metaData));
        metaData.setNumberOfExamples(this.getSampledSize(metaData));
        return metaData;
    }

    @Override
    protected MDInteger getSampledSize(ExampleSetMetaData emd) {
        return new MDInteger();
    }

    @Override
    public ExampleSet apply(ExampleSet exampleSet) throws OperatorException {
        PredictionModel model = (PredictionModel)this.modelInput.getData();
        exampleSet = model.apply(exampleSet);
        Attribute weightAttr = exampleSet.getAttributes().getWeight();
        if (weightAttr == null) {
            weightAttr = Tools.createWeightAttribute(exampleSet);
        }
        WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(exampleSet);
        WeightedPerformanceMeasures.reweightExamples(exampleSet, wp.getContingencyMatrix(), true);
        exampleSet.recalculateAttributeStatistics(exampleSet.getAttributes().getWeight());
        double maxWeight = exampleSet.getStatistics(exampleSet.getAttributes().getWeight(), "maximum");
        RandomGenerator randomGenerator = RandomGenerator.getRandomGenerator(this);
        int[] remappingIndices = new int[exampleSet.size()];
        int i = 0;
        for (Example example : exampleSet) {
            if (randomGenerator.nextDouble() > example.getValue(weightAttr) / maxWeight) {
                example.setValue(weightAttr, 1.0);
                remappingIndices[i] = 1;
            }
            ++i;
        }
        this.checkForStop();
        SplittedExampleSet splittedExampleSet = new SplittedExampleSet(exampleSet, new Partition(remappingIndices, 2));
        splittedExampleSet.selectSingleSubset(1);
        return splittedExampleSet;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return types;
    }

    @Override
    public ResourceConsumptionEstimator getResourceConsumptionEstimator() {
        return OperatorResourceConsumptionHandler.getResourceConsumptionEstimator(this.getInputPort(), ModelBasedSampling.class, null);
    }
}

