def test_generate_thresholdcurve_data(self):
        """
        Tests the generate_thresholdcurve_data method.
        """
        loader = converters.Loader(classname="weka.core.converters.ArffLoader")
        data = loader.load_file(self.datafile("diabetes.arff"))
        data.class_is_last()

        remove = filters.Filter(classname="weka.filters.unsupervised.attribute.Remove", options=["-R", "1-3"])
        cls = classifiers.Classifier(classname="weka.classifiers.bayes.NaiveBayes")
        fc = classifiers.FilteredClassifier()
        fc.filter = remove
        fc.classifier = cls

        evl = classifiers.Evaluation(data)
        evl.crossvalidate_model(cls, data, 10, Random(1))
        data = plot.generate_thresholdcurve_data(evl, 0)
        self.assertEqual(13, data.num_attributes, msg="number of attributes differs")
        self.assertEqual(769, data.num_instances, msg="number of rows differs")
        attname = "True Positives"
        self.assertIsNotNone(data.attribute_by_name(attname), msg="Failed to locate attribute: " + attname)
        attname = "False Positive Rate"
        self.assertIsNotNone(data.attribute_by_name(attname), msg="Failed to locate attribute: " + attname)
        attname = "Lift"
        self.assertIsNotNone(data.attribute_by_name(attname), msg="Failed to locate attribute: " + attname)
    def test_generate_thresholdcurve_data(self):
        """
        Tests the generate_thresholdcurve_data method.
        """
        loader = converters.Loader(classname="weka.core.converters.ArffLoader")
        data = loader.load_file(self.datafile("diabetes.arff"))
        data.class_is_last()

        remove = filters.Filter(
            classname="weka.filters.unsupervised.attribute.Remove",
            options=["-R", "1-3"])
        cls = classifiers.Classifier(
            classname="weka.classifiers.bayes.NaiveBayes")
        fc = classifiers.FilteredClassifier()
        fc.filter = remove
        fc.classifier = cls

        evl = classifiers.Evaluation(data)
        evl.crossvalidate_model(cls, data, 10, Random(1))
        data = plot.generate_thresholdcurve_data(evl, 0)
        self.assertEqual(13,
                         data.num_attributes,
                         msg="number of attributes differs")
        self.assertEqual(769, data.num_instances, msg="number of rows differs")
        attname = "True Positives"
        self.assertIsNotNone(data.attribute_by_name(attname),
                             msg="Failed to locate attribute: " + attname)
        attname = "False Positive Rate"
        self.assertIsNotNone(data.attribute_by_name(attname),
                             msg="Failed to locate attribute: " + attname)
        attname = "Lift"
        self.assertIsNotNone(data.attribute_by_name(attname),
                             msg="Failed to locate attribute: " + attname)
    def test_get_prc(self):
        """
        Tests the get_prc method.
        """
        loader = converters.Loader(classname="weka.core.converters.ArffLoader")
        data = loader.load_file(self.datafile("diabetes.arff"))
        data.class_is_last()

        remove = filters.Filter(classname="weka.filters.unsupervised.attribute.Remove", options=["-R", "1-3"])
        cls = classifiers.Classifier(classname="weka.classifiers.bayes.NaiveBayes")
        fc = classifiers.FilteredClassifier()
        fc.filter = remove
        fc.classifier = cls

        evl = classifiers.Evaluation(data)
        evl.crossvalidate_model(cls, data, 10, Random(1))
        data = plot.generate_thresholdcurve_data(evl, 0)
        area = plot.get_prc(data)
        self.assertAlmostEqual(0.892, area, places=3, msg="PRC differs")
Exemplo n.º 4
0
    def test_get_prc(self):
        """
        Tests the get_prc method.
        """
        loader = converters.Loader(classname="weka.core.converters.ArffLoader")
        data = loader.load_file(self.datafile("diabetes.arff"))
        data.class_is_last()

        remove = filters.Filter(classname="weka.filters.unsupervised.attribute.Remove", options=["-R", "1-3"])
        cls = classifiers.Classifier(classname="weka.classifiers.bayes.NaiveBayes")
        fc = classifiers.FilteredClassifier()
        fc.filter = remove
        fc.classifier = cls

        evl = classifiers.Evaluation(data)
        evl.crossvalidate_model(cls, data, 10, Random(1))
        data = plot.generate_thresholdcurve_data(evl, 0)
        area = plot.get_prc(data)
        self.assertAlmostEqual(0.892, area, places=3, msg="PRC differs")
Exemplo n.º 5
0
fname = data_dir + os.sep + "ReutersGrain-test.arff"
print("\nLoading dataset: " + fname + "\n")
loader = Loader(classname="weka.core.converters.ArffLoader")
test = loader.load_file(fname)
test.set_class_index(test.num_attributes() - 1)

setups = (
    ("weka.classifiers.trees.J48", []),
    ("weka.classifiers.bayes.NaiveBayes", []),
    ("weka.classifiers.bayes.NaiveBayesMultinomial", []),
    ("weka.classifiers.bayes.NaiveBayesMultinomial", ["-C"]),
    ("weka.classifiers.bayes.NaiveBayesMultinomial", ["-C", "-L", "-S"])
)

# cross-validate classifiers
for setup in setups:
    classifier, opt = setup
    print("\n--> %s (filter options: %s)\n" % (classifier, " ".join(opt)))
    cls = FilteredClassifier()
    cls.set_classifier(Classifier(classname=classifier))
    cls.set_filter(Filter(classname="weka.filters.unsupervised.attribute.StringToWordVector", options=opt))
    cls.build_classifier(data)
    evl = Evaluation(test)
    evl.test_model(cls, test)
    print("Accuracy: %0.0f%%" % evl.percent_correct())
    tcdata = plc.generate_thresholdcurve_data(evl, 0)
    print("AUC: %0.3f" % plc.get_auc(tcdata))
    print(evl.to_matrix("Matrix:"))

jvm.stop()
Exemplo n.º 6
0
fname = data_dir + os.sep + "ReutersGrain-test.arff"
print("\nLoading dataset: " + fname + "\n")
loader = Loader(classname="weka.core.converters.ArffLoader")
test = loader.load_file(fname)
test.class_is_last()

setups = (
    ("weka.classifiers.trees.J48", []),
    ("weka.classifiers.bayes.NaiveBayes", []),
    ("weka.classifiers.bayes.NaiveBayesMultinomial", []),
    ("weka.classifiers.bayes.NaiveBayesMultinomial", ["-C"]),
    ("weka.classifiers.bayes.NaiveBayesMultinomial", ["-C", "-L", "-stopwords-handler", "weka.core.stopwords.Rainbow"])
)

# cross-validate classifiers
for setup in setups:
    classifier, opt = setup
    print("\n--> %s (filter options: %s)\n" % (classifier, " ".join(opt)))
    cls = FilteredClassifier()
    cls.classifier = Classifier(classname=classifier)
    cls.filter = Filter(classname="weka.filters.unsupervised.attribute.StringToWordVector", options=opt)
    cls.build_classifier(data)
    evl = Evaluation(test)
    evl.test_model(cls, test)
    print("Accuracy: %0.0f%%" % evl.percent_correct)
    tcdata = plc.generate_thresholdcurve_data(evl, 0)
    print("AUC: %0.3f" % plc.get_auc(tcdata))
    print(evl.matrix("Matrix:"))

jvm.stop()