예제 #1
0
def _testChain():
    from rdkit.ML.DecTree import ID3
    oPts= [ \
      [1,0,0,0,1],
      [1,0,0,0,1],
      [1,0,0,0,1],
      [1,0,0,0,1],
      [1,0,0,0,1],
      [1,0,0,0,1],
      [1,0,0,0,1],
      [0,0,1,1,0],
      [0,0,1,1,0],
      [0,0,1,1,1],
      [0,1,0,1,0],
      [0,1,0,1,0],
      [0,1,0,0,1],
      ]
    tPts = oPts

    tree = ID3.ID3Boot(oPts,
                       attrs=range(len(oPts[0]) - 1),
                       nPossibleVals=[2] * len(oPts[0]))
    tree.Print()
    err, badEx = CrossValidate.CrossValidate(tree, oPts)
    print('original error:', err)

    err, badEx = CrossValidate.CrossValidate(tree, tPts)
    print('original holdout error:', err)
    newTree, frac2 = PruneTree(tree, oPts, tPts)
    newTree.Print()
    err, badEx = CrossValidate.CrossValidate(newTree, tPts)
    print('pruned holdout error is:', err)
    print(badEx)
예제 #2
0
def _testSpecific():
    from rdkit.ML.DecTree import ID3
    oPts= [ \
      [0,0,1,0],
      [0,1,1,1],
      [1,0,1,1],
      [1,1,0,0],
      [1,1,1,1],
      ]
    tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]]

    tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4)
    tree.Print()
    err, badEx = CrossValidate.CrossValidate(tree, oPts)
    print('original error:', err)

    err, badEx = CrossValidate.CrossValidate(tree, tPts)
    print('original holdout error:', err)
    newTree, frac2 = PruneTree(tree, oPts, tPts)
    newTree.Print()
    err, badEx = CrossValidate.CrossValidate(newTree, tPts)
    print('pruned holdout error is:', err)
    print(badEx)

    print(len(tree), len(newTree))
예제 #3
0
def PruneTree(tree, trainExamples, testExamples, minimizeTestErrorOnly=1):
    """ implements a reduced-error pruning of decision trees

   This algorithm is described on page 69 of Mitchell's book.

   Pruning can be done using just the set of testExamples (the validation set)
   or both the testExamples and the trainExamples by setting minimizeTestErrorOnly
   to 0.

   **Arguments**

     - tree: the initial tree to be pruned

     - trainExamples: the examples used to train the tree

     - testExamples: the examples held out for testing the tree

     - minimizeTestErrorOnly: if this toggle is zero, all examples (i.e.
       _trainExamples_ + _testExamples_ will be used to evaluate the error.

   **Returns**

     a 2-tuple containing:

        1) the best tree

        2) the best error (the one which corresponds to that tree)

  """
    if minimizeTestErrorOnly:
        testSet = testExamples
    else:
        testSet = trainExamples + testExamples

    # remove any stored examples the tree may have
    tree.ClearExamples()

    #
    # screen the test data through the tree so that we end up with the
    #  appropriate points stored at each node of the tree
    #
    totErr, badEx = CrossValidate.CrossValidate(tree,
                                                testSet,
                                                appendExamples=1)

    #
    # Prune
    #
    newTree = _Pruner(tree)

    #
    # And recalculate the errors
    #
    totErr, badEx = CrossValidate.CrossValidate(newTree, testSet)
    newTree.SetBadExamples(badEx)

    return newTree, totErr
예제 #4
0
    def test1(self):
        # " testing pruning with known results "
        oPts = [
            [0, 0, 1, 0],
            [0, 1, 1, 1],
            [1, 0, 1, 1],
            [1, 1, 0, 0],
            [1, 1, 1, 1],
        ]
        tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]]
        tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4)
        err, badEx = CrossValidate.CrossValidate(tree, oPts)
        assert err == 0.0, 'bad initial error'
        assert len(badEx) == 0, 'bad initial error'

        # prune with original data, shouldn't do anything
        f = StringIO()
        with redirect_stdout(f):
            PruneTree._verbose = True
            newTree, err = PruneTree.PruneTree(tree, [], oPts)
            PruneTree._verbose = False
        self.assertIn('Pruner', f.getvalue())
        assert newTree == tree, 'improper pruning'

        # prune with train data
        newTree, err = PruneTree.PruneTree(tree, [], tPts)
        assert newTree != tree, 'bad pruning'
        assert feq(err, 0.14286), 'bad error result'
예제 #5
0
  def test1(self):
    " testing pruning with known results "
    oPts= [ \
      [0,0,1,0],
      [0,1,1,1],
      [1,0,1,1],
      [1,1,0,0],
      [1,1,1,1],
      ]
    tPts = oPts+[[0,1,1,0],[0,1,1,0]]
    tree = ID3.ID3Boot(oPts,attrs=range(3),nPossibleVals=[2]*4)
    err,badEx = CrossValidate.CrossValidate(tree,oPts)
    assert err==0.0,'bad initial error'
    assert len(badEx)==0,'bad initial error'

    # prune with original data, shouldn't do anything
    newTree,err = PruneTree.PruneTree(tree,[],oPts)
    assert newTree==tree,'improper pruning'
    
    # prune with train data
    newTree,err = PruneTree.PruneTree(tree,[],tPts)
    assert newTree!=tree,'bad pruning'
    assert feq(err,0.14286),'bad error result'