Esempio n. 1
0
    def learnModel(self, X, y):
        nodeId = (0, )
        self.tree = DictTree()
        rootNode = DecisionNode(numpy.arange(X.shape[0]), y.mean())
        self.tree.setVertex(nodeId, rootNode)

        #We compute a sorted version of X
        argsortX = numpy.zeros(X.shape, numpy.int)
        for i in range(X.shape[1]):
            argsortX[:, i] = numpy.argsort(X[:, i])
            argsortX[:, i] = numpy.argsort(argsortX[:, i])

        self.growSkLearn(X, y)
        #self.recursiveSplit(X, y, argsortX, nodeId)
        self.unprunedTreeSize = self.tree.size

        if self.pruneType == "REP":
            #Note: This should be a seperate validation set
            self.repPrune(X, y)
        elif self.pruneType == "REP-CV":
            self.cvPrune(X, y)
        elif self.pruneType == "CART":
            self.cartPrune(X, y)
        elif self.pruneType == "none":
            pass
        else:
            raise ValueError("Unknown pruning type " + self.pruneType)
Esempio n. 2
0
    def setUp(self):
        self.dictTree = DictTree()
        self.dictTree.setVertex("a", "foo")

        self.dictTree.addEdge("a", "b")
        self.dictTree.addEdge("a", "c")
        self.dictTree.addEdge("b", "d")
        self.dictTree.addEdge("b", "e")
        self.dictTree.addEdge("e", "f")
Esempio n. 3
0
    def testGetRoot(self):
        dictTree = DictTree()

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertEquals(dictTree.getRootId(), "d")

        dictTree.addEdge("e", "d")
        self.assertEquals(dictTree.getRootId(), "e")
    def testGetRoot(self):
        dictTree = DictTree()

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertEquals(dictTree.getRootId(), "d")

        dictTree.addEdge("e", "d")
        self.assertEquals(dictTree.getRootId(), "e")
Esempio n. 5
0
    def learnModel(self, X, y):
        if numpy.unique(y).shape[0] != 2:
            raise ValueError("Must provide binary labels")
        if y.dtype != numpy.int:
            raise ValueError("Labels must be integers")

        self.shapeX = X.shape
        argsortX = numpy.zeros(X.shape, numpy.int)
        for i in range(X.shape[1]):
            argsortX[:, i] = numpy.argsort(X[:, i])
            argsortX[:, i] = numpy.argsort(argsortX[:, i])

        rootId = (0, )
        idStack = [rootId]
        self.tree = DictTree()
        rootNode = DecisionNode(numpy.arange(X.shape[0]), Util.mode(y))
        self.tree.setVertex(rootId, rootNode)
        bestError = float("inf")
        bestTree = self.tree

        #First grow a selection of trees

        while len(idStack) != 0:
            #Prune the current node away and grow from that node
            nodeId = idStack.pop()

            for i in range(self.sampleSize):
                self.tree = bestTree.deepCopy()
                try:
                    node = self.tree.getVertex(nodeId)
                except ValueError:
                    print(nodeId)
                    print(self.tree)
                    raise

                self.tree.pruneVertex(nodeId)
                self.growTree(X, y, argsortX, nodeId)
                self.prune(X, y)
                error = self.treeObjective(X, y)

                if error < bestError:
                    bestError = error
                    bestTree = self.tree.deepCopy()

            children = bestTree.children(nodeId)
            idStack.extend(children)

        self.tree = bestTree
    def learnModel(self, X, y):
        nodeId = (0, )         
        self.tree = DictTree()
        rootNode = DecisionNode(numpy.arange(X.shape[0]), y.mean())
        self.tree.setVertex(nodeId, rootNode)

        #We compute a sorted version of X 
        argsortX = numpy.zeros(X.shape, numpy.int)
        for i in range(X.shape[1]): 
            argsortX[:, i] = numpy.argsort(X[:, i])
            argsortX[:, i] = numpy.argsort(argsortX[:, i])
        
        self.growSkLearn(X, y)
        #self.recursiveSplit(X, y, argsortX, nodeId)
        self.unprunedTreeSize = self.tree.size
        
        if self.pruneType == "REP": 
            #Note: This should be a seperate validation set 
            self.repPrune(X, y)
        elif self.pruneType == "REP-CV":
            self.cvPrune(X, y)
        elif self.pruneType == "CART": 
            self.cartPrune(X, y)
        elif self.pruneType == "none": 
            pass
        else:
            raise ValueError("Unknown pruning type " + self.pruneType)
Esempio n. 7
0
    def testSetVertex(self):
        dictTree = DictTree()

        dictTree.setVertex("a")
        self.assertEquals(dictTree.getVertex("a"), None)
        self.assertRaises(RuntimeError, dictTree.setVertex, "b")

        dictTree.setVertex("a", 12)
        self.assertEquals(dictTree.getVertex("a"), 12)
Esempio n. 8
0
    def testStr(self):
        dictTree = DictTree()

        dictTree.addEdge(0, 1)
        dictTree.addEdge(0, 2)
        dictTree.addEdge(2, 3)
        dictTree.addEdge(2, 4)
        dictTree.addEdge(0, 5)
        dictTree.addEdge(4, 6)
Esempio n. 9
0
    def testAddChild(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")
        dictTree.addChild("a", "c", 2)
        dictTree.addChild("a", "d", 5)

        self.assertTrue(set(dictTree.leaves()) == set(["c", "d"]))

        self.assertEquals(dictTree.getVertex("c"), 2)
        self.assertEquals(dictTree.getVertex("d"), 5)

        self.assertTrue(dictTree.getEdge("a", "d"), 1.0)
        self.assertTrue(dictTree.getEdge("a", "c"), 1.0)
Esempio n. 10
0
    def testAddEdge(self):

        dictTree = DictTree()

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        #Add duplicate edge
        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertRaises(ValueError, dictTree.addEdge, "e", "a")

        #Add isolated edge
        self.assertRaises(ValueError, dictTree.addEdge, "r", "s")
 def setUp(self):
     self.dictTree = DictTree()
     self.dictTree.setVertex("a", "foo")
     
     self.dictTree.addEdge("a", "b")
     self.dictTree.addEdge("a", "c")
     self.dictTree.addEdge("b", "d")
     self.dictTree.addEdge("b", "e")
     self.dictTree.addEdge("e", "f")
    def testSetVertex(self):
        dictTree = DictTree()

        dictTree.setVertex("a")
        self.assertEquals(dictTree.getVertex("a"), None)
        self.assertRaises(RuntimeError, dictTree.setVertex, "b")

        dictTree.setVertex("a", 12)
        self.assertEquals(dictTree.getVertex("a"), 12)
    def testStr(self):
        dictTree = DictTree()

        dictTree.addEdge(0, 1)
        dictTree.addEdge(0, 2)
        dictTree.addEdge(2, 3)
        dictTree.addEdge(2, 4)
        dictTree.addEdge(0, 5)
        dictTree.addEdge(4, 6)
    def testAddChild(self): 
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")
        dictTree.addChild("a", "c", 2)
        dictTree.addChild("a", "d", 5)

        self.assertTrue(set(dictTree.leaves()) == set(["c", "d"]))
        
        self.assertEquals(dictTree.getVertex("c"), 2)
        self.assertEquals(dictTree.getVertex("d"), 5)
        
        self.assertTrue(dictTree.getEdge("a", "d"), 1.0)
        self.assertTrue(dictTree.getEdge("a", "c"), 1.0)
    def testAddEdge(self):

        dictTree = DictTree()

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        #Add duplicate edge 
        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertRaises(ValueError, dictTree.addEdge, "e", "a")
        
        #Add isolated edge
        self.assertRaises(ValueError, dictTree.addEdge, "r", "s")
Esempio n. 16
0
    def testSplitNode(self):
        d = 0
        k = 0
        maxDepth = 1
        inds = numpy.arange(self.y.shape[0])
        treeRank = TreeRank(self.leafRanklearner)
        treeRank.setMaxDepth(maxDepth)

        node = RankNode(inds, numpy.arange(self.X.shape[1]))

        tree = DictTree()
        tree.setVertex((0, 0), node)
        tree = treeRank.splitNode(tree, self.X, self.y, d, k)

        self.assertEquals(tree.getNumVertices(), 3)
        self.assertEquals(tree.getNumEdges(), 2)
        self.assertEquals(tree.getRootId(), (0, 0))
        self.assertTrue(not tree.getVertex((0, 0)).isLeafNode())
        self.assertTrue(tree.getVertex((1, 0)).isLeafNode())
        self.assertTrue(tree.getVertex((1, 1)).isLeafNode())

        self.assertTrue(tree.depth() <= maxDepth)
Esempio n. 17
0
 def learnModel(self, X, y):
     if numpy.unique(y).shape[0] != 2: 
         raise ValueError("Must provide binary labels")
     if y.dtype != numpy.int: 
         raise ValueError("Labels must be integers")
     
     self.shapeX = X.shape  
     argsortX = numpy.zeros(X.shape, numpy.int)
     for i in range(X.shape[1]): 
         argsortX[:, i] = numpy.argsort(X[:, i])
         argsortX[:, i] = numpy.argsort(argsortX[:, i])
     
         
     rootId = (0,)
     idStack = [rootId]
     self.tree = DictTree()
     rootNode = DecisionNode(numpy.arange(X.shape[0]), Util.mode(y))
     self.tree.setVertex(rootId, rootNode)
     bestError = float("inf")
     bestTree = self.tree 
     
     #First grow a selection of trees
     
     while len(idStack) != 0:
         #Prune the current node away and grow from that node 
         nodeId = idStack.pop()
         
         for i in range(self.sampleSize):
             self.tree = bestTree.deepCopy()
             try: 
                 node = self.tree.getVertex(nodeId)
             except ValueError:
                 print(nodeId)
                 print(self.tree)
                 raise 
                     
             self.tree.pruneVertex(nodeId)
             self.growTree(X, y, argsortX, nodeId)
             self.prune(X, y)
             error = self.treeObjective(X, y)
         
             if error < bestError: 
                 bestError = error
                 bestTree = self.tree.deepCopy()
         
         children = bestTree.children(nodeId)
         idStack.extend(children)
         
     self.tree = bestTree 
Esempio n. 18
0
    def testSubtree(self):
        newTree = DictTree()
        newTree.addEdge("a", "b")
        newTree.addEdge("a", "c")

        subtree = newTree.subtreeAt("b")
        self.assertEquals(subtree.getAllVertexIds(), ["b"])

        subtree = newTree.subtreeAt("c")
        self.assertEquals(subtree.getAllVertexIds(), ["c"])

        subtree = newTree.subtreeAt("a")
        self.assertEquals(set(subtree.getAllVertexIds()), set(["a", "c", "b"]))
Esempio n. 19
0
    def testSplitNode(self):
        d = 0
        k = 0
        maxDepth = 1
        inds = numpy.arange(self.y.shape[0])
        treeRank = TreeRank(self.leafRanklearner)
        treeRank.setMaxDepth(maxDepth)

        node = RankNode(inds, numpy.arange(self.X.shape[1]))

        tree = DictTree()
        tree.setVertex((0, 0), node)
        tree = treeRank.splitNode(tree, self.X, self.y, d, k)

        self.assertEquals(tree.getNumVertices(), 3)
        self.assertEquals(tree.getNumEdges(), 2)
        self.assertEquals(tree.getRootId(), (0, 0))
        self.assertTrue(not tree.getVertex((0, 0)).isLeafNode())
        self.assertTrue(tree.getVertex((1, 0)).isLeafNode())
        self.assertTrue(tree.getVertex((1, 1)).isLeafNode())

        self.assertTrue(tree.depth() <= maxDepth)
    def testPrune(self):
        startId = (0, )
        minSplit = 20
        maxDepth = 5
        gamma = 0.05
        learner = PenaltyDecisionTree(minSplit=minSplit,
                                      maxDepth=maxDepth,
                                      gamma=gamma,
                                      pruning=False)

        trainX = self.X[100:, :]
        trainY = self.y[100:]
        testX = self.X[0:100, :]
        testY = self.y[0:100]

        argsortX = numpy.zeros(trainX.shape, numpy.int)
        for i in range(trainX.shape[1]):
            argsortX[:, i] = numpy.argsort(trainX[:, i])
            argsortX[:, i] = numpy.argsort(argsortX[:, i])

        learner.tree = DictTree()
        rootNode = DecisionNode(numpy.arange(trainX.shape[0]),
                                Util.mode(trainY))
        learner.tree.setVertex(startId, rootNode)
        learner.growTree(trainX, trainY, argsortX, startId)
        learner.shapeX = trainX.shape
        learner.predict(trainX, trainY)
        learner.computeAlphas()

        obj1 = learner.treeObjective(trainX, trainY)
        size1 = learner.tree.getNumVertices()

        #Now we'll prune
        learner.prune(trainX, trainY)

        obj2 = learner.treeObjective(trainX, trainY)
        size2 = learner.tree.getNumVertices()

        self.assertTrue(obj1 >= obj2)
        self.assertTrue(size1 >= size2)

        #Check there are no nodes with alpha>alphaThreshold
        for vertexId in learner.tree.getAllVertexIds():
            self.assertTrue(
                learner.tree.getVertex(vertexId).alpha <=
                learner.alphaThreshold)
    def testSubtree(self): 
        newTree = DictTree()
        newTree.addEdge("a", "b")
        newTree.addEdge("a", "c")     
        
        subtree = newTree.subtreeAt("b")
        self.assertEquals(subtree.getAllVertexIds(), ["b"])        
        

        subtree = newTree.subtreeAt("c")
        self.assertEquals(subtree.getAllVertexIds(), ["c"])    
        
        subtree = newTree.subtreeAt("a")
        self.assertEquals(set(subtree.getAllVertexIds()), set(["a", "c", "b"]))
Esempio n. 22
0
    def learnModel(self, X, Y):
        """
        Learn a model for a set of examples given as the rows of the matrix X,
        with corresponding labels given in the elements of 1D array Y.

        :param X: A matrix with examples as rows
        :type X: :class:`ndarray`

        :param Y: A vector of binary labels as a 1D array
        :type Y: :class:`ndarray`
        """
        Parameter.checkClass(X, numpy.ndarray)
        Parameter.checkClass(Y, numpy.ndarray)
        Parameter.checkArray(X)
        Parameter.checkArray(Y)
        labels = numpy.unique(Y)
        if labels.shape[0] != 2:
            raise ValueError("Can only accept binary labelled data: " + str(labels))
        if (labels != numpy.array([-1, 1])).any(): 
            raise ValueError("Labels must be -1/+1: " + str(labels))
        if self.featureSize == None: 
            featureSize = numpy.sqrt(X.shape[1])/float(X.shape[1])
        else: 
            featureSize = self.featureSize

        tree = DictTree()
        trainInds = numpy.arange(Y.shape[0])
        featureInds = numpy.sort(numpy.random.permutation(X.shape[1])[0:int(numpy.round(X.shape[1]*featureSize))]) 

        #Seed the tree
        node = RankNode(trainInds, featureInds)
        tree.setVertex((0, 0), node)

        for d in range(self.maxDepth):
            for k in range(2**d):
                if tree.vertexExists((d, k)):
                    node = tree.getVertex((d, k))

                    if not node.isPure() and not node.isLeafNode():
                        self.splitNode(tree, X, Y, d, k)

        self.tree = tree 
Esempio n. 23
0
    def testCutTree(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")
        dictTree.addEdge("a", "b", 2)
        dictTree.addEdge("a", "c")
        dictTree.addEdge("c", "d", 5)
        dictTree.addEdge("c", "f")

        A = numpy.array([10, 2])
        dictTree.setVertex("b", A)

        newTree = dictTree.cut(2)
        self.assertEquals(newTree.getVertex("a"), "foo")
        self.assertTrue((newTree.getVertex("b") == A).all())
        self.assertEquals(newTree.getEdge("a", "b"), 2)
        self.assertEquals(newTree.getEdge("a", "c"), 1)
        self.assertEquals(newTree.getEdge("c", "d"), 5)
        self.assertEquals(newTree.getEdge("c", "f"), 1)
        self.assertEquals(newTree.getNumVertices(), dictTree.getNumVertices())
        self.assertEquals(newTree.getNumEdges(), dictTree.getNumEdges())

        newTree = dictTree.cut(1)
        self.assertEquals(newTree.getEdge("a", "b"), 2)
        self.assertEquals(newTree.getEdge("a", "c"), 1)
        self.assertEquals(newTree.getNumVertices(), 3)
        self.assertEquals(newTree.getNumEdges(), 2)

        newTree = dictTree.cut(0)
        self.assertEquals(newTree.getNumVertices(), 1)
        self.assertEquals(newTree.getNumEdges(), 0)
    def testCutTree(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")
        dictTree.addEdge("a", "b", 2)
        dictTree.addEdge("a", "c")
        dictTree.addEdge("c", "d", 5)
        dictTree.addEdge("c", "f")

        A = numpy.array([10, 2])
        dictTree.setVertex("b", A)

        newTree = dictTree.cut(2)
        self.assertEquals(newTree.getVertex("a"), "foo")
        self.assertTrue((newTree.getVertex("b") == A).all())
        self.assertEquals(newTree.getEdge("a", "b"), 2)
        self.assertEquals(newTree.getEdge("a", "c"), 1)
        self.assertEquals(newTree.getEdge("c", "d"), 5)
        self.assertEquals(newTree.getEdge("c", "f"), 1)
        self.assertEquals(newTree.getNumVertices(), dictTree.getNumVertices())
        self.assertEquals(newTree.getNumEdges(), dictTree.getNumEdges())

        newTree = dictTree.cut(1)
        self.assertEquals(newTree.getEdge("a", "b"), 2)
        self.assertEquals(newTree.getEdge("a", "c"), 1)
        self.assertEquals(newTree.getNumVertices(), 3)
        self.assertEquals(newTree.getNumEdges(), 2)

        newTree = dictTree.cut(0)
        self.assertEquals(newTree.getNumVertices(), 1)
        self.assertEquals(newTree.getNumEdges(), 0)
    def testDepth(self):
        dictTree = DictTree()
        self.assertEquals(dictTree.depth(), 0)
        dictTree.setVertex("a")
        self.assertEquals(dictTree.depth(), 0)

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertEquals(dictTree.depth(), 2)

        dictTree.addEdge("c", "e")
        self.assertEquals(dictTree.depth(), 3)
Esempio n. 26
0
class DecisionTreeLearner(AbstractPredictor): 
    def __init__(self, criterion="mse", maxDepth=10, minSplit=30, type="reg", pruneType="none", gamma=1000, folds=5, processes=None):
        """
        Need a minSplit for the internal nodes and one for leaves. 
        
        :param gamma: A value between 0 (no pruning) and 1 (full pruning) which decides how much pruning to do. 
        """
        super(DecisionTreeLearner, self).__init__()
        self.maxDepth = maxDepth
        self.minSplit = minSplit
        self.criterion = criterion
        self.type = type
        self.pruneType = pruneType 
        self.setGamma(gamma)
        self.folds = 5
        self.processes = processes
        self.alphas = numpy.array([])
    
    def learnModel(self, X, y):
        nodeId = (0, )         
        self.tree = DictTree()
        rootNode = DecisionNode(numpy.arange(X.shape[0]), y.mean())
        self.tree.setVertex(nodeId, rootNode)

        #We compute a sorted version of X 
        argsortX = numpy.zeros(X.shape, numpy.int)
        for i in range(X.shape[1]): 
            argsortX[:, i] = numpy.argsort(X[:, i])
            argsortX[:, i] = numpy.argsort(argsortX[:, i])
        
        self.growSkLearn(X, y)
        #self.recursiveSplit(X, y, argsortX, nodeId)
        self.unprunedTreeSize = self.tree.size
        
        if self.pruneType == "REP": 
            #Note: This should be a seperate validation set 
            self.repPrune(X, y)
        elif self.pruneType == "REP-CV":
            self.cvPrune(X, y)
        elif self.pruneType == "CART": 
            self.cartPrune(X, y)
        elif self.pruneType == "none": 
            pass
        else:
            raise ValueError("Unknown pruning type " + self.pruneType)
     
    #@profile 
    def recursiveSplit(self, X, y, argsortX, nodeId): 
        """
        Give a sample of data and a node index, we find the best split and 
        add children to the tree accordingly. 
        """
        if len(nodeId)-1 >= self.maxDepth: 
            return 
        
        node = self.tree.getVertex(nodeId)
        bestError, bestFeatureInd, bestThreshold, bestLeftInds, bestRightInds = findBestSplit(self.minSplit, X, y, node.getTrainInds(), argsortX)
    
        #The split may have 0 items in one set, so don't split 
        if bestLeftInds.sum() != 0 and bestRightInds.sum() != 0: 
            node.setError(bestError)
            node.setFeatureInd(bestFeatureInd)
            node.setThreshold(bestThreshold)
            
            leftChildId = self.getLeftChildId(nodeId)
            leftChild = DecisionNode(bestLeftInds, y[bestLeftInds].mean())
            self.tree.addChild(nodeId, leftChildId, leftChild)
            
            if leftChild.getTrainInds().shape[0] >= self.minSplit: 
                self.recursiveSplit(X, y, argsortX, leftChildId)
            
            rightChildId = self.getRightChildId(nodeId)
            rightChild = DecisionNode(bestRightInds, y[bestRightInds].mean())
            self.tree.addChild(nodeId, rightChildId, rightChild)
            
            if rightChild.getTrainInds().shape[0] >= self.minSplit: 
                self.recursiveSplit(X, y, argsortX, rightChildId)
    
    def growSkLearn(self, X, y): 
        """
        Grow a decision tree from sklearn. 
        """
        
        from sklearn.tree import DecisionTreeRegressor
        regressor = DecisionTreeRegressor(max_depth = self.maxDepth, min_samples_split=self.minSplit)
        regressor.fit(X, y)
        
        #Convert the sklearn tree into our tree 
        nodeId = (0, )          
        nodeStack = [(nodeId, 0)] 
        
        node = DecisionNode(numpy.arange(X.shape[0]), regressor.tree_.value[0])
        self.tree.setVertex(nodeId, node)
        
        while len(nodeStack) != 0: 
            nodeId, nodeInd = nodeStack.pop()
            
            node = self.tree.getVertex(nodeId)
            node.setError(regressor.tree_.best_error[nodeInd])
            node.setFeatureInd(regressor.tree_.feature[nodeInd])
            node.setThreshold(regressor.tree_.threshold[nodeInd])
                
            if regressor.tree_.children[nodeInd, 0] != -1: 
                leftChildInds = node.getTrainInds()[X[node.getTrainInds(), node.getFeatureInd()] < node.getThreshold()] 
                leftChildId = self.getLeftChildId(nodeId)
                leftChild = DecisionNode(leftChildInds, regressor.tree_.value[regressor.tree_.children[nodeInd, 0]])
                self.tree.addChild(nodeId, leftChildId, leftChild)
                nodeStack.append((self.getLeftChildId(nodeId), regressor.tree_.children[nodeInd, 0]))
                
            if regressor.tree_.children[nodeInd, 1] != -1: 
                rightChildInds = node.getTrainInds()[X[node.getTrainInds(), node.getFeatureInd()] >= node.getThreshold()]
                rightChildId = self.getRightChildId(nodeId)
                rightChild = DecisionNode(rightChildInds, regressor.tree_.value[regressor.tree_.children[nodeInd, 1]])
                self.tree.addChild(nodeId, rightChildId, rightChild)
                nodeStack.append((self.getRightChildId(nodeId), regressor.tree_.children[nodeInd, 1]))

    
    def predict(self, X): 
        """
        Make a prediction for the set of examples given in the matrix X. 
        """
        rootId = (0,)
        predY = numpy.zeros(X.shape[0])
        self.tree.getVertex(rootId).setTestInds(numpy.arange(X.shape[0]))
        predY = self.recursivePredict(X, predY, rootId)
        
        return predY 
        
    def recursivePredict(self, X, y, nodeId): 
        """
        Recurse through the tree and assign examples to the correct vertex. 
        """        
        node = self.tree.getVertex(nodeId)
        testInds = node.getTestInds()
        
        if self.tree.isLeaf(nodeId): 
            y[testInds] = node.getValue()
        else: 
             
            for childId in [self.getLeftChildId(nodeId), self.getRightChildId(nodeId)]:
                if self.tree.vertexExists(childId):
                    child = self.tree.getVertex(childId)
    
                    if childId[-1] == 0: 
                        childInds = X[testInds, node.getFeatureInd()] < node.getThreshold() 
                    else:
                        childInds = X[testInds, node.getFeatureInd()] >= node.getThreshold()
                    
                    child.setTestInds(testInds[childInds])   
                    y = self.recursivePredict(X, y, childId)
                
        return y
        
    def recursiveSetPrune(self, X, y, nodeId):
        """
        This computes test errors on nodes by passing in the test X and y. 
        """
        node = self.tree.getVertex(nodeId)
        testInds = node.getTestInds()
        node.setTestError(self.vertexTestError(y[testInds], node.getValue()))
    
        for childId in [self.getLeftChildId(nodeId), self.getRightChildId(nodeId)]:
            if self.tree.vertexExists(childId):
                child = self.tree.getVertex(childId)
                
                if childId[-1] == 0: 
                    childInds = X[testInds, node.getFeatureInd()] < node.getThreshold() 
                else:
                    childInds = X[testInds, node.getFeatureInd()] >= node.getThreshold()
                child.setTestInds(testInds[childInds])
                self.recursiveSetPrune(X, y, childId)
    
    def vertexTestError(self, trueY, predY):
        """
        This is the error used for pruning. We compute it at each node. 
        """
        return numpy.sum((trueY - predY)**2)
    
    def computeAlphas(self): 
        self.minAlpha = float("inf")
        self.maxAlpha = -float("inf")        
        
        for vertexId in self.tree.getAllVertexIds(): 
            currentNode = self.tree.getVertex(vertexId)
            subtreeLeaves = self.tree.leaves(vertexId)

            testErrorSum = 0 
            for leaf in subtreeLeaves: 
                testErrorSum += self.tree.getVertex(leaf).getTestError()
            
            #Alpha is normalised difference in error 
            if currentNode.getTestInds().shape[0] != 0: 
                currentNode.alpha = (testErrorSum - currentNode.getTestError())/float(currentNode.getTestInds().shape[0])       
                
                if currentNode.alpha < self.minAlpha:
                    self.minAlpha = currentNode.alpha 
                
                if currentNode.alpha > self.maxAlpha: 
                    self.maxAlpha = currentNode.alpha
                    
    def computeCARTAlphas(self, X):
        """
        Solve for the CART complexity based pruning. 
        """
        self.minAlpha = float("inf")
        self.maxAlpha = -float("inf")      
        alphas = [] 
        
        for vertexId in self.tree.getAllVertexIds(): 
            currentNode = self.tree.getVertex(vertexId)
            subtreeLeaves = self.tree.leaves(vertexId)

            testErrorSum = 0 
            for leaf in subtreeLeaves: 
                testErrorSum += self.tree.getVertex(leaf).getTestError()
            
            #Alpha is reduction in error per leaf - larger alphas are better 
            if currentNode.getTestInds().shape[0] != 0 and len(subtreeLeaves) != 1: 
                currentNode.alpha = (currentNode.getTestError() - testErrorSum)/float(X.shape[0]*(len(subtreeLeaves)-1))
                #Flip alpha so that pruning works 
                currentNode.alpha = -currentNode.alpha
                
                alphas.append(currentNode.alpha)
                
                """
                if currentNode.alpha < self.minAlpha:
                    self.minAlpha = currentNode.alpha 
                
                if currentNode.alpha > self.maxAlpha: 
                    self.maxAlpha = currentNode.alpha   
                """
        alphas = numpy.array(alphas)
        self.alphas = numpy.unique(alphas)
        self.minAlpha = numpy.min(self.alphas)
        self.maxAlpha = numpy.max(self.alphas)

    def repPrune(self, validX, validY): 
        """
        Prune the decision tree using reduced error pruning. 
        """
        rootId = (0,)
        self.tree.getVertex(rootId).setTestInds(numpy.arange(validX.shape[0]))
        self.recursiveSetPrune(validX, validY, rootId)        
        self.computeAlphas()        
        self.prune()
                            
    def prune(self): 
        """
        We prune as early as possible and make sure the final tree has at most 
        gamma vertices. 
        """
        i = self.alphas.shape[0]-1 
        #print(self.alphas)
        
        while self.tree.getNumVertices() > self.gamma and i >= 0: 
            #print(self.alphas[i], self.tree.getNumVertices())
            alphaThreshold = self.alphas[i] 
            toPrune = []
            
            for vertexId in self.tree.getAllVertexIds(): 
                if self.tree.getVertex(vertexId).alpha >= alphaThreshold: 
                    toPrune.append(vertexId)

            for vertexId in toPrune: 
                if self.tree.vertexExists(vertexId):
                    self.tree.pruneVertex(vertexId)                    
                    
            i -= 1

                    
    def cartPrune(self, trainX, trainY): 
        """
        Prune the tree according to the CART algorithm. Here, the chosen 
        tree is selected by thresholding alpha. In CART itself the best 
        tree is selected by using an independent pruning set. 
        """
        rootId = (0,)
        self.tree.getVertex(rootId).setTestInds(numpy.arange(trainX.shape[0]))
        self.recursiveSetPrune(trainX, trainY, rootId)        
        self.computeCARTAlphas(trainX)    
        self.prune()
                
    def cvPrune(self, validX, validY): 
        """
        We do something like reduced error pruning but we use cross validation 
        to decide which nodes to prune. 
        """
        
        #First set the value of the vertices using the training set. 
        #Reset all alphas to zero 
        inds = Sampling.crossValidation(self.folds, validX.shape[0])
        
        for i in self.tree.getAllVertexIds(): 
            self.tree.getVertex(i).setAlpha(0.0)
            self.tree.getVertex(i).setTestError(0.0)
        
        for trainInds, testInds in inds:             
            rootId = (0,)
            root = self.tree.getVertex(rootId)
            root.setTrainInds(trainInds)
            root.setTestInds(testInds)
            root.tempValue = numpy.mean(validY[trainInds])
            
            nodeStack = [(rootId, root.tempValue)]
            
            while len(nodeStack) != 0: 
                (nodeId, value) = nodeStack.pop()
                node = self.tree.getVertex(nodeId)
                tempTrainInds = node.getTrainInds()
                tempTestInds = node.getTestInds()
                node.setTestError(numpy.sum((validY[tempTestInds] - node.tempValue)**2) + node.getTestError())
                childIds = [self.getLeftChildId(nodeId), self.getRightChildId(nodeId)]
                
                for childId in childIds:                 
                    if self.tree.vertexExists(childId): 
                        child = self.tree.getVertex(childId)
                        
                        if childId[-1] == 0: 
                            childInds = validX[tempTrainInds, node.getFeatureInd()] < node.getThreshold()
                        else: 
                            childInds = validX[tempTrainInds, node.getFeatureInd()] >= node.getThreshold()
                        
                        if childInds.sum() !=0:   
                            value = numpy.mean(validY[tempTrainInds[childInds]])
                            
                        child.tempValue = value 
                        child.setTrainInds(tempTrainInds[childInds])
                        nodeStack.append((childId, value))
                        
                        if childId[-1] == 0: 
                            childInds = validX[tempTestInds, node.getFeatureInd()] < node.getThreshold() 
                        else: 
                            childInds = validX[tempTestInds, node.getFeatureInd()] >= node.getThreshold()  
                         
                        child.setTestInds(tempTestInds[childInds])
        
        self.computeAlphas()
        self.prune()
        
    def copy(self): 
        """
        Copies parameter values only 
        """
        newLearner = DecisionTreeLearner(self.criterion, self.maxDepth, self.minSplit, self.type, self.pruneType, self.gamma, self.folds)
        return newLearner 
        
    def getMetricMethod(self): 
        if self.type == "reg": 
            #return Evaluator.rootMeanSqError
            return Evaluator.meanAbsError
            #return Evaluator.meanSqError
        else:
            return Evaluator.binaryError      
            
    def getAlphaThreshold(self): 
        #return self.maxAlpha - (self.maxAlpha - self.minAlpha)*self.gamma
        #A more natural way of defining gamma 
        return self.alphas[numpy.round((1-self.gamma)*(self.alphas.shape[0]-1))]        
        
    def setGamma(self, gamma): 
        """
        Gamma is an upper bound on the number of nodes in the tree. 
        """
        Parameter.checkInt(gamma, 1, float("inf"))
        self.gamma = gamma
        
    def getGamma(self): 
        return self.gamma 
        
    def setPruneCV(self, folds): 
        Parameter.checkInt(folds, 1, float("inf"))
        self.folds = folds
        
    def getPruneCV(self): 
        return self.folds
        
    def getLeftChildId(self, nodeId): 
        leftChildId = list(nodeId)
        leftChildId.append(0)
        leftChildId = tuple(leftChildId)
        return leftChildId

    def getRightChildId(self, nodeId): 
        rightChildId = list(nodeId)
        rightChildId.append(1)
        rightChildId = tuple(rightChildId) 
        return rightChildId
   
    def getTree(self): 
        return self.tree 
        
    def complexity(self): 
        return self.tree.size
        
    def getBestLearner(self, meanErrors, paramDict, X, y, idx=None): 
        """
        Given a grid of errors, paramDict and examples, labels, find the 
        best learner and train it. In this case we set gamma to the real 
        size of the tree as learnt using CV. If idx == None then we simply 
        use the gamma corresponding to the lowest error. 
        """
        if idx == None: 
            return super(DecisionTreeLearner, self).getBestLearner(meanErrors, paramDict, X, y, idx)
        
        bestInds = numpy.unravel_index(numpy.argmin(meanErrors), meanErrors.shape)
        currentInd = 0    
        learner = self.copy()         
    
        for key, val in paramDict.items():
            method = getattr(learner, key)
            method(val[bestInds[currentInd]])
            currentInd += 1 
         
        treeSizes = []
        for trainInds, testInds in idx: 
            validX = X[trainInds, :]
            validY = y[trainInds]
            learner.learnModel(validX, validY)
            
            treeSizes.append(learner.tree.getNumVertices())
        
        bestGamma = int(numpy.round(numpy.array(treeSizes).mean()))
        
        learner.setGamma(bestGamma)
        learner.learnModel(X, y)            
        return learner 
        
    def getUnprunedTreeSize(self): 
        """
        Return the size of the tree before pruning was performed. 
        """
        return self.unprunedTreeSize

    def parallelPen(self, X, y, idx, paramDict, Cvs):
        """
        Perform parallel penalisation using any learner. 
        Using the best set of parameters train using the whole dataset. In this 
        case if gamma > max(treeSize) the penalty is infinite. 

        :param X: The examples as rows
        :type X: :class:`numpy.ndarray`

        :param y: The binary -1/+1 labels 
        :type y: :class:`numpy.ndarray`

        :param idx: A list of train/test splits

        :param paramDict: A dictionary index by the method name and with value as an array of values
        :type X: :class:`dict`

        """
        return super(DecisionTreeLearner, self).parallelPen(X, y, idx, paramDict, Cvs, computeVFPenTree)
        
class DictGraphTest(unittest.TestCase):
    def setUp(self):
        self.dictTree = DictTree()
        self.dictTree.setVertex("a", "foo")
        
        self.dictTree.addEdge("a", "b")
        self.dictTree.addEdge("a", "c")
        self.dictTree.addEdge("b", "d")
        self.dictTree.addEdge("b", "e")
        self.dictTree.addEdge("e", "f")
    

    def testInit(self):
        dictTree = DictTree()

    def testAddEdge(self):

        dictTree = DictTree()

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        #Add duplicate edge 
        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertRaises(ValueError, dictTree.addEdge, "e", "a")
        
        #Add isolated edge
        self.assertRaises(ValueError, dictTree.addEdge, "r", "s")


    def testGetRoot(self):
        dictTree = DictTree()

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertEquals(dictTree.getRootId(), "d")

        dictTree.addEdge("e", "d")
        self.assertEquals(dictTree.getRootId(), "e")

    def testSetVertex(self):
        dictTree = DictTree()

        dictTree.setVertex("a")
        self.assertEquals(dictTree.getVertex("a"), None)
        self.assertRaises(RuntimeError, dictTree.setVertex, "b")

        dictTree.setVertex("a", 12)
        self.assertEquals(dictTree.getVertex("a"), 12)

    def testStr(self):
        dictTree = DictTree()

        dictTree.addEdge(0, 1)
        dictTree.addEdge(0, 2)
        dictTree.addEdge(2, 3)
        dictTree.addEdge(2, 4)
        dictTree.addEdge(0, 5)
        dictTree.addEdge(4, 6)


    def testDepth(self):
        dictTree = DictTree()
        self.assertEquals(dictTree.depth(), 0)
        dictTree.setVertex("a")
        self.assertEquals(dictTree.depth(), 0)

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertEquals(dictTree.depth(), 2)

        dictTree.addEdge("c", "e")
        self.assertEquals(dictTree.depth(), 3)

    def testCutTree(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")
        dictTree.addEdge("a", "b", 2)
        dictTree.addEdge("a", "c")
        dictTree.addEdge("c", "d", 5)
        dictTree.addEdge("c", "f")

        A = numpy.array([10, 2])
        dictTree.setVertex("b", A)

        newTree = dictTree.cut(2)
        self.assertEquals(newTree.getVertex("a"), "foo")
        self.assertTrue((newTree.getVertex("b") == A).all())
        self.assertEquals(newTree.getEdge("a", "b"), 2)
        self.assertEquals(newTree.getEdge("a", "c"), 1)
        self.assertEquals(newTree.getEdge("c", "d"), 5)
        self.assertEquals(newTree.getEdge("c", "f"), 1)
        self.assertEquals(newTree.getNumVertices(), dictTree.getNumVertices())
        self.assertEquals(newTree.getNumEdges(), dictTree.getNumEdges())

        newTree = dictTree.cut(1)
        self.assertEquals(newTree.getEdge("a", "b"), 2)
        self.assertEquals(newTree.getEdge("a", "c"), 1)
        self.assertEquals(newTree.getNumVertices(), 3)
        self.assertEquals(newTree.getNumEdges(), 2)

        newTree = dictTree.cut(0)
        self.assertEquals(newTree.getNumVertices(), 1)
        self.assertEquals(newTree.getNumEdges(), 0)

    def testLeaves(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")

        self.assertTrue(set(dictTree.leaves()) == set(["a"]))
        dictTree.addEdge("a", "b", 2)
        dictTree.addEdge("a", "c")
        dictTree.addEdge("c", "d", 5)
        dictTree.addEdge("c", "f")

        self.assertTrue(set(dictTree.leaves()) == set(["b", "d", "f"]))

        dictTree.addEdge("b", 1)
        dictTree.addEdge("b", 2)
        self.assertTrue(set(dictTree.leaves()) == set([1, 2, "d", "f"]))
        
        #Test isSubtree leaves 
        self.assertTrue(set(dictTree.leaves("c")) == set(["d", "f"]))
        self.assertTrue(set(dictTree.leaves("b")) == set([1, 2]))


    def testAddChild(self): 
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")
        dictTree.addChild("a", "c", 2)
        dictTree.addChild("a", "d", 5)

        self.assertTrue(set(dictTree.leaves()) == set(["c", "d"]))
        
        self.assertEquals(dictTree.getVertex("c"), 2)
        self.assertEquals(dictTree.getVertex("d"), 5)
        
        self.assertTrue(dictTree.getEdge("a", "d"), 1.0)
        self.assertTrue(dictTree.getEdge("a", "c"), 1.0)
        
    def testPruneVertex(self): 
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")
        
        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("b", "d")
        dictTree.addEdge("b", "e")
        dictTree.addEdge("e", "f")
    
        dictTree.pruneVertex("b")
        self.assertFalse(dictTree.edgeExists("b", "e"))
        self.assertFalse(dictTree.edgeExists("b", "d"))
        self.assertFalse(dictTree.edgeExists("e", "f"))
        self.assertTrue(dictTree.vertexExists("b"))
        self.assertFalse(dictTree.vertexExists("d"))
        self.assertFalse(dictTree.vertexExists("e"))
        self.assertFalse(dictTree.vertexExists("f"))

        dictTree.pruneVertex("a")
        self.assertEquals(dictTree.getNumVertices(), 1)
        
    def testIsLeaf(self):         
        self.assertTrue(self.dictTree.isLeaf("c"))
        self.assertTrue(self.dictTree.isLeaf("d"))
        self.assertTrue(self.dictTree.isLeaf("f"))
        self.assertFalse(self.dictTree.isLeaf("a"))
        self.assertFalse(self.dictTree.isLeaf("b"))
        self.assertFalse(self.dictTree.isLeaf("e"))
        
    def testIsNonLeaf(self):         
        self.assertFalse(self.dictTree.isNonLeaf("c"))
        self.assertFalse(self.dictTree.isNonLeaf("d"))
        self.assertFalse(self.dictTree.isNonLeaf("f"))
        self.assertTrue(self.dictTree.isNonLeaf("a"))
        self.assertTrue(self.dictTree.isNonLeaf("b"))
        self.assertTrue(self.dictTree.isNonLeaf("e"))
        
    def testCopy(self): 
        newTree = self.dictTree.copy()
        
        newTree.addEdge("f", "x")
        newTree.addEdge("f", "y")
        
        self.assertEquals(newTree.getNumVertices(), self.dictTree.getNumVertices()+2)
        self.assertTrue(newTree.vertexExists("x"))
        self.assertTrue(newTree.vertexExists("y"))
        self.assertTrue(not self.dictTree.vertexExists("x"))
        self.assertTrue(not self.dictTree.vertexExists("x"))
        
    def testisSubtree(self): 
        newTree = DictTree()
        newTree.addEdge("a", "b")
        newTree.addEdge("a", "c")

        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("b", "d")
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("b", "e")        
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("e", "f")
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("a", "g")
        self.assertFalse(newTree.isSubtree(self.dictTree))
        
        newTree = DictTree()
        newTree.addEdge("b", "d")
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("b", "e")        
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("e", "f")
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("f", "g")
        self.assertFalse(newTree.isSubtree(self.dictTree))
        
        newTree = DictTree()
        newTree.setVertex("b")
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        self.assertFalse(self.dictTree.isSubtree(newTree))
        
        self.assertTrue(self.dictTree.isSubtree(self.dictTree))

    def testDeepCopy(self): 
        class A: 
            def __init__(self, x, y): 
                self.x = x      
                self.y = y

        a = A(1, numpy.array([1, 2]))        
        self.dictTree.setVertex("a", a)
        newTree = self.dictTree.deepCopy()  
        newTree.addEdge("f", "x")
        newTree.addEdge("f", "y")
        
        self.assertEquals(newTree.getNumVertices(), self.dictTree.getNumVertices()+2)
        self.assertTrue(newTree.vertexExists("x"))
        self.assertTrue(newTree.vertexExists("y"))
        self.assertTrue(not self.dictTree.vertexExists("x"))
        self.assertTrue(not self.dictTree.vertexExists("x"))
        self.assertEquals(self.dictTree.getVertex("a"), a)
        
        self.assertEquals(newTree.getVertex("a").x, 1)
        self.assertEquals(self.dictTree.getVertex("a").x, 1)
        a.x = 10
        self.assertEquals(newTree.getVertex("a").x, 1)
        self.assertEquals(self.dictTree.getVertex("a").x, 10)
        
        nptst.assert_array_equal(newTree.getVertex("a").y, numpy.array([1, 2])) 
        nptst.assert_array_equal(self.dictTree.getVertex("a").y, numpy.array([1, 2]))
        a.y = numpy.array([1,2,3])
        nptst.assert_array_equal(newTree.getVertex("a").y, numpy.array([1, 2])) 
        nptst.assert_array_equal(self.dictTree.getVertex("a").y, numpy.array([1, 2, 3]))

    def testSubtree(self): 
        newTree = DictTree()
        newTree.addEdge("a", "b")
        newTree.addEdge("a", "c")     
        
        subtree = newTree.subtreeAt("b")
        self.assertEquals(subtree.getAllVertexIds(), ["b"])        
        

        subtree = newTree.subtreeAt("c")
        self.assertEquals(subtree.getAllVertexIds(), ["c"])    
        
        subtree = newTree.subtreeAt("a")
        self.assertEquals(set(subtree.getAllVertexIds()), set(["a", "c", "b"]))
Esempio n. 28
0
    def testPruneVertex(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("b", "d")
        dictTree.addEdge("b", "e")
        dictTree.addEdge("e", "f")

        dictTree.pruneVertex("b")
        self.assertFalse(dictTree.edgeExists("b", "e"))
        self.assertFalse(dictTree.edgeExists("b", "d"))
        self.assertFalse(dictTree.edgeExists("e", "f"))
        self.assertTrue(dictTree.vertexExists("b"))
        self.assertFalse(dictTree.vertexExists("d"))
        self.assertFalse(dictTree.vertexExists("e"))
        self.assertFalse(dictTree.vertexExists("f"))

        dictTree.pruneVertex("a")
        self.assertEquals(dictTree.getNumVertices(), 1)
    def testLeaves(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")

        self.assertTrue(set(dictTree.leaves()) == set(["a"]))
        dictTree.addEdge("a", "b", 2)
        dictTree.addEdge("a", "c")
        dictTree.addEdge("c", "d", 5)
        dictTree.addEdge("c", "f")

        self.assertTrue(set(dictTree.leaves()) == set(["b", "d", "f"]))

        dictTree.addEdge("b", 1)
        dictTree.addEdge("b", 2)
        self.assertTrue(set(dictTree.leaves()) == set([1, 2, "d", "f"]))
        
        #Test isSubtree leaves 
        self.assertTrue(set(dictTree.leaves("c")) == set(["d", "f"]))
        self.assertTrue(set(dictTree.leaves("b")) == set([1, 2]))
Esempio n. 30
0
 def testInit(self):
     dictTree = DictTree()
Esempio n. 31
0
    def testDepth(self):
        dictTree = DictTree()
        self.assertEquals(dictTree.depth(), 0)
        dictTree.setVertex("a")
        self.assertEquals(dictTree.depth(), 0)

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertEquals(dictTree.depth(), 2)

        dictTree.addEdge("c", "e")
        self.assertEquals(dictTree.depth(), 3)
    def testisSubtree(self): 
        newTree = DictTree()
        newTree.addEdge("a", "b")
        newTree.addEdge("a", "c")

        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("b", "d")
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("b", "e")        
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("e", "f")
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("a", "g")
        self.assertFalse(newTree.isSubtree(self.dictTree))
        
        newTree = DictTree()
        newTree.addEdge("b", "d")
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("b", "e")        
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("e", "f")
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        newTree.addEdge("f", "g")
        self.assertFalse(newTree.isSubtree(self.dictTree))
        
        newTree = DictTree()
        newTree.setVertex("b")
        self.assertTrue(newTree.isSubtree(self.dictTree))
        
        self.assertFalse(self.dictTree.isSubtree(newTree))
        
        self.assertTrue(self.dictTree.isSubtree(self.dictTree))
    def testPruneVertex(self): 
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")
        
        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("b", "d")
        dictTree.addEdge("b", "e")
        dictTree.addEdge("e", "f")
    
        dictTree.pruneVertex("b")
        self.assertFalse(dictTree.edgeExists("b", "e"))
        self.assertFalse(dictTree.edgeExists("b", "d"))
        self.assertFalse(dictTree.edgeExists("e", "f"))
        self.assertTrue(dictTree.vertexExists("b"))
        self.assertFalse(dictTree.vertexExists("d"))
        self.assertFalse(dictTree.vertexExists("e"))
        self.assertFalse(dictTree.vertexExists("f"))

        dictTree.pruneVertex("a")
        self.assertEquals(dictTree.getNumVertices(), 1)
Esempio n. 34
0
    def testisSubtree(self):
        newTree = DictTree()
        newTree.addEdge("a", "b")
        newTree.addEdge("a", "c")

        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("b", "d")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("b", "e")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("e", "f")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("a", "g")
        self.assertFalse(newTree.isSubtree(self.dictTree))

        newTree = DictTree()
        newTree.addEdge("b", "d")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("b", "e")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("e", "f")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("f", "g")
        self.assertFalse(newTree.isSubtree(self.dictTree))

        newTree = DictTree()
        newTree.setVertex("b")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        self.assertFalse(self.dictTree.isSubtree(newTree))

        self.assertTrue(self.dictTree.isSubtree(self.dictTree))
Esempio n. 35
0
class PenaltyDecisionTree(AbstractPredictor):
    def __init__(self,
                 criterion="gain",
                 maxDepth=10,
                 minSplit=30,
                 learnType="reg",
                 pruning=True,
                 gamma=0.01,
                 sampleSize=10):
        """
        Learn a decision tree with penalty proportional to the root of the size 
        of the tree as in Nobel 2002. We use a stochastic approach in which we 
        learn a set of trees randomly and choose the best one. 

        :param criterion: The splitting criterion which is only informaiton gain currently 

        :param maxDepth: The maximum depth of the tree 
        :type maxDepth: `int`

        :param minSplit: The minimum size of a node for it to be split. 
        :type minSplit: `int`
        
        :param type: The type of learning to perform. Currently only regression 
        
        :param pruning: Whether to perform pruning or not. 
        :type pruning: `boolean`
        
        :param gamma: The weight on the penalty factor between 0 and 1
        :type gamma: `float`
        
        :param sampleSize: The number of trees to learn in the stochastic search. 
        :type sampleSize: `int`
        """
        super(PenaltyDecisionTree, self).__init__()
        self.maxDepth = maxDepth
        self.minSplit = minSplit
        self.criterion = criterion
        self.learnType = learnType
        self.setGamma(gamma)
        self.setSampleSize(sampleSize)
        self.pruning = pruning
        self.alphaThreshold = 0.0

    def setGamma(self, gamma):
        Parameter.checkFloat(gamma, 0.0, 1.0)
        self.gamma = gamma

    def setSampleSize(self, sampleSize):
        Parameter.checkInt(sampleSize, 1, float("inf"))
        self.sampleSize = sampleSize

    def setAlphaThreshold(self, alphaThreshold):
        Parameter.checkFloat(alphaThreshold, -float("inf"), float("inf"))
        self.alphaThreshold = alphaThreshold

    def getAlphaThreshold(self):
        return self.alphaThreshold

    def getLeftChildId(self, nodeId):
        leftChildId = list(nodeId)
        leftChildId.append(0)
        leftChildId = tuple(leftChildId)
        return leftChildId

    def getRightChildId(self, nodeId):
        rightChildId = list(nodeId)
        rightChildId.append(1)
        rightChildId = tuple(rightChildId)
        return rightChildId

    def getTree(self):
        return self.tree

    def learnModel(self, X, y):
        if numpy.unique(y).shape[0] != 2:
            raise ValueError("Must provide binary labels")
        if y.dtype != numpy.int:
            raise ValueError("Labels must be integers")

        self.shapeX = X.shape
        argsortX = numpy.zeros(X.shape, numpy.int)
        for i in range(X.shape[1]):
            argsortX[:, i] = numpy.argsort(X[:, i])
            argsortX[:, i] = numpy.argsort(argsortX[:, i])

        rootId = (0, )
        idStack = [rootId]
        self.tree = DictTree()
        rootNode = DecisionNode(numpy.arange(X.shape[0]), Util.mode(y))
        self.tree.setVertex(rootId, rootNode)
        bestError = float("inf")
        bestTree = self.tree

        #First grow a selection of trees

        while len(idStack) != 0:
            #Prune the current node away and grow from that node
            nodeId = idStack.pop()

            for i in range(self.sampleSize):
                self.tree = bestTree.deepCopy()
                try:
                    node = self.tree.getVertex(nodeId)
                except ValueError:
                    print(nodeId)
                    print(self.tree)
                    raise

                self.tree.pruneVertex(nodeId)
                self.growTree(X, y, argsortX, nodeId)
                self.prune(X, y)
                error = self.treeObjective(X, y)

                if error < bestError:
                    bestError = error
                    bestTree = self.tree.deepCopy()

            children = bestTree.children(nodeId)
            idStack.extend(children)

        self.tree = bestTree

    def growTree(self, X, y, argsortX, startId):
        """
        Grow a tree using a stack. Give a sample of data and a node index, we 
        find the best split and add children to the tree accordingly. We perform 
        pre-pruning based on the penalty. 
        """
        eps = 10**-4
        idStack = [startId]

        while len(idStack) != 0:
            nodeId = idStack.pop()
            node = self.tree.getVertex(nodeId)
            accuracies, thresholds = findBestSplitRisk(self.minSplit, X, y,
                                                       node.getTrainInds(),
                                                       argsortX)

            #Choose best feature based on gains
            accuracies += eps
            bestFeatureInd = Util.randomChoice(accuracies)[0]
            bestThreshold = thresholds[bestFeatureInd]

            nodeInds = node.getTrainInds()
            bestLeftInds = numpy.sort(nodeInds[numpy.arange(nodeInds.shape[0])[
                X[:, bestFeatureInd][nodeInds] < bestThreshold]])
            bestRightInds = numpy.sort(nodeInds[numpy.arange(
                nodeInds.shape[0])[
                    X[:, bestFeatureInd][nodeInds] >= bestThreshold]])

            #The split may have 0 items in one set, so don't split
            if bestLeftInds.sum() != 0 and bestRightInds.sum(
            ) != 0 and self.tree.depth() < self.maxDepth:
                node.setError(1 - accuracies[bestFeatureInd])
                node.setFeatureInd(bestFeatureInd)
                node.setThreshold(bestThreshold)

                leftChildId = self.getLeftChildId(nodeId)
                leftChild = DecisionNode(bestLeftInds,
                                         Util.mode(y[bestLeftInds]))
                self.tree.addChild(nodeId, leftChildId, leftChild)

                if leftChild.getTrainInds().shape[0] >= self.minSplit:
                    idStack.append(leftChildId)

                rightChildId = self.getRightChildId(nodeId)
                rightChild = DecisionNode(bestRightInds,
                                          Util.mode(y[bestRightInds]))
                self.tree.addChild(nodeId, rightChildId, rightChild)

                if rightChild.getTrainInds().shape[0] >= self.minSplit:
                    idStack.append(rightChildId)

    def predict(self, X, y=None):
        """
        Make a prediction for the set of examples given in the matrix X.  If 
        one passes in a label vector y then we set the errors for each node. On 
        the other hand if y=None, no errors are set. 
        """
        rootId = (0, )
        predY = numpy.zeros(X.shape[0])
        self.tree.getVertex(rootId).setTestInds(numpy.arange(X.shape[0]))
        idStack = [rootId]

        while len(idStack) != 0:
            nodeId = idStack.pop()
            node = self.tree.getVertex(nodeId)
            testInds = node.getTestInds()
            if y != None:
                node.setTestError(
                    self.vertexTestError(y[testInds], node.getValue()))

            if self.tree.isLeaf(nodeId):
                predY[testInds] = node.getValue()
            else:

                for childId in [
                        self.getLeftChildId(nodeId),
                        self.getRightChildId(nodeId)
                ]:
                    if self.tree.vertexExists(childId):
                        child = self.tree.getVertex(childId)

                        if childId[-1] == 0:
                            childInds = X[
                                testInds,
                                node.getFeatureInd()] < node.getThreshold()
                        else:
                            childInds = X[
                                testInds,
                                node.getFeatureInd()] >= node.getThreshold()

                        child.setTestInds(testInds[childInds])
                        idStack.append(childId)

        return predY

    def treeObjective(self, X, y):
        """
        Return the empirical risk plus penalty for the tree. 
        """
        predY = self.predict(X)
        (n, d) = X.shape
        return (1 - self.gamma) * numpy.sum(predY != y) / float(
            n) + self.gamma * numpy.sqrt(self.tree.getNumVertices())

    def prune(self, X, y):
        """
        Do some post pruning greedily. 
        """
        self.predict(X, y)
        self.computeAlphas()

        #Do the pruning, recomputing alpha along the way
        rootId = (0, )
        idStack = [rootId]

        while len(idStack) != 0:
            nodeId = idStack.pop()
            node = self.tree.getVertex(nodeId)

            if node.alpha > self.alphaThreshold:
                self.tree.pruneVertex(nodeId)
                self.computeAlphas()
            else:
                for childId in [
                        self.getLeftChildId(nodeId),
                        self.getRightChildId(nodeId)
                ]:
                    if self.tree.vertexExists(childId):
                        idStack.append(childId)

    def vertexTestError(self, trueY, predY):
        """
        This is the error used for pruning. We compute it at each node. 
        """
        return numpy.sum(trueY != predY)

    def computeAlphas(self):
        """
        The alpha value at each vertex is the improvement in the objective by 
        pruning at that vertex.  
        """
        n = self.shapeX[0]

        for vertexId in self.tree.getAllVertexIds():
            currentNode = self.tree.getVertex(vertexId)
            subtreeLeaves = self.tree.leaves(vertexId)

            subtreeError = 0
            for leaf in subtreeLeaves:
                subtreeError += self.tree.getVertex(leaf).getTestError()

            T = self.tree.getNumVertices()
            T2 = T - len(self.tree.subtreeIds(vertexId)) + 1
            currentNode.alpha = (1 - self.gamma) * (subtreeError -
                                                    currentNode.getTestError())
            currentNode.alpha /= n
            currentNode.alpha += self.gamma * numpy.sqrt(T)
            currentNode.alpha -= self.gamma * numpy.sqrt(T2)

    def copy(self):
        """
        Create a new tree with the same parameters. 
        """
        newLearner = PenaltyDecisionTree(criterion=self.criterion,
                                         maxDepth=self.maxDepth,
                                         minSplit=self.minSplit,
                                         learnType=self.learnType,
                                         pruning=self.pruning,
                                         gamma=self.gamma,
                                         sampleSize=self.sampleSize)
        return newLearner

    def getMetricMethod(self):
        """ 
        Returns a way to measure the performance of the classifier.
        """
        return Evaluator.binaryError
Esempio n. 36
0
    def testLeaves(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")

        self.assertTrue(set(dictTree.leaves()) == set(["a"]))
        dictTree.addEdge("a", "b", 2)
        dictTree.addEdge("a", "c")
        dictTree.addEdge("c", "d", 5)
        dictTree.addEdge("c", "f")

        self.assertTrue(set(dictTree.leaves()) == set(["b", "d", "f"]))

        dictTree.addEdge("b", 1)
        dictTree.addEdge("b", 2)
        self.assertTrue(set(dictTree.leaves()) == set([1, 2, "d", "f"]))

        #Test isSubtree leaves
        self.assertTrue(set(dictTree.leaves("c")) == set(["d", "f"]))
        self.assertTrue(set(dictTree.leaves("b")) == set([1, 2]))
Esempio n. 37
0
class DictGraphTest(unittest.TestCase):
    def setUp(self):
        self.dictTree = DictTree()
        self.dictTree.setVertex("a", "foo")

        self.dictTree.addEdge("a", "b")
        self.dictTree.addEdge("a", "c")
        self.dictTree.addEdge("b", "d")
        self.dictTree.addEdge("b", "e")
        self.dictTree.addEdge("e", "f")

    def testInit(self):
        dictTree = DictTree()

    def testAddEdge(self):

        dictTree = DictTree()

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        #Add duplicate edge
        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertRaises(ValueError, dictTree.addEdge, "e", "a")

        #Add isolated edge
        self.assertRaises(ValueError, dictTree.addEdge, "r", "s")

    def testGetRoot(self):
        dictTree = DictTree()

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertEquals(dictTree.getRootId(), "d")

        dictTree.addEdge("e", "d")
        self.assertEquals(dictTree.getRootId(), "e")

    def testSetVertex(self):
        dictTree = DictTree()

        dictTree.setVertex("a")
        self.assertEquals(dictTree.getVertex("a"), None)
        self.assertRaises(RuntimeError, dictTree.setVertex, "b")

        dictTree.setVertex("a", 12)
        self.assertEquals(dictTree.getVertex("a"), 12)

    def testStr(self):
        dictTree = DictTree()

        dictTree.addEdge(0, 1)
        dictTree.addEdge(0, 2)
        dictTree.addEdge(2, 3)
        dictTree.addEdge(2, 4)
        dictTree.addEdge(0, 5)
        dictTree.addEdge(4, 6)

    def testDepth(self):
        dictTree = DictTree()
        self.assertEquals(dictTree.depth(), 0)
        dictTree.setVertex("a")
        self.assertEquals(dictTree.depth(), 0)

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("d", "a")

        self.assertEquals(dictTree.depth(), 2)

        dictTree.addEdge("c", "e")
        self.assertEquals(dictTree.depth(), 3)

    def testCutTree(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")
        dictTree.addEdge("a", "b", 2)
        dictTree.addEdge("a", "c")
        dictTree.addEdge("c", "d", 5)
        dictTree.addEdge("c", "f")

        A = numpy.array([10, 2])
        dictTree.setVertex("b", A)

        newTree = dictTree.cut(2)
        self.assertEquals(newTree.getVertex("a"), "foo")
        self.assertTrue((newTree.getVertex("b") == A).all())
        self.assertEquals(newTree.getEdge("a", "b"), 2)
        self.assertEquals(newTree.getEdge("a", "c"), 1)
        self.assertEquals(newTree.getEdge("c", "d"), 5)
        self.assertEquals(newTree.getEdge("c", "f"), 1)
        self.assertEquals(newTree.getNumVertices(), dictTree.getNumVertices())
        self.assertEquals(newTree.getNumEdges(), dictTree.getNumEdges())

        newTree = dictTree.cut(1)
        self.assertEquals(newTree.getEdge("a", "b"), 2)
        self.assertEquals(newTree.getEdge("a", "c"), 1)
        self.assertEquals(newTree.getNumVertices(), 3)
        self.assertEquals(newTree.getNumEdges(), 2)

        newTree = dictTree.cut(0)
        self.assertEquals(newTree.getNumVertices(), 1)
        self.assertEquals(newTree.getNumEdges(), 0)

    def testLeaves(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")

        self.assertTrue(set(dictTree.leaves()) == set(["a"]))
        dictTree.addEdge("a", "b", 2)
        dictTree.addEdge("a", "c")
        dictTree.addEdge("c", "d", 5)
        dictTree.addEdge("c", "f")

        self.assertTrue(set(dictTree.leaves()) == set(["b", "d", "f"]))

        dictTree.addEdge("b", 1)
        dictTree.addEdge("b", 2)
        self.assertTrue(set(dictTree.leaves()) == set([1, 2, "d", "f"]))

        #Test isSubtree leaves
        self.assertTrue(set(dictTree.leaves("c")) == set(["d", "f"]))
        self.assertTrue(set(dictTree.leaves("b")) == set([1, 2]))

    def testAddChild(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")
        dictTree.addChild("a", "c", 2)
        dictTree.addChild("a", "d", 5)

        self.assertTrue(set(dictTree.leaves()) == set(["c", "d"]))

        self.assertEquals(dictTree.getVertex("c"), 2)
        self.assertEquals(dictTree.getVertex("d"), 5)

        self.assertTrue(dictTree.getEdge("a", "d"), 1.0)
        self.assertTrue(dictTree.getEdge("a", "c"), 1.0)

    def testPruneVertex(self):
        dictTree = DictTree()
        dictTree.setVertex("a", "foo")

        dictTree.addEdge("a", "b")
        dictTree.addEdge("a", "c")
        dictTree.addEdge("b", "d")
        dictTree.addEdge("b", "e")
        dictTree.addEdge("e", "f")

        dictTree.pruneVertex("b")
        self.assertFalse(dictTree.edgeExists("b", "e"))
        self.assertFalse(dictTree.edgeExists("b", "d"))
        self.assertFalse(dictTree.edgeExists("e", "f"))
        self.assertTrue(dictTree.vertexExists("b"))
        self.assertFalse(dictTree.vertexExists("d"))
        self.assertFalse(dictTree.vertexExists("e"))
        self.assertFalse(dictTree.vertexExists("f"))

        dictTree.pruneVertex("a")
        self.assertEquals(dictTree.getNumVertices(), 1)

    def testIsLeaf(self):
        self.assertTrue(self.dictTree.isLeaf("c"))
        self.assertTrue(self.dictTree.isLeaf("d"))
        self.assertTrue(self.dictTree.isLeaf("f"))
        self.assertFalse(self.dictTree.isLeaf("a"))
        self.assertFalse(self.dictTree.isLeaf("b"))
        self.assertFalse(self.dictTree.isLeaf("e"))

    def testIsNonLeaf(self):
        self.assertFalse(self.dictTree.isNonLeaf("c"))
        self.assertFalse(self.dictTree.isNonLeaf("d"))
        self.assertFalse(self.dictTree.isNonLeaf("f"))
        self.assertTrue(self.dictTree.isNonLeaf("a"))
        self.assertTrue(self.dictTree.isNonLeaf("b"))
        self.assertTrue(self.dictTree.isNonLeaf("e"))

    def testCopy(self):
        newTree = self.dictTree.copy()

        newTree.addEdge("f", "x")
        newTree.addEdge("f", "y")

        self.assertEquals(newTree.getNumVertices(),
                          self.dictTree.getNumVertices() + 2)
        self.assertTrue(newTree.vertexExists("x"))
        self.assertTrue(newTree.vertexExists("y"))
        self.assertTrue(not self.dictTree.vertexExists("x"))
        self.assertTrue(not self.dictTree.vertexExists("x"))

    def testisSubtree(self):
        newTree = DictTree()
        newTree.addEdge("a", "b")
        newTree.addEdge("a", "c")

        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("b", "d")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("b", "e")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("e", "f")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("a", "g")
        self.assertFalse(newTree.isSubtree(self.dictTree))

        newTree = DictTree()
        newTree.addEdge("b", "d")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("b", "e")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("e", "f")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        newTree.addEdge("f", "g")
        self.assertFalse(newTree.isSubtree(self.dictTree))

        newTree = DictTree()
        newTree.setVertex("b")
        self.assertTrue(newTree.isSubtree(self.dictTree))

        self.assertFalse(self.dictTree.isSubtree(newTree))

        self.assertTrue(self.dictTree.isSubtree(self.dictTree))

    def testDeepCopy(self):
        class A:
            def __init__(self, x, y):
                self.x = x
                self.y = y

        a = A(1, numpy.array([1, 2]))
        self.dictTree.setVertex("a", a)
        newTree = self.dictTree.deepCopy()
        newTree.addEdge("f", "x")
        newTree.addEdge("f", "y")

        self.assertEquals(newTree.getNumVertices(),
                          self.dictTree.getNumVertices() + 2)
        self.assertTrue(newTree.vertexExists("x"))
        self.assertTrue(newTree.vertexExists("y"))
        self.assertTrue(not self.dictTree.vertexExists("x"))
        self.assertTrue(not self.dictTree.vertexExists("x"))
        self.assertEquals(self.dictTree.getVertex("a"), a)

        self.assertEquals(newTree.getVertex("a").x, 1)
        self.assertEquals(self.dictTree.getVertex("a").x, 1)
        a.x = 10
        self.assertEquals(newTree.getVertex("a").x, 1)
        self.assertEquals(self.dictTree.getVertex("a").x, 10)

        nptst.assert_array_equal(newTree.getVertex("a").y, numpy.array([1, 2]))
        nptst.assert_array_equal(
            self.dictTree.getVertex("a").y, numpy.array([1, 2]))
        a.y = numpy.array([1, 2, 3])
        nptst.assert_array_equal(newTree.getVertex("a").y, numpy.array([1, 2]))
        nptst.assert_array_equal(
            self.dictTree.getVertex("a").y, numpy.array([1, 2, 3]))

    def testSubtree(self):
        newTree = DictTree()
        newTree.addEdge("a", "b")
        newTree.addEdge("a", "c")

        subtree = newTree.subtreeAt("b")
        self.assertEquals(subtree.getAllVertexIds(), ["b"])

        subtree = newTree.subtreeAt("c")
        self.assertEquals(subtree.getAllVertexIds(), ["c"])

        subtree = newTree.subtreeAt("a")
        self.assertEquals(set(subtree.getAllVertexIds()), set(["a", "c", "b"]))
    def testGrowTree(self):
        startId = (0, )
        minSplit = 20
        maxDepth = 3
        gamma = 0.01
        learner = PenaltyDecisionTree(minSplit=minSplit,
                                      maxDepth=maxDepth,
                                      gamma=gamma,
                                      pruning=False)

        trainX = self.X[100:, :]
        trainY = self.y[100:]
        testX = self.X[0:100, :]
        testY = self.y[0:100]

        argsortX = numpy.zeros(trainX.shape, numpy.int)
        for i in range(trainX.shape[1]):
            argsortX[:, i] = numpy.argsort(trainX[:, i])
            argsortX[:, i] = numpy.argsort(argsortX[:, i])

        learner.tree = DictTree()
        rootNode = DecisionNode(numpy.arange(trainX.shape[0]),
                                Util.mode(trainY))
        learner.tree.setVertex(startId, rootNode)

        #Note that this matches with the case where we create a new tree each time
        numpy.random.seed(21)
        bestError = float("inf")

        for i in range(20):
            learner.tree.pruneVertex(startId)
            learner.growTree(trainX, trainY, argsortX, startId)

            predTestY = learner.predict(testX)
            error = Evaluator.binaryError(predTestY, testY)
            #print(Evaluator.binaryError(predTestY, testY), learner.tree.getNumVertices())

            if error < bestError:
                bestError = error
                bestTree = learner.tree.copy()

            self.assertTrue(learner.tree.depth() <= maxDepth)

            for vertexId in learner.tree.nonLeaves():
                self.assertTrue(
                    learner.tree.getVertex(vertexId).getTrainInds().shape[0] >=
                    minSplit)

        bestError1 = bestError
        learner.tree = bestTree

        #Now we test growing a tree from a non-root vertex
        numpy.random.seed(21)
        for i in range(20):
            learner.tree.pruneVertex((0, 1))
            learner.growTree(trainX, trainY, argsortX, (0, 1))

            self.assertTrue(
                learner.tree.getVertex((0, )) == bestTree.getVertex((0, )))
            self.assertTrue(
                learner.tree.getVertex((0, 0)) == bestTree.getVertex((0, 0)))

            predTestY = learner.predict(testX)
            error = Evaluator.binaryError(predTestY, testY)

            if error < bestError:
                bestError = error
                bestTree = learner.tree.copy()
            #print(Evaluator.binaryError(predTestY, testY), learner.tree.getNumVertices())
        self.assertTrue(bestError1 >= bestError)
Esempio n. 39
0
class PenaltyDecisionTree(AbstractPredictor): 
    def __init__(self, criterion="gain", maxDepth=10, minSplit=30, learnType="reg", pruning=True, gamma=0.01, sampleSize=10):
        """
        Learn a decision tree with penalty proportional to the root of the size 
        of the tree as in Nobel 2002. We use a stochastic approach in which we 
        learn a set of trees randomly and choose the best one. 

        :param criterion: The splitting criterion which is only informaiton gain currently 

        :param maxDepth: The maximum depth of the tree 
        :type maxDepth: `int`

        :param minSplit: The minimum size of a node for it to be split. 
        :type minSplit: `int`
        
        :param type: The type of learning to perform. Currently only regression 
        
        :param pruning: Whether to perform pruning or not. 
        :type pruning: `boolean`
        
        :param gamma: The weight on the penalty factor between 0 and 1
        :type gamma: `float`
        
        :param sampleSize: The number of trees to learn in the stochastic search. 
        :type sampleSize: `int`
        """
        super(PenaltyDecisionTree, self).__init__()
        self.maxDepth = maxDepth
        self.minSplit = minSplit
        self.criterion = criterion
        self.learnType = learnType
        self.setGamma(gamma)
        self.setSampleSize(sampleSize) 
        self.pruning = pruning 
        self.alphaThreshold = 0.0
                
    def setGamma(self, gamma): 
        Parameter.checkFloat(gamma, 0.0, 1.0)
        self.gamma = gamma   
        
    def setSampleSize(self, sampleSize):
        Parameter.checkInt(sampleSize, 1, float("inf"))
        self.sampleSize = sampleSize                

    def setAlphaThreshold(self, alphaThreshold): 
        Parameter.checkFloat(alphaThreshold, -float("inf"), float("inf"))
        self.alphaThreshold = alphaThreshold
   
    def getAlphaThreshold(self): 
        return self.alphaThreshold
    
    def getLeftChildId(self, nodeId): 
        leftChildId = list(nodeId)
        leftChildId.append(0)
        leftChildId = tuple(leftChildId)
        return leftChildId

    def getRightChildId(self, nodeId): 
        rightChildId = list(nodeId)
        rightChildId.append(1)
        rightChildId = tuple(rightChildId) 
        return rightChildId
        
    def getTree(self): 
        return self.tree 
                
    def learnModel(self, X, y):
        if numpy.unique(y).shape[0] != 2: 
            raise ValueError("Must provide binary labels")
        if y.dtype != numpy.int: 
            raise ValueError("Labels must be integers")
        
        self.shapeX = X.shape  
        argsortX = numpy.zeros(X.shape, numpy.int)
        for i in range(X.shape[1]): 
            argsortX[:, i] = numpy.argsort(X[:, i])
            argsortX[:, i] = numpy.argsort(argsortX[:, i])
        
            
        rootId = (0,)
        idStack = [rootId]
        self.tree = DictTree()
        rootNode = DecisionNode(numpy.arange(X.shape[0]), Util.mode(y))
        self.tree.setVertex(rootId, rootNode)
        bestError = float("inf")
        bestTree = self.tree 
        
        #First grow a selection of trees
        
        while len(idStack) != 0:
            #Prune the current node away and grow from that node 
            nodeId = idStack.pop()
            
            for i in range(self.sampleSize):
                self.tree = bestTree.deepCopy()
                try: 
                    node = self.tree.getVertex(nodeId)
                except ValueError:
                    print(nodeId)
                    print(self.tree)
                    raise 
                        
                self.tree.pruneVertex(nodeId)
                self.growTree(X, y, argsortX, nodeId)
                self.prune(X, y)
                error = self.treeObjective(X, y)
            
                if error < bestError: 
                    bestError = error
                    bestTree = self.tree.deepCopy()
            
            children = bestTree.children(nodeId)
            idStack.extend(children)
            
        self.tree = bestTree 

    def growTree(self, X, y, argsortX, startId): 
        """
        Grow a tree using a stack. Give a sample of data and a node index, we 
        find the best split and add children to the tree accordingly. We perform 
        pre-pruning based on the penalty. 
        """
        eps = 10**-4 
        idStack = [startId]
        
        while len(idStack) != 0: 
            nodeId = idStack.pop()
            node = self.tree.getVertex(nodeId)
            accuracies, thresholds = findBestSplitRisk(self.minSplit, X, y, node.getTrainInds(), argsortX)
        
            #Choose best feature based on gains 
            accuracies += eps 
            bestFeatureInd = Util.randomChoice(accuracies)[0]
            bestThreshold = thresholds[bestFeatureInd]
        
            nodeInds = node.getTrainInds()    
            bestLeftInds = numpy.sort(nodeInds[numpy.arange(nodeInds.shape[0])[X[:, bestFeatureInd][nodeInds]<bestThreshold]]) 
            bestRightInds = numpy.sort(nodeInds[numpy.arange(nodeInds.shape[0])[X[:, bestFeatureInd][nodeInds]>=bestThreshold]])
            
            #The split may have 0 items in one set, so don't split 
            if bestLeftInds.sum() != 0 and bestRightInds.sum() != 0 and self.tree.depth() < self.maxDepth: 
                node.setError(1-accuracies[bestFeatureInd])
                node.setFeatureInd(bestFeatureInd)
                node.setThreshold(bestThreshold)            
                            
                leftChildId = self.getLeftChildId(nodeId)
                leftChild = DecisionNode(bestLeftInds, Util.mode(y[bestLeftInds]))
                self.tree.addChild(nodeId, leftChildId, leftChild)
                
                if leftChild.getTrainInds().shape[0] >= self.minSplit: 
                    idStack.append(leftChildId)
                
                rightChildId = self.getRightChildId(nodeId)
                rightChild = DecisionNode(bestRightInds, Util.mode(y[bestRightInds]))
                self.tree.addChild(nodeId, rightChildId, rightChild)
                
                if rightChild.getTrainInds().shape[0] >= self.minSplit: 
                    idStack.append(rightChildId)
        
    def predict(self, X, y=None): 
        """
        Make a prediction for the set of examples given in the matrix X.  If 
        one passes in a label vector y then we set the errors for each node. On 
        the other hand if y=None, no errors are set. 
        """ 
        rootId = (0,)
        predY = numpy.zeros(X.shape[0])
        self.tree.getVertex(rootId).setTestInds(numpy.arange(X.shape[0]))
        idStack = [rootId]

        while len(idStack) != 0:
            nodeId = idStack.pop()
            node = self.tree.getVertex(nodeId)
            testInds = node.getTestInds()
            if y!=None: 
                node.setTestError(self.vertexTestError(y[testInds], node.getValue()))
        
            if self.tree.isLeaf(nodeId): 
                predY[testInds] = node.getValue()
            else: 
                 
                for childId in [self.getLeftChildId(nodeId), self.getRightChildId(nodeId)]:
                    if self.tree.vertexExists(childId):
                        child = self.tree.getVertex(childId)
        
                        if childId[-1] == 0: 
                            childInds = X[testInds, node.getFeatureInd()] < node.getThreshold() 
                        else:
                            childInds = X[testInds, node.getFeatureInd()] >= node.getThreshold()
                        
                        child.setTestInds(testInds[childInds])   
                        idStack.append(childId)
                
        return predY

    def treeObjective(self, X, y): 
        """
        Return the empirical risk plus penalty for the tree. 
        """
        predY = self.predict(X)
        (n, d) = X.shape
        return (1-self.gamma)*numpy.sum(predY!=y)/float(n) + self.gamma*numpy.sqrt(self.tree.getNumVertices())

    def prune(self, X, y): 
        """
        Do some post pruning greedily. 
        """
        self.predict(X, y)  
        self.computeAlphas()
        
        #Do the pruning, recomputing alpha along the way 
        rootId = (0,)
        idStack = [rootId]

        while len(idStack) != 0:        
            nodeId = idStack.pop()
            node = self.tree.getVertex(nodeId)
    
            if node.alpha > self.alphaThreshold: 
                self.tree.pruneVertex(nodeId)
                self.computeAlphas()
            else: 
                for childId in [self.getLeftChildId(nodeId), self.getRightChildId(nodeId)]: 
                    if self.tree.vertexExists(childId):
                        idStack.append(childId)
        
    def vertexTestError(self, trueY, predY):
        """
        This is the error used for pruning. We compute it at each node. 
        """
        return numpy.sum(trueY != predY)
        
    def computeAlphas(self): 
        """
        The alpha value at each vertex is the improvement in the objective by 
        pruning at that vertex.  
        """
        n = self.shapeX[0]    
        
        for vertexId in self.tree.getAllVertexIds(): 
            currentNode = self.tree.getVertex(vertexId)            
            subtreeLeaves = self.tree.leaves(vertexId)
    
            subtreeError = 0 
            for leaf in subtreeLeaves: 
                subtreeError += self.tree.getVertex(leaf).getTestError()
        
            T = self.tree.getNumVertices()
            T2 = T - len(self.tree.subtreeIds(vertexId)) + 1 
            currentNode.alpha = (1-self.gamma)*(subtreeError - currentNode.getTestError())
            currentNode.alpha /= n
            currentNode.alpha += self.gamma * numpy.sqrt(T)
            currentNode.alpha -= self.gamma * numpy.sqrt(T2)

    def copy(self): 
        """
        Create a new tree with the same parameters. 
        """
        newLearner = PenaltyDecisionTree(criterion=self.criterion, maxDepth=self.maxDepth, minSplit=self.minSplit, learnType=self.learnType, pruning=self.pruning, gamma=self.gamma, sampleSize=self.sampleSize)
        return newLearner 
        
    def getMetricMethod(self):
        """ 
        Returns a way to measure the performance of the classifier.
        """
        return Evaluator.binaryError
Esempio n. 40
0
class DecisionTreeLearner(AbstractPredictor):
    def __init__(self,
                 criterion="mse",
                 maxDepth=10,
                 minSplit=30,
                 type="reg",
                 pruneType="none",
                 gamma=1000,
                 folds=5,
                 processes=None):
        """
        Need a minSplit for the internal nodes and one for leaves. 
        
        :param gamma: A value between 0 (no pruning) and 1 (full pruning) which decides how much pruning to do. 
        """
        super(DecisionTreeLearner, self).__init__()
        self.maxDepth = maxDepth
        self.minSplit = minSplit
        self.criterion = criterion
        self.type = type
        self.pruneType = pruneType
        self.setGamma(gamma)
        self.folds = 5
        self.processes = processes
        self.alphas = numpy.array([])

    def learnModel(self, X, y):
        nodeId = (0, )
        self.tree = DictTree()
        rootNode = DecisionNode(numpy.arange(X.shape[0]), y.mean())
        self.tree.setVertex(nodeId, rootNode)

        #We compute a sorted version of X
        argsortX = numpy.zeros(X.shape, numpy.int)
        for i in range(X.shape[1]):
            argsortX[:, i] = numpy.argsort(X[:, i])
            argsortX[:, i] = numpy.argsort(argsortX[:, i])

        self.growSkLearn(X, y)
        #self.recursiveSplit(X, y, argsortX, nodeId)
        self.unprunedTreeSize = self.tree.size

        if self.pruneType == "REP":
            #Note: This should be a seperate validation set
            self.repPrune(X, y)
        elif self.pruneType == "REP-CV":
            self.cvPrune(X, y)
        elif self.pruneType == "CART":
            self.cartPrune(X, y)
        elif self.pruneType == "none":
            pass
        else:
            raise ValueError("Unknown pruning type " + self.pruneType)

    #@profile
    def recursiveSplit(self, X, y, argsortX, nodeId):
        """
        Give a sample of data and a node index, we find the best split and 
        add children to the tree accordingly. 
        """
        if len(nodeId) - 1 >= self.maxDepth:
            return

        node = self.tree.getVertex(nodeId)
        bestError, bestFeatureInd, bestThreshold, bestLeftInds, bestRightInds = findBestSplit(
            self.minSplit, X, y, node.getTrainInds(), argsortX)

        #The split may have 0 items in one set, so don't split
        if bestLeftInds.sum() != 0 and bestRightInds.sum() != 0:
            node.setError(bestError)
            node.setFeatureInd(bestFeatureInd)
            node.setThreshold(bestThreshold)

            leftChildId = self.getLeftChildId(nodeId)
            leftChild = DecisionNode(bestLeftInds, y[bestLeftInds].mean())
            self.tree.addChild(nodeId, leftChildId, leftChild)

            if leftChild.getTrainInds().shape[0] >= self.minSplit:
                self.recursiveSplit(X, y, argsortX, leftChildId)

            rightChildId = self.getRightChildId(nodeId)
            rightChild = DecisionNode(bestRightInds, y[bestRightInds].mean())
            self.tree.addChild(nodeId, rightChildId, rightChild)

            if rightChild.getTrainInds().shape[0] >= self.minSplit:
                self.recursiveSplit(X, y, argsortX, rightChildId)

    def growSkLearn(self, X, y):
        """
        Grow a decision tree from sklearn. 
        """

        from sklearn.tree import DecisionTreeRegressor
        regressor = DecisionTreeRegressor(max_depth=self.maxDepth,
                                          min_samples_split=self.minSplit)
        regressor.fit(X, y)

        #Convert the sklearn tree into our tree
        nodeId = (0, )
        nodeStack = [(nodeId, 0)]

        node = DecisionNode(numpy.arange(X.shape[0]), regressor.tree_.value[0])
        self.tree.setVertex(nodeId, node)

        while len(nodeStack) != 0:
            nodeId, nodeInd = nodeStack.pop()

            node = self.tree.getVertex(nodeId)
            node.setError(regressor.tree_.best_error[nodeInd])
            node.setFeatureInd(regressor.tree_.feature[nodeInd])
            node.setThreshold(regressor.tree_.threshold[nodeInd])

            if regressor.tree_.children[nodeInd, 0] != -1:
                leftChildInds = node.getTrainInds()[
                    X[node.getTrainInds(),
                      node.getFeatureInd()] < node.getThreshold()]
                leftChildId = self.getLeftChildId(nodeId)
                leftChild = DecisionNode(
                    leftChildInds,
                    regressor.tree_.value[regressor.tree_.children[nodeInd,
                                                                   0]])
                self.tree.addChild(nodeId, leftChildId, leftChild)
                nodeStack.append((self.getLeftChildId(nodeId),
                                  regressor.tree_.children[nodeInd, 0]))

            if regressor.tree_.children[nodeInd, 1] != -1:
                rightChildInds = node.getTrainInds()[
                    X[node.getTrainInds(),
                      node.getFeatureInd()] >= node.getThreshold()]
                rightChildId = self.getRightChildId(nodeId)
                rightChild = DecisionNode(
                    rightChildInds,
                    regressor.tree_.value[regressor.tree_.children[nodeInd,
                                                                   1]])
                self.tree.addChild(nodeId, rightChildId, rightChild)
                nodeStack.append((self.getRightChildId(nodeId),
                                  regressor.tree_.children[nodeInd, 1]))

    def predict(self, X):
        """
        Make a prediction for the set of examples given in the matrix X. 
        """
        rootId = (0, )
        predY = numpy.zeros(X.shape[0])
        self.tree.getVertex(rootId).setTestInds(numpy.arange(X.shape[0]))
        predY = self.recursivePredict(X, predY, rootId)

        return predY

    def recursivePredict(self, X, y, nodeId):
        """
        Recurse through the tree and assign examples to the correct vertex. 
        """
        node = self.tree.getVertex(nodeId)
        testInds = node.getTestInds()

        if self.tree.isLeaf(nodeId):
            y[testInds] = node.getValue()
        else:

            for childId in [
                    self.getLeftChildId(nodeId),
                    self.getRightChildId(nodeId)
            ]:
                if self.tree.vertexExists(childId):
                    child = self.tree.getVertex(childId)

                    if childId[-1] == 0:
                        childInds = X[testInds, node.getFeatureInd(
                        )] < node.getThreshold()
                    else:
                        childInds = X[testInds, node.getFeatureInd(
                        )] >= node.getThreshold()

                    child.setTestInds(testInds[childInds])
                    y = self.recursivePredict(X, y, childId)

        return y

    def recursiveSetPrune(self, X, y, nodeId):
        """
        This computes test errors on nodes by passing in the test X and y. 
        """
        node = self.tree.getVertex(nodeId)
        testInds = node.getTestInds()
        node.setTestError(self.vertexTestError(y[testInds], node.getValue()))

        for childId in [
                self.getLeftChildId(nodeId),
                self.getRightChildId(nodeId)
        ]:
            if self.tree.vertexExists(childId):
                child = self.tree.getVertex(childId)

                if childId[-1] == 0:
                    childInds = X[testInds,
                                  node.getFeatureInd()] < node.getThreshold()
                else:
                    childInds = X[testInds,
                                  node.getFeatureInd()] >= node.getThreshold()
                child.setTestInds(testInds[childInds])
                self.recursiveSetPrune(X, y, childId)

    def vertexTestError(self, trueY, predY):
        """
        This is the error used for pruning. We compute it at each node. 
        """
        return numpy.sum((trueY - predY)**2)

    def computeAlphas(self):
        self.minAlpha = float("inf")
        self.maxAlpha = -float("inf")

        for vertexId in self.tree.getAllVertexIds():
            currentNode = self.tree.getVertex(vertexId)
            subtreeLeaves = self.tree.leaves(vertexId)

            testErrorSum = 0
            for leaf in subtreeLeaves:
                testErrorSum += self.tree.getVertex(leaf).getTestError()

            #Alpha is normalised difference in error
            if currentNode.getTestInds().shape[0] != 0:
                currentNode.alpha = (testErrorSum -
                                     currentNode.getTestError()) / float(
                                         currentNode.getTestInds().shape[0])

                if currentNode.alpha < self.minAlpha:
                    self.minAlpha = currentNode.alpha

                if currentNode.alpha > self.maxAlpha:
                    self.maxAlpha = currentNode.alpha

    def computeCARTAlphas(self, X):
        """
        Solve for the CART complexity based pruning. 
        """
        self.minAlpha = float("inf")
        self.maxAlpha = -float("inf")
        alphas = []

        for vertexId in self.tree.getAllVertexIds():
            currentNode = self.tree.getVertex(vertexId)
            subtreeLeaves = self.tree.leaves(vertexId)

            testErrorSum = 0
            for leaf in subtreeLeaves:
                testErrorSum += self.tree.getVertex(leaf).getTestError()

            #Alpha is reduction in error per leaf - larger alphas are better
            if currentNode.getTestInds().shape[0] != 0 and len(
                    subtreeLeaves) != 1:
                currentNode.alpha = (currentNode.getTestError() -
                                     testErrorSum) / float(
                                         X.shape[0] * (len(subtreeLeaves) - 1))
                #Flip alpha so that pruning works
                currentNode.alpha = -currentNode.alpha

                alphas.append(currentNode.alpha)
                """
                if currentNode.alpha < self.minAlpha:
                    self.minAlpha = currentNode.alpha 
                
                if currentNode.alpha > self.maxAlpha: 
                    self.maxAlpha = currentNode.alpha   
                """
        alphas = numpy.array(alphas)
        self.alphas = numpy.unique(alphas)
        self.minAlpha = numpy.min(self.alphas)
        self.maxAlpha = numpy.max(self.alphas)

    def repPrune(self, validX, validY):
        """
        Prune the decision tree using reduced error pruning. 
        """
        rootId = (0, )
        self.tree.getVertex(rootId).setTestInds(numpy.arange(validX.shape[0]))
        self.recursiveSetPrune(validX, validY, rootId)
        self.computeAlphas()
        self.prune()

    def prune(self):
        """
        We prune as early as possible and make sure the final tree has at most 
        gamma vertices. 
        """
        i = self.alphas.shape[0] - 1
        #print(self.alphas)

        while self.tree.getNumVertices() > self.gamma and i >= 0:
            #print(self.alphas[i], self.tree.getNumVertices())
            alphaThreshold = self.alphas[i]
            toPrune = []

            for vertexId in self.tree.getAllVertexIds():
                if self.tree.getVertex(vertexId).alpha >= alphaThreshold:
                    toPrune.append(vertexId)

            for vertexId in toPrune:
                if self.tree.vertexExists(vertexId):
                    self.tree.pruneVertex(vertexId)

            i -= 1

    def cartPrune(self, trainX, trainY):
        """
        Prune the tree according to the CART algorithm. Here, the chosen 
        tree is selected by thresholding alpha. In CART itself the best 
        tree is selected by using an independent pruning set. 
        """
        rootId = (0, )
        self.tree.getVertex(rootId).setTestInds(numpy.arange(trainX.shape[0]))
        self.recursiveSetPrune(trainX, trainY, rootId)
        self.computeCARTAlphas(trainX)
        self.prune()

    def cvPrune(self, validX, validY):
        """
        We do something like reduced error pruning but we use cross validation 
        to decide which nodes to prune. 
        """

        #First set the value of the vertices using the training set.
        #Reset all alphas to zero
        inds = Sampling.crossValidation(self.folds, validX.shape[0])

        for i in self.tree.getAllVertexIds():
            self.tree.getVertex(i).setAlpha(0.0)
            self.tree.getVertex(i).setTestError(0.0)

        for trainInds, testInds in inds:
            rootId = (0, )
            root = self.tree.getVertex(rootId)
            root.setTrainInds(trainInds)
            root.setTestInds(testInds)
            root.tempValue = numpy.mean(validY[trainInds])

            nodeStack = [(rootId, root.tempValue)]

            while len(nodeStack) != 0:
                (nodeId, value) = nodeStack.pop()
                node = self.tree.getVertex(nodeId)
                tempTrainInds = node.getTrainInds()
                tempTestInds = node.getTestInds()
                node.setTestError(
                    numpy.sum((validY[tempTestInds] - node.tempValue)**2) +
                    node.getTestError())
                childIds = [
                    self.getLeftChildId(nodeId),
                    self.getRightChildId(nodeId)
                ]

                for childId in childIds:
                    if self.tree.vertexExists(childId):
                        child = self.tree.getVertex(childId)

                        if childId[-1] == 0:
                            childInds = validX[
                                tempTrainInds,
                                node.getFeatureInd()] < node.getThreshold()
                        else:
                            childInds = validX[
                                tempTrainInds,
                                node.getFeatureInd()] >= node.getThreshold()

                        if childInds.sum() != 0:
                            value = numpy.mean(
                                validY[tempTrainInds[childInds]])

                        child.tempValue = value
                        child.setTrainInds(tempTrainInds[childInds])
                        nodeStack.append((childId, value))

                        if childId[-1] == 0:
                            childInds = validX[
                                tempTestInds,
                                node.getFeatureInd()] < node.getThreshold()
                        else:
                            childInds = validX[
                                tempTestInds,
                                node.getFeatureInd()] >= node.getThreshold()

                        child.setTestInds(tempTestInds[childInds])

        self.computeAlphas()
        self.prune()

    def copy(self):
        """
        Copies parameter values only 
        """
        newLearner = DecisionTreeLearner(self.criterion, self.maxDepth,
                                         self.minSplit, self.type,
                                         self.pruneType, self.gamma,
                                         self.folds)
        return newLearner

    def getMetricMethod(self):
        if self.type == "reg":
            #return Evaluator.rootMeanSqError
            return Evaluator.meanAbsError
            #return Evaluator.meanSqError
        else:
            return Evaluator.binaryError

    def getAlphaThreshold(self):
        #return self.maxAlpha - (self.maxAlpha - self.minAlpha)*self.gamma
        #A more natural way of defining gamma
        return self.alphas[numpy.round(
            (1 - self.gamma) * (self.alphas.shape[0] - 1))]

    def setGamma(self, gamma):
        """
        Gamma is an upper bound on the number of nodes in the tree. 
        """
        Parameter.checkInt(gamma, 1, float("inf"))
        self.gamma = gamma

    def getGamma(self):
        return self.gamma

    def setPruneCV(self, folds):
        Parameter.checkInt(folds, 1, float("inf"))
        self.folds = folds

    def getPruneCV(self):
        return self.folds

    def getLeftChildId(self, nodeId):
        leftChildId = list(nodeId)
        leftChildId.append(0)
        leftChildId = tuple(leftChildId)
        return leftChildId

    def getRightChildId(self, nodeId):
        rightChildId = list(nodeId)
        rightChildId.append(1)
        rightChildId = tuple(rightChildId)
        return rightChildId

    def getTree(self):
        return self.tree

    def complexity(self):
        return self.tree.size

    def getBestLearner(self, meanErrors, paramDict, X, y, idx=None):
        """
        Given a grid of errors, paramDict and examples, labels, find the 
        best learner and train it. In this case we set gamma to the real 
        size of the tree as learnt using CV. If idx == None then we simply 
        use the gamma corresponding to the lowest error. 
        """
        if idx == None:
            return super(DecisionTreeLearner,
                         self).getBestLearner(meanErrors, paramDict, X, y, idx)

        bestInds = numpy.unravel_index(numpy.argmin(meanErrors),
                                       meanErrors.shape)
        currentInd = 0
        learner = self.copy()

        for key, val in paramDict.items():
            method = getattr(learner, key)
            method(val[bestInds[currentInd]])
            currentInd += 1

        treeSizes = []
        for trainInds, testInds in idx:
            validX = X[trainInds, :]
            validY = y[trainInds]
            learner.learnModel(validX, validY)

            treeSizes.append(learner.tree.getNumVertices())

        bestGamma = int(numpy.round(numpy.array(treeSizes).mean()))

        learner.setGamma(bestGamma)
        learner.learnModel(X, y)
        return learner

    def getUnprunedTreeSize(self):
        """
        Return the size of the tree before pruning was performed. 
        """
        return self.unprunedTreeSize

    def parallelPen(self, X, y, idx, paramDict, Cvs):
        """
        Perform parallel penalisation using any learner. 
        Using the best set of parameters train using the whole dataset. In this 
        case if gamma > max(treeSize) the penalty is infinite. 

        :param X: The examples as rows
        :type X: :class:`numpy.ndarray`

        :param y: The binary -1/+1 labels 
        :type y: :class:`numpy.ndarray`

        :param idx: A list of train/test splits

        :param paramDict: A dictionary index by the method name and with value as an array of values
        :type X: :class:`dict`

        """
        return super(DecisionTreeLearner,
                     self).parallelPen(X, y, idx, paramDict, Cvs,
                                       computeVFPenTree)