Example #1
0
 def _test_merge_models(self,
                        m1def: str,
                        m2def: str,
                        io_map: List[Tuple[str, str]],
                        check_expectations: Callable[
                            [GraphProto, GraphProto, GraphProto], None],
                        inputs: Optional[List[str]] = None,
                        outputs: Optional[List[str]] = None,
                        prefix1: Optional[str] = None,
                        prefix2: Optional[str] = None) -> None:
     m1, m2 = _load_model(m1def), _load_model(m2def)
     g3 = compose.merge_graphs(
         m1.graph,
         m2.graph,
         io_map=io_map,
         inputs=inputs,
         outputs=outputs,
         prefix1=prefix1,
         prefix2=prefix2,
     )
     checker.check_graph(g3)
     check_expectations(m1.graph, m2.graph, g3)
     m3 = compose.merge_models(
         m1,
         m2,
         io_map=io_map,
         inputs=inputs,
         outputs=outputs,
         prefix1=prefix1,
         prefix2=prefix2,
     )
     checker.check_model(m3)
     check_expectations(m1.graph, m2.graph, m3.graph)
Example #2
0
 def _test_merge_models(
     self,
     m1def,  # type: Text
     m2def,  # type: Text
     io_map,  # type: List[Tuple[Text, Text]]
     check_expectations,  # type: Callable[[GraphProto, GraphProto, GraphProto], None]
     inputs=None,  # type: Optional[List[Text]]
     outputs=None,  # type: Optional[List[Text]]
     prefix1=None,  # type: Optional[Text]
     prefix2=None  # type: Optional[Text]
 ):  # type: (...) -> None
     m1, m2 = _load_model(m1def), _load_model(m2def)
     g3 = compose.merge_graphs(
         m1.graph,
         m2.graph,
         io_map=io_map,
         inputs=inputs,
         outputs=outputs,
         prefix1=prefix1,
         prefix2=prefix2,
     )
     checker.check_graph(g3)
     check_expectations(m1.graph, m2.graph, g3)
     m3 = compose.merge_models(
         m1,
         m2,
         io_map=io_map,
         inputs=inputs,
         outputs=outputs,
         prefix1=prefix1,
         prefix2=prefix2,
     )
     checker.check_model(m3)
     check_expectations(m1.graph, m2.graph, m3.graph)
Example #3
0
    def test_merge_models_with_metadata_props(self) -> None:
        m1 = _load_model(m1_def)
        helper.set_model_props(m1, {'p1': 'v1', 'p2': 'v2'})

        m2 = _load_model(m2_def)
        helper.set_model_props(m2, {'p3': 'v3', 'p4': 'v4'})

        io_map = [("B00", "B01")]
        m3 = compose.merge_models(m1, m2, io_map=io_map)
        assert len(m3.metadata_props) == 4

        # Overlap, but same value
        helper.set_model_props(m2, {'p1': 'v1', 'p4': 'v4'})
        m3 = compose.merge_models(m1, m2, io_map=io_map)
        assert len(m3.metadata_props) == 3

        # Same keys but not same value. Error
        helper.set_model_props(m2, {'p1': 'v5', 'p4': 'v4'})
        self.assertRaises(ValueError,
                          compose.merge_models,
                          m1,
                          m2,
                          io_map=io_map)
Example #4
0
    def test_error_opset_import_mismatch(self) -> None:
        '''
        Tests that providing models with different operator set imported produces an error
        '''
        m1, m2 = _load_model(m1_def), _load_model(m2_def)
        m1 = helper.make_model(m1.graph,
                               producer_name='test',
                               opset_imports=[helper.make_opsetid("", 10)])
        m2 = helper.make_model(m2.graph,
                               producer_name='test',
                               opset_imports=[helper.make_opsetid("", 15)])

        io_map = [("B00", "B01"), ("B10", "B11"), ("B20", "B21")]
        self.assertRaises(ValueError, compose.merge_models, m1, m2, io_map)

        # Converting to the same Operator set version, should work
        m1 = version_converter.convert_version(m1, 15)
        m3 = compose.merge_models(m1, m2, io_map=io_map)
        checker.check_model(m3)
Example #5
0
    def test_merge_drop_unnecessary_initializers_and_value_info(self) -> None:
        '''
        Tests automatic removal of initializers when merging graphs
        '''
        ops = [helper.make_opsetid("", 10)]

        g = GraphProto()
        g.input.extend(
            [helper.make_tensor_value_info('x', TensorProto.FLOAT, [])])
        g.output.extend(
            [helper.make_tensor_value_info('y', TensorProto.FLOAT, [])])
        g.node.extend(
            [helper.make_node('Identity', inputs=['x'], outputs=['y'])])

        g1 = GraphProto()
        g1.CopyFrom(g)
        g1.name = 'g1'
        m1 = helper.make_model(g1, producer_name='test', opset_imports=ops)
        checker.check_model(m1)

        g2 = GraphProto()
        g2.CopyFrom(g)
        g2.name = 'g2'
        g2.initializer.extend([
            helper.make_tensor(name='x',
                               data_type=TensorProto.FLOAT,
                               dims=(),
                               vals=[0])
        ])
        m2 = helper.make_model(g2, producer_name='test', opset_imports=ops)
        checker.check_model(m2)

        g3 = GraphProto()
        g3.CopyFrom(g)
        g3.name = 'g3'
        g3.sparse_initializer.extend([_make_sparse_tensor('x')])
        m3 = helper.make_model(g3, producer_name='test', opset_imports=ops)
        checker.check_model(m3)

        g4 = GraphProto()
        g4.CopyFrom(g)
        g4.name = 'g3'
        g4.value_info.extend(
            [helper.make_tensor_value_info('x', TensorProto.FLOAT, [])])
        m4 = helper.make_model(g4, producer_name='test', opset_imports=ops)
        checker.check_model(m4)

        # Initializer 'x' from m1 is removed, because there is no longer an input with that name
        out_m1 = compose.merge_models(m1,
                                      m2,
                                      prefix1='m1/',
                                      io_map=[('y', 'x')])
        self.assertEqual(0, len(out_m1.graph.initializer))

        # Sparse initializer 'x' from m1 is removed, because there is no longer an input with that name
        out_m2 = compose.merge_models(m1,
                                      m3,
                                      prefix1='m1/',
                                      io_map=[('y', 'x')])
        self.assertEqual(0, len(out_m2.graph.initializer))

        # Value info 'x' from m1 is removed, because there is no longer an input with that name
        out_m3 = compose.merge_models(m1,
                                      m4,
                                      prefix1='m1/',
                                      io_map=[('y', 'x')])
        self.assertEqual(0, len(out_m3.graph.value_info))
Example #6
0
    def test_overlapping_function_names(self) -> None:
        '''
        Tests error checking when the name of local function entries overlaps
        '''
        ops = [helper.make_opsetid("", 10), helper.make_opsetid("local", 10)]

        def _make_function(
            domain: str,
            fname: str,
            inputs: List[str],
            outputs: List[str],
            nodes: List[NodeProto],
        ) -> FunctionProto:
            f = FunctionProto()
            f.domain = domain
            f.name = fname
            f.input.extend(inputs)
            f.output.extend(outputs)
            f.node.extend(nodes)
            f.opset_import.extend(ops)
            return f

        ops = [helper.make_opsetid("", 10), helper.make_opsetid("local", 10)]

        g = GraphProto()
        g.input.extend([
            helper.make_tensor_value_info('x0', TensorProto.FLOAT, []),
            helper.make_tensor_value_info('x1', TensorProto.FLOAT, [])
        ])
        g.output.extend([
            helper.make_tensor_value_info('y', TensorProto.FLOAT, []),
        ])
        g.node.extend([
            helper.make_node('f1',
                             domain='local',
                             inputs=['x0', 'x1'],
                             outputs=['y'])
        ])

        g1 = GraphProto()
        g1.CopyFrom(g)
        g1.name = 'g1'
        m1 = helper.make_model(g1, producer_name='test', opset_imports=ops)
        m1.functions.extend([
            _make_function(
                'local', 'f1', ['x0', 'x1'], ['y'],
                [helper.make_node('Add', inputs=['x0', 'x1'], outputs=['y'])])
        ])
        checker.check_model(m1)

        g2 = GraphProto()
        g2.CopyFrom(g)
        g2.name = 'g2'
        m2 = helper.make_model(g2, producer_name='test', opset_imports=ops)
        m2.functions.extend([
            _make_function(
                'local', 'f1', ['x0', 'x1'], ['y'],
                [helper.make_node('Mul', inputs=['x0', 'x1'], outputs=['y'])])
        ])
        checker.check_model(m2)

        m = compose.merge_models(m1,
                                 m2,
                                 io_map=[('y', 'x0'), ('y', 'x1')],
                                 prefix1='m1/',
                                 prefix2='m2/')
        checker.check_model(m)

        nodes = [n.op_type for n in m.graph.node]
        self.assertEqual(['m1/f1', 'm2/f1'], nodes)

        functions = [f.name for f in m.functions]
        self.assertEqual(['m1/f1', 'm2/f1'], functions)

        g3 = GraphProto()
        g3.CopyFrom(g)
        g3.name = 'g3'
        g3.node[0].op_type = 'f2'
        m3 = helper.make_model(g3, producer_name='test', opset_imports=ops)
        m3.functions.extend([
            _make_function('local', 'f1', ['x0', 'x1'], ['y'], [
                helper.make_node('Add', inputs=['x0', 'x1'], outputs=['y0']),
                helper.make_node('Mul', inputs=['x0', 'x1'], outputs=['y1']),
                helper.make_node('Add', inputs=['y0', 'y1'], outputs=['y'])
            ]),
            _make_function('local', 'f2', ['x0', 'x1'], ['y'], [
                helper.make_node(
                    'f1', domain='local', inputs=['x0', 'x1'], outputs=['y0']),
                helper.make_node('Mul', inputs=['x0', 'x1'], outputs=['y1']),
                helper.make_node('Add', inputs=['y0', 'y1'], outputs=['y'])
            ])
        ])
        checker.check_model(m3)

        m = compose.merge_models(m1,
                                 m3,
                                 io_map=[('y', 'x0'), ('y', 'x1')],
                                 prefix1='m1/',
                                 prefix2='m3/')
        checker.check_model(m)

        nodes = [n.op_type for n in m.graph.node]
        self.assertEqual(['m1/f1', 'm3/f2'], nodes)

        functions = [f.name for f in m.functions]
        self.assertEqual(['m1/f1', 'm3/f1', 'm3/f2'], functions)

        self.assertEqual(['Add'], [n.op_type for n in m.functions[0].node])
        self.assertEqual(['Add', 'Mul', 'Add'],
                         [n.op_type for n in m.functions[1].node])
        self.assertEqual(['m3/f1', 'Mul', 'Add'],
                         [n.op_type for n in m.functions[2].node])