def test_huffman_tree_hierarchy(self): workspace.GlobalInit(['caffe2']) labelSet = list(range(0, 6)) counts = [1, 2, 3, 4, 5, 6] labels = sum([[l] * c for (l, c) in zip(labelSet, counts)], []) Y = np.array(labels).astype(np.int64) workspace.FeedBlob("labels", Y) arg = caffe2_pb2.Argument() arg.name = 'num_classes' arg.i = 6 op = core.CreateOperator( 'HuffmanTreeHierarchy', ['labels'], ['huffman_tree'], 'HuffmanTreeHierarchy', arg=[arg]) workspace.RunOperatorOnce(op) huffmanTreeOutput = workspace.FetchBlob('huffman_tree') treeOutput = hsm_pb2.TreeProto() treeOutput.ParseFromString(huffmanTreeOutput[0]) treePathOutput = hsmu.create_hierarchy(treeOutput) label_to_path = {} for path in treePathOutput.paths: label_to_path[path.word_id] = path def checkPath(label, indices, code): path = label_to_path[label] self.assertEqual(len(path.path_nodes), len(code)) self.assertEqual(len(path.path_nodes), len(code)) for path_node, index, target in \ zip(path.path_nodes, indices, code): self.assertEqual(path_node.index, index) self.assertEqual(path_node.target, target) checkPath(0, [0, 4, 6, 8], [1, 0, 0, 0]) checkPath(1, [0, 4, 6, 8], [1, 0, 0, 1]) checkPath(2, [0, 4, 6], [1, 0, 1]) checkPath(3, [0, 2], [0, 0]) checkPath(4, [0, 2], [0, 1]) checkPath(5, [0, 4], [1, 1])
import unittest from caffe2.proto import caffe2_pb2, hsm_pb2 from caffe2.python import workspace, core, gradient_checker import caffe2.python.hypothesis_test_util as hu import caffe2.python.hsm_util as hsmu # User inputs tree using protobuf file or, in this case, python utils # The hierarchy in this test looks as shown below. Note that the final subtrees # (with word_ids as leaves) have been collapsed for visualization # * # / \ # * 5,6,7,8 # / \ # 0,1,2 3,4 tree = hsm_pb2.TreeProto() words = [[0, 1, 2], [3, 4], [5, 6, 7, 8]] node1 = hsmu.create_node_with_words(words[0]) node2 = hsmu.create_node_with_words(words[1]) node3 = hsmu.create_node_with_words(words[2]) node4 = hsmu.create_node_with_nodes([node1, node2]) node = hsmu.create_node_with_nodes([node4, node3]) tree.root_node.MergeFrom(node) # Internal util to translate input tree to list of (word_id,path). serialized # hierarchy is passed into the operator_def as a string argument, hierarchy_proto = hsmu.create_hierarchy(tree) arg = caffe2_pb2.Argument() arg.name = "hierarchy" arg.s = hierarchy_proto.SerializeToString()