def test_cv_score(self): def label_weight(listInst, fLabel): dblWeight = 0.0 for inst in listInst: if inst.fLabel == fLabel: dblWeight += inst.dblWeight return dblWeight cValues = 4 fxnGen = build_consistent_generator(cValues=cValues, fxnGenWeight=random.random) cInst = random.randint(30,60) listLeft = fxnGen(cInst) listRight = [dtree.Instance([cAttr+cValues+1 for cAttr in inst.listAttrs], inst.fLabel) for inst in fxnGen(cInst)] fMajL = dtree.majority_label(listLeft) fMajR = dtree.majority_label(listRight) iterableFolds = [dtree.TreeFold(listLeft,listRight), dtree.TreeFold(listRight,listLeft)] dblScore = dtree.cv_score(iterableFolds) dblL = label_weight(listRight, fMajL) dblR = label_weight(listLeft, fMajR) dblTotalWeight = sum([inst.dblWeight for inst in listRight + listLeft]) self.assertAlmostEqual((dblL + dblR)/dblTotalWeight, dblScore)
def build_jagged_instances(): return [dtree.Instance([0]*random.randint(5,10)) for _ in xrange(random.randint(25,30))]
def test_prune_tree(self): """ Test bottom-up pruning with a validation set. The test builds a random tree, then randomly chooses a node at which to prune. To induce pruning, the test does the following: - set the default label of the node to T - set the default label of the nodes, and actual label of the leaves, of all descendants to F - generate a large number of T instances that follow a path through the node - set the default labels of all ancestors of the node to F - prune the tree - repeat for the node's parent, continuing up to the root. """ def set_labels(dtRoot,f): def down(dt): if dt.is_leaf(): dt.fLabel = f dt.fDefaultLabel = f map(down,dt.dictChildren.values()) down(dtRoot) def check_passes(dtRoot,dtCheck,inst): def down(dt): assert not dt.is_leaf() assert len(dt.dictChildren) == cValue dt = dt.dictChildren[inst.listAttrs[dt.ixAttr]] if dt == dtCheck: return down(dtRoot) cAttr = random.randint(2,4) cValue = random.randint(2,4) dtBase = build_random_tree(cAttr,cValue) listPath = [] listAttrs = [] listDt = [] fTargetValue = True#randbool() set_labels(dtBase, not fTargetValue) dt = dtBase while not dt.is_leaf(): ixValue = random.choice(dt.dictChildren.keys()) listPath.append(ixValue) listAttrs.append(dt.ixAttr) dt = dt.dictChildren[ixValue] while listPath: listPath.pop() dt = dtRoot = dtBase.copy() for ixValue in listPath: dt = dt.dictChildren[ixValue] assert dt.is_node() dt.fDefaultLabel = fTargetValue listInst = [] fxnEnd = lambda: randlist(0,cValue-1,cAttr - len(listPath)) for _ in xrange(random.randint(1,10)): listValue = listPath + fxnEnd() listInstAttr = [None for _ in xrange(cAttr)] assert len(listValue) == cAttr for ixValue,ixAttr in zip(listValue,listAttrs): listInstAttr[ixAttr] = ixValue inst = dtree.Instance(listInstAttr, fTargetValue) check_passes(dtRoot,dt,inst) listInst.append(inst) dtree.prune_tree(dtRoot,listInst) dt = dtRoot for ix,ixValue in enumerate(listPath): assert dt.ixAttr == listAttrs[ix] self.assertTrue(dt.is_node(), str(dtRoot)) self.assertTrue(ixValue in dt.dictChildren) dt = dt.dictChildren[ixValue] self.assertTrue(dt.is_leaf(), str(dt))
def increase_values(inst): listIncreased = [c+cValues+1 for c in inst.listAttrs] return dtree.Instance(listIncreased, not fMajorityLabel)
def build_foldable_instances(lo=3,hi=10): cFold = random.randint(lo,hi) cInsts = random.randint(1,10)*cFold return [dtree.Instance([i],randbool()) for i in range(cInsts)],cFold
def test_classify_unknown(self): cValue = 3 dt = build_random_tree(4,cValue) inst = dtree.Instance(randlist(cValue+1, cValue+5, 4)) fLabel = dtree.classify(dt,inst) self.assertEqual(fLabel, dt.fDefaultLabel)
def build_one_instance(cAttrs,cValues,fxnGenWeight,fxnGenLabel): listAttrs = randlist(0,cValues-1,cAttrs) return dtree.Instance(listAttrs, fxnGenLabel(listAttrs), fxnGenWeight())
def test_build_tree_rec_leaf(self): fLabel = randbool() listInst = [dtree.Instance([], fLabel)] * random.randint(1, 3) dt = dtree.build_tree_rec([], listInst, 0.0, -1) self.assertTrue(dt.is_leaf(), "dt was not a leaf") self.assertEqual(dt.fLabel, fLabel)
def dt_predict(data, m): """Classify an input.""" inst = dtree.Instance(data, False) if dtree.classify_boosted(m, inst): return 1 return 0
def dt_insts(y, x): """Make a list of dt instances out of targets and data.""" listInst = [] for i in xrange(len(y)): listInst.append(dtree.Instance(x[i], bool(y[i]))) return listInst