def test_add_pruned_collection_proto_in_bytes_list(self): collection_name = 'proto_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[ collection_name].bytes_list.value.extend([ compat.as_bytes( compat.as_str_any(_make_asset_file_def_any('node1'))), compat.as_bytes( compat.as_str_any(_make_asset_file_def_any('node2'))), compat.as_bytes( compat.as_str_any(_make_asset_file_def_any('node3'))), compat.as_bytes( compat.as_str_any(_make_asset_file_def_any('node4'))) ]) meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5'] meta_graph_transform._add_pruned_collection(base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] expected_values = [ compat.as_bytes( compat.as_str_any(_make_asset_file_def_any('node1'))), compat.as_bytes( compat.as_str_any(_make_asset_file_def_any('node3'))) ] self.assertEqual(expected_values, collection.bytes_list.value[:])
def test_add_pruned_collection_proto_in_any_list(self): # Note: This also tests _is_removed_mentioned(). collection_name = 'proto_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[ collection_name].any_list.value.extend([ _make_asset_file_def_any('node1'), _make_asset_file_def_any('node2'), _make_asset_file_def_any('node3'), _make_asset_file_def_any('node4'), _make_asset_file_def_any('/a/a_1'), _make_asset_file_def_any('/b/b_1') ]) meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1'] meta_graph_transform._add_pruned_collection(base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] expected_protos = [ _make_asset_file_def_any('node1'), _make_asset_file_def_any('node3'), _make_asset_file_def_any('/a/a_1'), ] self.assertEqual(expected_protos, collection.any_list.value[:])
def test_add_pruned_collection_proto_in_bytes_list(self): # Note: This also tests _is_removed_mentioned(). collection_name = 'proto_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[collection_name].bytes_list.value.extend( [compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node2'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/a/a_1'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/b/b_1'))) ]) meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] expected_values = [ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/a/a_1'))), ] self.assertEqual(expected_values, collection.bytes_list.value[:])
def test_add_pruned_collection_int(self): collection_name = 'int_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[collection_name].int64_list.value[:] = ( [10, 20, 30, 40]) meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] expected_ints = [10, 20, 30, 40] self.assertEqual(expected_ints, collection.int64_list.value)
def test_add_pruned_collection_node(self): collection_name = 'node_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[collection_name].node_list.value.extend( ['node1', 'node2', 'node3', 'node4']) meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] expected_nodes = ['node1', 'node3'] self.assertEqual(expected_nodes, collection.node_list.value)
def test_add_pruned_collection_proto_in_any_list(self): collection_name = 'proto_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[collection_name].any_list.value.extend( [_make_asset_file_def_any('node1'), _make_asset_file_def_any('node2'), _make_asset_file_def_any('node3'), _make_asset_file_def_any('node4')]) meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] expected_protos = [_make_asset_file_def_any('node1'), _make_asset_file_def_any('node3')] self.assertEqual(expected_protos, collection.any_list.value[:])