Пример #1
0
def CrossValidationDriver(examples, attrs, nPossibleValues, nQuantBounds, mEstimateVal=0.0,
                          holdOutFrac=0.3, modelBuilder=makeNBClassificationModel, silent=0,
                          calcTotalError=0, **kwargs):
  nTot = len(examples)
  if not kwargs.get('replacementSelection', 0):
    testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=1,
                                                       replacement=0)
  else:
    testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=0,
                                                       replacement=1)

  trainExamples = [examples[x] for x in trainIndices]
  testExamples = [examples[x] for x in testIndices]

  NBmodel = modelBuilder(trainExamples, attrs, nPossibleValues, nQuantBounds, mEstimateVal,
                         **kwargs)

  if not calcTotalError:
    xValError, _ = CrossValidate(NBmodel, testExamples, appendExamples=1)
  else:
    xValError, _ = CrossValidate(NBmodel, examples, appendExamples=0)

  if not silent:
    print('Validation error was %%%4.2f' % (100 * xValError))
  NBmodel._trainIndices = trainIndices
  return NBmodel, xValError
Пример #2
0
    def test_exceptions(self):
        self.assertRaises(ValueError, SplitData.SplitIndices, 10, -0.1)
        self.assertRaises(ValueError, SplitData.SplitIndices, 10, 1.1)

        f = StringIO()
        with redirect_stdout(f):
            SplitData.SplitIndices(10, 0.5, replacement=True, silent=False)
        s = f.getvalue()
        self.assertIn('Training', s)
        self.assertIn('hold-out', s)
Пример #3
0
    def test_SplitData(self):
        self.assertRaises(ValueError, SplitData.SplitDataSet, None, -1.1)
        self.assertRaises(ValueError, SplitData.SplitDataSet, None, 1.1)

        data = list(range(10))
        DataUtils.InitRandomNumbers((23, 42))
        f = StringIO()
        with redirect_stdout(f):
            result = SplitData.SplitDataSet(data, 0.5)
        self.assertEqual(set(result[0]).intersection(result[1]), set())
        self.assertEqual(len(result[0]), 5)
        s = f.getvalue()
        self.assertIn('Training', s)
        self.assertIn('hold-out', s)
Пример #4
0
def CrossValidationDriver(examples,
                          attrs,
                          nPossibleValues,
                          numNeigh,
                          modelBuilder=makeClassificationModel,
                          distFunc=DistFunctions.EuclideanDist,
                          holdOutFrac=0.3,
                          silent=0,
                          calcTotalError=0,
                          **kwargs):
    """ Driver function for building a KNN model of a specified type

  **Arguments**

    - examples: the full set of examples

    - numNeigh: number of neighbors for the KNN model (basically k in k-NN)

    - knnModel: the type of KNN model (a classification vs regression model)

    - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
      (used to calculate error)

    - silent: a toggle used to control how much visual noise this makes as it goes

    - calcTotalError: a toggle used to indicate whether the classification error
      of the tree should be calculated using the entire data set (when true) or just
      the training hold out set (when false)
      """

    nTot = len(examples)
    if not kwargs.get('replacementSelection', 0):
        testIndices, trainIndices = SplitData.SplitIndices(nTot,
                                                           holdOutFrac,
                                                           silent=1,
                                                           legacy=1,
                                                           replacement=0)
    else:
        testIndices, trainIndices = SplitData.SplitIndices(nTot,
                                                           holdOutFrac,
                                                           silent=1,
                                                           legacy=0,
                                                           replacement=1)
    trainExamples = [examples[x] for x in trainIndices]
    testExamples = [examples[x] for x in testIndices]

    nTrain = len(trainExamples)

    if not silent:
        print "Training with %d examples" % (nTrain)

    knnMod = modelBuilder(numNeigh, attrs, distFunc)

    knnMod.SetTrainingExamples(trainExamples)
    knnMod.SetTestExamples(testExamples)

    if not calcTotalError:
        xValError, badExamples = CrossValidate(knnMod,
                                               testExamples,
                                               appendExamples=1)
    else:
        xValError, badExamples = CrossValidate(knnMod,
                                               examples,
                                               appendExamples=0)

    if not silent:
        'Validation error was %%%4.2f' % (100 * xValError)

    knnMod._trainIndices = trainIndices
    return knnMod, xValError
Пример #5
0
    pickVects = {}
    halfwayPts = [1e8] * len(models)
    for whichModel, model in enumerate(models):
        tmpD = dataSet
        try:
            seed = model._randomSeed
        except AttributeError:
            pass
        else:
            DataUtils.InitRandomNumbers(seed)
        if details.shuffleActivities:
            DataUtils.RandomizeActivities(tmpD, shuffle=1)
        if hasattr(model, '_splitFrac') and (details.doHoldout
                                             or details.doTraining):
            trainIdx, testIdx = SplitData.SplitIndices(tmpD.GetNPts(),
                                                       model._splitFrac,
                                                       silent=1)
            if details.filterFrac != 0.0:
                trainFilt, temp = DataUtils.FilterData(tmpD,
                                                       details.filterVal,
                                                       details.filterFrac,
                                                       -1,
                                                       indicesToUse=trainIdx,
                                                       indicesOnly=1)
                testIdx += temp
                trainIdx = trainFilt
            if details.doTraining:
                testIdx, trainIdx = trainIdx, testIdx
        else:
            testIdx = range(tmpD.GetNPts())
Пример #6
0
def RunOnData(details, data, progressCallback=None, saveIt=1, setDescNames=0):
    nExamples = data.GetNPts()
    if details.lockRandom:
        seed = details.randomSeed
    else:
        import random
        seed = (random.randint(0, 1e6), random.randint(0, 1e6))
    DataUtils.InitRandomNumbers(seed)
    testExamples = []
    if details.shuffleActivities == 1:
        DataUtils.RandomizeActivities(data, shuffle=1, runDetails=details)
    elif details.randomActivities == 1:
        DataUtils.RandomizeActivities(data, shuffle=0, runDetails=details)

    namedExamples = data.GetNamedData()
    if details.splitRun == 1:
        trainIdx, testIdx = SplitData.SplitIndices(len(namedExamples),
                                                   details.splitFrac,
                                                   silent=not _verbose)

        trainExamples = [namedExamples[x] for x in trainIdx]
        testExamples = [namedExamples[x] for x in testIdx]
    else:
        testExamples = []
        testIdx = []
        trainIdx = range(len(namedExamples))
        trainExamples = namedExamples

    if details.filterFrac != 0.0:
        # if we're doing quantization on the fly, we need to handle that here:
        if hasattr(details, 'activityBounds') and details.activityBounds:
            tExamples = []
            bounds = details.activityBounds
            for pt in trainExamples:
                pt = pt[:]
                act = pt[-1]
                placed = 0
                bound = 0
                while not placed and bound < len(bounds):
                    if act < bounds[bound]:
                        pt[-1] = bound
                        placed = 1
                    else:
                        bound += 1
                if not placed:
                    pt[-1] = bound
                tExamples.append(pt)
        else:
            bounds = None
            tExamples = trainExamples
        trainIdx, temp = DataUtils.FilterData(tExamples,
                                              details.filterVal,
                                              details.filterFrac,
                                              -1,
                                              indicesOnly=1)
        tmp = [trainExamples[x] for x in trainIdx]
        testExamples += [trainExamples[x] for x in temp]
        trainExamples = tmp

        counts = DataUtils.CountResults(trainExamples, bounds=bounds)
        ks = counts.keys()
        ks.sort()
        message('Result Counts in training set:')
        for k in ks:
            message(str((k, counts[k])))
        counts = DataUtils.CountResults(testExamples, bounds=bounds)
        ks = counts.keys()
        ks.sort()
        message('Result Counts in test set:')
        for k in ks:
            message(str((k, counts[k])))
    nExamples = len(trainExamples)
    message('Training with %d examples' % (nExamples))

    nVars = data.GetNVars()
    attrs = range(1, nVars + 1)
    nPossibleVals = data.GetNPossibleVals()
    for i in range(1, len(nPossibleVals)):
        if nPossibleVals[i - 1] == -1:
            attrs.remove(i)

    if details.pickleDataFileName != '':
        pickleDataFile = open(details.pickleDataFileName, 'wb+')
        cPickle.dump(trainExamples, pickleDataFile)
        cPickle.dump(testExamples, pickleDataFile)
        pickleDataFile.close()

    if details.bayesModel:
        composite = BayesComposite.BayesComposite()
    else:
        composite = Composite.Composite()

    composite._randomSeed = seed
    composite._splitFrac = details.splitFrac
    composite._shuffleActivities = details.shuffleActivities
    composite._randomizeActivities = details.randomActivities

    if hasattr(details, 'filterFrac'):
        composite._filterFrac = details.filterFrac
    if hasattr(details, 'filterVal'):
        composite._filterVal = details.filterVal

    composite.SetModelFilterData(details.modelFilterFrac,
                                 details.modelFilterVal)

    composite.SetActivityQuantBounds(details.activityBounds)
    nPossibleVals = data.GetNPossibleVals()
    if details.activityBounds:
        nPossibleVals[-1] = len(details.activityBounds) + 1

    if setDescNames:
        composite.SetInputOrder(data.GetVarNames())
        composite.SetDescriptorNames(details._descNames)
    else:
        composite.SetDescriptorNames(data.GetVarNames())
    composite.SetActivityQuantBounds(details.activityBounds)
    if details.nModels == 1:
        details.internalHoldoutFrac = 0.0
    if details.useTrees:
        from rdkit.ML.DecTree import CrossValidate, PruneTree
        if details.qBounds != []:
            from rdkit.ML.DecTree import BuildQuantTree
            builder = BuildQuantTree.QuantTreeBoot
        else:
            from rdkit.ML.DecTree import ID3
            builder = ID3.ID3Boot
        driver = CrossValidate.CrossValidationDriver
        pruner = PruneTree.PruneTree

        composite.SetQuantBounds(details.qBounds)
        nPossibleVals = data.GetNPossibleVals()
        if details.activityBounds:
            nPossibleVals[-1] = len(details.activityBounds) + 1
        composite.Grow(trainExamples,
                       attrs,
                       nPossibleVals=[0] + nPossibleVals,
                       buildDriver=driver,
                       pruner=pruner,
                       nTries=details.nModels,
                       pruneIt=details.pruneIt,
                       lessGreedy=details.lessGreedy,
                       needsQuantization=0,
                       treeBuilder=builder,
                       nQuantBounds=details.qBounds,
                       startAt=details.startAt,
                       maxDepth=details.limitDepth,
                       progressCallback=progressCallback,
                       holdOutFrac=details.internalHoldoutFrac,
                       replacementSelection=details.replacementSelection,
                       recycleVars=details.recycleVars,
                       randomDescriptors=details.randomDescriptors,
                       silent=not _verbose)

    elif details.useSigTrees:
        from rdkit.ML.DecTree import CrossValidate
        from rdkit.ML.DecTree import BuildSigTree
        builder = BuildSigTree.SigTreeBuilder
        driver = CrossValidate.CrossValidationDriver
        nPossibleVals = data.GetNPossibleVals()
        if details.activityBounds:
            nPossibleVals[-1] = len(details.activityBounds) + 1
        if hasattr(details, 'sigTreeBiasList'):
            biasList = details.sigTreeBiasList
        else:
            biasList = None
        if hasattr(details, 'useCMIM'):
            useCMIM = details.useCMIM
        else:
            useCMIM = 0
        if hasattr(details, 'allowCollections'):
            allowCollections = details.allowCollections
        else:
            allowCollections = False
        composite.Grow(trainExamples,
                       attrs,
                       nPossibleVals=[0] + nPossibleVals,
                       buildDriver=driver,
                       nTries=details.nModels,
                       needsQuantization=0,
                       treeBuilder=builder,
                       maxDepth=details.limitDepth,
                       progressCallback=progressCallback,
                       holdOutFrac=details.internalHoldoutFrac,
                       replacementSelection=details.replacementSelection,
                       recycleVars=details.recycleVars,
                       randomDescriptors=details.randomDescriptors,
                       biasList=biasList,
                       useCMIM=useCMIM,
                       allowCollection=allowCollections,
                       silent=not _verbose)

    elif details.useKNN:
        from rdkit.ML.KNN import CrossValidate
        from rdkit.ML.KNN import DistFunctions

        driver = CrossValidate.CrossValidationDriver
        dfunc = ''
        if (details.knnDistFunc == "Euclidean"):
            dfunc = DistFunctions.EuclideanDist
        elif (details.knnDistFunc == "Tanimoto"):
            dfunc = DistFunctions.TanimotoDist
        else:
            assert 0, "Bad KNN distance metric value"

        composite.Grow(trainExamples,
                       attrs,
                       nPossibleVals=[0] + nPossibleVals,
                       buildDriver=driver,
                       nTries=details.nModels,
                       needsQuantization=0,
                       numNeigh=details.knnNeighs,
                       holdOutFrac=details.internalHoldoutFrac,
                       distFunc=dfunc)

    elif details.useNaiveBayes or details.useSigBayes:
        from rdkit.ML.NaiveBayes import CrossValidate
        driver = CrossValidate.CrossValidationDriver
        if not (hasattr(details, 'useSigBayes') and details.useSigBayes):
            composite.Grow(trainExamples,
                           attrs,
                           nPossibleVals=[0] + nPossibleVals,
                           buildDriver=driver,
                           nTries=details.nModels,
                           needsQuantization=0,
                           nQuantBounds=details.qBounds,
                           holdOutFrac=details.internalHoldoutFrac,
                           replacementSelection=details.replacementSelection,
                           mEstimateVal=details.mEstimateVal,
                           silent=not _verbose)
        else:
            if hasattr(details, 'useCMIM'):
                useCMIM = details.useCMIM
            else:
                useCMIM = 0

            composite.Grow(trainExamples,
                           attrs,
                           nPossibleVals=[0] + nPossibleVals,
                           buildDriver=driver,
                           nTries=details.nModels,
                           needsQuantization=0,
                           nQuantBounds=details.qBounds,
                           mEstimateVal=details.mEstimateVal,
                           useSigs=True,
                           useCMIM=useCMIM,
                           holdOutFrac=details.internalHoldoutFrac,
                           replacementSelection=details.replacementSelection,
                           silent=not _verbose)


##   elif details.useSVM:
##     from rdkit.ML.SVM import CrossValidate
##     driver = CrossValidate.CrossValidationDriver
##     composite.Grow(trainExamples, attrs, nPossibleVals=[0]+nPossibleVals,
##                    buildDriver=driver, nTries=details.nModels,
##                    needsQuantization=0,
##                    cost=details.svmCost,gamma=details.svmGamma,
##                    weights=details.svmWeights,degree=details.svmDegree,
##                    type=details.svmType,kernelType=details.svmKernel,
##                    coef0=details.svmCoeff,eps=details.svmEps,nu=details.svmNu,
##                    cache_size=details.svmCache,shrinking=details.svmShrink,
##                    dataType=details.svmDataType,
##                    holdOutFrac=details.internalHoldoutFrac,
##                    replacementSelection=details.replacementSelection,
##                    silent=not _verbose)

    else:
        from rdkit.ML.Neural import CrossValidate
        driver = CrossValidate.CrossValidationDriver
        composite.Grow(trainExamples,
                       attrs, [0] + nPossibleVals,
                       nTries=details.nModels,
                       buildDriver=driver,
                       needsQuantization=0)

    composite.AverageErrors()
    composite.SortModels()
    modelList, counts, avgErrs = composite.GetAllData()
    counts = numpy.array(counts)
    avgErrs = numpy.array(avgErrs)
    composite._varNames = data.GetVarNames()

    for i in range(len(modelList)):
        modelList[i].NameModel(composite._varNames)

    # do final statistics
    weightedErrs = counts * avgErrs
    averageErr = sum(weightedErrs) / sum(counts)
    devs = (avgErrs - averageErr)
    devs = devs * counts
    devs = numpy.sqrt(devs * devs)
    avgDev = sum(devs) / sum(counts)
    message('# Overall Average Error: %%% 5.2f, Average Deviation: %%% 6.2f' %
            (100. * averageErr, 100. * avgDev))

    if details.bayesModel:
        composite.Train(trainExamples, verbose=0)

    # blow out the saved examples and then save the composite:
    composite.ClearModelExamples()
    if saveIt:
        composite.Pickle(details.outName)
    details.model = DbModule.binaryHolder(cPickle.dumps(composite))

    badExamples = []
    if not details.detailedRes and (not hasattr(details, 'noScreen')
                                    or not details.noScreen):
        if details.splitRun:
            message('Testing all hold-out examples')
            wrong = testall(composite, testExamples, badExamples)
            message('%d examples (%% %5.2f) were misclassified' %
                    (len(wrong),
                     100. * float(len(wrong)) / float(len(testExamples))))
            _runDetails.holdout_error = float(len(wrong)) / len(testExamples)
        else:
            message('Testing all examples')
            wrong = testall(composite, namedExamples, badExamples)
            message('%d examples (%% %5.2f) were misclassified' %
                    (len(wrong),
                     100. * float(len(wrong)) / float(len(namedExamples))))
            _runDetails.overall_error = float(len(wrong)) / len(namedExamples)

    if details.detailedRes:
        message('\nEntire data set:')
        resTup = ScreenComposite.ShowVoteResults(range(data.GetNPts()), data,
                                                 composite, nPossibleVals[-1],
                                                 details.threshold)
        nGood, nBad, nSkip, avgGood, avgBad, avgSkip, voteTab = resTup
        nPts = len(namedExamples)
        nClass = nGood + nBad
        _runDetails.overall_error = float(nBad) / nClass
        _runDetails.overall_correct_conf = avgGood
        _runDetails.overall_incorrect_conf = avgBad
        _runDetails.overall_result_matrix = repr(voteTab)
        nRej = nClass - nPts
        if nRej > 0:
            _runDetails.overall_fraction_dropped = float(nRej) / nPts

        if details.splitRun:
            message('\nHold-out data:')
            resTup = ScreenComposite.ShowVoteResults(range(len(testExamples)),
                                                     testExamples, composite,
                                                     nPossibleVals[-1],
                                                     details.threshold)
            nGood, nBad, nSkip, avgGood, avgBad, avgSkip, voteTab = resTup
            nPts = len(testExamples)
            nClass = nGood + nBad
            _runDetails.holdout_error = float(nBad) / nClass
            _runDetails.holdout_correct_conf = avgGood
            _runDetails.holdout_incorrect_conf = avgBad
            _runDetails.holdout_result_matrix = repr(voteTab)
            nRej = nClass - nPts
            if nRej > 0:
                _runDetails.holdout_fraction_dropped = float(nRej) / nPts

    if details.persistTblName and details.dbName:
        message('Updating results table %s:%s' %
                (details.dbName, details.persistTblName))
        details.Store(db=details.dbName, table=details.persistTblName)

    if details.badName != '':
        badFile = open(details.badName, 'w+')
        for i in range(len(badExamples)):
            ex = badExamples[i]
            vote = wrong[i]
            outStr = '%s\t%s\n' % (ex, vote)
            badFile.write(outStr)
        badFile.close()

    composite.ClearModelExamples()
    return composite
Пример #7
0
def BalanceComposite(details,composite,data1=None,data2=None):
  """ balances the composite using the parameters provided in details

   **Arguments**

     - details a _CompositeRun.RunDetails_ object

     - composite: the composite model to be balanced

     - data1: (optional) if provided, this should be the
       data set used to construct the original models

     - data2: (optional) if provided, this should be the
       data set used to construct the new individual models

  """
  if not details.balCnt or details.balCnt > len(composite):
    return composite
  message("Balancing Composite")

  #
  # start by getting data set 1: which is the data set used to build the
  #  original models
  #
  if data1 is None:
    message("\tReading First Data Set")
    fName = details.balTable.strip()
    tmp = details.tableName
    details.tableName = fName
    dbName = details.dbName
    details.dbName = details.balDb
    data1 = details.GetDataSet()
    details.tableName = tmp
    details.dbName = dbName
  if data1 is None:
    return composite
  details.splitFrac = composite._splitFrac
  details.randomSeed = composite._randomSeed
  DataUtils.InitRandomNumbers(details.randomSeed)
  if details.shuffleActivities == 1:
    DataUtils.RandomizeActivities(data1,shuffle=1,runDetails=details)
  elif details.randomActivities == 1:
    DataUtils.RandomizeActivities(data1,shuffle=0,runDetails=details)
  namedExamples = data1.GetNamedData()
  if details.balDoHoldout or details.balDoTrain:
    trainIdx,testIdx = SplitData.SplitIndices(len(namedExamples),details.splitFrac,
                                              silent=1)
    trainExamples = [namedExamples[x] for x in trainIdx]
    testExamples = [namedExamples[x] for x in testIdx]
    if details.filterFrac != 0.0:
      trainIdx,temp = DataUtils.FilterData(trainExamples,details.filterVal,
                                           details.filterFrac,-1,
                                           indicesOnly=1)
      tmp = [trainExamples[x] for x in trainIdx]
      testExamples += [trainExamples[x] for x in temp]
      trainExamples = tmp
    if details.balDoHoldout:
      testExamples,trainExamples = trainExamples,testExamples
  else:
    trainExamples = namedExamples
  dataSet1 = trainExamples
  cols1 = [x.upper() for x in data1.GetVarNames()]
  data1 = None

  #
  # now grab data set 2: the data used to build the new individual models
  #
  if data2 is None:
    message("\tReading Second Data Set")
    data2 = details.GetDataSet()
  if data2 is None:
    return composite
  details.splitFrac = composite._splitFrac
  details.randomSeed = composite._randomSeed
  DataUtils.InitRandomNumbers(details.randomSeed)
  if details.shuffleActivities == 1:
    DataUtils.RandomizeActivities(data2,shuffle=1,runDetails=details)
  elif details.randomActivities == 1:
    DataUtils.RandomizeActivities(data2,shuffle=0,runDetails=details)
  dataSet2 = data2.GetNamedData()
  cols2 = [x.upper() for x in data2.GetVarNames()]
  data2 = None

  # and balance it:
  res = []
  weights = details.balWeight
  if type(weights) not in (types.TupleType,types.ListType):
    weights = (weights,)
  for weight in weights:
    message("\tBalancing with Weight: %.4f"%(weight))
    res.append(AdjustComposite.BalanceComposite(composite,dataSet1,dataSet2,
                                                weight,
                                                details.balCnt,
                                                names1=cols1,names2=cols2))
  return res
Пример #8
0
def CrossValidationDriver(examples,
                          attrs=[],
                          nPossibleVals=[],
                          holdOutFrac=.3,
                          silent=0,
                          tolerance=0.3,
                          calcTotalError=0,
                          hiddenSizes=None,
                          **kwargs):
    """
    **Arguments**

      - examples: the full set of examples

      - attrs: a list of attributes to consider in the tree building
         *This argument is ignored*

      - nPossibleVals: a list of the number of possible values each variable can adopt
         *This argument is ignored*

      - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
         (used to calculate the error)

      - silent: a toggle used to control how much visual noise this makes as it goes.

      - tolerance: the tolerance for convergence of the net

      - calcTotalError: if this is true the entire data set is used to calculate
           accuracy of the net

      - hiddenSizes: a list containing the size(s) of the hidden layers in the network.
           if _hiddenSizes_ is None, one hidden layer containing the same number of nodes
           as the input layer will be used

    **Returns**

       a 2-tuple containing:

         1) the net

         2) the cross-validation error of the net

    **Note**
      At the moment, this is specific to nets with only one output

  """
    nTot = len(examples)
    if not kwargs.get('replacementSelection', 0):
        testIndices, trainIndices = SplitData.SplitIndices(nTot,
                                                           holdOutFrac,
                                                           silent=1,
                                                           legacy=1,
                                                           replacement=0)
    else:
        testIndices, trainIndices = SplitData.SplitIndices(nTot,
                                                           holdOutFrac,
                                                           silent=1,
                                                           legacy=0,
                                                           replacement=1)
    trainExamples = [examples[x] for x in trainIndices]
    testExamples = [examples[x] for x in testIndices]

    nTrain = len(trainExamples)
    if not silent:
        print('Training with %d examples' % (nTrain))

    nInput = len(examples[0]) - 1
    nOutput = 1
    if hiddenSizes is None:
        nHidden = nInput
        netSize = [nInput, nHidden, nOutput]
    else:
        netSize = [nInput] + hiddenSizes + [nOutput]
    net = Network.Network(netSize)
    t = Trainers.BackProp()
    t.TrainOnLine(trainExamples,
                  net,
                  errTol=tolerance,
                  useAvgErr=0,
                  silent=silent)

    nTest = len(testExamples)
    if not silent:
        print('Testing with %d examples' % nTest)
    if not calcTotalError:
        xValError, _ = CrossValidate(net, testExamples, tolerance)
    else:
        xValError, _ = CrossValidate(net, examples, tolerance)
    if not silent:
        print('Validation error was %%%4.2f' % (100 * xValError))
    net._trainIndices = trainIndices
    return net, xValError
Пример #9
0
def CrossValidationDriver(examples,
                          attrs,
                          nPossibleVals,
                          holdOutFrac=.3,
                          silent=0,
                          calcTotalError=0,
                          treeBuilder=ID3.ID3Boot,
                          lessGreedy=0,
                          startAt=None,
                          nQuantBounds=[],
                          maxDepth=-1,
                          **kwargs):
    """ Driver function for building trees and doing cross validation

    **Arguments**

      - examples: the full set of examples

      - attrs: a list of attributes to consider in the tree building

      - nPossibleVals: a list of the number of possible values each variable can adopt

      - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
         (used to calculate the error)

      - silent: a toggle used to control how much visual noise this makes as it goes.

      - calcTotalError: a toggle used to indicate whether the classification error
        of the tree should be calculated using the entire data set (when true) or just
        the training hold out set (when false)

      - treeBuilder: the function to call to build the tree

      - lessGreedy: toggles use of the less greedy tree growth algorithm (see
        _ChooseOptimalRoot_).

      - startAt: forces the tree to be rooted at this descriptor

      - nQuantBounds: an optional list.  If present, it's assumed that the builder
        algorithm takes this argument as well (for building QuantTrees)

      - maxDepth: an optional integer.  If present, it's assumed that the builder
        algorithm takes this argument as well

    **Returns**

       a 2-tuple containing:

         1) the tree

         2) the cross-validation error of the tree
         
  """
    nTot = len(examples)
    if not kwargs.get('replacementSelection', 0):
        testIndices, trainIndices = SplitData.SplitIndices(nTot,
                                                           holdOutFrac,
                                                           silent=1,
                                                           legacy=1,
                                                           replacement=0)
    else:
        testIndices, trainIndices = SplitData.SplitIndices(nTot,
                                                           holdOutFrac,
                                                           silent=1,
                                                           legacy=0,
                                                           replacement=1)
    trainExamples = [examples[x] for x in trainIndices]
    testExamples = [examples[x] for x in testIndices]

    nTrain = len(trainExamples)
    if not silent:
        print('Training with %d examples' % (nTrain))

    if not lessGreedy:
        if nQuantBounds is None or nQuantBounds == []:
            tree = treeBuilder(trainExamples,
                               attrs,
                               nPossibleVals,
                               initialVar=startAt,
                               maxDepth=maxDepth,
                               **kwargs)
        else:
            tree = treeBuilder(trainExamples,
                               attrs,
                               nPossibleVals,
                               nQuantBounds,
                               initialVar=startAt,
                               maxDepth=maxDepth,
                               **kwargs)
    else:
        tree = ChooseOptimalRoot(examples,
                                 trainExamples,
                                 testExamples,
                                 attrs,
                                 nPossibleVals,
                                 treeBuilder,
                                 nQuantBounds,
                                 maxDepth=maxDepth,
                                 **kwargs)

    nTest = len(testExamples)
    if not silent:
        print('Testing with %d examples' % nTest)
    if not calcTotalError:
        xValError, badExamples = CrossValidate(tree,
                                               testExamples,
                                               appendExamples=1)
    else:
        xValError, badExamples = CrossValidate(tree,
                                               examples,
                                               appendExamples=0)
    if not silent:
        print('Validation error was %%%4.2f' % (100 * xValError))
    tree.SetBadExamples(badExamples)
    tree.SetTrainingExamples(trainExamples)
    tree.SetTestExamples(testExamples)
    tree._trainIndices = trainIndices
    return tree, xValError