Beispiel #1
0
    def test_huffman_tree_hierarchy(self):
        workspace.GlobalInit(['caffe2'])
        labelSet = range(1, 7)
        counts = [2, 4, 10, 15, 25, 40]
        labels = sum([[l] * c for (l, c) in zip(labelSet, counts)], [])
        Y = np.array(labels).astype(np.int64)
        workspace.FeedBlob("labels", Y)
        arg = caffe2_pb2.Argument()
        arg.name = 'num_classes'
        arg.i = 6
        op = core.CreateOperator(
            'HuffmanTreeHierarchy',
            ['labels'],
            ['huffman_tree'],
            'HuffmanTreeHierarchy',
            arg=[arg])
        workspace.RunOperatorOnce(op)
        huffmanTreePaths = workspace.FetchBlob('huffman_tree')
        treePathOutput = hsm_pb2.HierarchyProto()
        treePathOutput.ParseFromString(huffmanTreePaths[0])

        def checkPath(label, path, indices, code):
            self.assertEqual(path.word_id, label)
            self.assertEqual(len(path.path_nodes), len(code))
            self.assertEqual(len(path.path_nodes), len(code))
            for path_node, index, target in \
                    zip(path.path_nodes, indices, code):
                self.assertEqual(path_node.index, index)
                self.assertEqual(path_node.target, target)
        checkPath(1, treePathOutput.paths[0], [4, 3, 2, 1, 0], [1, 1, 1, 0, 0])
        checkPath(2, treePathOutput.paths[1], [4, 3, 2, 1, 0], [1, 1, 1, 0, 1])
        checkPath(3, treePathOutput.paths[2], [4, 3, 2, 1], [1, 1, 1, 1])
        checkPath(4, treePathOutput.paths[3], [4, 3, 2], [1, 1, 0])
        checkPath(5, treePathOutput.paths[4], [4, 3], [1, 0])
        checkPath(6, treePathOutput.paths[5], [4], [0])
Beispiel #2
0
def create_hierarchy(tree_proto):
    max_index = 0

    def create_path(path, word):
        path_proto = hsm_pb2.PathProto()
        path_proto.word_id = word
        for entry in path:
            new_path_node = path_proto.path_nodes.add()
            new_path_node.index = entry[0]
            new_path_node.length = entry[1]
            new_path_node.target = entry[2]
        return path_proto

    def recursive_path_builder(node_proto, path, hierarchy_proto, max_index):
        node_proto.offset = max_index
        path.append([
            max_index,
            len(node_proto.word_ids) + len(node_proto.children), 0
        ])
        max_index += len(node_proto.word_ids) + len(node_proto.children)
        if hierarchy_proto.size < max_index:
            hierarchy_proto.size = max_index
        for target, node in enumerate(node_proto.children):
            path[-1][2] = target
            max_index = recursive_path_builder(node, path, hierarchy_proto,
                                               max_index)
        for target, word in enumerate(node_proto.word_ids):
            path[-1][2] = target + len(node_proto.children)
            path_entry = create_path(path, word)
            new_path_entry = hierarchy_proto.paths.add()
            new_path_entry.MergeFrom(path_entry)
        del path[-1]
        return max_index

    node = tree_proto.root_node
    hierarchy_proto = hsm_pb2.HierarchyProto()
    path = []
    max_index = recursive_path_builder(node, path, hierarchy_proto, max_index)
    return hierarchy_proto