コード例 #1
0
    def _onnx_node_to_tensorflow_op(cls,
                                    node,
                                    tensor_dict,
                                    handlers=None,
                                    opset=None,
                                    strict=True):
        """
    Convert onnx node to tensorflow op.

    Args:
      node: Onnx node object.
      tensor_dict: Tensor dict of graph.
      opset: Opset version of the operator set. Default 0 means using latest version.
      strict: whether to enforce semantic equivalence between the original model
        and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
        Changing to False is strongly discouraged.
    Returns:
      Tensorflow op
    """
        handlers = handlers or cls._get_handlers(opset)
        handler = handlers[node.domain].get(node.op_type, None)
        if handler:
            return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
        else:
            exception.OP_UNIMPLEMENTED_EXCEPT(node.op_type)
コード例 #2
0
    def handle(cls, node, **kwargs):
        """ Main method in handler. It will find corresponding versioned handle method,
    whose name format is `version_%d`. So prefix `version_` is reserved in onnx-tensorflow.
    DON'T use it for other purpose.

    :param node: NodeProto for backend or TensorflowNode for frontend.
    :param kwargs: Other args.
    :return: NodeProto for frontend or TensorflowNode for backend.
    """
        ver_handle = getattr(cls, "version_{}".format(cls.SINCE_VERSION), None)
        if ver_handle:
            cls.args_check(node, **kwargs)
            return ver_handle(node, **kwargs)
        exception.OP_UNIMPLEMENTED_EXCEPT(node.op_type, cls.SINCE_VERSION)
        return None
コード例 #3
0
    def tensorflow_graph_to_onnx_graph(cls,
                                       graph_def,
                                       output,
                                       opset=((defs.ONNX_DOMAIN,
                                               defs.onnx_opset_version()), ),
                                       name="graph",
                                       ignore_unimplemented=False):
        """Converts a Tensorflow Graph Proto to an ONNX graph

    This function converts a Tensorflow Graph proto to an equivalent
    representation of ONNX graph.

    :param graph_def: Tensorflow Graph Proto object.
    :param output: List of Tensorflow NodeDef object specifying which nodes
      to be taken as outputs of the ONNX graph.
    :param opset: Opset, which should be ((str domain: int version number),).
    :param name: The name of the output ONNX Graph.
    :param ignore_unimplemented: Convert to ONNX model and ignore all the operators
      that are not currently supported by onnx-tensorflow.
      This is an experimental feature. By enabling this feature,
      the graph would not be guaranteed to match the ONNX specifications.

    :returns: The equivalent ONNX Graph Proto object.
    """
        onnx_graph = OnnxGraph(name)
        exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented

        opset_dict = {}
        for domain, version in opset:
            if domain == "ai.onnx":
                domain = defs.ONNX_DOMAIN
            opset_dict[domain] = version

        handlers = get_all_frontend_handlers(opset_dict)

        node_tup = [(node.name, TensorflowNode(node))
                    for node in graph_def.node]
        for name, node in node_tup:

            if node.op_type == "Placeholder":
                onnx_graph.add_input_proto(node)
            elif node.op_type == "Const":
                onnx_graph.add_const(node)
                onnx_graph.add_const_proto(node)
                onnx_graph.add_input_proto(node)
            else:
                onnx_graph.add_value_info_proto(node)
                handler = handlers.get(node.domain, {}).get(node.op_type, None)
                node_proto = None
                if handler:
                    node_proto = handler.handle(
                        node,
                        consts=onnx_graph.consts,
                        node_dict=dict(node_tup),
                        data_type_cast_map=onnx_graph.data_type_cast_map)
                else:
                    exception.OP_UNIMPLEMENTED_EXCEPT(
                        node.op_type,
                        domain=None
                        if node.domain in handlers else node.domain)

                if node_proto is None:
                    node_proto = FrontendHandler.make_node_from_tf_node(
                        node, op_type=node.op_type, should_check=False)
                onnx_graph.add_node_proto(node_proto)

        for o in output:
            output_node = TensorflowNode(o)
            onnx_graph.add_output_proto(output_node)

        return onnx_graph.make_graph_proto()
コード例 #4
0
    def tensorflow_graph_to_onnx_graph(cls,
                                       tf_graph,
                                       opset=((defs.ONNX_DOMAIN,
                                               defs.onnx_opset_version()), ),
                                       ignore_unimplemented=False):
        """Converts a TensorflowGraph to an ONNX graph

    This function converts a TensorflowGraph to an equivalent
    representation of ONNX graph.

    :param tf_graph: TensorflowGraph object.
    :param opset: Opset, which should be ((str domain: int version number),).
    :param ignore_unimplemented: Convert to ONNX model and ignore all the operators
      that are not currently supported by onnx-tensorflow.
      This is an experimental feature. By enabling this feature,
      the graph would not be guaranteed to match the ONNX specifications.

    :returns: The equivalent ONNX Graph Proto object.
    """
        onnx_graph = OnnxGraph(tf_graph.graph_name)
        exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented
        training_ops_to_remove = ["RandomShuffleQueueV2"]

        opset_dict = {}
        for domain, version in opset:
            if domain == "ai.onnx":
                domain = defs.ONNX_DOMAIN
            opset_dict[domain] = version

        handlers = get_all_frontend_handlers(opset_dict)

        node_tup = [(n.name, n) for n in tf_graph.nodes]
        for name, node in node_tup:
            if node.op_type == "Placeholder":
                onnx_graph.add_input_proto(node)
            elif node.op_type == "Const":
                onnx_graph.add_const(node)
                onnx_graph.add_const_proto(node)
                onnx_graph.add_input_proto(node)
            elif node.op_type in training_ops_to_remove:
                logger.info(
                    "A training op with name {} type {} has been removed.".
                    format(node.name, node.op_type))
            elif node.op_type == "QueueDequeueManyV2":
                num_output = len(node.attr["_output_shapes"])
                for index, shape, onnx_type in zip(
                        range(num_output), node.attr["_output_shapes"],
                        node.attr["component_types"]):
                    onnx_graph.add_input_proto_explicit(node.name + ":" +
                                                        str(index),
                                                        shape,
                                                        onnx_dtype=onnx_type)
            else:
                onnx_graph.add_value_info_proto(node)
                handler = handlers.get(node.domain, {}).get(node.op_type, None)
                node_proto = None
                if handler:
                    node_proto = handler.handle(
                        node,
                        consts=onnx_graph.consts,
                        node_dict=dict(node_tup),
                        data_type_cast_map=onnx_graph.data_type_cast_map)
                else:
                    exception.OP_UNIMPLEMENTED_EXCEPT(
                        node.op_type,
                        domain=None
                        if node.domain in handlers else node.domain)

                if node_proto is None:
                    node_proto = FrontendHandler.make_node_from_tf_node(
                        node, op_type=node.op_type, should_check=False)
                onnx_graph.add_node_proto(node_proto)

        for o in tf_graph.outputs:
            output_node = tf_graph.get_node_by_name(o)
            onnx_graph.add_output_proto(output_node)

        return onnx_graph.make_graph_proto()