def _test_merge_models(self, m1def: str, m2def: str, io_map: List[Tuple[str, str]], check_expectations: Callable[ [GraphProto, GraphProto, GraphProto], None], inputs: Optional[List[str]] = None, outputs: Optional[List[str]] = None, prefix1: Optional[str] = None, prefix2: Optional[str] = None) -> None: m1, m2 = _load_model(m1def), _load_model(m2def) g3 = compose.merge_graphs( m1.graph, m2.graph, io_map=io_map, inputs=inputs, outputs=outputs, prefix1=prefix1, prefix2=prefix2, ) checker.check_graph(g3) check_expectations(m1.graph, m2.graph, g3) m3 = compose.merge_models( m1, m2, io_map=io_map, inputs=inputs, outputs=outputs, prefix1=prefix1, prefix2=prefix2, ) checker.check_model(m3) check_expectations(m1.graph, m2.graph, m3.graph)
def _test_merge_models( self, m1def, # type: Text m2def, # type: Text io_map, # type: List[Tuple[Text, Text]] check_expectations, # type: Callable[[GraphProto, GraphProto, GraphProto], None] inputs=None, # type: Optional[List[Text]] outputs=None, # type: Optional[List[Text]] prefix1=None, # type: Optional[Text] prefix2=None # type: Optional[Text] ): # type: (...) -> None m1, m2 = _load_model(m1def), _load_model(m2def) g3 = compose.merge_graphs( m1.graph, m2.graph, io_map=io_map, inputs=inputs, outputs=outputs, prefix1=prefix1, prefix2=prefix2, ) checker.check_graph(g3) check_expectations(m1.graph, m2.graph, g3) m3 = compose.merge_models( m1, m2, io_map=io_map, inputs=inputs, outputs=outputs, prefix1=prefix1, prefix2=prefix2, ) checker.check_model(m3) check_expectations(m1.graph, m2.graph, m3.graph)
def test_merge_models_with_metadata_props(self) -> None: m1 = _load_model(m1_def) helper.set_model_props(m1, {'p1': 'v1', 'p2': 'v2'}) m2 = _load_model(m2_def) helper.set_model_props(m2, {'p3': 'v3', 'p4': 'v4'}) io_map = [("B00", "B01")] m3 = compose.merge_models(m1, m2, io_map=io_map) assert len(m3.metadata_props) == 4 # Overlap, but same value helper.set_model_props(m2, {'p1': 'v1', 'p4': 'v4'}) m3 = compose.merge_models(m1, m2, io_map=io_map) assert len(m3.metadata_props) == 3 # Same keys but not same value. Error helper.set_model_props(m2, {'p1': 'v5', 'p4': 'v4'}) self.assertRaises(ValueError, compose.merge_models, m1, m2, io_map=io_map)
def test_error_opset_import_mismatch(self) -> None: ''' Tests that providing models with different operator set imported produces an error ''' m1, m2 = _load_model(m1_def), _load_model(m2_def) m1 = helper.make_model(m1.graph, producer_name='test', opset_imports=[helper.make_opsetid("", 10)]) m2 = helper.make_model(m2.graph, producer_name='test', opset_imports=[helper.make_opsetid("", 15)]) io_map = [("B00", "B01"), ("B10", "B11"), ("B20", "B21")] self.assertRaises(ValueError, compose.merge_models, m1, m2, io_map) # Converting to the same Operator set version, should work m1 = version_converter.convert_version(m1, 15) m3 = compose.merge_models(m1, m2, io_map=io_map) checker.check_model(m3)
def test_merge_drop_unnecessary_initializers_and_value_info(self) -> None: ''' Tests automatic removal of initializers when merging graphs ''' ops = [helper.make_opsetid("", 10)] g = GraphProto() g.input.extend( [helper.make_tensor_value_info('x', TensorProto.FLOAT, [])]) g.output.extend( [helper.make_tensor_value_info('y', TensorProto.FLOAT, [])]) g.node.extend( [helper.make_node('Identity', inputs=['x'], outputs=['y'])]) g1 = GraphProto() g1.CopyFrom(g) g1.name = 'g1' m1 = helper.make_model(g1, producer_name='test', opset_imports=ops) checker.check_model(m1) g2 = GraphProto() g2.CopyFrom(g) g2.name = 'g2' g2.initializer.extend([ helper.make_tensor(name='x', data_type=TensorProto.FLOAT, dims=(), vals=[0]) ]) m2 = helper.make_model(g2, producer_name='test', opset_imports=ops) checker.check_model(m2) g3 = GraphProto() g3.CopyFrom(g) g3.name = 'g3' g3.sparse_initializer.extend([_make_sparse_tensor('x')]) m3 = helper.make_model(g3, producer_name='test', opset_imports=ops) checker.check_model(m3) g4 = GraphProto() g4.CopyFrom(g) g4.name = 'g3' g4.value_info.extend( [helper.make_tensor_value_info('x', TensorProto.FLOAT, [])]) m4 = helper.make_model(g4, producer_name='test', opset_imports=ops) checker.check_model(m4) # Initializer 'x' from m1 is removed, because there is no longer an input with that name out_m1 = compose.merge_models(m1, m2, prefix1='m1/', io_map=[('y', 'x')]) self.assertEqual(0, len(out_m1.graph.initializer)) # Sparse initializer 'x' from m1 is removed, because there is no longer an input with that name out_m2 = compose.merge_models(m1, m3, prefix1='m1/', io_map=[('y', 'x')]) self.assertEqual(0, len(out_m2.graph.initializer)) # Value info 'x' from m1 is removed, because there is no longer an input with that name out_m3 = compose.merge_models(m1, m4, prefix1='m1/', io_map=[('y', 'x')]) self.assertEqual(0, len(out_m3.graph.value_info))
def test_overlapping_function_names(self) -> None: ''' Tests error checking when the name of local function entries overlaps ''' ops = [helper.make_opsetid("", 10), helper.make_opsetid("local", 10)] def _make_function( domain: str, fname: str, inputs: List[str], outputs: List[str], nodes: List[NodeProto], ) -> FunctionProto: f = FunctionProto() f.domain = domain f.name = fname f.input.extend(inputs) f.output.extend(outputs) f.node.extend(nodes) f.opset_import.extend(ops) return f ops = [helper.make_opsetid("", 10), helper.make_opsetid("local", 10)] g = GraphProto() g.input.extend([ helper.make_tensor_value_info('x0', TensorProto.FLOAT, []), helper.make_tensor_value_info('x1', TensorProto.FLOAT, []) ]) g.output.extend([ helper.make_tensor_value_info('y', TensorProto.FLOAT, []), ]) g.node.extend([ helper.make_node('f1', domain='local', inputs=['x0', 'x1'], outputs=['y']) ]) g1 = GraphProto() g1.CopyFrom(g) g1.name = 'g1' m1 = helper.make_model(g1, producer_name='test', opset_imports=ops) m1.functions.extend([ _make_function( 'local', 'f1', ['x0', 'x1'], ['y'], [helper.make_node('Add', inputs=['x0', 'x1'], outputs=['y'])]) ]) checker.check_model(m1) g2 = GraphProto() g2.CopyFrom(g) g2.name = 'g2' m2 = helper.make_model(g2, producer_name='test', opset_imports=ops) m2.functions.extend([ _make_function( 'local', 'f1', ['x0', 'x1'], ['y'], [helper.make_node('Mul', inputs=['x0', 'x1'], outputs=['y'])]) ]) checker.check_model(m2) m = compose.merge_models(m1, m2, io_map=[('y', 'x0'), ('y', 'x1')], prefix1='m1/', prefix2='m2/') checker.check_model(m) nodes = [n.op_type for n in m.graph.node] self.assertEqual(['m1/f1', 'm2/f1'], nodes) functions = [f.name for f in m.functions] self.assertEqual(['m1/f1', 'm2/f1'], functions) g3 = GraphProto() g3.CopyFrom(g) g3.name = 'g3' g3.node[0].op_type = 'f2' m3 = helper.make_model(g3, producer_name='test', opset_imports=ops) m3.functions.extend([ _make_function('local', 'f1', ['x0', 'x1'], ['y'], [ helper.make_node('Add', inputs=['x0', 'x1'], outputs=['y0']), helper.make_node('Mul', inputs=['x0', 'x1'], outputs=['y1']), helper.make_node('Add', inputs=['y0', 'y1'], outputs=['y']) ]), _make_function('local', 'f2', ['x0', 'x1'], ['y'], [ helper.make_node( 'f1', domain='local', inputs=['x0', 'x1'], outputs=['y0']), helper.make_node('Mul', inputs=['x0', 'x1'], outputs=['y1']), helper.make_node('Add', inputs=['y0', 'y1'], outputs=['y']) ]) ]) checker.check_model(m3) m = compose.merge_models(m1, m3, io_map=[('y', 'x0'), ('y', 'x1')], prefix1='m1/', prefix2='m3/') checker.check_model(m) nodes = [n.op_type for n in m.graph.node] self.assertEqual(['m1/f1', 'm3/f2'], nodes) functions = [f.name for f in m.functions] self.assertEqual(['m1/f1', 'm3/f1', 'm3/f2'], functions) self.assertEqual(['Add'], [n.op_type for n in m.functions[0].node]) self.assertEqual(['Add', 'Mul', 'Add'], [n.op_type for n in m.functions[1].node]) self.assertEqual(['m3/f1', 'Mul', 'Add'], [n.op_type for n in m.functions[2].node])