def ID3Boot(examples, attrs, nPossibleVals, initialVar=None, depth=0, maxDepth=-1, **kwargs): """ Bootstrapping code for the ID3 algorithm see ID3 for descriptions of the arguments If _initialVar_ is not set, the algorithm will automatically choose the first variable in the tree (the standard greedy approach). Otherwise, _initialVar_ will be used as the first split. """ totEntropy = CalcTotalEntropy(examples, nPossibleVals) varTable = GenVarTable(examples, nPossibleVals, attrs) tree = DecTree.DecTreeNode(None, 'node') # tree.SetExamples(examples) tree._nResultCodes = nPossibleVals[-1] # <perl>you've got to love any language which will let you # do this much work in a single line :-)</perl> if initialVar is None: best = attrs[numpy.argmax([entropy.InfoGain(x) for x in varTable])] else: best = initialVar tree.SetName('Var: %d' % best) tree.SetData(totEntropy) tree.SetLabel(best) tree.SetTerminal(0) nextAttrs = list(attrs) if not kwargs.get('recycleVars', 0): nextAttrs.remove(best) for val in range(nPossibleVals[best]): nextExamples = [] for example in examples: if example[best] == val: nextExamples.append(example) tree.AddChildNode( ID3(nextExamples, best, nextAttrs, nPossibleVals, depth, maxDepth, **kwargs)) return tree
def ID3(examples, target, attrs, nPossibleVals, depth=0, maxDepth=-1, **kwargs): """ Implements the ID3 algorithm for constructing decision trees. From Mitchell's book, page 56 This is *slightly* modified from Mitchell's book because it supports multivalued (non-binary) results. **Arguments** - examples: a list (nInstances long) of lists of variable values + instance values - target: an int - attrs: a list of ints indicating which variables can be used in the tree - nPossibleVals: a list containing the number of possible values of every variable. - depth: (optional) the current depth in the tree - maxDepth: (optional) the maximum depth to which the tree will be grown **Returns** a DecTree.DecTreeNode with the decision tree **NOTE:** This code cannot bootstrap (start from nothing...) use _ID3Boot_ (below) for that. """ varTable = GenVarTable(examples, nPossibleVals, attrs) tree = DecTree.DecTreeNode(None, 'node') # store the total entropy... in case that is interesting totEntropy = CalcTotalEntropy(examples, nPossibleVals) tree.SetData(totEntropy) # tree.SetExamples(examples) # the matrix of results for this target: tMat = GenVarTable(examples, nPossibleVals, [target])[0] # counts of each result code: counts = sum(tMat) nzCounts = numpy.nonzero(counts)[0] if len(nzCounts) == 1: # bottomed out because there is only one result code left # with any counts (i.e. there's only one type of example # left... this is GOOD!). res = nzCounts[0] tree.SetLabel(res) tree.SetName(str(res)) tree.SetTerminal(1) elif len(attrs) == 0 or (maxDepth >= 0 and depth >= maxDepth): # Bottomed out: no variables left or max depth hit # We don't really know what to do here, so # use the heuristic of picking the most prevalent # result v = numpy.argmax(counts) tree.SetLabel(v) tree.SetName('%d?' % v) tree.SetTerminal(1) else: # find the variable which gives us the largest information gain gains = [entropy.InfoGain(x) for x in varTable] best = attrs[numpy.argmax(gains)] # remove that variable from the lists of possible variables nextAttrs = attrs[:] if not kwargs.get('recycleVars', 0): nextAttrs.remove(best) # set some info at this node tree.SetName('Var: %d' % best) tree.SetLabel(best) # tree.SetExamples(examples) tree.SetTerminal(0) # loop over possible values of the new variable and # build a subtree for each one for val in range(nPossibleVals[best]): nextExamples = [] for example in examples: if example[best] == val: nextExamples.append(example) if len(nextExamples) == 0: # this particular value of the variable has no examples, # so there's not much sense in recursing. # This can (and does) happen. v = numpy.argmax(counts) tree.AddChild('%d' % v, label=v, data=0.0, isTerminal=1) else: # recurse tree.AddChildNode( ID3(nextExamples, best, nextAttrs, nPossibleVals, depth + 1, maxDepth, **kwargs)) return tree
return self.GetDataTuple(which) def __str__(self): """ allows the forest to show itself as a string """ outStr = 'Forest\n' for i in range(len(self.treeList)): outStr = ( outStr + ' Tree % 4d: % 5d occurances %%% 5.2f average error\n' % (i, self.countList[i], 100. * self.errList[i])) return outStr def __init__(self): self.treeList = [] self.errList = [] self.countList = [] self.treeVotes = [] if __name__ == '__main__': from rdkit.ML.DecTree import DecTree f = Forest() n = DecTree.DecTreeNode(None, 'foo') f.AddTree(n, 0.5) f.AddTree(n, 0.5) f.AverageErrors() f.SortTrees() print(f)
def _Pruner(node, level=0): """Recursively finds and removes the nodes whose removals improve classification **Arguments** - node: the tree to be pruned. The pruning data should already be contained within node (i.e. node.GetExamples() should return the pruning data) - level: (optional) the level of recursion, used only in _verbose printing **Returns** the pruned version of node **Notes** - This uses a greedy algorithm which basically does a DFS traversal of the tree, removing nodes whenever possible. - If removing a node does not affect the accuracy, it *will be* removed. We favor smaller trees. """ if _verbose: print(' ' * level, '<%d> ' % level, '>>> Pruner') children = node.GetChildren()[:] bestTree = copy.deepcopy(node) bestErr = 1e6 emptyChildren = [] # # Loop over the children of this node, removing them when doing so # either improves the local error or leaves it unchanged (we're # introducing a bias for simpler trees). # for i in range(len(children)): child = children[i] examples = child.GetExamples() if _verbose: print(' ' * level, '<%d> ' % level, ' Child:', i, child.GetLabel()) bestTree.Print() print() if len(examples): if _verbose: print(' ' * level, '<%d> ' % level, ' Examples', len(examples)) if not child.GetTerminal(): if _verbose: print(' ' * level, '<%d> ' % level, ' Nonterminal') workTree = copy.deepcopy(bestTree) # # First recurse on the child (try removing things below it) # newNode = _Pruner(child, level=level + 1) workTree.ReplaceChildIndex(i, newNode) tempErr = _GetLocalError(workTree) if tempErr <= bestErr: bestErr = tempErr bestTree = copy.deepcopy(workTree) if _verbose: print(' ' * level, '<%d> ' % level, '>->->->->->') print(' ' * level, '<%d> ' % level, 'replacing:', i, child.GetLabel()) child.Print() print(' ' * level, '<%d> ' % level, 'with:') newNode.Print() print(' ' * level, '<%d> ' % level, '<-<-<-<-<-<') else: workTree.ReplaceChildIndex(i, child) # # Now try replacing the child entirely # bestGuess = MaxCount(child.GetExamples()) newNode = DecTree.DecTreeNode(workTree, 'L:%d' % (bestGuess), label=bestGuess, isTerminal=1) newNode.SetExamples(child.GetExamples()) workTree.ReplaceChildIndex(i, newNode) if _verbose: print(' ' * level, '<%d> ' % level, 'ATTEMPT:') workTree.Print() newErr = _GetLocalError(workTree) if _verbose: print(' ' * level, '<%d> ' % level, '---> ', newErr, bestErr) if newErr <= bestErr: bestErr = newErr bestTree = copy.deepcopy(workTree) if _verbose: print(' ' * level, '<%d> ' % level, 'PRUNING:') workTree.Print() else: if _verbose: print(' ' * level, '<%d> ' % level, 'FAIL') # whoops... put the child back in: workTree.ReplaceChildIndex(i, child) else: if _verbose: print(' ' * level, '<%d> ' % level, ' Terminal') else: if _verbose: print(' ' * level, '<%d> ' % level, ' No Examples', len(examples)) # # FIX: we need to figure out what to do here (nodes that contain # no examples in the testing set). I can concoct arguments for # leaving them in and for removing them. At the moment they are # left intact. # pass if _verbose: print(' ' * level, '<%d> ' % level, '<<< out') return bestTree