def test_huffman_tree_hierarchy(self): workspace.GlobalInit(['caffe2']) labelSet = range(1, 7) counts = [2, 4, 10, 15, 25, 40] labels = sum([[l] * c for (l, c) in zip(labelSet, counts)], []) Y = np.array(labels).astype(np.int64) workspace.FeedBlob("labels", Y) arg = caffe2_pb2.Argument() arg.name = 'num_classes' arg.i = 6 op = core.CreateOperator( 'HuffmanTreeHierarchy', ['labels'], ['huffman_tree'], 'HuffmanTreeHierarchy', arg=[arg]) workspace.RunOperatorOnce(op) huffmanTreePaths = workspace.FetchBlob('huffman_tree') treePathOutput = hsm_pb2.HierarchyProto() treePathOutput.ParseFromString(huffmanTreePaths[0]) def checkPath(label, path, indices, code): self.assertEqual(path.word_id, label) self.assertEqual(len(path.path_nodes), len(code)) self.assertEqual(len(path.path_nodes), len(code)) for path_node, index, target in \ zip(path.path_nodes, indices, code): self.assertEqual(path_node.index, index) self.assertEqual(path_node.target, target) checkPath(1, treePathOutput.paths[0], [4, 3, 2, 1, 0], [1, 1, 1, 0, 0]) checkPath(2, treePathOutput.paths[1], [4, 3, 2, 1, 0], [1, 1, 1, 0, 1]) checkPath(3, treePathOutput.paths[2], [4, 3, 2, 1], [1, 1, 1, 1]) checkPath(4, treePathOutput.paths[3], [4, 3, 2], [1, 1, 0]) checkPath(5, treePathOutput.paths[4], [4, 3], [1, 0]) checkPath(6, treePathOutput.paths[5], [4], [0])
def create_hierarchy(tree_proto): max_index = 0 def create_path(path, word): path_proto = hsm_pb2.PathProto() path_proto.word_id = word for entry in path: new_path_node = path_proto.path_nodes.add() new_path_node.index = entry[0] new_path_node.length = entry[1] new_path_node.target = entry[2] return path_proto def recursive_path_builder(node_proto, path, hierarchy_proto, max_index): node_proto.offset = max_index path.append([ max_index, len(node_proto.word_ids) + len(node_proto.children), 0 ]) max_index += len(node_proto.word_ids) + len(node_proto.children) if hierarchy_proto.size < max_index: hierarchy_proto.size = max_index for target, node in enumerate(node_proto.children): path[-1][2] = target max_index = recursive_path_builder(node, path, hierarchy_proto, max_index) for target, word in enumerate(node_proto.word_ids): path[-1][2] = target + len(node_proto.children) path_entry = create_path(path, word) new_path_entry = hierarchy_proto.paths.add() new_path_entry.MergeFrom(path_entry) del path[-1] return max_index node = tree_proto.root_node hierarchy_proto = hsm_pb2.HierarchyProto() path = [] max_index = recursive_path_builder(node, path, hierarchy_proto, max_index) return hierarchy_proto