def test1(self): # " testing pruning with known results " oPts = [ [0, 0, 1, 0], [0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 0], [1, 1, 1, 1], ] tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]] tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4) err, badEx = CrossValidate.CrossValidate(tree, oPts) assert err == 0.0, 'bad initial error' assert len(badEx) == 0, 'bad initial error' # prune with original data, shouldn't do anything f = StringIO() with redirect_stdout(f): PruneTree._verbose = True newTree, err = PruneTree.PruneTree(tree, [], oPts) PruneTree._verbose = False self.assertIn('Pruner', f.getvalue()) assert newTree == tree, 'improper pruning' # prune with train data newTree, err = PruneTree.PruneTree(tree, [], tPts) assert newTree != tree, 'bad pruning' assert feq(err, 0.14286), 'bad error result'
def test1(self): " testing pruning with known results " oPts= [ \ [0,0,1,0], [0,1,1,1], [1,0,1,1], [1,1,0,0], [1,1,1,1], ] tPts = oPts+[[0,1,1,0],[0,1,1,0]] tree = ID3.ID3Boot(oPts,attrs=range(3),nPossibleVals=[2]*4) err,badEx = CrossValidate.CrossValidate(tree,oPts) assert err==0.0,'bad initial error' assert len(badEx)==0,'bad initial error' # prune with original data, shouldn't do anything newTree,err = PruneTree.PruneTree(tree,[],oPts) assert newTree==tree,'improper pruning' # prune with train data newTree,err = PruneTree.PruneTree(tree,[],tPts) assert newTree!=tree,'bad pruning' assert feq(err,0.14286),'bad error result'
def Grow(self, examples, attrs, nPossibleVals, nTries=10, pruneIt=0, lessGreedy=0): """ Grows the forest by adding trees **Arguments** - examples: the examples to be used for training - attrs: a list of the attributes to be used in training - nPossibleVals: a list with the number of possible values each variable (as well as the result) can take on - nTries: the number of new trees to add - pruneIt: a toggle for whether or not the tree should be pruned - lessGreedy: toggles the use of a less greedy construction algorithm where each possible tree root is used. The best tree from each step is actually added to the forest. """ self._nPossible = nPossibleVals for i in range(nTries): tree, frac = CrossValidate.CrossValidationDriver( examples, attrs, nPossibleVals, silent=1, calcTotalError=1, lessGreedy=lessGreedy) if pruneIt: tree, frac2 = PruneTree.PruneTree(tree, tree.GetTrainingExamples(), tree.GetTestExamples(), minimizeTestErrorOnly=0) print('prune: ', frac, frac2) frac = frac2 self.AddTree(tree, frac) if i % (nTries / 10) == 0: print('Cycle: % 4d' % (i))
def test_exampleCode(self): f = StringIO() with redirect_stdout(f): try: PruneTree._testRandom() self.assertTrue(os.path.isfile('prune.pkl')) finally: if os.path.isfile('orig.pkl'): os.remove('orig.pkl') if os.path.isfile('prune.pkl'): os.remove('prune.pkl') self.assertIn('pruned error', f.getvalue()) f = StringIO() with redirect_stdout(f): PruneTree._testSpecific() self.assertIn('pruned holdout error', f.getvalue()) f = StringIO() with redirect_stdout(f): PruneTree._testChain() self.assertIn('pruned holdout error', f.getvalue())