def testCvPrune(self): numExamples = 500 X, y = data.make_regression(numExamples) y = Standardiser().standardiseArray(y) numTrain = numpy.round(numExamples * 0.33) numValid = numpy.round(numExamples * 0.33) trainX = X[0:numTrain, :] trainY = y[0:numTrain] validX = X[numTrain:numTrain+numValid, :] validY = y[numTrain:numTrain+numValid] testX = X[numTrain+numValid:, :] testY = y[numTrain+numValid:] learner = DecisionTreeLearner() learner.learnModel(trainX, trainY) error1 = Evaluator.rootMeanSqError(learner.predict(testX), testY) #print(learner.getTree()) unprunedTree = learner.tree.copy() learner.setGamma(1000) learner.cvPrune(trainX, trainY) self.assertEquals(unprunedTree.getNumVertices(), learner.tree.getNumVertices()) learner.setGamma(100) learner.cvPrune(trainX, trainY) #Test if pruned tree is subtree of current: for vertexId in learner.tree.getAllVertexIds(): self.assertTrue(vertexId in unprunedTree.getAllVertexIds()) #The error should be better after pruning learner.learnModel(trainX, trainY) #learner.cvPrune(validX, validY, 0.0, 5) learner.repPrune(validX, validY) error2 = Evaluator.rootMeanSqError(learner.predict(testX), testY) self.assertTrue(error1 >= error2)
for vertexId in learner.tree.getAllVertexIds(): alpha = learner.tree.getVertex(vertexId).alpha if alpha < minAlpha: minAlpha = alpha if alpha > maxAlpha: maxAlpha = alpha numAlphas = 100 alphas = numpy.linspace(maxAlpha+0.1, minAlpha, numAlphas) errors = numpy.zeros(numAlphas) for i in range(alphas.shape[0]): #learner.learnModel(trainX, trainY) learner.setAlphaThreshold(alphas[i]) learner.cvPrune(trainX, trainY) #learner.cvPrune(validX, validY, alphas[numpy.argmin(errors)]) #learner.prune(validX, validY, alphas[i]) predY = learner.predict(testX) errors[i] = Evaluator.rootMeanSqError(predY, testY) plt.figure(3) plt.scatter(alphas, errors) #Now plot best tree plt.figure(4) learner.learnModel(trainX, trainY) #learner.cvPrune(validX, validY, alphas[numpy.argmin(errors)]) learner.setAlphaThreshold(alphas[numpy.argmin(errors)]) learner.cvPrune(trainX, trainY) rootId = learner.tree.getRootId()