def _testChain(): from rdkit.ML.DecTree import ID3 oPts= [ \ [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [0,0,1,1,0], [0,0,1,1,0], [0,0,1,1,1], [0,1,0,1,0], [0,1,0,1,0], [0,1,0,0,1], ] tPts = oPts tree = ID3.ID3Boot(oPts, attrs=range(len(oPts[0]) - 1), nPossibleVals=[2] * len(oPts[0])) tree.Print() err, badEx = CrossValidate.CrossValidate(tree, oPts) print('original error:', err) err, badEx = CrossValidate.CrossValidate(tree, tPts) print('original holdout error:', err) newTree, frac2 = PruneTree(tree, oPts, tPts) newTree.Print() err, badEx = CrossValidate.CrossValidate(newTree, tPts) print('pruned holdout error is:', err) print(badEx)
def _testSpecific(): from rdkit.ML.DecTree import ID3 oPts= [ \ [0,0,1,0], [0,1,1,1], [1,0,1,1], [1,1,0,0], [1,1,1,1], ] tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]] tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4) tree.Print() err, badEx = CrossValidate.CrossValidate(tree, oPts) print('original error:', err) err, badEx = CrossValidate.CrossValidate(tree, tPts) print('original holdout error:', err) newTree, frac2 = PruneTree(tree, oPts, tPts) newTree.Print() err, badEx = CrossValidate.CrossValidate(newTree, tPts) print('pruned holdout error is:', err) print(badEx) print(len(tree), len(newTree))
def PruneTree(tree, trainExamples, testExamples, minimizeTestErrorOnly=1): """ implements a reduced-error pruning of decision trees This algorithm is described on page 69 of Mitchell's book. Pruning can be done using just the set of testExamples (the validation set) or both the testExamples and the trainExamples by setting minimizeTestErrorOnly to 0. **Arguments** - tree: the initial tree to be pruned - trainExamples: the examples used to train the tree - testExamples: the examples held out for testing the tree - minimizeTestErrorOnly: if this toggle is zero, all examples (i.e. _trainExamples_ + _testExamples_ will be used to evaluate the error. **Returns** a 2-tuple containing: 1) the best tree 2) the best error (the one which corresponds to that tree) """ if minimizeTestErrorOnly: testSet = testExamples else: testSet = trainExamples + testExamples # remove any stored examples the tree may have tree.ClearExamples() # # screen the test data through the tree so that we end up with the # appropriate points stored at each node of the tree # totErr, badEx = CrossValidate.CrossValidate(tree, testSet, appendExamples=1) # # Prune # newTree = _Pruner(tree) # # And recalculate the errors # totErr, badEx = CrossValidate.CrossValidate(newTree, testSet) newTree.SetBadExamples(badEx) return newTree, totErr
def test1(self): # " testing pruning with known results " oPts = [ [0, 0, 1, 0], [0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 0], [1, 1, 1, 1], ] tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]] tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4) err, badEx = CrossValidate.CrossValidate(tree, oPts) assert err == 0.0, 'bad initial error' assert len(badEx) == 0, 'bad initial error' # prune with original data, shouldn't do anything f = StringIO() with redirect_stdout(f): PruneTree._verbose = True newTree, err = PruneTree.PruneTree(tree, [], oPts) PruneTree._verbose = False self.assertIn('Pruner', f.getvalue()) assert newTree == tree, 'improper pruning' # prune with train data newTree, err = PruneTree.PruneTree(tree, [], tPts) assert newTree != tree, 'bad pruning' assert feq(err, 0.14286), 'bad error result'
def testRun(self): " test that the CrossValidationDriver runs " from rdkit.ML.DecTree import randomtest examples, attrs, nPossibleVals = randomtest.GenRandomExamples( nExamples=200) tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals, silent=1)
def testReplacementSelection(self): # " use selection with replacement " RDRandom.seed(self.randomSeed) examples, attrs, nPossibleVals = randomtest.GenRandomExamples( nExamples=200, seed=self.randomArraySeed) tree, frac = CrossValidate.CrossValidationDriver( examples, attrs, nPossibleVals, silent=1, replacementSelection=1) self.assertTrue(tree) self.assertAlmostEqual(frac, 0.01666, 4)
def test_TestRun(self): try: f = StringIO() with redirect_stdout(f): CrossValidate.TestRun() self.assertTrue(os.path.isfile('save.pkl')) s = f.getvalue() self.assertIn('t1 == t2 True', s) finally: if os.path.isfile('save.pkl'): os.remove('save.pkl')
def testRun(self): # " test that the CrossValidationDriver runs " examples, attrs, nPossibleVals = randomtest.GenRandomExamples( nExamples=200) f = StringIO() with redirect_stdout(f): tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals, silent=False) self.assertGreater(frac, 0) self.assertEqual('Var: 1', tree.GetName()) self.assertIn('Validation error', f.getvalue()) CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals, lessGreedy=True, calcTotalError=True, silent=True)
def testReplacementSelection(self): " use selection with replacement " from rdkit.ML.DecTree import randomtest from rdkit import RDRandom RDRandom.seed(self.randomSeed) examples, attrs, nPossibleVals = randomtest.GenRandomExamples( nExamples=200, seed=self.randomArraySeed) tree, frac = CrossValidate.CrossValidationDriver( examples, attrs, nPossibleVals, silent=1, replacementSelection=1) assert tree assert feq(frac, 0.0833)
def _testRandom(): from rdkit.ML.DecTree import randomtest # examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=20, randScale=0.25, # nExamples=200) examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=10, randScale=0.5, nExamples=200) tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals) tree.Print() tree.Pickle('orig.pkl') print('original error is:', frac) print('----Pruning') newTree, frac2 = PruneTree(tree, tree.GetTrainingExamples(), tree.GetTestExamples()) newTree.Print() print('pruned error is:', frac2) newTree.Pickle('prune.pkl')
def Grow(self, examples, attrs, nPossibleVals, nTries=10, pruneIt=0, lessGreedy=0): """ Grows the forest by adding trees **Arguments** - examples: the examples to be used for training - attrs: a list of the attributes to be used in training - nPossibleVals: a list with the number of possible values each variable (as well as the result) can take on - nTries: the number of new trees to add - pruneIt: a toggle for whether or not the tree should be pruned - lessGreedy: toggles the use of a less greedy construction algorithm where each possible tree root is used. The best tree from each step is actually added to the forest. """ self._nPossible = nPossibleVals for i in range(nTries): tree, frac = CrossValidate.CrossValidationDriver( examples, attrs, nPossibleVals, silent=1, calcTotalError=1, lessGreedy=lessGreedy) if pruneIt: tree, frac2 = PruneTree.PruneTree(tree, tree.GetTrainingExamples(), tree.GetTestExamples(), minimizeTestErrorOnly=0) print('prune: ', frac, frac2) frac = frac2 self.AddTree(tree, frac) if i % (nTries / 10) == 0: print('Cycle: % 4d' % (i))
def testResults(self): # " test the results of CrossValidation " RDRandom.seed(self.randomSeed) examples, attrs, nPossibleVals = randomtest.GenRandomExamples( nExamples=200, seed=self.randomArraySeed) tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals, silent=1) self.assertGreater(frac, 0) with open(self.origTreeName, 'r') as inTFile: buf = inTFile.read().replace('\r\n', '\n').encode('utf-8') inTFile.close() inFile = BytesIO(buf) oTree = pickle.load(inFile) assert oTree == tree, 'Random CrossValidation test failed'
def testResults(self): " test the results of CrossValidation " from rdkit.ML.DecTree import randomtest from rdkit import RDRandom RDRandom.seed(self.randomSeed) examples, attrs, nPossibleVals = randomtest.GenRandomExamples( nExamples=200, seed=self.randomArraySeed) tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals, silent=1) import cPickle #cPickle.dump(tree,file(self.origTreeName,'w+')) inFile = open(self.origTreeName, 'r') oTree = cPickle.load(inFile) assert oTree == tree, 'Random CrossValidation test failed'
def testResults(self): " test the results of CrossValidation " from rdkit.ML.DecTree import randomtest from rdkit import RDRandom RDRandom.seed(self.randomSeed) examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nExamples = 200, seed=self.randomArraySeed) tree,frac = CrossValidate.CrossValidationDriver(examples,attrs, nPossibleVals,silent=1) from rdkit.six.moves import cPickle #cPickle.dump(tree,open(self.origTreeName,'w+')) with open(self.origTreeName,'r') as inTFile: buf = inTFile.read().replace('\r\n', '\n').encode('utf-8') inTFile.close() with io.BytesIO(buf) as inFile: oTree = cPickle.load(inFile) assert oTree==tree,'Random CrossValidation test failed'
def test1(self): " testing pruning with known results " oPts= [ \ [0,0,1,0], [0,1,1,1], [1,0,1,1], [1,1,0,0], [1,1,1,1], ] tPts = oPts+[[0,1,1,0],[0,1,1,0]] tree = ID3.ID3Boot(oPts,attrs=range(3),nPossibleVals=[2]*4) err,badEx = CrossValidate.CrossValidate(tree,oPts) assert err==0.0,'bad initial error' assert len(badEx)==0,'bad initial error' # prune with original data, shouldn't do anything newTree,err = PruneTree.PruneTree(tree,[],oPts) assert newTree==tree,'improper pruning' # prune with train data newTree,err = PruneTree.PruneTree(tree,[],tPts) assert newTree!=tree,'bad pruning' assert feq(err,0.14286),'bad error result'