コード例 #1
0
    def test_add_pruned_collection_proto_in_bytes_list(self):
        collection_name = 'proto_collection'
        base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
        base_meta_graph_def.collection_def[
            collection_name].bytes_list.value.extend([
                compat.as_bytes(
                    compat.as_str_any(_make_asset_file_def_any('node1'))),
                compat.as_bytes(
                    compat.as_str_any(_make_asset_file_def_any('node2'))),
                compat.as_bytes(
                    compat.as_str_any(_make_asset_file_def_any('node3'))),
                compat.as_bytes(
                    compat.as_str_any(_make_asset_file_def_any('node4')))
            ])

        meta_graph_def = meta_graph_pb2.MetaGraphDef()
        removed_op_names = ['node2', 'node4', 'node5']
        meta_graph_transform._add_pruned_collection(base_meta_graph_def,
                                                    meta_graph_def,
                                                    collection_name,
                                                    removed_op_names)

        collection = meta_graph_def.collection_def[collection_name]

        expected_values = [
            compat.as_bytes(
                compat.as_str_any(_make_asset_file_def_any('node1'))),
            compat.as_bytes(
                compat.as_str_any(_make_asset_file_def_any('node3')))
        ]
        self.assertEqual(expected_values, collection.bytes_list.value[:])
コード例 #2
0
    def test_add_pruned_collection_proto_in_any_list(self):
        # Note: This also tests _is_removed_mentioned().
        collection_name = 'proto_collection'
        base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
        base_meta_graph_def.collection_def[
            collection_name].any_list.value.extend([
                _make_asset_file_def_any('node1'),
                _make_asset_file_def_any('node2'),
                _make_asset_file_def_any('node3'),
                _make_asset_file_def_any('node4'),
                _make_asset_file_def_any('/a/a_1'),
                _make_asset_file_def_any('/b/b_1')
            ])

        meta_graph_def = meta_graph_pb2.MetaGraphDef()
        removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1']
        meta_graph_transform._add_pruned_collection(base_meta_graph_def,
                                                    meta_graph_def,
                                                    collection_name,
                                                    removed_op_names)

        collection = meta_graph_def.collection_def[collection_name]

        expected_protos = [
            _make_asset_file_def_any('node1'),
            _make_asset_file_def_any('node3'),
            _make_asset_file_def_any('/a/a_1'),
        ]
        self.assertEqual(expected_protos, collection.any_list.value[:])
コード例 #3
0
  def test_add_pruned_collection_proto_in_bytes_list(self):
    # Note: This also tests _is_removed_mentioned().
    collection_name = 'proto_collection'
    base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
    base_meta_graph_def.collection_def[collection_name].bytes_list.value.extend(
        [compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node2'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/a/a_1'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/b/b_1')))
        ])

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1']
    meta_graph_transform._add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)

    collection = meta_graph_def.collection_def[collection_name]

    expected_values = [
        compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
        compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))),
        compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/a/a_1'))),
    ]
    self.assertEqual(expected_values, collection.bytes_list.value[:])
コード例 #4
0
  def test_add_pruned_collection_int(self):
    collection_name = 'int_collection'
    base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
    base_meta_graph_def.collection_def[collection_name].int64_list.value[:] = (
        [10, 20, 30, 40])

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    removed_op_names = ['node2', 'node4', 'node5']
    meta_graph_transform._add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)

    collection = meta_graph_def.collection_def[collection_name]

    expected_ints = [10, 20, 30, 40]
    self.assertEqual(expected_ints, collection.int64_list.value)
コード例 #5
0
  def test_add_pruned_collection_node(self):
    collection_name = 'node_collection'
    base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
    base_meta_graph_def.collection_def[collection_name].node_list.value.extend(
        ['node1', 'node2', 'node3', 'node4'])

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    removed_op_names = ['node2', 'node4', 'node5']
    meta_graph_transform._add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)

    collection = meta_graph_def.collection_def[collection_name]

    expected_nodes = ['node1', 'node3']
    self.assertEqual(expected_nodes, collection.node_list.value)
コード例 #6
0
  def test_add_pruned_collection_proto_in_any_list(self):
    collection_name = 'proto_collection'
    base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
    base_meta_graph_def.collection_def[collection_name].any_list.value.extend(
        [_make_asset_file_def_any('node1'),
         _make_asset_file_def_any('node2'),
         _make_asset_file_def_any('node3'),
         _make_asset_file_def_any('node4')])

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    removed_op_names = ['node2', 'node4', 'node5']
    meta_graph_transform._add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)

    collection = meta_graph_def.collection_def[collection_name]

    expected_protos = [_make_asset_file_def_any('node1'),
                       _make_asset_file_def_any('node3')]
    self.assertEqual(expected_protos, collection.any_list.value[:])