Ejemplo n.º 1
0
def get_all_frontend_handlers(opset_dict):
    """ Get a dict of all frontend handler classes.
  e.g. {'domain': {'Abs': Abs handler class}, ...}, }.

  :param opset_dict: A dict of opset. e.g. {'domain': version, ...}
  :return: Dict.
  """
    handlers = {}
    for handler in FrontendHandler.__subclasses__():
        handler.check_cls()

        domain = handler.DOMAIN
        version = opset_dict[domain]
        handler.VERSION = version

        since_version = 1
        if handler.ONNX_OP and defs.has(handler.ONNX_OP,
                                        domain=handler.DOMAIN):
            since_version = defs.get_schema(
                handler.ONNX_OP,
                domain=handler.DOMAIN,
                max_inclusive_version=version).since_version
        else:
            warnings.warn("Unknown op {} in domain `{}`. "
                          "Can't check specification by ONNX. "
                          "Please set should_check flag to False "
                          "when call make_node method in handler.".format(
                              handler.ONNX_OP or "Undefined", handler.DOMAIN
                              or "ai.onnx"))
        handler.SINCE_VERSION = since_version

        for tf_op in handler.TF_OP:
            handlers.setdefault(domain, {})[tf_op] = handler
    return handlers
Ejemplo n.º 2
0
def get_frontend_coverage():
  """ Get frontend coverage for document.

  :return: dict of frontend coverages
  onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...}
  tf_coverage: e.g. {'domain': {'TF_OP': [versions], ...}, ...}
  experimental_op: e.g. {'ONNX_OP'...}
  """

  tf_coverage = {}
  onnx_coverage = {}
  experimental_op = set()
  for handler in FrontendHandler.__subclasses__():
    handler.check_cls()
    versions = handler.get_versions()
    domain = handler.DOMAIN
    for tf_op in handler.TF_OP:
      _update_coverage(tf_coverage, domain, tf_op, versions)
    if handler.ONNX_OP:
      onnx_op = handler.ONNX_OP
      if getattr(handler, "EXPERIMENTAL", False):
        experimental_op.add(handler.ONNX_OP)
      _update_coverage(onnx_coverage, domain, onnx_op, versions)
  return dict(
      onnx_coverage=onnx_coverage,
      tf_coverage=tf_coverage,
      experimental_op=experimental_op)
Ejemplo n.º 3
0
def get_frontend_coverage():
    """ Get frontend coverage for document.

  :return: onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...}
  tf_coverage: e.g. {'domain': {'TF_OP': [versions], ...}, ...}
  """

    tf_coverage = {}
    onnx_coverage = {}
    for handler in FrontendHandler.__subclasses__():
        handler.check_cls()
        versions = handler.get_versions()
        domain = handler.DOMAIN
        for tf_op in handler.TF_OP:
            _update_coverage(tf_coverage, domain, tf_op, versions)
        if handler.ONNX_OP:
            _update_coverage(onnx_coverage, domain, handler.ONNX_OP, versions)
    return onnx_coverage, tf_coverage
Ejemplo n.º 4
0
    def tensorflow_graph_to_onnx_graph(cls,
                                       graph_def,
                                       output,
                                       opset=((defs.ONNX_DOMAIN,
                                               defs.onnx_opset_version()), ),
                                       name="graph",
                                       ignore_unimplemented=False):
        """Converts a Tensorflow Graph Proto to an ONNX graph

    This function converts a Tensorflow Graph proto to an equivalent
    representation of ONNX graph.

    :param graph_def: Tensorflow Graph Proto object.
    :param output: List of Tensorflow NodeDef object specifying which nodes
      to be taken as outputs of the ONNX graph.
    :param opset: Opset, which should be ((str domain: int version number),).
    :param name: The name of the output ONNX Graph.
    :param ignore_unimplemented: Convert to ONNX model and ignore all the operators
      that are not currently supported by onnx-tensorflow.
      This is an experimental feature. By enabling this feature,
      the graph would not be guaranteed to match the ONNX specifications.

    :returns: The equivalent ONNX Graph Proto object.
    """
        onnx_graph = OnnxGraph(name)
        exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented

        opset_dict = {}
        for domain, version in opset:
            if domain == "ai.onnx":
                domain = defs.ONNX_DOMAIN
            opset_dict[domain] = version

        handlers = get_all_frontend_handlers(opset_dict)

        node_tup = [(node.name, TensorflowNode(node))
                    for node in graph_def.node]
        for name, node in node_tup:

            if node.op_type == "Placeholder":
                onnx_graph.add_input_proto(node)
            elif node.op_type == "Const":
                onnx_graph.add_const(node)
                onnx_graph.add_const_proto(node)
                onnx_graph.add_input_proto(node)
            else:
                onnx_graph.add_value_info_proto(node)
                handler = handlers.get(node.domain, {}).get(node.op_type, None)
                node_proto = None
                if handler:
                    node_proto = handler.handle(
                        node,
                        consts=onnx_graph.consts,
                        node_dict=dict(node_tup),
                        data_type_cast_map=onnx_graph.data_type_cast_map)
                else:
                    exception.OP_UNIMPLEMENTED_EXCEPT(
                        node.op_type,
                        domain=None
                        if node.domain in handlers else node.domain)

                if node_proto is None:
                    node_proto = FrontendHandler.make_node_from_tf_node(
                        node, op_type=node.op_type, should_check=False)
                onnx_graph.add_node_proto(node_proto)

        for o in output:
            output_node = TensorflowNode(o)
            onnx_graph.add_output_proto(output_node)

        return onnx_graph.make_graph_proto()
Ejemplo n.º 5
0
    def tensorflow_graph_to_onnx_graph(cls,
                                       tf_graph,
                                       opset=((defs.ONNX_DOMAIN,
                                               defs.onnx_opset_version()), ),
                                       ignore_unimplemented=False):
        """Converts a TensorflowGraph to an ONNX graph

    This function converts a TensorflowGraph to an equivalent
    representation of ONNX graph.

    :param tf_graph: TensorflowGraph object.
    :param opset: Opset, which should be ((str domain: int version number),).
    :param ignore_unimplemented: Convert to ONNX model and ignore all the operators
      that are not currently supported by onnx-tensorflow.
      This is an experimental feature. By enabling this feature,
      the graph would not be guaranteed to match the ONNX specifications.

    :returns: The equivalent ONNX Graph Proto object.
    """
        onnx_graph = OnnxGraph(tf_graph.graph_name)
        exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented
        training_ops_to_remove = ["RandomShuffleQueueV2"]

        opset_dict = {}
        for domain, version in opset:
            if domain == "ai.onnx":
                domain = defs.ONNX_DOMAIN
            opset_dict[domain] = version

        handlers = get_all_frontend_handlers(opset_dict)

        node_tup = [(n.name, n) for n in tf_graph.nodes]
        for name, node in node_tup:
            if node.op_type == "Placeholder":
                onnx_graph.add_input_proto(node)
            elif node.op_type == "Const":
                onnx_graph.add_const(node)
                onnx_graph.add_const_proto(node)
                onnx_graph.add_input_proto(node)
            elif node.op_type in training_ops_to_remove:
                logger.info(
                    "A training op with name {} type {} has been removed.".
                    format(node.name, node.op_type))
            elif node.op_type == "QueueDequeueManyV2":
                num_output = len(node.attr["_output_shapes"])
                for index, shape, onnx_type in zip(
                        range(num_output), node.attr["_output_shapes"],
                        node.attr["component_types"]):
                    onnx_graph.add_input_proto_explicit(node.name + ":" +
                                                        str(index),
                                                        shape,
                                                        onnx_dtype=onnx_type)
            else:
                onnx_graph.add_value_info_proto(node)
                handler = handlers.get(node.domain, {}).get(node.op_type, None)
                node_proto = None
                if handler:
                    node_proto = handler.handle(
                        node,
                        consts=onnx_graph.consts,
                        node_dict=dict(node_tup),
                        data_type_cast_map=onnx_graph.data_type_cast_map)
                else:
                    exception.OP_UNIMPLEMENTED_EXCEPT(
                        node.op_type,
                        domain=None
                        if node.domain in handlers else node.domain)

                if node_proto is None:
                    node_proto = FrontendHandler.make_node_from_tf_node(
                        node, op_type=node.op_type, should_check=False)
                onnx_graph.add_node_proto(node_proto)

        for o in tf_graph.outputs:
            output_node = tf_graph.get_node_by_name(o)
            onnx_graph.add_output_proto(output_node)

        return onnx_graph.make_graph_proto()