Exemple #1
0
def QuantTreeBoot(examples,
                  attrs,
                  nPossibleVals,
                  nBoundsPerVar,
                  initialVar=None,
                  maxDepth=-1,
                  **kwargs):
    """ Bootstrapping code for the QuantTree

    If _initialVar_ is not set, the algorithm will automatically
     choose the first variable in the tree (the standard greedy
     approach).  Otherwise, _initialVar_ will be used as the first
     split.

  """
    attrs = list(attrs)
    for i in range(len(nBoundsPerVar)):
        if nBoundsPerVar[i] == -1 and i in attrs:
            attrs.remove(i)

    tree = QuantTree.QuantTreeNode(None, 'node')
    nPossibleRes = nPossibleVals[-1]
    tree._nResultCodes = nPossibleRes

    resCodes = [int(x[-1]) for x in examples]
    counts = [0] * nPossibleRes
    for res in resCodes:
        counts[res] += 1
    if initialVar is None:
        best, gainHere, qBounds = FindBest(resCodes, examples, nBoundsPerVar,
                                           nPossibleRes, nPossibleVals, attrs,
                                           **kwargs)
    else:
        best = initialVar
        if nBoundsPerVar[best] > 0:
            vTable = map(lambda x, z=best: x[z], examples)
            qBounds, gainHere = Quantize.FindVarMultQuantBounds(
                vTable, nBoundsPerVar[best], resCodes, nPossibleRes)
        elif nBoundsPerVar[best] == 0:
            vTable = ID3.GenVarTable(examples, nPossibleVals, [best])[0]
            gainHere = entropy.InfoGain(vTable)
            qBounds = []
        else:
            gainHere = -1e6
            qBounds = []

    tree.SetName('Var: %d' % (best))
    tree.SetData(gainHere)
    tree.SetLabel(best)
    tree.SetTerminal(0)
    tree.SetQuantBounds(qBounds)
    nextAttrs = list(attrs)
    if not kwargs.get('recycleVars', 0):
        nextAttrs.remove(best)

    indices = list(range(len(examples)))
    if len(qBounds) > 0:
        for bound in qBounds:
            nextExamples = []
            for index in list(indices):
                ex = examples[index]
                if ex[best] < bound:
                    nextExamples.append(ex)
                    indices.remove(index)

            if len(nextExamples):
                tree.AddChildNode(
                    BuildQuantTree(nextExamples,
                                   best,
                                   nextAttrs,
                                   nPossibleVals,
                                   nBoundsPerVar,
                                   depth=1,
                                   maxDepth=maxDepth,
                                   **kwargs))
            else:
                v = numpy.argmax(counts)
                tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1)
        # add the last points remaining
        nextExamples = []
        for index in indices:
            nextExamples.append(examples[index])
        if len(nextExamples) != 0:
            tree.AddChildNode(
                BuildQuantTree(nextExamples,
                               best,
                               nextAttrs,
                               nPossibleVals,
                               nBoundsPerVar,
                               depth=1,
                               maxDepth=maxDepth,
                               **kwargs))
        else:
            v = numpy.argmax(counts)
            tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1)
    else:
        for val in range(nPossibleVals[best]):
            nextExamples = []
            for example in examples:
                if example[best] == val:
                    nextExamples.append(example)
            if len(nextExamples) != 0:
                tree.AddChildNode(
                    BuildQuantTree(nextExamples,
                                   best,
                                   nextAttrs,
                                   nPossibleVals,
                                   nBoundsPerVar,
                                   depth=1,
                                   maxDepth=maxDepth,
                                   **kwargs))
            else:
                v = numpy.argmax(counts)
                tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1)
    return tree
Exemple #2
0
def FindBest(resCodes,
             examples,
             nBoundsPerVar,
             nPossibleRes,
             nPossibleVals,
             attrs,
             exIndices=None,
             **kwargs):
    bestGain = -1e6
    best = -1
    bestBounds = []

    if exIndices is None:
        exIndices = list(range(len(examples)))

    if not len(exIndices):
        return best, bestGain, bestBounds

    nToTake = kwargs.get('randomDescriptors', 0)
    if nToTake > 0:
        nAttrs = len(attrs)
        if nToTake < nAttrs:
            ids = list(range(nAttrs))
            random.shuffle(ids, random=random.random)
            tmp = [attrs[x] for x in ids[:nToTake]]
            attrs = tmp

    for var in attrs:
        nBounds = nBoundsPerVar[var]
        if nBounds > 0:
            # vTable = map(lambda x,z=var:x[z],examples)
            try:
                vTable = [examples[x][var] for x in exIndices]
            except IndexError:
                print('index error retrieving variable: %d' % var)
                raise
            qBounds, gainHere = Quantize.FindVarMultQuantBounds(
                vTable, nBounds, resCodes, nPossibleRes)
            # print('\tvar:',var,qBounds,gainHere)
        elif nBounds == 0:
            vTable = ID3.GenVarTable((examples[x] for x in exIndices),
                                     nPossibleVals, [var])[0]
            gainHere = entropy.InfoGain(vTable)
            qBounds = []
        else:
            gainHere = -1e6
            qBounds = []
        if gainHere > bestGain:
            bestGain = gainHere
            bestBounds = qBounds
            best = var
        elif bestGain == gainHere:
            if len(qBounds) < len(bestBounds):
                best = var
                bestBounds = qBounds
    if best == -1:
        print('best unaltered')
        print('\tattrs:', attrs)
        print('\tnBounds:', numpy.take(nBoundsPerVar, attrs))
        print('\texamples:')
        for example in (examples[x] for x in exIndices):
            print('\t\t', example)

    if 0:
        print('BEST:', len(exIndices), best, bestGain, bestBounds)
        if (len(exIndices) < 10):
            print(len(exIndices), len(resCodes), len(examples))
            exs = [examples[x] for x in exIndices]
            vals = [x[best] for x in exs]
            sortIdx = numpy.argsort(vals)
            sortVals = [exs[x] for x in sortIdx]
            sortResults = [resCodes[x] for x in sortIdx]
            for i in range(len(vals)):
                print('   ', i, ['%.4f' % x for x in sortVals[i][1:-1]],
                      sortResults[i])
    return best, bestGain, bestBounds