def main(): """ Just runs some example code. """ # load a dataset iris_file = helper.get_data_dir() + os.sep + "iris.arff" helper.print_info("Loading dataset: " + iris_file) loader = Loader("weka.core.converters.ArffLoader") iris_data = loader.load_file(iris_file) iris_data.class_is_last() # train classifier classifier = Classifier("weka.classifiers.trees.J48") classifier.build_classifier(iris_data) # save and read object helper.print_title("I/O: model (using serialization module)") outfile = tempfile.gettempdir() + os.sep + "j48.model" serialization.write(outfile, classifier) model = Classifier(jobject=serialization.read(outfile)) print(model) # save classifier and dataset header (multiple objects) helper.print_title("I/O: model and header (using serialization module)") serialization.write_all( outfile, [classifier, Instances.template_instances(iris_data)]) objects = serialization.read_all(outfile) for i, obj in enumerate(objects): helper.print_info("Object #" + str(i + 1) + ":") if javabridge.get_env().is_instance_of( obj, javabridge.get_env().find_class("weka/core/Instances")): obj = Instances(jobject=obj) elif javabridge.get_env().is_instance_of( obj, javabridge.get_env().find_class( "weka/classifiers/Classifier")): obj = Classifier(jobject=obj) print(obj) # save and read object helper.print_title("I/O: just model (using Classifier class)") outfile = tempfile.gettempdir() + os.sep + "j48.model" classifier.serialize(outfile) model, _ = Classifier.deserialize(outfile) print(model) # save classifier and dataset header (multiple objects) helper.print_title("I/O: model and header (using Classifier class)") classifier.serialize(outfile, header=iris_data) model, header = Classifier.deserialize(outfile) print(model) if header is not None: print(header)
def TrainingModel(arff, modelOutput, clsfier): # 启动java虚拟机 jvm.start() # 导入训练集 loader = Loader(classname="weka.core.converters.ArffLoader") train = loader.load_file(arff) train.class_is_first() # 使用RandomForest算法进行训练,因为在GUI版本weka中使用多种方式训练后发现此方式TPR与TNR较高 cls_name = "weka.classifiers." + clsfier clsf = Classifier(classname=cls_name) clsf.build_classifier(train) print(clsf) # 建立模型 fc = FilteredClassifier() fc.classifier = clsf evl = Evaluation(train) evl.crossvalidate_model(fc, train, 10, Random(1)) print(evl.percent_correct) print(evl.summary()) print(evl.class_details()) print(evl.matrix()) # 结果统计 matrixResults = evl.confusion_matrix TN = float(matrixResults[0][0]) FP = float(matrixResults[0][1]) FN = float(matrixResults[1][0]) TP = float(matrixResults[1][1]) TPR = TP / (TP + FN) TNR = TN / (FP + TN) PPV = TP / (TP + FP) NPV = TN / (TN + FN) print("算法: " + clsfier) print("敏感度 TPR: " + str(TPR)) print("特异度 TNR: " + str(TNR)) print("PPV: " + str(PPV)) print("NPV: " + str(NPV)) # 保存模型 clsf.serialize(modelOutput, header=train) # 退出虚拟机 jvm.stop() print("分析模型建立完成")
def train(training_dataset_path, model_cache_file_name, evaluation_is_on, summary_file_path): """Model Training function The function uses the WEKA machine learning library, implemented by python-weka-wrapper Python library. Divides the data into given folds, and do the training and evaluation. Trained model copied to __predictors global variable and also saved (together with training data set) to the model_cache_file_name file. Evaluation summary is being written to summary_file_path file. Args: :param training_dataset_path: the path of the input arff file. :param model_cache_file_name: :param evaluation_is_on: run evaluation after training (true / false) :param summary_file_path: the path of the model evaluation summary file. Returns: None """ global __classifiers global __predictors training_data = converters.load_any_file(training_dataset_path) training_data.class_is_last() lines = [] summaries = [] summary_line = [ 'Model'.ljust(16), 'Precision'.ljust(12), 'Recall'.ljust(12), 'F-measure'.ljust(12), 'Accuracy'.ljust(12), 'FPR'.ljust(12) ] summaries.append('\t'.join(summary_line)) for classifier, option_str in __classifiers.items(): option_list = re.findall(r'"(?:[^"]+)"|(?:[^ ]+)', option_str) option_list = [s.replace('"', '') for s in option_list] classifier_name = classifier.split('.')[-1] info_str = "Using classifier: {classifier}, options: {options}".format( classifier=classifier_name, options=str(option_list)) localizer_log.msg(info_str) lines.append(info_str) # Train cls = Classifier(classname=classifier, options=option_list) localizer_log.msg("Start building classifier") cls.build_classifier(training_data) localizer_log.msg("Completed building classifier") localizer_log.msg("Saving trained model to {model_cache_name}".format( model_cache_name=model_cache_file_name)) # localizer_config.save_model(cls, training_data, model_cache_file_name) path = os.path.join('caches', 'model') if not os.path.exists(path): os.makedirs(path, exist_ok=True) path = os.path.join(path, model_cache_file_name + '.cache') cls.serialize(path) localizer_log.msg("Trained model saved") classifier2, _ = Classifier.deserialize(path) print(classifier2) __predictors[classifier_name] = cls if evaluation_is_on: # Model evaluation localizer_log.msg("Start evaluation classifier") evl = Evaluation(training_data) localizer_log.msg("Complete evaluation classifier") localizer_log.msg("Start cross-validating classifier") evl.crossvalidate_model(cls, training_data, 10, Random(1)) localizer_log.msg("Complete cross-validating classifier") # print(evl.percent_correct) # print(evl.summary()) # print(evl.class_details()) lines.append(evl.summary()) lines.append(evl.class_details()) summary_line = [] summary_line.append(classifier_name.ljust(16)) summary_line.append("{:.3f}".format(evl.weighted_precision * 100).ljust(12)) summary_line.append("{:.3f}".format(evl.weighted_recall * 100).ljust(12)) summary_line.append("{:.3f}".format(evl.weighted_f_measure * 100).ljust(12)) summary_line.append("{:.3f}".format(evl.percent_correct).ljust(12)) summary_line.append("{:.3f}".format( evl.weighted_false_positive_rate * 100).ljust(12)) summaries.append('\t'.join(summary_line)) # Save evaluation summary to file with open(summary_file_path, 'w') as f: f.writelines('\n'.join(lines)) f.writelines('\n' * 5) f.writelines('\n'.join(summaries))