def _testChain(): from rdkit.ML.DecTree import ID3 oPts= [ \ [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [0,0,1,1,0], [0,0,1,1,0], [0,0,1,1,1], [0,1,0,1,0], [0,1,0,1,0], [0,1,0,0,1], ] tPts = oPts tree = ID3.ID3Boot(oPts, attrs=range(len(oPts[0]) - 1), nPossibleVals=[2] * len(oPts[0])) tree.Print() err, badEx = CrossValidate.CrossValidate(tree, oPts) print('original error:', err) err, badEx = CrossValidate.CrossValidate(tree, tPts) print('original holdout error:', err) newTree, frac2 = PruneTree(tree, oPts, tPts) newTree.Print() err, badEx = CrossValidate.CrossValidate(newTree, tPts) print('pruned holdout error is:', err) print(badEx)
def _testSpecific(): from rdkit.ML.DecTree import ID3 oPts= [ \ [0,0,1,0], [0,1,1,1], [1,0,1,1], [1,1,0,0], [1,1,1,1], ] tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]] tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4) tree.Print() err, badEx = CrossValidate.CrossValidate(tree, oPts) print('original error:', err) err, badEx = CrossValidate.CrossValidate(tree, tPts) print('original holdout error:', err) newTree, frac2 = PruneTree(tree, oPts, tPts) newTree.Print() err, badEx = CrossValidate.CrossValidate(newTree, tPts) print('pruned holdout error is:', err) print(badEx) print(len(tree), len(newTree))
def PruneTree(tree, trainExamples, testExamples, minimizeTestErrorOnly=1): """ implements a reduced-error pruning of decision trees This algorithm is described on page 69 of Mitchell's book. Pruning can be done using just the set of testExamples (the validation set) or both the testExamples and the trainExamples by setting minimizeTestErrorOnly to 0. **Arguments** - tree: the initial tree to be pruned - trainExamples: the examples used to train the tree - testExamples: the examples held out for testing the tree - minimizeTestErrorOnly: if this toggle is zero, all examples (i.e. _trainExamples_ + _testExamples_ will be used to evaluate the error. **Returns** a 2-tuple containing: 1) the best tree 2) the best error (the one which corresponds to that tree) """ if minimizeTestErrorOnly: testSet = testExamples else: testSet = trainExamples + testExamples # remove any stored examples the tree may have tree.ClearExamples() # # screen the test data through the tree so that we end up with the # appropriate points stored at each node of the tree # totErr, badEx = CrossValidate.CrossValidate(tree, testSet, appendExamples=1) # # Prune # newTree = _Pruner(tree) # # And recalculate the errors # totErr, badEx = CrossValidate.CrossValidate(newTree, testSet) newTree.SetBadExamples(badEx) return newTree, totErr
def test1(self): # " testing pruning with known results " oPts = [ [0, 0, 1, 0], [0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 0], [1, 1, 1, 1], ] tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]] tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4) err, badEx = CrossValidate.CrossValidate(tree, oPts) assert err == 0.0, 'bad initial error' assert len(badEx) == 0, 'bad initial error' # prune with original data, shouldn't do anything f = StringIO() with redirect_stdout(f): PruneTree._verbose = True newTree, err = PruneTree.PruneTree(tree, [], oPts) PruneTree._verbose = False self.assertIn('Pruner', f.getvalue()) assert newTree == tree, 'improper pruning' # prune with train data newTree, err = PruneTree.PruneTree(tree, [], tPts) assert newTree != tree, 'bad pruning' assert feq(err, 0.14286), 'bad error result'
def test1(self): " testing pruning with known results " oPts= [ \ [0,0,1,0], [0,1,1,1], [1,0,1,1], [1,1,0,0], [1,1,1,1], ] tPts = oPts+[[0,1,1,0],[0,1,1,0]] tree = ID3.ID3Boot(oPts,attrs=range(3),nPossibleVals=[2]*4) err,badEx = CrossValidate.CrossValidate(tree,oPts) assert err==0.0,'bad initial error' assert len(badEx)==0,'bad initial error' # prune with original data, shouldn't do anything newTree,err = PruneTree.PruneTree(tree,[],oPts) assert newTree==tree,'improper pruning' # prune with train data newTree,err = PruneTree.PruneTree(tree,[],tPts) assert newTree!=tree,'bad pruning' assert feq(err,0.14286),'bad error result'