def QuantTreeBoot(examples, attrs, nPossibleVals, nBoundsPerVar, initialVar=None, maxDepth=-1, **kwargs): """ Bootstrapping code for the QuantTree If _initialVar_ is not set, the algorithm will automatically choose the first variable in the tree (the standard greedy approach). Otherwise, _initialVar_ will be used as the first split. """ attrs = list(attrs) for i in range(len(nBoundsPerVar)): if nBoundsPerVar[i] == -1 and i in attrs: attrs.remove(i) tree = QuantTree.QuantTreeNode(None, 'node') nPossibleRes = nPossibleVals[-1] tree._nResultCodes = nPossibleRes resCodes = [int(x[-1]) for x in examples] counts = [0] * nPossibleRes for res in resCodes: counts[res] += 1 if initialVar is None: best, gainHere, qBounds = FindBest(resCodes, examples, nBoundsPerVar, nPossibleRes, nPossibleVals, attrs, **kwargs) else: best = initialVar if nBoundsPerVar[best] > 0: vTable = map(lambda x, z=best: x[z], examples) qBounds, gainHere = Quantize.FindVarMultQuantBounds( vTable, nBoundsPerVar[best], resCodes, nPossibleRes) elif nBoundsPerVar[best] == 0: vTable = ID3.GenVarTable(examples, nPossibleVals, [best])[0] gainHere = entropy.InfoGain(vTable) qBounds = [] else: gainHere = -1e6 qBounds = [] tree.SetName('Var: %d' % (best)) tree.SetData(gainHere) tree.SetLabel(best) tree.SetTerminal(0) tree.SetQuantBounds(qBounds) nextAttrs = list(attrs) if not kwargs.get('recycleVars', 0): nextAttrs.remove(best) indices = list(range(len(examples))) if len(qBounds) > 0: for bound in qBounds: nextExamples = [] for index in list(indices): ex = examples[index] if ex[best] < bound: nextExamples.append(ex) indices.remove(index) if len(nextExamples): tree.AddChildNode( BuildQuantTree(nextExamples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=1, maxDepth=maxDepth, **kwargs)) else: v = numpy.argmax(counts) tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1) # add the last points remaining nextExamples = [] for index in indices: nextExamples.append(examples[index]) if len(nextExamples) != 0: tree.AddChildNode( BuildQuantTree(nextExamples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=1, maxDepth=maxDepth, **kwargs)) else: v = numpy.argmax(counts) tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1) else: for val in range(nPossibleVals[best]): nextExamples = [] for example in examples: if example[best] == val: nextExamples.append(example) if len(nextExamples) != 0: tree.AddChildNode( BuildQuantTree(nextExamples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=1, maxDepth=maxDepth, **kwargs)) else: v = numpy.argmax(counts) tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1) return tree
def FindBest(resCodes, examples, nBoundsPerVar, nPossibleRes, nPossibleVals, attrs, exIndices=None, **kwargs): bestGain = -1e6 best = -1 bestBounds = [] if exIndices is None: exIndices = list(range(len(examples))) if not len(exIndices): return best, bestGain, bestBounds nToTake = kwargs.get('randomDescriptors', 0) if nToTake > 0: nAttrs = len(attrs) if nToTake < nAttrs: ids = list(range(nAttrs)) random.shuffle(ids, random=random.random) tmp = [attrs[x] for x in ids[:nToTake]] attrs = tmp for var in attrs: nBounds = nBoundsPerVar[var] if nBounds > 0: # vTable = map(lambda x,z=var:x[z],examples) try: vTable = [examples[x][var] for x in exIndices] except IndexError: print('index error retrieving variable: %d' % var) raise qBounds, gainHere = Quantize.FindVarMultQuantBounds( vTable, nBounds, resCodes, nPossibleRes) # print('\tvar:',var,qBounds,gainHere) elif nBounds == 0: vTable = ID3.GenVarTable((examples[x] for x in exIndices), nPossibleVals, [var])[0] gainHere = entropy.InfoGain(vTable) qBounds = [] else: gainHere = -1e6 qBounds = [] if gainHere > bestGain: bestGain = gainHere bestBounds = qBounds best = var elif bestGain == gainHere: if len(qBounds) < len(bestBounds): best = var bestBounds = qBounds if best == -1: print('best unaltered') print('\tattrs:', attrs) print('\tnBounds:', numpy.take(nBoundsPerVar, attrs)) print('\texamples:') for example in (examples[x] for x in exIndices): print('\t\t', example) if 0: print('BEST:', len(exIndices), best, bestGain, bestBounds) if (len(exIndices) < 10): print(len(exIndices), len(resCodes), len(examples)) exs = [examples[x] for x in exIndices] vals = [x[best] for x in exs] sortIdx = numpy.argsort(vals) sortVals = [exs[x] for x in sortIdx] sortResults = [resCodes[x] for x in sortIdx] for i in range(len(vals)): print(' ', i, ['%.4f' % x for x in sortVals[i][1:-1]], sortResults[i]) return best, bestGain, bestBounds