예제 #1
0
파일: ID3.py 프로젝트: Kaziaa/rdkit-1
def ID3Boot(examples,
            attrs,
            nPossibleVals,
            initialVar=None,
            depth=0,
            maxDepth=-1,
            **kwargs):
    """ Bootstrapping code for the ID3 algorithm

    see ID3 for descriptions of the arguments

    If _initialVar_ is not set, the algorithm will automatically
     choose the first variable in the tree (the standard greedy
     approach).  Otherwise, _initialVar_ will be used as the first
     split.

  """
    totEntropy = CalcTotalEntropy(examples, nPossibleVals)
    varTable = GenVarTable(examples, nPossibleVals, attrs)

    tree = DecTree.DecTreeNode(None, 'node')
    # tree.SetExamples(examples)
    tree._nResultCodes = nPossibleVals[-1]

    # <perl>you've got to love any language which will let you
    # do this much work in a single line :-)</perl>
    if initialVar is None:
        best = attrs[numpy.argmax([entropy.InfoGain(x) for x in varTable])]
    else:
        best = initialVar

    tree.SetName('Var: %d' % best)
    tree.SetData(totEntropy)
    tree.SetLabel(best)
    tree.SetTerminal(0)
    nextAttrs = list(attrs)
    if not kwargs.get('recycleVars', 0):
        nextAttrs.remove(best)

    for val in range(nPossibleVals[best]):
        nextExamples = []
        for example in examples:
            if example[best] == val:
                nextExamples.append(example)

        tree.AddChildNode(
            ID3(nextExamples, best, nextAttrs, nPossibleVals, depth, maxDepth,
                **kwargs))
    return tree
예제 #2
0
파일: ID3.py 프로젝트: Kaziaa/rdkit-1
def ID3(examples,
        target,
        attrs,
        nPossibleVals,
        depth=0,
        maxDepth=-1,
        **kwargs):
    """ Implements the ID3 algorithm for constructing decision trees.

    From Mitchell's book, page 56

    This is *slightly* modified from Mitchell's book because it supports
      multivalued (non-binary) results.

    **Arguments**

      - examples: a list (nInstances long) of lists of variable values + instance
              values

      - target: an int

      - attrs: a list of ints indicating which variables can be used in the tree

      - nPossibleVals: a list containing the number of possible values of
                   every variable.

      - depth: (optional) the current depth in the tree

      - maxDepth: (optional) the maximum depth to which the tree
                   will be grown

    **Returns**

     a DecTree.DecTreeNode with the decision tree

    **NOTE:** This code cannot bootstrap (start from nothing...)
          use _ID3Boot_ (below) for that.
  """
    varTable = GenVarTable(examples, nPossibleVals, attrs)
    tree = DecTree.DecTreeNode(None, 'node')

    # store the total entropy... in case that is interesting
    totEntropy = CalcTotalEntropy(examples, nPossibleVals)
    tree.SetData(totEntropy)
    # tree.SetExamples(examples)

    # the matrix of results for this target:
    tMat = GenVarTable(examples, nPossibleVals, [target])[0]
    # counts of each result code:
    counts = sum(tMat)
    nzCounts = numpy.nonzero(counts)[0]

    if len(nzCounts) == 1:
        # bottomed out because there is only one result code left
        #  with any counts (i.e. there's only one type of example
        #  left... this is GOOD!).
        res = nzCounts[0]
        tree.SetLabel(res)
        tree.SetName(str(res))
        tree.SetTerminal(1)
    elif len(attrs) == 0 or (maxDepth >= 0 and depth >= maxDepth):
        # Bottomed out: no variables left or max depth hit
        #  We don't really know what to do here, so
        #  use the heuristic of picking the most prevalent
        #  result
        v = numpy.argmax(counts)
        tree.SetLabel(v)
        tree.SetName('%d?' % v)
        tree.SetTerminal(1)
    else:
        # find the variable which gives us the largest information gain

        gains = [entropy.InfoGain(x) for x in varTable]
        best = attrs[numpy.argmax(gains)]

        # remove that variable from the lists of possible variables
        nextAttrs = attrs[:]
        if not kwargs.get('recycleVars', 0):
            nextAttrs.remove(best)

        # set some info at this node
        tree.SetName('Var: %d' % best)
        tree.SetLabel(best)
        # tree.SetExamples(examples)
        tree.SetTerminal(0)

        # loop over possible values of the new variable and
        #  build a subtree for each one
        for val in range(nPossibleVals[best]):
            nextExamples = []
            for example in examples:
                if example[best] == val:
                    nextExamples.append(example)
            if len(nextExamples) == 0:
                # this particular value of the variable has no examples,
                #  so there's not much sense in recursing.
                #  This can (and does) happen.
                v = numpy.argmax(counts)
                tree.AddChild('%d' % v, label=v, data=0.0, isTerminal=1)
            else:
                # recurse
                tree.AddChildNode(
                    ID3(nextExamples, best, nextAttrs, nPossibleVals,
                        depth + 1, maxDepth, **kwargs))
    return tree
예제 #3
0
        return self.GetDataTuple(which)

    def __str__(self):
        """ allows the forest to show itself as a string

    """
        outStr = 'Forest\n'
        for i in range(len(self.treeList)):
            outStr = (
                outStr +
                '  Tree % 4d:  % 5d occurances  %%% 5.2f average error\n' %
                (i, self.countList[i], 100. * self.errList[i]))
        return outStr

    def __init__(self):
        self.treeList = []
        self.errList = []
        self.countList = []
        self.treeVotes = []


if __name__ == '__main__':
    from rdkit.ML.DecTree import DecTree
    f = Forest()
    n = DecTree.DecTreeNode(None, 'foo')
    f.AddTree(n, 0.5)
    f.AddTree(n, 0.5)
    f.AverageErrors()
    f.SortTrees()
    print(f)
예제 #4
0
def _Pruner(node, level=0):
    """Recursively finds and removes the nodes whose removals improve classification

     **Arguments**

       - node: the tree to be pruned.  The pruning data should already be contained
         within node (i.e. node.GetExamples() should return the pruning data)

       - level: (optional) the level of recursion, used only in _verbose printing
     

     **Returns**

        the pruned version of node


     **Notes**

      - This uses a greedy algorithm which basically does a DFS traversal of the tree,
        removing nodes whenever possible.
      
      - If removing a node does not affect the accuracy, it *will be* removed.  We
        favor smaller trees.
      
  """
    if _verbose:
        print('  ' * level, '<%d>  ' % level, '>>> Pruner')
    children = node.GetChildren()[:]

    bestTree = copy.deepcopy(node)
    bestErr = 1e6
    emptyChildren = []
    #
    # Loop over the children of this node, removing them when doing so
    #  either improves the local error or leaves it unchanged (we're
    #  introducing a bias for simpler trees).
    #
    for i in range(len(children)):
        child = children[i]
        examples = child.GetExamples()
        if _verbose:
            print('  ' * level, '<%d>  ' % level, ' Child:', i,
                  child.GetLabel())
            bestTree.Print()
            print()
        if len(examples):
            if _verbose:
                print('  ' * level, '<%d>  ' % level, '  Examples',
                      len(examples))
            if not child.GetTerminal():
                if _verbose:
                    print('  ' * level, '<%d>  ' % level, '    Nonterminal')

                workTree = copy.deepcopy(bestTree)
                #
                # First recurse on the child (try removing things below it)
                #
                newNode = _Pruner(child, level=level + 1)
                workTree.ReplaceChildIndex(i, newNode)
                tempErr = _GetLocalError(workTree)
                if tempErr <= bestErr:
                    bestErr = tempErr
                    bestTree = copy.deepcopy(workTree)
                    if _verbose:
                        print('  ' * level, '<%d>  ' % level, '>->->->->->')
                        print('  ' * level, '<%d>  ' % level, 'replacing:', i,
                              child.GetLabel())
                        child.Print()
                        print('  ' * level, '<%d>  ' % level, 'with:')
                        newNode.Print()
                        print('  ' * level, '<%d>  ' % level, '<-<-<-<-<-<')
                else:
                    workTree.ReplaceChildIndex(i, child)
                #
                # Now try replacing the child entirely
                #
                bestGuess = MaxCount(child.GetExamples())
                newNode = DecTree.DecTreeNode(workTree,
                                              'L:%d' % (bestGuess),
                                              label=bestGuess,
                                              isTerminal=1)
                newNode.SetExamples(child.GetExamples())
                workTree.ReplaceChildIndex(i, newNode)
                if _verbose:
                    print('  ' * level, '<%d>  ' % level, 'ATTEMPT:')
                    workTree.Print()
                newErr = _GetLocalError(workTree)
                if _verbose:
                    print('  ' * level, '<%d>  ' % level, '---> ', newErr,
                          bestErr)
                if newErr <= bestErr:
                    bestErr = newErr
                    bestTree = copy.deepcopy(workTree)
                    if _verbose:
                        print('  ' * level, '<%d>  ' % level, 'PRUNING:')
                        workTree.Print()
                else:
                    if _verbose:
                        print('  ' * level, '<%d>  ' % level, 'FAIL')
                    # whoops... put the child back in:
                    workTree.ReplaceChildIndex(i, child)
            else:
                if _verbose:
                    print('  ' * level, '<%d>  ' % level, '    Terminal')
        else:
            if _verbose:
                print('  ' * level, '<%d>  ' % level, '  No Examples',
                      len(examples))
            #
            # FIX:  we need to figure out what to do here (nodes that contain
            #   no examples in the testing set).  I can concoct arguments for
            #   leaving them in and for removing them.  At the moment they are
            #   left intact.
            #
            pass

    if _verbose:
        print('  ' * level, '<%d>  ' % level, '<<< out')
    return bestTree