def test_preorder(self): self.tree.add_root(MultiTreeNode(val=10)) first_node = self.tree.add_child(self.tree.root, MultiTreeNode(val=2)) second_node = self.tree.add_child(first_node, MultiTreeNode(val=3)) third_node = self.tree.add_child(self.tree.root, MultiTreeNode(val=4)) fourt_node = self.tree.add_child(first_node, MultiTreeNode(val=5)) result = self.tree.preorder() correct = [10, 2, 3, 5, 4] assert correct == result
def test_find(self): tree = MultiTree() root = tree.add_root(MultiTreeNode([1, 2, 3, 4])) a = tree.add_child(root, MultiTreeNode([1])) b = tree.add_child(root, MultiTreeNode([2])) c = tree.add_child(root, MultiTreeNode([3, 4])) assert tree.find(3).val == [3, 4] d = tree.add_child(c, MultiTreeNode([3])) e = tree.add_child(c, MultiTreeNode([4])) assert tree.find(4).val == [4]
def runner(current): # loop to remove every link that larger than average sum_weight, edges = mst_graph.sum_weight(current.val[0]) # terminate when there is not edge to go if len(edges) is 0: return # sort weight desc edges.sort(key=lambda x: -x[2]) avg_weight = sum_weight / len(edges) btree = BinaryTree() btree.add_root(BinaryTreeNode(current.val)) for edge in edges: # remove this link if edge[2] >= avg_weight: # remove the link mst_graph.double_unlink(edge[0], edge[1]) # add this link to the binary tree parent = btree.find(edge[0]) # if edge[0] in left.val: # parent = left # else: # parent = right left = btree.add_left(parent, BinaryTreeNode(mst_graph.connected_with(edge[0]))) right = btree.add_right(parent, BinaryTreeNode(mst_graph.connected_with(edge[1]))) # else or the last one if edge[2] < avg_weight or edge == edges[-1]: # groups are the display output of the btree groups = btree.leaves() for group in groups: new_node = MultiTreeNode(group) # add new group to the leat of the multi tree tree.add_child(current, new_node) # recursively run it runner(new_node) # if the link's weight has become smaller than avg_weight, # there is no need to keep going on break
def test_add_child(self): self.tree.add_root(MultiTreeNode(val=10)) root = self.tree.root adding_node = self.tree.add_child(root, MultiTreeNode(val=1)) assert root.children[0].val == 1 assert adding_node.parent.val == 10
def test_add_root(self): self.tree.add_root(MultiTreeNode(val=10)) assert self.tree.root.val == 10
def train(self, training_classes): separability, label_to_int, int_to_label = self._find_separability(training_classes) # create a mesh class_cnt = len(training_classes.keys()) mesh = Graph(class_cnt) for i, row in enumerate(separability): for j, sep in enumerate(row): mesh.link(i, j, sep) # create the mst of this mesh mst_list = mesh.mst() mst_graph = Graph(class_cnt) for link in mst_list: mst_graph.double_link(link[0], link[1], link[2]) # recursively remove links (that are greater than the average of the mst) # at the same time create the binary tree tree = MultiTree() # the root node is the node of all members, assume that all connected with 0 all_classes = mst_graph.connected_with(0) tree.add_root(MultiTreeNode(all_classes)) def runner(current): # loop to remove every link that larger than average sum_weight, edges = mst_graph.sum_weight(current.val[0]) # terminate when there is not edge to go if len(edges) is 0: return # sort weight desc edges.sort(key=lambda x: -x[2]) avg_weight = sum_weight / len(edges) btree = BinaryTree() btree.add_root(BinaryTreeNode(current.val)) for edge in edges: # remove this link if edge[2] >= avg_weight: # remove the link mst_graph.double_unlink(edge[0], edge[1]) # add this link to the binary tree parent = btree.find(edge[0]) # if edge[0] in left.val: # parent = left # else: # parent = right left = btree.add_left(parent, BinaryTreeNode(mst_graph.connected_with(edge[0]))) right = btree.add_right(parent, BinaryTreeNode(mst_graph.connected_with(edge[1]))) # else or the last one if edge[2] < avg_weight or edge == edges[-1]: # groups are the display output of the btree groups = btree.leaves() for group in groups: new_node = MultiTreeNode(group) # add new group to the leat of the multi tree tree.add_child(current, new_node) # recursively run it runner(new_node) # if the link's weight has become smaller than avg_weight, # there is no need to keep going on break if self.verbose: start_time = time.process_time() runner(tree.root) if self.verbose: print('train: %.4f' % (time.process_time() - start_time)) # now got the tree # train svm according to the mulitree svm_cnt = 0 def train(training_classes): def runner(current, universe): if current.children == None: return child_universes = [{} for each in current.children] for class_name, samples in universe.items(): for i, child in enumerate(current.children): # the class belongs to this child if class_name in child.val: child_universes[i][class_name] = samples current.svms = [None for child in current.children] # one against the rest method for i, child in enumerate(current.children): training = [] labels = [] for class_int, samples in universe.items(): # class in this child is marked as 0 if class_int in child.val: training += samples.tolist() labels += [0 for each in samples] else: # put to one labeled box training += samples.tolist() labels += [1 for each in samples] training = numpy.array(training) labels = numpy.array(labels) # train the svms # using one against the rest method current.svms[i] = sklearn.svm.SVC(kernel='rbf', gamma=self.gamma, C=self.C) \ .fit(training, labels) nonlocal svm_cnt svm_cnt += 1 # the recursive part runner(child, child_universes[i]) # relabel all the classes to int based universe = {} for key, val in training_classes.items(): universe[label_to_int[key]] = val runner(tree.root, universe) train(training_classes) # make these vars visible to the outsiders self.tree = tree self.int_to_label = int_to_label self.label_to_int = label_to_int return svm_cnt