Beispiel #1
0
    def _test_overlapping_names(
        self,
        inputs0: List[str] = ['i0', 'i1'],
        inputs1: List[str] = ['i2', 'i3'],
        outputs0: List[str] = ['o0', 'o1'],
        outputs1: List[str] = ['o2', 'o3'],
        value_info0: List[str] = ['v0', 'v1'],
        value_info1: List[str] = ['v2', 'v3'],
        initializer0: List[str] = ['init0', 'init1'],
        initializer1: List[str] = ['init2', 'init3'],
        sparse_initializer0: List[str] = ['sparse_init0', 'sparse_init1'],
        sparse_initializer1: List[str] = ['sparse_init2', 'sparse_init3'],
    ) -> None:
        n0 = [
            helper.make_node('Identity',
                             inputs=[inputs0[i]],
                             outputs=[outputs0[i]])
            for i in range(len(inputs0))
        ]
        i0 = [
            helper.make_tensor_value_info(inputs0[i], TensorProto.FLOAT, [])
            for i in range(len(inputs0))
        ]
        o0 = [
            helper.make_tensor_value_info(outputs0[i], TensorProto.FLOAT, [])
            for i in range(len(outputs0))
        ]
        vi0 = [
            helper.make_tensor_value_info(value_info0[i], TensorProto.FLOAT,
                                          []) for i in range(len(value_info0))
        ]
        init0 = [
            helper.make_tensor(name=initializer0[i],
                               data_type=TensorProto.INT64,
                               dims=(),
                               vals=[1]) for i in range(len(initializer0))
        ]

        sparse_init0 = [
            _make_sparse_tensor(sparse_initializer0[i])
            for i in range(len(sparse_initializer0))
        ]

        n1 = [
            helper.make_node('Identity',
                             inputs=[inputs1[i]],
                             outputs=[outputs1[i]])
            for i in range(len(inputs1))
        ]
        i1 = [
            helper.make_tensor_value_info(inputs1[i], TensorProto.FLOAT, [])
            for i in range(len(inputs1))
        ]
        o1 = [
            helper.make_tensor_value_info(outputs1[i], TensorProto.FLOAT, [])
            for i in range(len(outputs1))
        ]
        vi1 = [
            helper.make_tensor_value_info(value_info1[i], TensorProto.FLOAT,
                                          []) for i in range(len(value_info1))
        ]
        init1 = [
            helper.make_tensor(name=initializer1[i],
                               data_type=TensorProto.INT64,
                               dims=(),
                               vals=[1]) for i in range(len(initializer1))
        ]
        sparse_init1 = [
            _make_sparse_tensor(sparse_initializer1[i])
            for i in range(len(sparse_initializer1))
        ]

        ops = [helper.make_opsetid("", 10)]
        m0 = helper.make_model(helper.make_graph(
            nodes=n0,
            name='g0',
            inputs=i0,
            outputs=o0,
            value_info=vi0,
            initializer=init0,
            sparse_initializer=sparse_init0),
                               producer_name='test',
                               opset_imports=ops)
        m1 = helper.make_model(helper.make_graph(
            nodes=n1,
            name='g1',
            inputs=i1,
            outputs=o1,
            value_info=vi1,
            initializer=init1,
            sparse_initializer=sparse_init1),
                               producer_name='test',
                               opset_imports=ops)

        overlap = compose.check_overlapping_names(m0.graph, m1.graph)
        i = 0

        overlapping_inputs = list(set(inputs0) & set(inputs1))
        overlapping_outputs = list(set(outputs0) & set(outputs1))
        overlapping_edges = list(set(overlapping_inputs + overlapping_outputs))
        if len(overlapping_edges) > 0:
            self.assertEqual(overlap[i], ('edge', overlapping_edges))
            i += 1

        overlapping_vis = list(set(value_info0) & set(value_info1))
        if len(overlapping_vis) > 0:
            self.assertEqual(overlap[i], ('value_info', overlapping_vis))
            i += 1

        overlapping_init = list(set(initializer0) & set(initializer1))
        if len(overlapping_init) > 0:
            self.assertEqual(overlap[i], ('initializer', overlapping_init))
            i += 1

        overlapping_sparse_init = list(
            set(sparse_initializer0) & set(sparse_initializer1))
        if len(overlapping_sparse_init) > 0:
            expected_overlap = []
            for overlapping_name in overlapping_sparse_init:
                expected_overlap.append(overlapping_name + '_values')
                expected_overlap.append(overlapping_name + '_idx')
            self.assertEqual(overlap[i],
                             ('sparse_initializer', expected_overlap))
            i += 1

        m0_new = compose.add_prefix(m0, prefix='g0/')
        overlap = compose.check_overlapping_names(m0_new.graph, m1.graph)
        self.assertEqual(0, len(overlap))
Beispiel #2
0
    def _test_add_prefix(self,
                         rename_nodes: bool = False,
                         rename_edges: bool = False,
                         rename_inputs: bool = False,
                         rename_outputs: bool = False,
                         rename_initializers: bool = False,
                         rename_value_infos: bool = False,
                         inplace: bool = False) -> None:
        m1 = _load_model(m1_def)

        prefix = 'pre/'

        if inplace:
            m2 = ModelProto()
            m2.CopyFrom(m1)
            compose.add_prefix(m2,
                               prefix,
                               rename_nodes=rename_nodes,
                               rename_edges=rename_edges,
                               rename_inputs=rename_inputs,
                               rename_outputs=rename_outputs,
                               rename_initializers=rename_initializers,
                               rename_value_infos=rename_value_infos,
                               inplace=True)
        else:
            m2 = compose.add_prefix(m1,
                                    prefix,
                                    rename_nodes=rename_nodes,
                                    rename_edges=rename_edges,
                                    rename_inputs=rename_inputs,
                                    rename_outputs=rename_outputs,
                                    rename_initializers=rename_initializers,
                                    rename_value_infos=rename_value_infos)
        g_in = m1.graph
        g_out = m2.graph

        if rename_edges or rename_inputs or rename_outputs or rename_initializers or rename_value_infos:
            name_mapping = {}

            # Rename inputs/outputs/edges. Propagate name changes from and to edges
            if rename_edges:
                for n in g_in.node:
                    for e in n.input:
                        name_mapping[e] = _prefixed(prefix, e)
                    for e in n.output:
                        name_mapping[e] = _prefixed(prefix, e)
            else:
                if rename_inputs:
                    for elem in g_in.input:
                        name_mapping[elem.name] = _prefixed(prefix, elem.name)
                if rename_outputs:
                    for elem in g_in.output:
                        name_mapping[elem.name] = _prefixed(prefix, elem.name)

            if rename_initializers:
                for init in g_in.initializer:
                    name_mapping[init.name] = _prefixed(prefix, init.name)
                for sparse_init in g_in.sparse_initializer:
                    name_mapping[sparse_init.values.name] = \
                        _prefixed(prefix, sparse_init.values.name)
                    name_mapping[sparse_init.indices.name] = \
                        _prefixed(prefix, sparse_init.indices.name)

            if rename_value_infos:
                for value_info in g_in.output:
                    name_mapping[value_info.name] = _prefixed(
                        prefix, value_info.name)

            for n1, n0 in zip(g_out.node, g_in.node):
                for e1, e0 in zip(n1.input, n0.input):
                    self.assertEqual(name_mapping.get(e0, e0), e1)
                for e1, e0 in zip(n1.output, n0.output):
                    self.assertEqual(name_mapping.get(e0, e0), e1)
            for i1, i0 in zip(g_out.input, g_in.input):
                self.assertEqual(name_mapping.get(i0.name, i0.name), i1.name)
            for o1, o0 in zip(g_out.output, g_in.output):
                self.assertEqual(name_mapping.get(o0.name, o0.name), o1.name)

            for init1, init0 in zip(g_out.initializer, g_in.initializer):
                self.assertEqual(name_mapping.get(init0.name, init0.name),
                                 init1.name)

            for sparse_init1, sparse_init0 in zip(g_out.sparse_initializer,
                                                  g_in.sparse_initializer):
                self.assertEqual(
                    name_mapping.get(sparse_init0.values.name,
                                     sparse_init0.values.name),
                    sparse_init1.values.name)
                self.assertEqual(
                    name_mapping.get(sparse_init0.indices.name,
                                     sparse_init0.indices.name),
                    sparse_init1.indices.name)

            for vi1, vi0 in zip(g_out.value_info, g_in.value_info):
                self.assertEqual(name_mapping.get(vi0.name, vi0.name),
                                 vi1.name)

            if rename_nodes:
                for n1, n0 in zip(g_out.node, g_in.node):
                    self.assertEqual(_prefixed(prefix, n0.name), n1.name)