Beispiel #1
0
def make_graph(
    nodes,  # type: Sequence[NodeProto]
    name,  # type: Text
    inputs,  # type: Sequence[ValueInfoProto]
    outputs,  # type: Sequence[ValueInfoProto]
    initializer=None,  # type: Optional[Sequence[TensorProto]]
    doc_string=None,  # type: Optional[Text]
    value_info=[],  # type: Sequence[ValueInfoProto]
    sparse_initializer=None,  # type: Optional[Sequence[SparseTensorProto]]
):  # type: (...) -> GraphProto
    if initializer is None:
        initializer = []
    if sparse_initializer is None:
        sparse_initializer = []
    if value_info is None:
        value_info = []
    graph = GraphProto()
    graph.node.extend(nodes)
    graph.name = name
    graph.input.extend(inputs)
    graph.output.extend(outputs)
    graph.initializer.extend(initializer)
    graph.sparse_initializer.extend(sparse_initializer)
    graph.value_info.extend(value_info)
    if doc_string:
        graph.doc_string = doc_string
    return graph
Beispiel #2
0
def make_graph(
    nodes: Sequence[NodeProto],
    name: Text,
    inputs: Sequence[ValueInfoProto],
    outputs: Sequence[ValueInfoProto],
    initializer: Optional[Sequence[TensorProto]] = None,
    doc_string: Optional[Text] = None,
    value_info: Sequence[ValueInfoProto] = [],
    sparse_initializer: Optional[Sequence[SparseTensorProto]] = None,
) -> GraphProto:
    if initializer is None:
        initializer = []
    if sparse_initializer is None:
        sparse_initializer = []
    if value_info is None:
        value_info = []
    graph = GraphProto()
    graph.node.extend(nodes)
    graph.name = name
    graph.input.extend(inputs)
    graph.output.extend(outputs)
    graph.initializer.extend(initializer)
    graph.sparse_initializer.extend(sparse_initializer)
    graph.value_info.extend(value_info)
    if doc_string:
        graph.doc_string = doc_string
    return graph
Beispiel #3
0
def make_graph(nodes, name, inputs, outputs, initializer=None, doc_string=None):
    if initializer is None:
        initializer = []
    graph = GraphProto()
    graph.node.extend(nodes)
    graph.name = name
    graph.input.extend(inputs)
    graph.output.extend(outputs)
    graph.initializer.extend(initializer)
    if doc_string:
        graph.doc_string = doc_string
    return graph
Beispiel #4
0
def make_graph(
    nodes: Sequence[NodeProto],
    name: Text,
    inputs: Sequence[ValueInfoProto],
    outputs: Sequence[ValueInfoProto],
    initializer: Optional[Sequence[TensorProto]] = None,
    doc_string: Optional[Text] = None,
    value_info: Sequence[ValueInfoProto] = [],
    sparse_initializer: Optional[Sequence[SparseTensorProto]] = None,
) -> GraphProto:
    """Construct a GraphProto

    Arguments:
        nodes: list of NodeProto
        name (string): graph name
        inputs: list of ValueInfoProto
        outputs: list of ValueInfoProto
        initializer: list of TensorProto
        doc_string (string): graph documentation
        value_info: list of ValueInfoProto
        sparse_initializer: list of SparseTensorProto
    Returns:
        GraphProto
    """
    if initializer is None:
        initializer = []
    if sparse_initializer is None:
        sparse_initializer = []
    if value_info is None:
        value_info = []
    graph = GraphProto()
    graph.node.extend(nodes)
    graph.name = name
    graph.input.extend(inputs)
    graph.output.extend(outputs)
    graph.initializer.extend(initializer)
    graph.sparse_initializer.extend(sparse_initializer)
    graph.value_info.extend(value_info)
    if doc_string:
        graph.doc_string = doc_string
    return graph
Beispiel #5
0
def make_graph(
    nodes,  # type: Sequence[NodeProto]
    name,  # type: Text
    inputs,  # type: Sequence[ValueInfoProto]
    outputs,  # type: Sequence[ValueInfoProto]
    initializer=None,  # type: Optional[Sequence[TensorProto]]
    doc_string=None,  # type: Optional[Text]
    value_info=[],  # type: Sequence[ValueInfoProto]
):  # type: (...) -> GraphProto
    if initializer is None:
        initializer = []
    if value_info is None:
        value_info = []
    graph = GraphProto()
    graph.node.extend(nodes)
    graph.name = name
    graph.input.extend(inputs)
    graph.output.extend(outputs)
    graph.initializer.extend(initializer)
    graph.value_info.extend(value_info)
    if doc_string:
        graph.doc_string = doc_string
    return graph
Beispiel #6
0
def merge_graphs(
    g1: GraphProto,
    g2: GraphProto,
    io_map: List[Tuple[Text, Text]],
    inputs: Optional[List[Text]] = None,
    outputs: Optional[List[Text]] = None,
    prefix1: Optional[Text] = None,
    prefix2: Optional[Text] = None,
    name: Optional[Text] = None,
    doc_string: Optional[Text] = None,
) -> GraphProto:
    """Combines two ONNX graphs into a single one.

    The combined graph is defined by connecting the specified set of outputs/inputs. Those inputs/outputs
    not specified in the io_map argument will remain as inputs/outputs of the combined graph.

    Arguments:
        g1 (GraphProto): First graph
        g2 (GraphProto): Second graph
        io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...]
                                          representing outputs of the first graph and inputs of the second
                                          to be connected
        inputs (list of string): Optional list of inputs to be included in the combined graph
                                 By default, all inputs not present in the ``io_map`` argument will be
                                 included in the combined model
        outputs (list of string): Optional list of outputs to be included in the combined graph
                                  By default, all outputs not present in the ``io_map`` argument will be
                                  included in the combined model
        prefix1 (string): Optional prefix to be added to all names in g1
        prefix2 (string): Optional prefix to be added to all names in g2
        name (string): Optional name for the combined graph
                       By default, the name is g1.name and g2.name concatenated with an undescore delimiter
        doc_string (string): Optional docstring for the combined graph
                             If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used
    """
    if type(g1) is not GraphProto:
        raise ValueError("g1 argument is not an ONNX graph")
    if type(g2) is not GraphProto:
        raise ValueError("g2 argument is not an ONNX graph")

    # Prefixing names in the graph if requested, adjusting io_map accordingly
    if prefix1 or prefix2:
        if prefix1:
            g1_copy = GraphProto()
            g1_copy.CopyFrom(g1)
            g1 = g1_copy
            g1 = add_prefix_graph(g1, prefix=prefix1)
        if prefix2:
            g2_copy = GraphProto()
            g2_copy.CopyFrom(g2)
            g2 = g2_copy
            g2 = add_prefix_graph(g2, prefix=prefix2)
        io_map = [(prefix1 + io[0] if prefix1 else io[0],
                   prefix2 + io[1] if prefix2 else io[1]) for io in io_map]

    io_map_g1_outs = set([io[0] for io in io_map])
    io_map_g2_ins = set([io[1] for io in io_map])
    reversed_io_map = {in_name: out_name for out_name, in_name in io_map}
    g1_outs = set([o.name for o in g1.output])
    g2_ins = set([i.name for i in g2.input])

    # If necessary extract subgraphs
    if inputs or outputs:
        if not inputs:
            g1_inputs = [i.name for i in g1.input]
            g2_inputs = [i.name for i in g2.input]
        else:
            input_set = set(inputs)
            g1_inputs = [i.name for i in g1.input if i.name in input_set]
            g2_inputs = [
                i.name for i in g2.input
                if i.name in input_set or i.name in io_map_g2_ins
            ]

        if not outputs:
            g1_outputs = [o.name for o in g1.input]
            g2_outputs = [o.name for o in g2.input]
        else:
            output_set = set(outputs)
            g1_outputs = [
                o.name for o in g1.output
                if o.name in output_set or o.name in io_map_g1_outs
            ]
            g2_outputs = [o.name for o in g2.output if o.name in output_set]

        if len(g1_inputs) < len(g1.input) or len(g1_outputs) < len(g1.output):
            e1 = utils.Extractor(helper.make_model(g1))
            g1 = e1.extract_model(g1_inputs, g1_outputs).graph

        if len(g2_inputs) < len(g2.input) or len(g2_outputs) < len(g2.output):
            e2 = utils.Extractor(helper.make_model(g2))
            g2 = e2.extract_model(g2_inputs, g2_outputs).graph

    # Check that input/output names specified in the io_map argument are valid input/output names
    for g1_out_name, g2_in_name in io_map:
        if g1_out_name not in g1_outs:
            raise ValueError(f"Output {g1_out_name} is not present in g1")
        if g2_in_name not in g2_ins:
            raise ValueError(f"Input {g2_in_name} is not present in g2")

    # Check for name collision
    overlapping_names = check_overlapping_names(g1, g2, io_map)
    if len(overlapping_names) > 0:
        category, names = overlapping_names[0]
        raise ValueError(
            "Cant merge two graphs with overlapping names. "
            f"Found repeated {category} names: " + ", ".join(names) + "\n" +
            "Consider using ``onnx.compose.add_prefix`` to add a prefix to names in one of the graphs."
        )

    g = GraphProto()

    g.node.extend(g1.node)
    g2_nodes_begin = len(g.node)
    g.node.extend(g2.node)
    g2_nodes_end = len(g.node)

    # Connecting outputs of the first graph with the inputs of the second
    for node_idx in range(g2_nodes_begin, g2_nodes_end):
        node = g.node[node_idx]
        for index, name in enumerate(node.input):
            if name in reversed_io_map:
                node.input[index] = reversed_io_map[name]

    if inputs:
        input_set = set(inputs)
        g.input.extend([i for i in g1.input if i.name in input_set])
        g.input.extend([i for i in g2.input if i.name in input_set])
    else:
        g.input.extend(g1.input)
        g.input.extend([i for i in g2.input if i.name not in io_map_g2_ins])

    if outputs:
        output_set = set(outputs)
        g.output.extend([o for o in g1.output if o.name in output_set])
        g.output.extend([o for o in g2.output if o.name in output_set])
    else:
        g.output.extend([o for o in g1.output if o.name not in io_map_g1_outs])
        g.output.extend(g2.output)

    g.initializer.extend(g1.initializer)
    g.initializer.extend(
        [init for init in g2.initializer if init.name not in io_map_g2_ins])

    g.sparse_initializer.extend(g1.sparse_initializer)
    g.sparse_initializer.extend([
        init for init in g2.sparse_initializer
        if init.values.name not in io_map_g2_ins
    ])

    g.value_info.extend(g1.value_info)
    g.value_info.extend(
        [vi for vi in g2.value_info if vi.name not in io_map_g2_ins])

    g.name = name if name is not None else "_".join([g1.name, g2.name])

    if doc_string is None:
        doc_string = f"Graph combining {g1.name} and {g2.name}\n" + \
            g1.name + "\n\n" + g1.doc_string + "\n\n" + g2.name + "\n\n" + g2.doc_string
    g.doc_string = doc_string

    return g