def mode2(examples, target_attr): v = [] db = ShroomDatabase(examples) if target_attr == 'class': v = db.fetch_class_vector() else: v = db.fetch_class_vector(target_attr) return mode(v)
def mode2(examples, target_attr): v = [] db = ShroomDatabase(examples) if target_attr=='class': v = db.fetch_class_vector() else: v = db.fetch_class_vector(target_attr) return mode(v)
def id3(criteria, db, target_attr, attributes, defs): v = db.fetch_class_vector() homogeneous, label = is_homogeneous(v) if (homogeneous): return ID3Tree(ID3LeafNode(label)) if len(attributes) == 0: label = mode2(examples, target_attr)[0] return ID3Tree(ID3LeafNode(label)) A, stat = criteria.recommend_next_attr(attributes, db, defs) decision_node = None if isinstance(criteria, InformationGainCriteria): decision_node = ID3DecisionNode(A, stat, 0.0) else: decision_node = ID3DecisionNode(A, 0.0, stat) tree = ID3Tree(decision_node) for v in defs.attr_values[A]: edge = ID3Edge(A, v) subset_records = [] for x in db.records: if ((x.attributes[A] == v) and (x.attributes[A] != '?')): subset_records.append(x) if len(subset_records) == 0: label = mode2(db.records, 'class')[0] leaf_node = ID3LeafNode(label) tree.add_node(decision_node, edge, leaf_node) else: subattributes = attributes - set([A]) subset_db = ShroomDatabase(subset_records) subtree = id3(criteria, subset_db, target_attr, subattributes, defs) tree.add_tree(decision_node, edge, subtree) return tree
def calc_all_class_error(attributes, db, defs): """Calculates the classification error for all attributes of a ShroomDatabase.""" tot_len = len(db.fetch_class_vector()) class_error_table = dict() class_error_before = calc_class_error(db.fetch_class_vector()) for attr in attributes: interim_db_list = [] for symbol in defs.attr_values[attr]: interim_list = [] for r in db.records: if r.attributes[attr] == symbol: interim_list.append(r) interim_db_list.append(ShroomDatabase(interim_list)) interim_class_error = 0 for idb in interim_db_list: vec = idb.fetch_class_vector() interim_class_error += (len(vec) / tot_len * calc_class_error(vec)) class_error_after = class_error_before - interim_class_error class_error_table[attr] = class_error_after return class_error_table
#!/usr/bin/env python # # filename: test_database.py # authors: Jon David and Jarrett Decker # date: Thursday, February 10, 2016 # import pdb from database import ShroomDatabase trainfilename = "./data/training.dat" testfilename = "./data/testing.dat" outfilename = "./output.dat" print("\n\n==== Test database ====") print("training file: ", trainfilename) print("testing file: ", testfilename) print("output file: ", outfilename) training_db = ShroomDatabase([], trainfilename) test_db = ShroomDatabase([], testfilename) out_db = ShroomDatabase(test_db.records) out_db.save_data(outfilename) print("\nVerify {} and {} are equal.".format(testfilename, outfilename)) print("Use diff -Z, to ignore whitespace at end of line")
from datadef import ShroomDefs from database import ShroomDatabase from id3_v2 import * deffilename = "_shroom.data.definition" trainfilename = "./data/training.dat" testfilename = "./data/testing.dat" print("\n\n==== Test id3_v2 ====") print("definition file: ", deffilename) print("training set: ", trainfilename) mydefs = ShroomDefs(deffilename) mydb = ShroomDatabase([], trainfilename) #pdb.set_trace() gain_criteria = InformationGainCriteria() tree = id3(gain_criteria, mydb, 'class', mydefs.attr_set, mydefs) tree.print_entire_tree() print("\n\n==== Test classification====") testdb = ShroomDatabase([], testfilename) record = testdb.records[0] classification = tree.classify(record) print("record to classify: " + record.get_raw_string()) print("classification: " + classification)
#!/usr/bin/env python # # filename: test_gain.py # authors: Jon David and Jarett Decker # date: Thursday, February 4, 2016 # from datadef import ShroomDefs from database import ShroomDatabase from id3 import * deffilename = "_shroom.data.definition" trainfilename = "./data/training.dat" print("\n") print("definition file: ", deffilename) print("training set: ", trainfilename) mydefs = ShroomDefs(deffilename) mydb = ShroomDatabase([], trainfilename) gain_table = calc_all_gain(mydefs.attr_set, mydb, mydefs) print("\nGain table:") print("===========================") for attr in gain_table: gain = gain_table[attr] print(attr, ": ", gain) rmend_attr, gain = recommend_next_attr(gain_table) print("\nRecommend: ", rmend_attr, ", ", gain)