def _onnx_rewrite_operator_node(existing_names, node, sub_onx): """ Replaces a node by a subgraph. :param existing_names: existing results names :param node: onnx node to replace :param sub_onx: onnx sub_graph to use as a replacement :return: new_initializer, new_nodes """ if len(node.input) != len(sub_onx.graph.input): raise ValueError( # pragma: no cover "Mismatch with the number of inputs for operator type %r. " "%d != %d." % (node.op_type, len(node.input), len(sub_onx.graph.nput))) if len(node.output) != len(sub_onx.graph.output): raise ValueError( # pragma: no cover "Mismatch with the number of outputs for operator type %r. " "%d != %d." % (node.op_type, len(node.output), len(sub_onx.graph.output))) replaces = {} for inp, name in zip(sub_onx.graph.input, node.input): replaces[inp.name] = name for inp, name in zip(sub_onx.graph.output, node.output): replaces[inp.name] = name new_inits = [] for init in sub_onx.graph.initializer: name = _unique_name(existing_names, init.name) replaces[init.name] = name tensor = from_array(to_array(init), name=name) new_inits.append(tensor) new_nodes = [] for n in sub_onx.graph.node: new_node = NodeProto() new_node.op_type = n.op_type new_node.attribute.extend(n.attribute) # pylint: disable=E1101 new_node.input.extend( # pylint: disable=E1101 [replaces[i] for i in n.input]) # pylint: disable=E1101 new_node.domain = n.domain new_out = [] for o in n.output: if o in replaces: new_out.append(replaces[o]) else: n = _unique_name(existing_names, o) new_out.append(n) new_node.output.extend(new_out) # pylint: disable=E1101 new_nodes.append(new_node) return new_inits, new_nodes
def make_node( op_type, # type: Text inputs, # type: Sequence[Text] outputs, # type: Sequence[Text] name=None, # type: Optional[Text] doc_string=None, # type: Optional[Text] domain=None, # type: Optional[Text] _dtype=None, # type: [np.float32, np.float64] **kwargs # type: Any ): # type: (...) -> NodeProto """Construct a NodeProto. Arguments: op_type (string): The name of the operator to construct inputs (list of string): list of input names outputs (list of string): list of output names name (string, default None): optional unique identifier for NodeProto doc_string (string, default None): optional documentation string for NodeProto dtype: dtype for double used to infer domain (string, default None): optional domain for NodeProto. If it's None, we will just use default domain (which is empty) **kwargs (dict): the attributes of the node. The acceptable values are documented in :func:`make_attribute`. """ if _dtype is None: raise ValueError("dtype cannot be None") node = NodeProto() node.op_type = op_type node.input.extend(inputs) node.output.extend(outputs) if name: node.name = name if doc_string: node.doc_string = doc_string if domain is not None: node.domain = domain if kwargs: node.attribute.extend( make_attribute(key, value, dtype=_dtype, domain=domain) for key, value in sorted(kwargs.items())) return node
def make_node( op_type: Text, inputs: Sequence[Text], outputs: Sequence[Text], name: Optional[Text] = None, doc_string: Optional[Text] = None, domain: Optional[Text] = None, **kwargs: Any ) -> NodeProto: """Construct a NodeProto. Arguments: op_type (string): The name of the operator to construct inputs (list of string): list of input names outputs (list of string): list of output names name (string, default None): optional unique identifier for NodeProto doc_string (string, default None): optional documentation string for NodeProto domain (string, default None): optional domain for NodeProto. If it's None, we will just use default domain (which is empty) **kwargs (dict): the attributes of the node. The acceptable values are documented in :func:`make_attribute`. Returns: NodeProto """ node = NodeProto() node.op_type = op_type node.input.extend(inputs) node.output.extend(outputs) if name: node.name = name if doc_string: node.doc_string = doc_string if domain is not None: node.domain = domain if kwargs: node.attribute.extend( make_attribute(key, value) for key, value in sorted(kwargs.items()) if value is not None) return node
def _make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, attributes=None): """ Constructs a NodeProto. :param op_type: (string): The name of the operator to construct :param inputs: list of input names :param outputs: list of output names :param name: optional unique identifier for NodeProto :param doc_string: optional documentation string for NodeProto :param domain: optional domain for NodeProto. If it's None, we will just use default domain (which is empty) :param attributes: the attributes of the node. The acceptable values are documented in :func:`make_attribute`. :return: node """ node = NodeProto() node.op_type = op_type node.input.extend(inputs) node.output.extend(outputs) if name: node.name = name if doc_string: node.doc_string = doc_string if domain is not None: node.domain = domain if isinstance(attributes, dict): if len(attributes) > 0: node.attribute.extend( make_attribute(key, value) for key, value in sorted(attributes.items())) elif attributes: for att in attributes: node.attribute.extend([att]) return node