def test_cv_score(self): def label_weight(listInst, fLabel): dblWeight = 0.0 for inst in listInst: if inst.fLabel == fLabel: dblWeight += inst.dblWeight return dblWeight cValues = 4 fxnGen = build_consistent_generator(cValues=cValues, fxnGenWeight=random.random) cInst = random.randint(30, 60) listLeft = fxnGen(cInst) listRight = [ dtree.Instance([cAttr + cValues + 1 for cAttr in inst.listAttrs], inst.fLabel) for inst in fxnGen(cInst) ] fMajL = dtree.majority_label(listLeft) fMajR = dtree.majority_label(listRight) iterableFolds = [ dtree.TreeFold(listLeft, listRight), dtree.TreeFold(listRight, listLeft) ] dblScore = dtree.cv_score(iterableFolds) dblL = label_weight(listRight, fMajL) dblR = label_weight(listLeft, fMajR) dblTotalWeight = sum([inst.dblWeight for inst in listRight + listLeft]) self.assertAlmostEqual((dblL + dblR) / dblTotalWeight, dblScore)
def task(self): listInstClean = get_clean_insts() listInstNoisy = get_noisy_insts() listData = [] listNames = ["Clean", "Noisy"] for listInst, sName in zip([listInstClean, listInstNoisy], listNames): dt = dtree.build_tree(listInst) tf = dtree.TreeFold(listInst, listInst) rslt = dtree.evaluate_classification(tf) dblCorrect, dblIncorrect = dtree.weight_correct_incorrect(rslt) dblAccuracy = dblCorrect / (dblCorrect + dblIncorrect) listData.append(dblAccuracy) return { "chart": { "defaultSeriesType": "column" }, "title": { "text": "Clean vs. Noisy Training Set Accuracy" }, "xAxis": { "categories": listNames }, "yAxis": { "title": { "text": "Accuracy" }, "min": 0.0, "max": 1.0 }, "series": [{ "name": "Training Set Accuracy", "data": listData }] }
def test_evaluate_classification(self): def increase_values(inst): listIncreased = [c+cValues+1 for c in inst.listAttrs] return dtree.Instance(listIncreased, not fMajorityLabel) def filter_unclassifiable(listInst): dt = dtree.build_tree(listInst) return [inst for inst in listInst if dtree.classify(dt,inst) == inst.fLabel] cValues = 2 fxnGen = build_instance_generator(cValues=cValues) listInst = fxnGen(15) force_instance_consistency(listInst) listInst = filter_unclassifiable(listInst) fMajorityLabel = dtree.majority_label(listInst) listInstImpossible = map(increase_values,listInst) listInstTest = listInst + listInstImpossible cvf = dtree.TreeFold(listInst, listInstTest) rslt = dtree.evaluate_classification(cvf) self.assertEqual(len(listInst), len(rslt.listInstCorrect)) self.assertEqual(len(listInstImpossible), len(rslt.listInstIncorrect)) self.assertTrue(check_instance_membership( listInst, rslt.listInstCorrect), "Missing correct instances") self.assertTrue(check_instance_membership( listInstImpossible, rslt.listInstIncorrect), "Missing incorrect instances")