Ejemplo n.º 1
0
def main():
    """
    Just runs some example code.
    """

    # load a dataset
    iris_file = helper.get_data_dir() + os.sep + "iris.arff"
    helper.print_info("Loading dataset: " + iris_file)
    loader = Loader("weka.core.converters.ArffLoader")
    iris_data = loader.load_file(iris_file)
    iris_data.class_is_last()

    # train classifier
    classifier = Classifier("weka.classifiers.trees.J48")
    classifier.build_classifier(iris_data)

    # save and read object
    helper.print_title("I/O: model (using serialization module)")
    outfile = tempfile.gettempdir() + os.sep + "j48.model"
    serialization.write(outfile, classifier)
    model = Classifier(jobject=serialization.read(outfile))
    print(model)

    # save classifier and dataset header (multiple objects)
    helper.print_title("I/O: model and header (using serialization module)")
    serialization.write_all(
        outfile,
        [classifier, Instances.template_instances(iris_data)])
    objects = serialization.read_all(outfile)
    for i, obj in enumerate(objects):
        helper.print_info("Object #" + str(i + 1) + ":")
        if javabridge.get_env().is_instance_of(
                obj,
                javabridge.get_env().find_class("weka/core/Instances")):
            obj = Instances(jobject=obj)
        elif javabridge.get_env().is_instance_of(
                obj,
                javabridge.get_env().find_class(
                    "weka/classifiers/Classifier")):
            obj = Classifier(jobject=obj)
        print(obj)

    # save and read object
    helper.print_title("I/O: just model (using Classifier class)")
    outfile = tempfile.gettempdir() + os.sep + "j48.model"
    classifier.serialize(outfile)
    model, _ = Classifier.deserialize(outfile)
    print(model)

    # save classifier and dataset header (multiple objects)
    helper.print_title("I/O: model and header (using Classifier class)")
    classifier.serialize(outfile, header=iris_data)
    model, header = Classifier.deserialize(outfile)
    print(model)
    if header is not None:
        print(header)
def TrainingModel(arff, modelOutput, clsfier):
    # 启动java虚拟机
    jvm.start()
    # 导入训练集
    loader = Loader(classname="weka.core.converters.ArffLoader")
    train = loader.load_file(arff)
    train.class_is_first()
    # 使用RandomForest算法进行训练,因为在GUI版本weka中使用多种方式训练后发现此方式TPR与TNR较高
    cls_name = "weka.classifiers." + clsfier
    clsf = Classifier(classname=cls_name)
    clsf.build_classifier(train)
    print(clsf)
    # 建立模型
    fc = FilteredClassifier()
    fc.classifier = clsf
    evl = Evaluation(train)
    evl.crossvalidate_model(fc, train, 10, Random(1))
    print(evl.percent_correct)
    print(evl.summary())
    print(evl.class_details())
    print(evl.matrix())
    # 结果统计
    matrixResults = evl.confusion_matrix
    TN = float(matrixResults[0][0])
    FP = float(matrixResults[0][1])
    FN = float(matrixResults[1][0])
    TP = float(matrixResults[1][1])
    TPR = TP / (TP + FN)
    TNR = TN / (FP + TN)
    PPV = TP / (TP + FP)
    NPV = TN / (TN + FN)
    print("算法: " + clsfier)
    print("敏感度 TPR: " + str(TPR))
    print("特异度 TNR: " + str(TNR))
    print("PPV: " + str(PPV))
    print("NPV: " + str(NPV))
    # 保存模型
    clsf.serialize(modelOutput, header=train)
    # 退出虚拟机
    jvm.stop()
    print("分析模型建立完成")
Ejemplo n.º 3
0
def train(training_dataset_path, model_cache_file_name, evaluation_is_on,
          summary_file_path):
    """Model Training function

    The function uses the WEKA machine learning library, implemented by
    python-weka-wrapper Python library. Divides the data into given
    folds, and do the training and evaluation. Trained model copied to __predictors global variable
    and also saved (together with training data set) to the model_cache_file_name file. Evaluation summary is being written to summary_file_path file.

    Args:
        :param training_dataset_path: the path of the input arff file.
        :param model_cache_file_name:
        :param evaluation_is_on: run evaluation after training (true / false)
        :param summary_file_path: the path of the model evaluation summary file.

    Returns:
        None
    """

    global __classifiers
    global __predictors

    training_data = converters.load_any_file(training_dataset_path)
    training_data.class_is_last()

    lines = []
    summaries = []
    summary_line = [
        'Model'.ljust(16), 'Precision'.ljust(12), 'Recall'.ljust(12),
        'F-measure'.ljust(12), 'Accuracy'.ljust(12), 'FPR'.ljust(12)
    ]
    summaries.append('\t'.join(summary_line))

    for classifier, option_str in __classifiers.items():
        option_list = re.findall(r'"(?:[^"]+)"|(?:[^ ]+)', option_str)
        option_list = [s.replace('"', '') for s in option_list]

        classifier_name = classifier.split('.')[-1]
        info_str = "Using classifier: {classifier}, options: {options}".format(
            classifier=classifier_name, options=str(option_list))
        localizer_log.msg(info_str)
        lines.append(info_str)

        # Train
        cls = Classifier(classname=classifier, options=option_list)
        localizer_log.msg("Start building classifier")
        cls.build_classifier(training_data)
        localizer_log.msg("Completed building classifier")
        localizer_log.msg("Saving trained model to {model_cache_name}".format(
            model_cache_name=model_cache_file_name))

        # localizer_config.save_model(cls, training_data, model_cache_file_name)
        path = os.path.join('caches', 'model')
        if not os.path.exists(path):
            os.makedirs(path, exist_ok=True)
        path = os.path.join(path, model_cache_file_name + '.cache')
        cls.serialize(path)
        localizer_log.msg("Trained model saved")

        classifier2, _ = Classifier.deserialize(path)
        print(classifier2)

        __predictors[classifier_name] = cls

        if evaluation_is_on:

            # Model evaluation
            localizer_log.msg("Start evaluation classifier")
            evl = Evaluation(training_data)
            localizer_log.msg("Complete evaluation classifier")

            localizer_log.msg("Start cross-validating classifier")
            evl.crossvalidate_model(cls, training_data, 10, Random(1))
            localizer_log.msg("Complete cross-validating classifier")

            # print(evl.percent_correct)
            # print(evl.summary())
            # print(evl.class_details())

            lines.append(evl.summary())
            lines.append(evl.class_details())

            summary_line = []
            summary_line.append(classifier_name.ljust(16))
            summary_line.append("{:.3f}".format(evl.weighted_precision *
                                                100).ljust(12))
            summary_line.append("{:.3f}".format(evl.weighted_recall *
                                                100).ljust(12))
            summary_line.append("{:.3f}".format(evl.weighted_f_measure *
                                                100).ljust(12))
            summary_line.append("{:.3f}".format(evl.percent_correct).ljust(12))
            summary_line.append("{:.3f}".format(
                evl.weighted_false_positive_rate * 100).ljust(12))
            summaries.append('\t'.join(summary_line))

            # Save evaluation summary to file
            with open(summary_file_path, 'w') as f:
                f.writelines('\n'.join(lines))
                f.writelines('\n' * 5)
                f.writelines('\n'.join(summaries))