def test_extract_model_with_local_function(self) -> None: r''' # 1. build a model with graph below. extract models with output combinations # 2. validate extracted models' local functions # # model graph: # i0 i1 i2 # | __________________|__________________/_________ # | | | | / | # | | | | / | # func_add func_identity add identity # | ___\___________\____________________|_________ | # | | \ \ | _______|___| # | | \ \ | | | | # add function_nested_identity_add add function_nested_identity_add # | | | | # | | | | # o_func_add o_all_func0 o_no_func o_all_func1 # # where function_nested_identity_add is a function that is defined with functions: # a b # | | # func_identity func_identity # \ / # func_add # | # c # ''' # function common func_domain = 'local' func_opset_imports = [onnx.helper.make_opsetid("", 14)] func_nested_opset_imports = [ onnx.helper.make_opsetid("", 14), onnx.helper.make_opsetid(func_domain, 1)] # add function func_add_name = 'func_add' func_add_inputs = ['a', 'b'] func_add_outputs = ['c'] func_add_nodes = [onnx.helper.make_node('Add', ['a', 'b'], ['c'])] func_add = onnx.helper.make_function( func_domain, func_add_name, func_add_inputs, func_add_outputs, func_add_nodes, func_opset_imports) # identity function func_identity_name = 'func_identity' func_identity_inputs = ['a'] func_identity_outputs = ['b'] func_identity_nodes = [onnx.helper.make_node('Identity', ['a'], ['b'])] func_identity = onnx.helper.make_function( func_domain, func_identity_name, func_identity_inputs, func_identity_outputs, func_identity_nodes, func_opset_imports) # nested identity/add function func_nested_identity_add_name = 'func_nested_identity_add' func_nested_identity_add_inputs = ['a', 'b'] func_nested_identity_add_outputs = ['c'] func_nested_identity_add_nodes = [ onnx.helper.make_node('func_identity', ['a'], ['a1'], domain=func_domain), onnx.helper.make_node('func_identity', ['b'], ['b1'], domain=func_domain), onnx.helper.make_node('func_add', ['a1', 'b1'], ['c'], domain=func_domain)] func_nested_identity_add = onnx.helper.make_function( func_domain, func_nested_identity_add_name, func_nested_identity_add_inputs, func_nested_identity_add_outputs, func_nested_identity_add_nodes, func_nested_opset_imports) # create graph nodes node_func_add = onnx.helper.make_node(func_add_name, ['i0', 'i1'], ['t0'], domain=func_domain) node_add0 = onnx.helper.make_node('Add', ['i1', 'i2'], ['t2']) node_add1 = onnx.helper.make_node('Add', ['t0', 't2'], ['o_func_add']) node_func_identity = onnx.helper.make_node(func_identity_name, ['i1'], ['t1'], domain=func_domain) node_identity = onnx.helper.make_node('Identity', ['i1'], ['t3']) node_add2 = onnx.helper.make_node('Add', ['t3', 't2'], ['o_no_func']) node_func_nested0 = onnx.helper.make_node( func_nested_identity_add_name, ['t0', 't1'], ['o_all_func0'], domain=func_domain) node_func_nested1 = onnx.helper.make_node( func_nested_identity_add_name, ['t3', 't2'], ['o_all_func1'], domain=func_domain) graph_name = 'graph_with_imbedded_functions' ir_version = 8 opset_imports = [onnx.helper.make_opsetid("", 14), onnx.helper.make_opsetid("local", 1)] tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=2, shape=[5]) graph = onnx.helper.make_graph( [node_func_add, node_add0, node_add1, node_func_identity, node_identity, node_func_nested0, node_func_nested1, node_add2], graph_name, [ onnx.helper.make_value_info(name='i0', type_proto=tensor_type_proto), onnx.helper.make_value_info(name='i1', type_proto=tensor_type_proto), onnx.helper.make_value_info(name='i2', type_proto=tensor_type_proto)], [ onnx.helper.make_value_info(name='o_no_func', type_proto=tensor_type_proto), onnx.helper.make_value_info(name='o_func_add', type_proto=tensor_type_proto), onnx.helper.make_value_info(name='o_all_func0', type_proto=tensor_type_proto), onnx.helper.make_value_info(name='o_all_func1', type_proto=tensor_type_proto)]) meta = { 'ir_version': ir_version, 'opset_imports': opset_imports, 'producer_name': 'test_extract_model_with_local_function', 'functions': [func_identity, func_add, func_nested_identity_add], } model = onnx.helper.make_model(graph, **meta) checker.check_model(model) extracted_with_no_funcion = utils.Extractor(model).extract_model(['i0', 'i1', 'i2'], ['o_no_func']) self._verify_function_set(extracted_with_no_funcion, {}, func_domain) extracted_with_add_funcion = utils.Extractor(model).extract_model(['i0', 'i1', 'i2'], ['o_func_add']) self._verify_function_set(extracted_with_add_funcion, {func_add_name}, func_domain) extracted_with_o_all_funcion0 = utils.Extractor(model).extract_model(['i0', 'i1', 'i2'], ['o_all_func0']) self._verify_function_set( extracted_with_o_all_funcion0, {func_add_name, func_identity_name, func_nested_identity_add_name}, func_domain) extracted_with_o_all_funcion1 = utils.Extractor(model).extract_model(['i0', 'i1', 'i2'], ['o_all_func1']) self._verify_function_set( extracted_with_o_all_funcion1, {func_add_name, func_identity_name, func_nested_identity_add_name}, func_domain) extracted_with_o_all_funcion2 = utils.Extractor(model).extract_model( ['i0', 'i1', 'i2'], ['o_no_func', 'o_func_add', 'o_all_func0', 'o_all_func1']) self._verify_function_set( extracted_with_o_all_funcion2, {func_add_name, func_identity_name, func_nested_identity_add_name}, func_domain)
def merge_graphs( g1: GraphProto, g2: GraphProto, io_map: List[Tuple[Text, Text]], inputs: Optional[List[Text]] = None, outputs: Optional[List[Text]] = None, prefix1: Optional[Text] = None, prefix2: Optional[Text] = None, name: Optional[Text] = None, doc_string: Optional[Text] = None, ) -> GraphProto: """Combines two ONNX graphs into a single one. The combined graph is defined by connecting the specified set of outputs/inputs. Those inputs/outputs not specified in the io_map argument will remain as inputs/outputs of the combined graph. Arguments: g1 (GraphProto): First graph g2 (GraphProto): Second graph io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...] representing outputs of the first graph and inputs of the second to be connected inputs (list of string): Optional list of inputs to be included in the combined graph By default, all inputs not present in the ``io_map`` argument will be included in the combined model outputs (list of string): Optional list of outputs to be included in the combined graph By default, all outputs not present in the ``io_map`` argument will be included in the combined model prefix1 (string): Optional prefix to be added to all names in g1 prefix2 (string): Optional prefix to be added to all names in g2 name (string): Optional name for the combined graph By default, the name is g1.name and g2.name concatenated with an undescore delimiter doc_string (string): Optional docstring for the combined graph If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used """ if type(g1) is not GraphProto: raise ValueError("g1 argument is not an ONNX graph") if type(g2) is not GraphProto: raise ValueError("g2 argument is not an ONNX graph") # Prefixing names in the graph if requested, adjusting io_map accordingly if prefix1 or prefix2: if prefix1: g1_copy = GraphProto() g1_copy.CopyFrom(g1) g1 = g1_copy g1 = add_prefix_graph(g1, prefix=prefix1) if prefix2: g2_copy = GraphProto() g2_copy.CopyFrom(g2) g2 = g2_copy g2 = add_prefix_graph(g2, prefix=prefix2) io_map = [(prefix1 + io[0] if prefix1 else io[0], prefix2 + io[1] if prefix2 else io[1]) for io in io_map] io_map_g1_outs = set([io[0] for io in io_map]) io_map_g2_ins = set([io[1] for io in io_map]) reversed_io_map = {in_name: out_name for out_name, in_name in io_map} g1_outs = set([o.name for o in g1.output]) g2_ins = set([i.name for i in g2.input]) # If necessary extract subgraphs if inputs or outputs: if not inputs: g1_inputs = [i.name for i in g1.input] g2_inputs = [i.name for i in g2.input] else: input_set = set(inputs) g1_inputs = [i.name for i in g1.input if i.name in input_set] g2_inputs = [ i.name for i in g2.input if i.name in input_set or i.name in io_map_g2_ins ] if not outputs: g1_outputs = [o.name for o in g1.input] g2_outputs = [o.name for o in g2.input] else: output_set = set(outputs) g1_outputs = [ o.name for o in g1.output if o.name in output_set or o.name in io_map_g1_outs ] g2_outputs = [o.name for o in g2.output if o.name in output_set] if len(g1_inputs) < len(g1.input) or len(g1_outputs) < len(g1.output): e1 = utils.Extractor(helper.make_model(g1)) g1 = e1.extract_model(g1_inputs, g1_outputs).graph if len(g2_inputs) < len(g2.input) or len(g2_outputs) < len(g2.output): e2 = utils.Extractor(helper.make_model(g2)) g2 = e2.extract_model(g2_inputs, g2_outputs).graph # Check that input/output names specified in the io_map argument are valid input/output names for g1_out_name, g2_in_name in io_map: if g1_out_name not in g1_outs: raise ValueError(f"Output {g1_out_name} is not present in g1") if g2_in_name not in g2_ins: raise ValueError(f"Input {g2_in_name} is not present in g2") # Check for name collision overlapping_names = check_overlapping_names(g1, g2, io_map) if len(overlapping_names) > 0: category, names = overlapping_names[0] raise ValueError( "Cant merge two graphs with overlapping names. " f"Found repeated {category} names: " + ", ".join(names) + "\n" + "Consider using ``onnx.compose.add_prefix`` to add a prefix to names in one of the graphs." ) g = GraphProto() g.node.extend(g1.node) g2_nodes_begin = len(g.node) g.node.extend(g2.node) g2_nodes_end = len(g.node) # Connecting outputs of the first graph with the inputs of the second for node_idx in range(g2_nodes_begin, g2_nodes_end): node = g.node[node_idx] for index, name in enumerate(node.input): if name in reversed_io_map: node.input[index] = reversed_io_map[name] if inputs: input_set = set(inputs) g.input.extend([i for i in g1.input if i.name in input_set]) g.input.extend([i for i in g2.input if i.name in input_set]) else: g.input.extend(g1.input) g.input.extend([i for i in g2.input if i.name not in io_map_g2_ins]) if outputs: output_set = set(outputs) g.output.extend([o for o in g1.output if o.name in output_set]) g.output.extend([o for o in g2.output if o.name in output_set]) else: g.output.extend([o for o in g1.output if o.name not in io_map_g1_outs]) g.output.extend(g2.output) g.initializer.extend(g1.initializer) g.initializer.extend( [init for init in g2.initializer if init.name not in io_map_g2_ins]) g.sparse_initializer.extend(g1.sparse_initializer) g.sparse_initializer.extend([ init for init in g2.sparse_initializer if init.values.name not in io_map_g2_ins ]) g.value_info.extend(g1.value_info) g.value_info.extend( [vi for vi in g2.value_info if vi.name not in io_map_g2_ins]) g.name = name if name is not None else "_".join([g1.name, g2.name]) if doc_string is None: doc_string = f"Graph combining {g1.name} and {g2.name}\n" + \ g1.name + "\n\n" + g1.doc_string + "\n\n" + g2.name + "\n\n" + g2.doc_string g.doc_string = doc_string return g