def train_kea(model_filename, train_directory, file_extension, ref_extension, ref_tokenization_function, stemmer, pre_processor, candidate_extractor, candidate_clusterer, tfidf_ranker): """ TODO """ #classifier = MultinomialNB() classifier = None if not path.exists(model_filename): feature_sets = [] target_sets = [] working_pool = Pool() pool_args = [] for filename in listdir(train_directory): if filename.rfind(file_extension) >= 0 \ and len(filename) - filename.rfind(file_extension) == len(file_extension): pool_args.append( (filename, train_directory, file_extension, ref_extension, ref_tokenization_function, stemmer, pre_processor, candidate_extractor, candidate_clusterer, tfidf_ranker)) feature_sets_and_target_sets = working_pool.map( feature_class_extraction_pool_worker, pool_args) #feature_sets_and_target_sets = [] #for args in pool_args: # feature_sets_and_target_sets.append(feature_class_extraction_pool_worker(args)) for feature_set, target_set in feature_sets_and_target_sets: for i, features in enumerate(feature_set): #feature_sets.append(features) #target_sets.append(target_set[i]) feature_sets.append((features, target_set[i])) # classifier training #classifier.fit(feature_sets, target_sets) WekaClassifier._CLASSIFIER_CLASS[ "naivebayessimple"] = "weka.classifiers.bayes.NaiveBayesSimple" classifier = WekaClassifier.train(model_filename + "_train", feature_sets, "naivebayessimple") # serialize the classifier classifier_file = open(model_filename, "w") pickle.dump(classifier, classifier_file) classifier_file.close() else: # load the serialized classifier classifier_file = open(model_filename, "r") classifier = pickle.load(classifier_file) classifier_file.close() return classifier
def train(self,train_set): nltk.classify.config_java(bin="/usr/bin/java",options=["-Xmx15g"]) #classpath='/Applications/MacPorts/Weka.app/Contents/Resources/Java/weka.jar classpath='/Applications/weka-3-6-5.app/Contents/Resources/Java/weka.jar' nltk.classify.config_weka(classpath=classpath) WekaClassifier._CLASSIFIER_CLASS = { 'naivebayes': 'weka.classifiers.bayes.NaiveBayes', 'C4.5': 'weka.classifiers.trees.J48', 'log_regression': 'weka.classifiers.functions.Logistic', 'svm': 'weka.classifiers.functions.SMO', 'kstar': 'weka.classifiers.lazy.kstar', 'ripper': 'weka.classifiers.rules.JRip', 'MultilayerPerceptron':'weka.classifiers.functions.MultilayerPerceptron' } self.fname='/tmp/weka.%s.model'%random.randint(99, 9999999) self.classifier = WekaClassifier.train(self.fname, train_set, classifier=self.weka_classifier)
def call_train_weka(train_feats, **train_kwargs): return WekaClassifier.train(train_kwargs['model_filename'], train_feats, 'C4.5')