Example #1
0
    def test1(self):
        # " testing pruning with known results "
        oPts = [
            [0, 0, 1, 0],
            [0, 1, 1, 1],
            [1, 0, 1, 1],
            [1, 1, 0, 0],
            [1, 1, 1, 1],
        ]
        tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]]
        tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4)
        err, badEx = CrossValidate.CrossValidate(tree, oPts)
        assert err == 0.0, 'bad initial error'
        assert len(badEx) == 0, 'bad initial error'

        # prune with original data, shouldn't do anything
        f = StringIO()
        with redirect_stdout(f):
            PruneTree._verbose = True
            newTree, err = PruneTree.PruneTree(tree, [], oPts)
            PruneTree._verbose = False
        self.assertIn('Pruner', f.getvalue())
        assert newTree == tree, 'improper pruning'

        # prune with train data
        newTree, err = PruneTree.PruneTree(tree, [], tPts)
        assert newTree != tree, 'bad pruning'
        assert feq(err, 0.14286), 'bad error result'
Example #2
0
  def test1(self):
    " testing pruning with known results "
    oPts= [ \
      [0,0,1,0],
      [0,1,1,1],
      [1,0,1,1],
      [1,1,0,0],
      [1,1,1,1],
      ]
    tPts = oPts+[[0,1,1,0],[0,1,1,0]]
    tree = ID3.ID3Boot(oPts,attrs=range(3),nPossibleVals=[2]*4)
    err,badEx = CrossValidate.CrossValidate(tree,oPts)
    assert err==0.0,'bad initial error'
    assert len(badEx)==0,'bad initial error'

    # prune with original data, shouldn't do anything
    newTree,err = PruneTree.PruneTree(tree,[],oPts)
    assert newTree==tree,'improper pruning'
    
    # prune with train data
    newTree,err = PruneTree.PruneTree(tree,[],tPts)
    assert newTree!=tree,'bad pruning'
    assert feq(err,0.14286),'bad error result'
Example #3
0
    def Grow(self,
             examples,
             attrs,
             nPossibleVals,
             nTries=10,
             pruneIt=0,
             lessGreedy=0):
        """ Grows the forest by adding trees

     **Arguments**

      - examples: the examples to be used for training

      - attrs: a list of the attributes to be used in training

      - nPossibleVals: a list with the number of possible values each variable
        (as well as the result) can take on

      - nTries: the number of new trees to add

      - pruneIt: a toggle for whether or not the tree should be pruned

      - lessGreedy: toggles the use of a less greedy construction algorithm where
        each possible tree root is used.  The best tree from each step is actually
        added to the forest.

    """
        self._nPossible = nPossibleVals
        for i in range(nTries):
            tree, frac = CrossValidate.CrossValidationDriver(
                examples,
                attrs,
                nPossibleVals,
                silent=1,
                calcTotalError=1,
                lessGreedy=lessGreedy)
            if pruneIt:
                tree, frac2 = PruneTree.PruneTree(tree,
                                                  tree.GetTrainingExamples(),
                                                  tree.GetTestExamples(),
                                                  minimizeTestErrorOnly=0)
                print('prune: ', frac, frac2)
                frac = frac2
            self.AddTree(tree, frac)
            if i % (nTries / 10) == 0:
                print('Cycle: % 4d' % (i))
Example #4
0
    def test_exampleCode(self):
        f = StringIO()
        with redirect_stdout(f):
            try:
                PruneTree._testRandom()
                self.assertTrue(os.path.isfile('prune.pkl'))
            finally:
                if os.path.isfile('orig.pkl'):
                    os.remove('orig.pkl')
                if os.path.isfile('prune.pkl'):
                    os.remove('prune.pkl')
        self.assertIn('pruned error', f.getvalue())

        f = StringIO()
        with redirect_stdout(f):
            PruneTree._testSpecific()
        self.assertIn('pruned holdout error', f.getvalue())

        f = StringIO()
        with redirect_stdout(f):
            PruneTree._testChain()
        self.assertIn('pruned holdout error', f.getvalue())
Example #5
0
    def test_exampleCode(self):
        f = StringIO()
        with redirect_stdout(f):
            try:
                PruneTree._testRandom()
                self.assertTrue(os.path.isfile('prune.pkl'))
            finally:
                if os.path.isfile('orig.pkl'):
                    os.remove('orig.pkl')
                if os.path.isfile('prune.pkl'):
                    os.remove('prune.pkl')
        self.assertIn('pruned error', f.getvalue())

        f = StringIO()
        with redirect_stdout(f):
            PruneTree._testSpecific()
        self.assertIn('pruned holdout error', f.getvalue())

        f = StringIO()
        with redirect_stdout(f):
            PruneTree._testChain()
        self.assertIn('pruned holdout error', f.getvalue())