예제 #1
0
    def test_cv_score(self):
        def label_weight(listInst, fLabel):
            dblWeight = 0.0
            for inst in listInst:
                if inst.fLabel == fLabel:
                    dblWeight += inst.dblWeight
            return dblWeight

        cValues = 4
        fxnGen = build_consistent_generator(cValues=cValues,
                                            fxnGenWeight=random.random)
        cInst = random.randint(30, 60)
        listLeft = fxnGen(cInst)
        listRight = [
            dtree.Instance([cAttr + cValues + 1
                            for cAttr in inst.listAttrs], inst.fLabel)
            for inst in fxnGen(cInst)
        ]
        fMajL = dtree.majority_label(listLeft)
        fMajR = dtree.majority_label(listRight)
        iterableFolds = [
            dtree.TreeFold(listLeft, listRight),
            dtree.TreeFold(listRight, listLeft)
        ]
        dblScore = dtree.cv_score(iterableFolds)
        dblL = label_weight(listRight, fMajL)
        dblR = label_weight(listLeft, fMajR)
        dblTotalWeight = sum([inst.dblWeight for inst in listRight + listLeft])
        self.assertAlmostEqual((dblL + dblR) / dblTotalWeight, dblScore)
    def test_evaluate_classification(self):
        def increase_values(inst):
            listIncreased = [c + cValues + 1 for c in inst.listAttrs]
            return dtree.Instance(listIncreased, not fMajorityLabel)

        def filter_unclassifiable(listInst):
            dt = dtree.build_tree(listInst)
            return [inst for inst in listInst
                    if dtree.classify(dt, inst) == inst.fLabel]
        cValues = 2
        fxnGen = build_instance_generator(cValues=cValues)
        listInst = fxnGen(15)
        force_instance_consistency(listInst)
        listInst = filter_unclassifiable(listInst)
        fMajorityLabel = dtree.majority_label(listInst)
        listInstImpossible = map(increase_values, listInst)
        listInstTest = listInst + listInstImpossible
        cvf = dtree.TreeFold(listInst, listInstTest)
        rslt = dtree.evaluate_classification(cvf)
        self.assertEqual(len(listInst), len(rslt.listInstCorrect))
        self.assertEqual(len(listInstImpossible), len(rslt.listInstIncorrect))
        self.assertTrue(check_instance_membership(
            listInst, rslt.listInstCorrect), "Missing correct instances")
        self.assertTrue(check_instance_membership(
            listInstImpossible, rslt.listInstIncorrect),
            "Missing incorrect instances")
예제 #3
0
 def test_evaluate_classification(self):
     def increase_values(inst):
         listIncreased = [c+cValues+1 for c in inst.listAttrs]
         return dtree.Instance(listIncreased, not fMajorityLabel)
     def filter_unclassifiable(listInst):
         dt = dtree.build_tree(listInst)
         return [inst for inst in listInst
                 if dtree.classify(dt,inst) == inst.fLabel]
     cValues = 2
     fxnGen = build_instance_generator(cValues=cValues)
     listInst = fxnGen(15)
     force_instance_consistency(listInst)
     listInst = filter_unclassifiable(listInst)
     fMajorityLabel = dtree.majority_label(listInst)
     listInstImpossible = map(increase_values,listInst)
     listInstTest = listInst + listInstImpossible
     cvf = dtree.TreeFold(listInst, listInstTest)
     rslt = dtree.evaluate_classification(cvf)
     self.assertEqual(len(listInst), len(rslt.listInstCorrect))
     self.assertEqual(len(listInstImpossible), len(rslt.listInstIncorrect))
     self.assertTrue(check_instance_membership(
         listInst, rslt.listInstCorrect), "Missing correct instances")
     self.assertTrue(check_instance_membership(
         listInstImpossible, rslt.listInstIncorrect),
                     "Missing incorrect instances")
예제 #4
0
 def test_build_tree_no_gain(self):
     listAttr = randlist(0,5,10)
     listInst = [dtree.Instance(listAttr, randbool())]*random.randint(25,30)
     dt = dtree.build_tree(listInst)
     fMajorityLabel = dtree.majority_label(listInst)
     self.assertTrue(dt.is_leaf())
     self.assertEquals(dt.fLabel, fMajorityLabel)        
 def test_majority_label(self):
     fxnGenTrue = build_instance_generator(1.0)
     fxnGenFalse = build_instance_generator(0.0)
     cLenTrue = random.randint(5, 10)
     cLenFalse = random.randint(5, 10)
     if cLenTrue == cLenFalse:
         cLenTrue += 1
     listInst = fxnGenTrue(cLenTrue) + fxnGenFalse(cLenFalse)
     fMajorityLabel = dtree.majority_label(listInst)
     self.assertEqual(fMajorityLabel, cLenTrue > cLenFalse)
예제 #6
0
 def test_majority_label(self):
     fxnGenTrue = build_instance_generator(1.0)
     fxnGenFalse = build_instance_generator(0.0)
     cLenTrue = random.randint(5,10)
     cLenFalse = random.randint(5,10)
     if cLenTrue == cLenFalse:
         cLenTrue += 1
     listInst = fxnGenTrue(cLenTrue) + fxnGenFalse(cLenFalse)
     fMajorityLabel = dtree.majority_label(listInst)
     self.assertEqual(fMajorityLabel, cLenTrue > cLenFalse)
예제 #7
0
파일: testdtree.py 프로젝트: dzhu/cs181
 def test_cv_score(self):
     def label_weight(listInst, fLabel):
         dblWeight = 0.0
         for inst in listInst:
             if inst.fLabel == fLabel:
                 dblWeight += inst.dblWeight
         return dblWeight
     cValues = 4
     fxnGen = build_consistent_generator(cValues=cValues)
     cInst = random.randint(30,60)
     listLeft = fxnGen(cInst)
     listRight = [dtree.Instance([cAttr+cValues+1
                                  for cAttr in inst.listAttrs],
                           inst.fLabel) for inst in fxnGen(cInst)]
     fMajL = dtree.majority_label(listLeft)
     fMajR = dtree.majority_label(listRight)
     iterableFolds = dtree.yield_cv_folds(listLeft + listRight, 2)
     dblScore = dtree.cv_score(iterableFolds)
     dblL = label_weight(listRight, fMajL)
     dblR = label_weight(listLeft, fMajR)
     self.assertAlmostEqual((dblL + dblR)/(2.0*cInst), dblScore)
예제 #8
0
 def test_majority_label_weighted(self):
     dblScale = 25.0
     def gen_insts_for_label(fLabel):
         dblW = random.random() * dblScale
         listInst = []
         dblInstWeight = 0.0
         while dblInstWeight < dblW:
             dblNextWeight = random.random()
             listInst.append(dtree.Instance([],fLabel,dblNextWeight))
             dblInstWeight += dblNextWeight
         return listInst,dblInstWeight
     listInstT,dblT = gen_insts_for_label(True)
     listInstF,dblF = gen_insts_for_label(False)
     listInstAll = listInstT + listInstF
     random.shuffle(listInstAll)
     fMajorityLabel = dtree.majority_label(listInstAll)
     self.assertEqual(dblT > dblF, fMajorityLabel)