Ejemplo n.º 1
0
def _onnx_rewrite_operator_node(existing_names, node, sub_onx):
    """
    Replaces a node by a subgraph.

    :param existing_names: existing results names
    :param node: onnx node to replace
    :param sub_onx: onnx sub_graph to use as a replacement
    :return: new_initializer, new_nodes
    """
    if len(node.input) != len(sub_onx.graph.input):
        raise ValueError(  # pragma: no cover
            "Mismatch with the number of inputs for operator type %r. "
            "%d != %d." %
            (node.op_type, len(node.input), len(sub_onx.graph.nput)))
    if len(node.output) != len(sub_onx.graph.output):
        raise ValueError(  # pragma: no cover
            "Mismatch with the number of outputs for operator type %r. "
            "%d != %d." %
            (node.op_type, len(node.output), len(sub_onx.graph.output)))
    replaces = {}
    for inp, name in zip(sub_onx.graph.input, node.input):
        replaces[inp.name] = name
    for inp, name in zip(sub_onx.graph.output, node.output):
        replaces[inp.name] = name

    new_inits = []
    for init in sub_onx.graph.initializer:
        name = _unique_name(existing_names, init.name)
        replaces[init.name] = name
        tensor = from_array(to_array(init), name=name)
        new_inits.append(tensor)

    new_nodes = []
    for n in sub_onx.graph.node:
        new_node = NodeProto()
        new_node.op_type = n.op_type
        new_node.attribute.extend(n.attribute)  # pylint: disable=E1101
        new_node.input.extend(  # pylint: disable=E1101
            [replaces[i] for i in n.input])  # pylint: disable=E1101
        new_node.domain = n.domain
        new_out = []
        for o in n.output:
            if o in replaces:
                new_out.append(replaces[o])
            else:
                n = _unique_name(existing_names, o)
                new_out.append(n)
        new_node.output.extend(new_out)  # pylint: disable=E1101
        new_nodes.append(new_node)

    return new_inits, new_nodes
Ejemplo n.º 2
0
def make_node(
    op_type,  # type: Text
    inputs,  # type: Sequence[Text]
    outputs,  # type: Sequence[Text]
    name=None,  # type: Optional[Text]
    doc_string=None,  # type: Optional[Text]
    domain=None,  # type: Optional[Text]
    _dtype=None,  # type: [np.float32, np.float64]
    **kwargs  # type: Any
):  # type: (...) -> NodeProto
    """Construct a NodeProto.

    Arguments:
        op_type (string): The name of the operator to construct
        inputs (list of string): list of input names
        outputs (list of string): list of output names
        name (string, default None): optional unique identifier for NodeProto
        doc_string (string, default None): optional documentation
            string for NodeProto
        dtype: dtype for double used to infer
        domain (string, default None): optional domain for NodeProto.
            If it's None, we will just use default domain (which is empty)
        **kwargs (dict): the attributes of the node.  The acceptable values
            are documented in :func:`make_attribute`.
    """
    if _dtype is None:
        raise ValueError("dtype cannot be None")
    node = NodeProto()
    node.op_type = op_type
    node.input.extend(inputs)
    node.output.extend(outputs)
    if name:
        node.name = name
    if doc_string:
        node.doc_string = doc_string
    if domain is not None:
        node.domain = domain
    if kwargs:
        node.attribute.extend(
            make_attribute(key, value, dtype=_dtype, domain=domain)
            for key, value in sorted(kwargs.items()))
    return node
Ejemplo n.º 3
0
def make_node(
        op_type: Text,
        inputs: Sequence[Text],
        outputs: Sequence[Text],
        name: Optional[Text] = None,
        doc_string: Optional[Text] = None,
        domain: Optional[Text] = None,
        **kwargs: Any
) -> NodeProto:
    """Construct a NodeProto.

    Arguments:
        op_type (string): The name of the operator to construct
        inputs (list of string): list of input names
        outputs (list of string): list of output names
        name (string, default None): optional unique identifier for NodeProto
        doc_string (string, default None): optional documentation string for NodeProto
        domain (string, default None): optional domain for NodeProto.
            If it's None, we will just use default domain (which is empty)
        **kwargs (dict): the attributes of the node.  The acceptable values
            are documented in :func:`make_attribute`.
    Returns:
        NodeProto
    """

    node = NodeProto()
    node.op_type = op_type
    node.input.extend(inputs)
    node.output.extend(outputs)
    if name:
        node.name = name
    if doc_string:
        node.doc_string = doc_string
    if domain is not None:
        node.domain = domain
    if kwargs:
        node.attribute.extend(
            make_attribute(key, value)
            for key, value in sorted(kwargs.items())
            if value is not None)
    return node
def _make_node(op_type,
               inputs,
               outputs,
               name=None,
               doc_string=None,
               domain=None,
               attributes=None):
    """
    Constructs a NodeProto.

    :param op_type: (string): The name of the operator to construct
    :param inputs: list of input names
    :param outputs: list of output names
    :param name: optional unique identifier for NodeProto
    :param doc_string: optional documentation
        string for NodeProto
    :param domain: optional domain for NodeProto.
        If it's None, we will just use default domain (which is empty)
    :param attributes: the attributes of the node.  The acceptable values
        are documented in :func:`make_attribute`.
    :return: node
    """
    node = NodeProto()
    node.op_type = op_type
    node.input.extend(inputs)
    node.output.extend(outputs)
    if name:
        node.name = name
    if doc_string:
        node.doc_string = doc_string
    if domain is not None:
        node.domain = domain
    if isinstance(attributes, dict):
        if len(attributes) > 0:
            node.attribute.extend(
                make_attribute(key, value)
                for key, value in sorted(attributes.items()))
    elif attributes:
        for att in attributes:
            node.attribute.extend([att])
    return node