/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.postag;

import java.io.File;
import java.io.IOException;
import java.util.Map;
import opennlp.tools.postag.MutableTagDictionary;
import opennlp.tools.postag.POSEvaluator;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSSample;
import opennlp.tools.postag.POSTagFormat;
import opennlp.tools.postag.POSTaggerEvaluationMonitor;
import opennlp.tools.postag.POSTaggerFactory;
import opennlp.tools.postag.POSTaggerME;
import opennlp.tools.postag.TagDictionary;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.eval.CrossValidationPartitioner;
import opennlp.tools.util.eval.Mean;

public class POSTaggerCrossValidator {
    private final String languageCode;
    private final TrainingParameters params;
    private final POSTagFormat posTagFormat;
    private byte[] featureGeneratorBytes;
    private Map<String, Object> resources;
    private final Mean wordAccuracy = new Mean();
    private final POSTaggerEvaluationMonitor[] listeners;
    private String factoryClassName;
    private POSTaggerFactory factory;
    private final Integer tagdicCutoff;
    private File tagDictionaryFile;

    public POSTaggerCrossValidator(String languageCode, TrainingParameters trainParam, File tagDictionary, byte[] featureGeneratorBytes, Map<String, Object> resources, Integer tagdicCutoff, String factoryClass, POSTagFormat format, POSTaggerEvaluationMonitor ... listeners) {
        this.languageCode = languageCode;
        this.params = trainParam;
        this.featureGeneratorBytes = featureGeneratorBytes;
        this.resources = resources;
        this.listeners = listeners;
        this.factoryClassName = factoryClass;
        this.tagdicCutoff = tagdicCutoff;
        this.tagDictionaryFile = tagDictionary;
        this.posTagFormat = format;
    }

    public POSTaggerCrossValidator(String languageCode, TrainingParameters trainParam, File tagDictionary, byte[] featureGeneratorBytes, Map<String, Object> resources, Integer tagdicCutoff, String factoryClass, POSTaggerEvaluationMonitor ... listeners) {
        this(languageCode, trainParam, tagDictionary, featureGeneratorBytes, resources, tagdicCutoff, factoryClass, POSTagFormat.UD, listeners);
    }

    public POSTaggerCrossValidator(String languageCode, TrainingParameters trainParam, POSTaggerFactory factory, POSTaggerEvaluationMonitor ... listeners) {
        this(languageCode, trainParam, factory, POSTagFormat.UD, listeners);
    }

    public POSTaggerCrossValidator(String languageCode, TrainingParameters trainParam, POSTaggerFactory factory, POSTagFormat format, POSTaggerEvaluationMonitor ... listeners) {
        this.languageCode = languageCode;
        this.params = trainParam;
        this.listeners = listeners;
        this.factory = factory;
        this.posTagFormat = format;
        this.tagdicCutoff = null;
    }

    public void evaluate(ObjectStream<POSSample> samples, int nFolds) throws IOException {
        CrossValidationPartitioner<POSSample> partitioner = new CrossValidationPartitioner<POSSample>(samples, nFolds);
        while (partitioner.hasNext()) {
            CrossValidationPartitioner.TrainingSampleStream<POSSample> trainingSampleStream = partitioner.next();
            if (this.tagDictionaryFile != null && this.factory.getTagDictionary() == null) {
                this.factory.setTagDictionary(this.factory.createTagDictionary(this.tagDictionaryFile));
            }
            TagDictionary dict = null;
            if (this.tagdicCutoff != null) {
                dict = this.factory.getTagDictionary();
                if (dict == null) {
                    dict = this.factory.createEmptyTagDictionary();
                }
                if (!(dict instanceof MutableTagDictionary)) {
                    throw new IllegalArgumentException("Can't extend a TagDictionary that does not implement MutableTagDictionary.");
                }
                POSTaggerME.populatePOSDictionary(trainingSampleStream, (MutableTagDictionary)dict, this.tagdicCutoff);
                trainingSampleStream.reset();
            }
            if (this.factory == null) {
                this.factory = POSTaggerFactory.create(this.factoryClassName, null, null);
            }
            this.factory.init(this.featureGeneratorBytes, this.resources, dict);
            POSModel model = POSTaggerME.train(this.languageCode, trainingSampleStream, this.params, this.factory);
            POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model, this.posTagFormat), this.listeners);
            evaluator.evaluate(trainingSampleStream.getTestSampleStream());
            this.wordAccuracy.add(evaluator.getWordAccuracy(), evaluator.getWordCount());
            if (this.tagdicCutoff == null) continue;
            this.factory.setTagDictionary(null);
        }
    }

    public double getWordAccuracy() {
        return this.wordAccuracy.mean();
    }

    public long getWordCount() {
        return this.wordAccuracy.count();
    }
}

