Example #1
0
def _testChain():
    from rdkit.ML.DecTree import ID3
    oPts= [ \
      [1,0,0,0,1],
      [1,0,0,0,1],
      [1,0,0,0,1],
      [1,0,0,0,1],
      [1,0,0,0,1],
      [1,0,0,0,1],
      [1,0,0,0,1],
      [0,0,1,1,0],
      [0,0,1,1,0],
      [0,0,1,1,1],
      [0,1,0,1,0],
      [0,1,0,1,0],
      [0,1,0,0,1],
      ]
    tPts = oPts

    tree = ID3.ID3Boot(oPts,
                       attrs=range(len(oPts[0]) - 1),
                       nPossibleVals=[2] * len(oPts[0]))
    tree.Print()
    err, badEx = CrossValidate.CrossValidate(tree, oPts)
    print('original error:', err)

    err, badEx = CrossValidate.CrossValidate(tree, tPts)
    print('original holdout error:', err)
    newTree, frac2 = PruneTree(tree, oPts, tPts)
    newTree.Print()
    err, badEx = CrossValidate.CrossValidate(newTree, tPts)
    print('pruned holdout error is:', err)
    print(badEx)
Example #2
0
def _testSpecific():
    from rdkit.ML.DecTree import ID3
    oPts= [ \
      [0,0,1,0],
      [0,1,1,1],
      [1,0,1,1],
      [1,1,0,0],
      [1,1,1,1],
      ]
    tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]]

    tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4)
    tree.Print()
    err, badEx = CrossValidate.CrossValidate(tree, oPts)
    print('original error:', err)

    err, badEx = CrossValidate.CrossValidate(tree, tPts)
    print('original holdout error:', err)
    newTree, frac2 = PruneTree(tree, oPts, tPts)
    newTree.Print()
    err, badEx = CrossValidate.CrossValidate(newTree, tPts)
    print('pruned holdout error is:', err)
    print(badEx)

    print(len(tree), len(newTree))
Example #3
0
def PruneTree(tree, trainExamples, testExamples, minimizeTestErrorOnly=1):
    """ implements a reduced-error pruning of decision trees

   This algorithm is described on page 69 of Mitchell's book.

   Pruning can be done using just the set of testExamples (the validation set)
   or both the testExamples and the trainExamples by setting minimizeTestErrorOnly
   to 0.

   **Arguments**

     - tree: the initial tree to be pruned

     - trainExamples: the examples used to train the tree

     - testExamples: the examples held out for testing the tree

     - minimizeTestErrorOnly: if this toggle is zero, all examples (i.e.
       _trainExamples_ + _testExamples_ will be used to evaluate the error.

   **Returns**

     a 2-tuple containing:

        1) the best tree

        2) the best error (the one which corresponds to that tree)

  """
    if minimizeTestErrorOnly:
        testSet = testExamples
    else:
        testSet = trainExamples + testExamples

    # remove any stored examples the tree may have
    tree.ClearExamples()

    #
    # screen the test data through the tree so that we end up with the
    #  appropriate points stored at each node of the tree
    #
    totErr, badEx = CrossValidate.CrossValidate(tree,
                                                testSet,
                                                appendExamples=1)

    #
    # Prune
    #
    newTree = _Pruner(tree)

    #
    # And recalculate the errors
    #
    totErr, badEx = CrossValidate.CrossValidate(newTree, testSet)
    newTree.SetBadExamples(badEx)

    return newTree, totErr
Example #4
0
    def test1(self):
        # " testing pruning with known results "
        oPts = [
            [0, 0, 1, 0],
            [0, 1, 1, 1],
            [1, 0, 1, 1],
            [1, 1, 0, 0],
            [1, 1, 1, 1],
        ]
        tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]]
        tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4)
        err, badEx = CrossValidate.CrossValidate(tree, oPts)
        assert err == 0.0, 'bad initial error'
        assert len(badEx) == 0, 'bad initial error'

        # prune with original data, shouldn't do anything
        f = StringIO()
        with redirect_stdout(f):
            PruneTree._verbose = True
            newTree, err = PruneTree.PruneTree(tree, [], oPts)
            PruneTree._verbose = False
        self.assertIn('Pruner', f.getvalue())
        assert newTree == tree, 'improper pruning'

        # prune with train data
        newTree, err = PruneTree.PruneTree(tree, [], tPts)
        assert newTree != tree, 'bad pruning'
        assert feq(err, 0.14286), 'bad error result'
Example #5
0
 def testRun(self):
     " test that the CrossValidationDriver runs "
     from rdkit.ML.DecTree import randomtest
     examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
         nExamples=200)
     tree, frac = CrossValidate.CrossValidationDriver(examples,
                                                      attrs,
                                                      nPossibleVals,
                                                      silent=1)
Example #6
0
 def testReplacementSelection(self):
     # " use selection with replacement "
     RDRandom.seed(self.randomSeed)
     examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
         nExamples=200, seed=self.randomArraySeed)
     tree, frac = CrossValidate.CrossValidationDriver(
         examples, attrs, nPossibleVals, silent=1, replacementSelection=1)
     self.assertTrue(tree)
     self.assertAlmostEqual(frac, 0.01666, 4)
Example #7
0
 def test_TestRun(self):
     try:
         f = StringIO()
         with redirect_stdout(f):
             CrossValidate.TestRun()
         self.assertTrue(os.path.isfile('save.pkl'))
         s = f.getvalue()
         self.assertIn('t1 == t2 True', s)
     finally:
         if os.path.isfile('save.pkl'):
             os.remove('save.pkl')
Example #8
0
    def testRun(self):
        # " test that the CrossValidationDriver runs "
        examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
            nExamples=200)
        f = StringIO()
        with redirect_stdout(f):
            tree, frac = CrossValidate.CrossValidationDriver(examples,
                                                             attrs,
                                                             nPossibleVals,
                                                             silent=False)
        self.assertGreater(frac, 0)
        self.assertEqual('Var: 1', tree.GetName())
        self.assertIn('Validation error', f.getvalue())

        CrossValidate.CrossValidationDriver(examples,
                                            attrs,
                                            nPossibleVals,
                                            lessGreedy=True,
                                            calcTotalError=True,
                                            silent=True)
Example #9
0
 def testReplacementSelection(self):
     " use selection with replacement "
     from rdkit.ML.DecTree import randomtest
     from rdkit import RDRandom
     RDRandom.seed(self.randomSeed)
     examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
         nExamples=200, seed=self.randomArraySeed)
     tree, frac = CrossValidate.CrossValidationDriver(
         examples, attrs, nPossibleVals, silent=1, replacementSelection=1)
     assert tree
     assert feq(frac, 0.0833)
Example #10
0
def _testRandom():
    from rdkit.ML.DecTree import randomtest
    #   examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=20, randScale=0.25,
    #                                                                 nExamples=200)
    examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=10, randScale=0.5,
                                                                  nExamples=200)
    tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals)
    tree.Print()
    tree.Pickle('orig.pkl')
    print('original error is:', frac)

    print('----Pruning')
    newTree, frac2 = PruneTree(tree, tree.GetTrainingExamples(), tree.GetTestExamples())
    newTree.Print()
    print('pruned error is:', frac2)
    newTree.Pickle('prune.pkl')
Example #11
0
    def Grow(self,
             examples,
             attrs,
             nPossibleVals,
             nTries=10,
             pruneIt=0,
             lessGreedy=0):
        """ Grows the forest by adding trees

     **Arguments**

      - examples: the examples to be used for training

      - attrs: a list of the attributes to be used in training

      - nPossibleVals: a list with the number of possible values each variable
        (as well as the result) can take on

      - nTries: the number of new trees to add

      - pruneIt: a toggle for whether or not the tree should be pruned

      - lessGreedy: toggles the use of a less greedy construction algorithm where
        each possible tree root is used.  The best tree from each step is actually
        added to the forest.

    """
        self._nPossible = nPossibleVals
        for i in range(nTries):
            tree, frac = CrossValidate.CrossValidationDriver(
                examples,
                attrs,
                nPossibleVals,
                silent=1,
                calcTotalError=1,
                lessGreedy=lessGreedy)
            if pruneIt:
                tree, frac2 = PruneTree.PruneTree(tree,
                                                  tree.GetTrainingExamples(),
                                                  tree.GetTestExamples(),
                                                  minimizeTestErrorOnly=0)
                print('prune: ', frac, frac2)
                frac = frac2
            self.AddTree(tree, frac)
            if i % (nTries / 10) == 0:
                print('Cycle: % 4d' % (i))
Example #12
0
    def testResults(self):
        # " test the results of CrossValidation "
        RDRandom.seed(self.randomSeed)
        examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
            nExamples=200, seed=self.randomArraySeed)
        tree, frac = CrossValidate.CrossValidationDriver(examples,
                                                         attrs,
                                                         nPossibleVals,
                                                         silent=1)
        self.assertGreater(frac, 0)

        with open(self.origTreeName, 'r') as inTFile:
            buf = inTFile.read().replace('\r\n', '\n').encode('utf-8')
            inTFile.close()
        inFile = BytesIO(buf)
        oTree = pickle.load(inFile)

        assert oTree == tree, 'Random CrossValidation test failed'
Example #13
0
    def testResults(self):
        " test the results of CrossValidation "
        from rdkit.ML.DecTree import randomtest
        from rdkit import RDRandom
        RDRandom.seed(self.randomSeed)
        examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
            nExamples=200, seed=self.randomArraySeed)
        tree, frac = CrossValidate.CrossValidationDriver(examples,
                                                         attrs,
                                                         nPossibleVals,
                                                         silent=1)

        import cPickle
        #cPickle.dump(tree,file(self.origTreeName,'w+'))
        inFile = open(self.origTreeName, 'r')
        oTree = cPickle.load(inFile)

        assert oTree == tree, 'Random CrossValidation test failed'
Example #14
0
  def testResults(self):
    " test the results of CrossValidation "
    from rdkit.ML.DecTree import randomtest
    from rdkit import RDRandom
    RDRandom.seed(self.randomSeed)
    examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nExamples = 200,
                                                                seed=self.randomArraySeed)
    tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,
                                                    nPossibleVals,silent=1)

    from rdkit.six.moves import cPickle
    #cPickle.dump(tree,open(self.origTreeName,'w+'))
    with open(self.origTreeName,'r') as inTFile:
      buf = inTFile.read().replace('\r\n', '\n').encode('utf-8')
      inTFile.close()
    with io.BytesIO(buf) as inFile:
      oTree = cPickle.load(inFile)

    assert oTree==tree,'Random CrossValidation test failed'
Example #15
0
  def test1(self):
    " testing pruning with known results "
    oPts= [ \
      [0,0,1,0],
      [0,1,1,1],
      [1,0,1,1],
      [1,1,0,0],
      [1,1,1,1],
      ]
    tPts = oPts+[[0,1,1,0],[0,1,1,0]]
    tree = ID3.ID3Boot(oPts,attrs=range(3),nPossibleVals=[2]*4)
    err,badEx = CrossValidate.CrossValidate(tree,oPts)
    assert err==0.0,'bad initial error'
    assert len(badEx)==0,'bad initial error'

    # prune with original data, shouldn't do anything
    newTree,err = PruneTree.PruneTree(tree,[],oPts)
    assert newTree==tree,'improper pruning'
    
    # prune with train data
    newTree,err = PruneTree.PruneTree(tree,[],tPts)
    assert newTree!=tree,'bad pruning'
    assert feq(err,0.14286),'bad error result'