def testCvPrune(self): 
     numExamples = 500
     X, y = data.make_regression(numExamples)  
     
     y = Standardiser().standardiseArray(y)
     
     numTrain = numpy.round(numExamples * 0.33)     
     numValid = numpy.round(numExamples * 0.33) 
     
     trainX = X[0:numTrain, :]
     trainY = y[0:numTrain]
     validX = X[numTrain:numTrain+numValid, :]
     validY = y[numTrain:numTrain+numValid]
     testX = X[numTrain+numValid:, :]
     testY = y[numTrain+numValid:]
     
     learner = DecisionTreeLearner()
     learner.learnModel(trainX, trainY)
     error1 = Evaluator.rootMeanSqError(learner.predict(testX), testY)
     
     #print(learner.getTree())
     unprunedTree = learner.tree.copy() 
     learner.setGamma(1000)
     learner.cvPrune(trainX, trainY)
     
     self.assertEquals(unprunedTree.getNumVertices(), learner.tree.getNumVertices())
     learner.setGamma(100)
     learner.cvPrune(trainX, trainY)
     
     #Test if pruned tree is subtree of current: 
     for vertexId in learner.tree.getAllVertexIds(): 
         self.assertTrue(vertexId in unprunedTree.getAllVertexIds())
         
     #The error should be better after pruning 
     learner.learnModel(trainX, trainY)
     #learner.cvPrune(validX, validY, 0.0, 5)
     learner.repPrune(validX, validY)
   
     error2 = Evaluator.rootMeanSqError(learner.predict(testX), testY)
     
     self.assertTrue(error1 >= error2)
Example #2
0
for vertexId in learner.tree.getAllVertexIds(): 
    alpha = learner.tree.getVertex(vertexId).alpha
    
    if alpha < minAlpha: 
        minAlpha = alpha 
    if alpha > maxAlpha: 
        maxAlpha = alpha 
        
numAlphas = 100
alphas = numpy.linspace(maxAlpha+0.1, minAlpha, numAlphas)
errors = numpy.zeros(numAlphas)

for i in range(alphas.shape[0]): 
    #learner.learnModel(trainX, trainY)
    learner.setAlphaThreshold(alphas[i])
    learner.cvPrune(trainX, trainY)
    #learner.cvPrune(validX, validY, alphas[numpy.argmin(errors)])
    #learner.prune(validX, validY, alphas[i])
    predY = learner.predict(testX)
    errors[i] = Evaluator.rootMeanSqError(predY, testY)
    
plt.figure(3)
plt.scatter(alphas, errors)

#Now plot best tree 
plt.figure(4)
learner.learnModel(trainX, trainY)
#learner.cvPrune(validX, validY, alphas[numpy.argmin(errors)])
learner.setAlphaThreshold(alphas[numpy.argmin(errors)])
learner.cvPrune(trainX, trainY)
rootId = learner.tree.getRootId()