def _test_overlapping_names( self, inputs0: List[str] = ['i0', 'i1'], inputs1: List[str] = ['i2', 'i3'], outputs0: List[str] = ['o0', 'o1'], outputs1: List[str] = ['o2', 'o3'], value_info0: List[str] = ['v0', 'v1'], value_info1: List[str] = ['v2', 'v3'], initializer0: List[str] = ['init0', 'init1'], initializer1: List[str] = ['init2', 'init3'], sparse_initializer0: List[str] = ['sparse_init0', 'sparse_init1'], sparse_initializer1: List[str] = ['sparse_init2', 'sparse_init3'], ) -> None: n0 = [ helper.make_node('Identity', inputs=[inputs0[i]], outputs=[outputs0[i]]) for i in range(len(inputs0)) ] i0 = [ helper.make_tensor_value_info(inputs0[i], TensorProto.FLOAT, []) for i in range(len(inputs0)) ] o0 = [ helper.make_tensor_value_info(outputs0[i], TensorProto.FLOAT, []) for i in range(len(outputs0)) ] vi0 = [ helper.make_tensor_value_info(value_info0[i], TensorProto.FLOAT, []) for i in range(len(value_info0)) ] init0 = [ helper.make_tensor(name=initializer0[i], data_type=TensorProto.INT64, dims=(), vals=[1]) for i in range(len(initializer0)) ] sparse_init0 = [ _make_sparse_tensor(sparse_initializer0[i]) for i in range(len(sparse_initializer0)) ] n1 = [ helper.make_node('Identity', inputs=[inputs1[i]], outputs=[outputs1[i]]) for i in range(len(inputs1)) ] i1 = [ helper.make_tensor_value_info(inputs1[i], TensorProto.FLOAT, []) for i in range(len(inputs1)) ] o1 = [ helper.make_tensor_value_info(outputs1[i], TensorProto.FLOAT, []) for i in range(len(outputs1)) ] vi1 = [ helper.make_tensor_value_info(value_info1[i], TensorProto.FLOAT, []) for i in range(len(value_info1)) ] init1 = [ helper.make_tensor(name=initializer1[i], data_type=TensorProto.INT64, dims=(), vals=[1]) for i in range(len(initializer1)) ] sparse_init1 = [ _make_sparse_tensor(sparse_initializer1[i]) for i in range(len(sparse_initializer1)) ] ops = [helper.make_opsetid("", 10)] m0 = helper.make_model(helper.make_graph( nodes=n0, name='g0', inputs=i0, outputs=o0, value_info=vi0, initializer=init0, sparse_initializer=sparse_init0), producer_name='test', opset_imports=ops) m1 = helper.make_model(helper.make_graph( nodes=n1, name='g1', inputs=i1, outputs=o1, value_info=vi1, initializer=init1, sparse_initializer=sparse_init1), producer_name='test', opset_imports=ops) overlap = compose.check_overlapping_names(m0.graph, m1.graph) i = 0 overlapping_inputs = list(set(inputs0) & set(inputs1)) overlapping_outputs = list(set(outputs0) & set(outputs1)) overlapping_edges = list(set(overlapping_inputs + overlapping_outputs)) if len(overlapping_edges) > 0: self.assertEqual(overlap[i], ('edge', overlapping_edges)) i += 1 overlapping_vis = list(set(value_info0) & set(value_info1)) if len(overlapping_vis) > 0: self.assertEqual(overlap[i], ('value_info', overlapping_vis)) i += 1 overlapping_init = list(set(initializer0) & set(initializer1)) if len(overlapping_init) > 0: self.assertEqual(overlap[i], ('initializer', overlapping_init)) i += 1 overlapping_sparse_init = list( set(sparse_initializer0) & set(sparse_initializer1)) if len(overlapping_sparse_init) > 0: expected_overlap = [] for overlapping_name in overlapping_sparse_init: expected_overlap.append(overlapping_name + '_values') expected_overlap.append(overlapping_name + '_idx') self.assertEqual(overlap[i], ('sparse_initializer', expected_overlap)) i += 1 m0_new = compose.add_prefix(m0, prefix='g0/') overlap = compose.check_overlapping_names(m0_new.graph, m1.graph) self.assertEqual(0, len(overlap))
def _test_add_prefix(self, rename_nodes: bool = False, rename_edges: bool = False, rename_inputs: bool = False, rename_outputs: bool = False, rename_initializers: bool = False, rename_value_infos: bool = False, inplace: bool = False) -> None: m1 = _load_model(m1_def) prefix = 'pre/' if inplace: m2 = ModelProto() m2.CopyFrom(m1) compose.add_prefix(m2, prefix, rename_nodes=rename_nodes, rename_edges=rename_edges, rename_inputs=rename_inputs, rename_outputs=rename_outputs, rename_initializers=rename_initializers, rename_value_infos=rename_value_infos, inplace=True) else: m2 = compose.add_prefix(m1, prefix, rename_nodes=rename_nodes, rename_edges=rename_edges, rename_inputs=rename_inputs, rename_outputs=rename_outputs, rename_initializers=rename_initializers, rename_value_infos=rename_value_infos) g_in = m1.graph g_out = m2.graph if rename_edges or rename_inputs or rename_outputs or rename_initializers or rename_value_infos: name_mapping = {} # Rename inputs/outputs/edges. Propagate name changes from and to edges if rename_edges: for n in g_in.node: for e in n.input: name_mapping[e] = _prefixed(prefix, e) for e in n.output: name_mapping[e] = _prefixed(prefix, e) else: if rename_inputs: for elem in g_in.input: name_mapping[elem.name] = _prefixed(prefix, elem.name) if rename_outputs: for elem in g_in.output: name_mapping[elem.name] = _prefixed(prefix, elem.name) if rename_initializers: for init in g_in.initializer: name_mapping[init.name] = _prefixed(prefix, init.name) for sparse_init in g_in.sparse_initializer: name_mapping[sparse_init.values.name] = \ _prefixed(prefix, sparse_init.values.name) name_mapping[sparse_init.indices.name] = \ _prefixed(prefix, sparse_init.indices.name) if rename_value_infos: for value_info in g_in.output: name_mapping[value_info.name] = _prefixed( prefix, value_info.name) for n1, n0 in zip(g_out.node, g_in.node): for e1, e0 in zip(n1.input, n0.input): self.assertEqual(name_mapping.get(e0, e0), e1) for e1, e0 in zip(n1.output, n0.output): self.assertEqual(name_mapping.get(e0, e0), e1) for i1, i0 in zip(g_out.input, g_in.input): self.assertEqual(name_mapping.get(i0.name, i0.name), i1.name) for o1, o0 in zip(g_out.output, g_in.output): self.assertEqual(name_mapping.get(o0.name, o0.name), o1.name) for init1, init0 in zip(g_out.initializer, g_in.initializer): self.assertEqual(name_mapping.get(init0.name, init0.name), init1.name) for sparse_init1, sparse_init0 in zip(g_out.sparse_initializer, g_in.sparse_initializer): self.assertEqual( name_mapping.get(sparse_init0.values.name, sparse_init0.values.name), sparse_init1.values.name) self.assertEqual( name_mapping.get(sparse_init0.indices.name, sparse_init0.indices.name), sparse_init1.indices.name) for vi1, vi0 in zip(g_out.value_info, g_in.value_info): self.assertEqual(name_mapping.get(vi0.name, vi0.name), vi1.name) if rename_nodes: for n1, n0 in zip(g_out.node, g_in.node): self.assertEqual(_prefixed(prefix, n0.name), n1.name)