예제 #1
0
파일: hsm_test.py 프로젝트: Yangqing/caffe2
    def test_huffman_tree_hierarchy(self):
        workspace.GlobalInit(['caffe2'])
        labelSet = list(range(0, 6))
        counts = [1, 2, 3, 4, 5, 6]
        labels = sum([[l] * c for (l, c) in zip(labelSet, counts)], [])
        Y = np.array(labels).astype(np.int64)
        workspace.FeedBlob("labels", Y)
        arg = caffe2_pb2.Argument()
        arg.name = 'num_classes'
        arg.i = 6
        op = core.CreateOperator(
            'HuffmanTreeHierarchy',
            ['labels'],
            ['huffman_tree'],
            'HuffmanTreeHierarchy',
            arg=[arg])
        workspace.RunOperatorOnce(op)
        huffmanTreeOutput = workspace.FetchBlob('huffman_tree')
        treeOutput = hsm_pb2.TreeProto()
        treeOutput.ParseFromString(huffmanTreeOutput[0])
        treePathOutput = hsmu.create_hierarchy(treeOutput)

        label_to_path = {}
        for path in treePathOutput.paths:
            label_to_path[path.word_id] = path

        def checkPath(label, indices, code):
            path = label_to_path[label]
            self.assertEqual(len(path.path_nodes), len(code))
            self.assertEqual(len(path.path_nodes), len(code))
            for path_node, index, target in \
                    zip(path.path_nodes, indices, code):
                self.assertEqual(path_node.index, index)
                self.assertEqual(path_node.target, target)
        checkPath(0, [0, 4, 6, 8], [1, 0, 0, 0])
        checkPath(1, [0, 4, 6, 8], [1, 0, 0, 1])
        checkPath(2, [0, 4, 6], [1, 0, 1])
        checkPath(3, [0, 2], [0, 0])
        checkPath(4, [0, 2], [0, 1])
        checkPath(5, [0, 4], [1, 1])
예제 #2
0
    def test_huffman_tree_hierarchy(self):
        workspace.GlobalInit(['caffe2'])
        labelSet = list(range(0, 6))
        counts = [1, 2, 3, 4, 5, 6]
        labels = sum([[l] * c for (l, c) in zip(labelSet, counts)], [])
        Y = np.array(labels).astype(np.int64)
        workspace.FeedBlob("labels", Y)
        arg = caffe2_pb2.Argument()
        arg.name = 'num_classes'
        arg.i = 6
        op = core.CreateOperator(
            'HuffmanTreeHierarchy',
            ['labels'],
            ['huffman_tree'],
            'HuffmanTreeHierarchy',
            arg=[arg])
        workspace.RunOperatorOnce(op)
        huffmanTreeOutput = workspace.FetchBlob('huffman_tree')
        treeOutput = hsm_pb2.TreeProto()
        treeOutput.ParseFromString(huffmanTreeOutput[0])
        treePathOutput = hsmu.create_hierarchy(treeOutput)

        label_to_path = {}
        for path in treePathOutput.paths:
            label_to_path[path.word_id] = path

        def checkPath(label, indices, code):
            path = label_to_path[label]
            self.assertEqual(len(path.path_nodes), len(code))
            self.assertEqual(len(path.path_nodes), len(code))
            for path_node, index, target in \
                    zip(path.path_nodes, indices, code):
                self.assertEqual(path_node.index, index)
                self.assertEqual(path_node.target, target)
        checkPath(0, [0, 4, 6, 8], [1, 0, 0, 0])
        checkPath(1, [0, 4, 6, 8], [1, 0, 0, 1])
        checkPath(2, [0, 4, 6], [1, 0, 1])
        checkPath(3, [0, 2], [0, 0])
        checkPath(4, [0, 2], [0, 1])
        checkPath(5, [0, 4], [1, 1])
예제 #3
0
파일: hsm_test.py 프로젝트: Yangqing/caffe2
# structure:
# node5: [0, 2, ["node4", "node3"]] # offset, length, "node4, node3"
# node4: [2, 2, ["node1", "node2"]]
# node1: [4, 3, [0, 1 ,2]]
# node2: [7, 2, [3, 4]
# node3: [9, 4, [5, 6, 7, 8]
struct = [[0, 2, ["node4", "node3"], "node5"],
            [2, 2, ["node1", "node2"], "node4"],
            [4, 3, [0, 1, 2], "node1"],
            [7, 2, [3, 4], "node2"],
            [9, 4, [5, 6, 7, 8], "node3"]]

# Internal util to translate input tree to list of (word_id,path). serialized
# hierarchy is passed into the operator_def as a string argument,
hierarchy_proto = hsmu.create_hierarchy(tree)
arg = caffe2_pb2.Argument()
arg.name = "hierarchy"
arg.s = hierarchy_proto.SerializeToString()

beam = 5
args_search = []
arg_search = caffe2_pb2.Argument()
arg_search.name = "tree"
arg_search.s = tree.SerializeToString()
args_search.append(arg_search)
arg_search = caffe2_pb2.Argument()
arg_search.name = "beam"
arg_search.f = beam
args_search.append(arg_search)
예제 #4
0
파일: hsm_test.py 프로젝트: zxsted/caffe2
#         /  \
#        *    5,6,7,8
#       / \
#  0,1,2   3,4
tree = hsm_pb2.TreeProto()
words = [[0, 1, 2], [3, 4], [5, 6, 7, 8]]
node1 = hsmu.create_node_with_words(words[0])
node2 = hsmu.create_node_with_words(words[1])
node3 = hsmu.create_node_with_words(words[2])
node4 = hsmu.create_node_with_nodes([node1, node2])
node = hsmu.create_node_with_nodes([node4, node3])
tree.root_node.MergeFrom(node)

# Internal util to translate input tree to list of (word_id,path). serialized
# hierarchy is passed into the operator_def as a string argument,
hierarchy_proto = hsmu.create_hierarchy(tree)
arg = caffe2_pb2.Argument()
arg.name = "hierarchy"
arg.s = hierarchy_proto.SerializeToString()


class TestHsm(hu.HypothesisTestCase):
    def test_hsm_run_once(self):
        workspace.GlobalInit(['caffe2'])
        workspace.FeedBlob("data",
                           np.random.randn(1000, 100).astype(np.float32))
        workspace.FeedBlob("weights",
                           np.random.randn(1000, 100).astype(np.float32))
        workspace.FeedBlob("bias", np.random.randn(1000).astype(np.float32))
        workspace.FeedBlob("labels", np.random.randn(1000).astype(np.int32))
        op = core.CreateOperator(