Exemple #1
0
    def test_extract_model_with_local_function(self) -> None:
        r'''
        #   1. build a model with graph below. extract models with output combinations
        #   2. validate extracted models' local functions
        #
        # model graph:
        #      i0                    i1                 i2
        #      |   __________________|__________________/_________
        #      |  |                  |             |   /          |
        #      |  |                  |             |  /           |
        #   func_add        func_identity          add         identity
        #    |  ___\___________\____________________|_________    |
        #    | |    \           \                   |  _______|___|
        #    | |     \           \                  | |       |   |
        #    add     function_nested_identity_add   add     function_nested_identity_add
        #     |                 |                    |              |
        #     |                 |                    |              |
        #   o_func_add      o_all_func0           o_no_func     o_all_func1
        #
        # where function_nested_identity_add is a function that is defined with functions:
        #       a               b
        #       |               |
        #   func_identity   func_identity
        #             \       /
        #             func_add
        #                |
        #                c
        #
        '''

        # function common
        func_domain = 'local'
        func_opset_imports = [onnx.helper.make_opsetid("", 14)]
        func_nested_opset_imports = [
            onnx.helper.make_opsetid("", 14), onnx.helper.make_opsetid(func_domain, 1)]

        # add function
        func_add_name = 'func_add'
        func_add_inputs = ['a', 'b']
        func_add_outputs = ['c']
        func_add_nodes = [onnx.helper.make_node('Add', ['a', 'b'], ['c'])]
        func_add = onnx.helper.make_function(
            func_domain,
            func_add_name,
            func_add_inputs,
            func_add_outputs,
            func_add_nodes,
            func_opset_imports)

        # identity function
        func_identity_name = 'func_identity'
        func_identity_inputs = ['a']
        func_identity_outputs = ['b']
        func_identity_nodes = [onnx.helper.make_node('Identity', ['a'], ['b'])]
        func_identity = onnx.helper.make_function(
            func_domain,
            func_identity_name,
            func_identity_inputs,
            func_identity_outputs,
            func_identity_nodes,
            func_opset_imports)

        # nested identity/add function
        func_nested_identity_add_name = 'func_nested_identity_add'
        func_nested_identity_add_inputs = ['a', 'b']
        func_nested_identity_add_outputs = ['c']
        func_nested_identity_add_nodes = [
            onnx.helper.make_node('func_identity', ['a'], ['a1'], domain=func_domain),
            onnx.helper.make_node('func_identity', ['b'], ['b1'], domain=func_domain),
            onnx.helper.make_node('func_add', ['a1', 'b1'], ['c'], domain=func_domain)]
        func_nested_identity_add = onnx.helper.make_function(
            func_domain,
            func_nested_identity_add_name,
            func_nested_identity_add_inputs,
            func_nested_identity_add_outputs,
            func_nested_identity_add_nodes,
            func_nested_opset_imports)

        # create graph nodes
        node_func_add = onnx.helper.make_node(func_add_name, ['i0', 'i1'], ['t0'], domain=func_domain)
        node_add0 = onnx.helper.make_node('Add', ['i1', 'i2'], ['t2'])
        node_add1 = onnx.helper.make_node('Add', ['t0', 't2'], ['o_func_add'])
        node_func_identity = onnx.helper.make_node(func_identity_name, ['i1'], ['t1'], domain=func_domain)
        node_identity = onnx.helper.make_node('Identity', ['i1'], ['t3'])
        node_add2 = onnx.helper.make_node('Add', ['t3', 't2'], ['o_no_func'])
        node_func_nested0 = onnx.helper.make_node(
            func_nested_identity_add_name,
            ['t0', 't1'],
            ['o_all_func0'],
            domain=func_domain)
        node_func_nested1 = onnx.helper.make_node(
            func_nested_identity_add_name,
            ['t3', 't2'],
            ['o_all_func1'],
            domain=func_domain)

        graph_name = 'graph_with_imbedded_functions'
        ir_version = 8
        opset_imports = [onnx.helper.make_opsetid("", 14), onnx.helper.make_opsetid("local", 1)]
        tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=2, shape=[5])

        graph = onnx.helper.make_graph(
            [node_func_add, node_add0, node_add1, node_func_identity, node_identity,
            node_func_nested0, node_func_nested1, node_add2],
            graph_name,
            [
                onnx.helper.make_value_info(name='i0', type_proto=tensor_type_proto),
                onnx.helper.make_value_info(name='i1', type_proto=tensor_type_proto),
                onnx.helper.make_value_info(name='i2', type_proto=tensor_type_proto)],
            [
                onnx.helper.make_value_info(name='o_no_func', type_proto=tensor_type_proto),
                onnx.helper.make_value_info(name='o_func_add', type_proto=tensor_type_proto),
                onnx.helper.make_value_info(name='o_all_func0', type_proto=tensor_type_proto),
                onnx.helper.make_value_info(name='o_all_func1', type_proto=tensor_type_proto)])

        meta = {
            'ir_version': ir_version,
            'opset_imports': opset_imports,
            'producer_name': 'test_extract_model_with_local_function',
            'functions': [func_identity, func_add, func_nested_identity_add],
        }
        model = onnx.helper.make_model(graph, **meta)

        checker.check_model(model)
        extracted_with_no_funcion = utils.Extractor(model).extract_model(['i0', 'i1', 'i2'], ['o_no_func'])
        self._verify_function_set(extracted_with_no_funcion, {}, func_domain)

        extracted_with_add_funcion = utils.Extractor(model).extract_model(['i0', 'i1', 'i2'], ['o_func_add'])
        self._verify_function_set(extracted_with_add_funcion, {func_add_name}, func_domain)

        extracted_with_o_all_funcion0 = utils.Extractor(model).extract_model(['i0', 'i1', 'i2'], ['o_all_func0'])
        self._verify_function_set(
            extracted_with_o_all_funcion0,
            {func_add_name, func_identity_name, func_nested_identity_add_name},
            func_domain)

        extracted_with_o_all_funcion1 = utils.Extractor(model).extract_model(['i0', 'i1', 'i2'], ['o_all_func1'])
        self._verify_function_set(
            extracted_with_o_all_funcion1,
            {func_add_name, func_identity_name, func_nested_identity_add_name},
            func_domain)

        extracted_with_o_all_funcion2 = utils.Extractor(model).extract_model(
            ['i0', 'i1', 'i2'], ['o_no_func', 'o_func_add', 'o_all_func0', 'o_all_func1'])
        self._verify_function_set(
            extracted_with_o_all_funcion2,
            {func_add_name, func_identity_name, func_nested_identity_add_name},
            func_domain)
Exemple #2
0
def merge_graphs(
    g1: GraphProto,
    g2: GraphProto,
    io_map: List[Tuple[Text, Text]],
    inputs: Optional[List[Text]] = None,
    outputs: Optional[List[Text]] = None,
    prefix1: Optional[Text] = None,
    prefix2: Optional[Text] = None,
    name: Optional[Text] = None,
    doc_string: Optional[Text] = None,
) -> GraphProto:
    """Combines two ONNX graphs into a single one.

    The combined graph is defined by connecting the specified set of outputs/inputs. Those inputs/outputs
    not specified in the io_map argument will remain as inputs/outputs of the combined graph.

    Arguments:
        g1 (GraphProto): First graph
        g2 (GraphProto): Second graph
        io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...]
                                          representing outputs of the first graph and inputs of the second
                                          to be connected
        inputs (list of string): Optional list of inputs to be included in the combined graph
                                 By default, all inputs not present in the ``io_map`` argument will be
                                 included in the combined model
        outputs (list of string): Optional list of outputs to be included in the combined graph
                                  By default, all outputs not present in the ``io_map`` argument will be
                                  included in the combined model
        prefix1 (string): Optional prefix to be added to all names in g1
        prefix2 (string): Optional prefix to be added to all names in g2
        name (string): Optional name for the combined graph
                       By default, the name is g1.name and g2.name concatenated with an undescore delimiter
        doc_string (string): Optional docstring for the combined graph
                             If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used
    """
    if type(g1) is not GraphProto:
        raise ValueError("g1 argument is not an ONNX graph")
    if type(g2) is not GraphProto:
        raise ValueError("g2 argument is not an ONNX graph")

    # Prefixing names in the graph if requested, adjusting io_map accordingly
    if prefix1 or prefix2:
        if prefix1:
            g1_copy = GraphProto()
            g1_copy.CopyFrom(g1)
            g1 = g1_copy
            g1 = add_prefix_graph(g1, prefix=prefix1)
        if prefix2:
            g2_copy = GraphProto()
            g2_copy.CopyFrom(g2)
            g2 = g2_copy
            g2 = add_prefix_graph(g2, prefix=prefix2)
        io_map = [(prefix1 + io[0] if prefix1 else io[0],
                   prefix2 + io[1] if prefix2 else io[1]) for io in io_map]

    io_map_g1_outs = set([io[0] for io in io_map])
    io_map_g2_ins = set([io[1] for io in io_map])
    reversed_io_map = {in_name: out_name for out_name, in_name in io_map}
    g1_outs = set([o.name for o in g1.output])
    g2_ins = set([i.name for i in g2.input])

    # If necessary extract subgraphs
    if inputs or outputs:
        if not inputs:
            g1_inputs = [i.name for i in g1.input]
            g2_inputs = [i.name for i in g2.input]
        else:
            input_set = set(inputs)
            g1_inputs = [i.name for i in g1.input if i.name in input_set]
            g2_inputs = [
                i.name for i in g2.input
                if i.name in input_set or i.name in io_map_g2_ins
            ]

        if not outputs:
            g1_outputs = [o.name for o in g1.input]
            g2_outputs = [o.name for o in g2.input]
        else:
            output_set = set(outputs)
            g1_outputs = [
                o.name for o in g1.output
                if o.name in output_set or o.name in io_map_g1_outs
            ]
            g2_outputs = [o.name for o in g2.output if o.name in output_set]

        if len(g1_inputs) < len(g1.input) or len(g1_outputs) < len(g1.output):
            e1 = utils.Extractor(helper.make_model(g1))
            g1 = e1.extract_model(g1_inputs, g1_outputs).graph

        if len(g2_inputs) < len(g2.input) or len(g2_outputs) < len(g2.output):
            e2 = utils.Extractor(helper.make_model(g2))
            g2 = e2.extract_model(g2_inputs, g2_outputs).graph

    # Check that input/output names specified in the io_map argument are valid input/output names
    for g1_out_name, g2_in_name in io_map:
        if g1_out_name not in g1_outs:
            raise ValueError(f"Output {g1_out_name} is not present in g1")
        if g2_in_name not in g2_ins:
            raise ValueError(f"Input {g2_in_name} is not present in g2")

    # Check for name collision
    overlapping_names = check_overlapping_names(g1, g2, io_map)
    if len(overlapping_names) > 0:
        category, names = overlapping_names[0]
        raise ValueError(
            "Cant merge two graphs with overlapping names. "
            f"Found repeated {category} names: " + ", ".join(names) + "\n" +
            "Consider using ``onnx.compose.add_prefix`` to add a prefix to names in one of the graphs."
        )

    g = GraphProto()

    g.node.extend(g1.node)
    g2_nodes_begin = len(g.node)
    g.node.extend(g2.node)
    g2_nodes_end = len(g.node)

    # Connecting outputs of the first graph with the inputs of the second
    for node_idx in range(g2_nodes_begin, g2_nodes_end):
        node = g.node[node_idx]
        for index, name in enumerate(node.input):
            if name in reversed_io_map:
                node.input[index] = reversed_io_map[name]

    if inputs:
        input_set = set(inputs)
        g.input.extend([i for i in g1.input if i.name in input_set])
        g.input.extend([i for i in g2.input if i.name in input_set])
    else:
        g.input.extend(g1.input)
        g.input.extend([i for i in g2.input if i.name not in io_map_g2_ins])

    if outputs:
        output_set = set(outputs)
        g.output.extend([o for o in g1.output if o.name in output_set])
        g.output.extend([o for o in g2.output if o.name in output_set])
    else:
        g.output.extend([o for o in g1.output if o.name not in io_map_g1_outs])
        g.output.extend(g2.output)

    g.initializer.extend(g1.initializer)
    g.initializer.extend(
        [init for init in g2.initializer if init.name not in io_map_g2_ins])

    g.sparse_initializer.extend(g1.sparse_initializer)
    g.sparse_initializer.extend([
        init for init in g2.sparse_initializer
        if init.values.name not in io_map_g2_ins
    ])

    g.value_info.extend(g1.value_info)
    g.value_info.extend(
        [vi for vi in g2.value_info if vi.name not in io_map_g2_ins])

    g.name = name if name is not None else "_".join([g1.name, g2.name])

    if doc_string is None:
        doc_string = f"Graph combining {g1.name} and {g2.name}\n" + \
            g1.name + "\n\n" + g1.doc_string + "\n\n" + g2.name + "\n\n" + g2.doc_string
    g.doc_string = doc_string

    return g