Exemplo n.º 1
0
    def handle_trivial(cls, node, input_dict):
        op_name_lowered = op_name_to_lower(node.op_type)

        attrs = dict([(x, node.attrs[x]) for x in node.attrs.keys()])

        if op_name_lowered in cls.DEFAULT_ONNX_ATTR_PER_OP:
            default_attrs = cls.DEFAULT_ONNX_ATTR_PER_OP[op_name_lowered]
            default_attrs.update(attrs)
            attrs = default_attrs

        # Perform automatic attribute value translation.
        attrs = dict([(x, cls.attr_translator[x](cls, attrs[x]) \
          if x in cls.attr_translator else attrs[x]) \
                      for x in attrs.keys()])

        # Create an identity map from onnx attribute names to tf
        # attribute names.
        attr_map = dict([(x, x) for x in attrs.keys()])

        # Modify the map accoridng to onnx_tf_attribute_map.
        attr_map = dict([(x, ONNX_ATTR_TO_TF_ATTR[x] \
          if x in ONNX_ATTR_TO_TF_ATTR.keys() else x) \
                         for x in attr_map.keys()])

        # TODO: Per op attribute name mapping has the final say.

        # Modify the map according to onnx_tf_per_op_attr_map
        attr_map = dict([
            (x, ONNX_ATTR_TO_TF_ATTR_PER_OP[op_name_lowered][x]
             if op_name_lowered in ONNX_ATTR_TO_TF_ATTR_PER_OP
             and x in ONNX_ATTR_TO_TF_ATTR_PER_OP[op_name_lowered].keys() else
             attr_map[x]) for x in attr_map.keys()
        ])

        # Substitute attribute names in attrs.
        attrs = dict([(attr_map[x], y) for (x, y) in attrs.items()])
        # Remove the key according to onnx_tf_per_op_attr_remove
        attrs = {
            x: attrs[x]
            for x in attrs
            if not (op_name_lowered in ONNX_ATTR_TO_REMOVE_PER_OP
                    and x in ONNX_ATTR_TO_REMOVE_PER_OP[op_name_lowered])
        }
        inputs = [input_dict[name] for name in node.inputs]
        return [ONNX_OP_TO_TF_OP[op_name_to_lower(node.op_type)] \
                  (*inputs, **attrs)]
Exemplo n.º 2
0
    def _onnx_node_to_tensorflow_op(cls, node, input_dict, opset=0):
        """
    Convert onnx node to tensorflow op.

    Args:
      node: Onnx node object.
      input_dict: Inputs dict of graph.
      opset: Opset version of the operator set. Default 0 means using latest version.

    Returns:
      Tensorflow op
    """
        op_name_lowered = op_name_to_lower(node.op_type)
        handler_name = "handle_" + op_name_lowered

        # Check if specialized handler exists.
        versions = backend_opset_version[op_name_lowered]

        if opset == 0:
            version = max(versions)
        else:
            versions = sorted(versions + [opset])
            version = versions[
                max([i for i, v in enumerate(versions) if v == opset]) - 1]

        backend_ver = 'backend_v{}'.format(version)
        backend = cls.backend_version_cache.setdefault(
            backend_ver,
            importlib.import_module('onnx_tf.backends.' +
                                    backend_ver).TensorflowBackend)

        if hasattr(backend, handler_name):
            method_to_call = getattr(backend, handler_name)
            return method_to_call(node, input_dict)
        elif op_name_lowered in ONNX_OP_TO_TF_OP.keys():
            return backend.handle_trivial(node, input_dict)
        else:
            raise NotImplementedError("{} op is not implemented.".format(
                node.op_type))
Exemplo n.º 3
0
  def tensorflow_graph_to_onnx_graph(cls, graph_def, output, name="graph"):
    """Function that converts a tensorflow graph to an onnx graph.

    Args:
        graph_def: Tensorflow Graph Proto object.
        output: A Tensorflow NodeDef object specifying which node
          to be taken as output of the ONNX graph.
        name: The name of the output ONNX Graph.

    Returns:
        The equivalent ONNX Graph Proto object.

    """

    # This list holds the protobuf objects of type ValueInfoProto
    # representing the input to the converted ONNX graph.
    inputs_proto = []

    # This list holds the protobuf objects of type NodeProto
    # representing the ops in the converted ONNX graph.
    ops_proto = []

    # This dictionary contains a map from the name of the constant
    # op to the array of values it holds. This is useful because
    # tensorflow is less eager to know about input values at
    # graph construction time than ONNX. That is to say, some ONNX
    # attributes are input tensors in TF. This dictionary extracts
    # those values of constant tensors that are known at graph
    # construction time.
    consts = {}

    # Sometimes the constants are used as inputs to ops. This list
    # holds initializers that creates global constant tensors available
    # to be accessed by ops as inputs (as oppose to attributes which
    # is supplied by the `consts` map above).
    consts_proto = []

    for node in graph_def.node:
      node = TensorflowNode(node)
      if node.op == "Placeholder":
        # Tensorflow requires dtype to be known.
        # TODO: currently `dtype` is translated to `to`.
        onnx_type = node.attr["dtype"]
        shape = node.attr["shape"]
        input_proto = make_tensor_value_info(node.name,
                                             onnx_type,
                                             shape)
        inputs_proto.append(input_proto)
      elif node.op == "Const":
        const_dim = len(node.attr["value"].shape)
        consts[node.name] = node.attr["value"]
        raw_values = ([node.attr["value"].tolist()]
                      if const_dim == 0
                      else node.attr["value"].flatten().tolist())
        if const_dim == 0:
            values = [node.attr["value"]]
        else:
            values = node.attr["value"]
        shape = np.array(values).shape
        consts_proto.append(make_tensor(
                            name=node.name,
                            data_type=node.attr["dtype"],
                            dims=shape,
                            vals=raw_values))
        input_proto = make_tensor_value_info(node.name,
                                             node.attr["dtype"],
                                             shape)
        inputs_proto.append(input_proto)
      elif node.op in TF_OP_STR_TO_ONNX_OP.keys():
        # Remove tensorflow-specific attrs that are not
        # needed/allowed in ONNX.
        attr_to_remove = ["_output_shapes", "T", "seed2", "Tidx"]
        node.attr = dict(filter(lambda pair: pair[0]
                                not in attr_to_remove, node.attr.items()))

        node_output = node.name
        ops_proto.append(make_node(TF_OP_STR_TO_ONNX_OP[node.op],
                                   node.inputs,
                                   [node_output],
                                   name=node.name,
                                   **node.attr))
      else:
        handler_name = "handle_" + op_name_to_lower(node.op)

        # Check if specialized handler exists.
        if handler_name in dir(cls):
          method_to_call = getattr(cls, handler_name)
          ops_proto.append(method_to_call(node, consts))
        else:
          raise NotImplementedError("{} op is not implemented.".format(node.op))

    output = TensorflowNode(output)
    # making output proto
    # TODO: deal with multi-output case.
    # TODO: default to BOOL, cf.
    # https://github.com/tensorflow/tensorflow/issues/14769
    output_onnx_type = output.attr.get("T", TensorProto.BOOL)
    output_proto = make_tensor_value_info(output.name,
                                          output_onnx_type,
                                          output.attr["_output_shapes"][0])
    return make_graph(ops_proto,
                      name,
                      inputs_proto,
                      [output_proto],
                      consts_proto)
Exemplo n.º 4
0
    def tensorflow_graph_to_onnx_graph(cls,
                                       graph_def,
                                       output,
                                       opset=(("", 0), ),
                                       name="graph"):
        """Converts a Tensorflow Graph Proto to an ONNX graph

    This function converts a Tensorflow Graph proto to an equivalent
    representation of ONNX graph.

    :param graph_def: Tensorflow Graph Proto object.
    :param output: A Tensorflow NodeDef object specifying which node
      to be taken as output of the ONNX graph.
    :param opset: Opset, which should be ((str domain: int version number),).
    :param name: The name of the output ONNX Graph.

    :returns: The equivalent ONNX Graph Proto object.
    """

        # This list holds the protobuf objects of type ValueInfoProto
        # representing the input to the converted ONNX graph.
        inputs_proto = []

        # This list holds the protobuf objects of type NodeProto
        # representing the ops in the converted ONNX graph.
        ops_proto = []

        # This dictionary contains a map from the name of the constant
        # op to the array of values it holds. This is useful because
        # tensorflow is less eager to know about input values at
        # graph construction time than ONNX. That is to say, some ONNX
        # attributes are input tensors in TF. This dictionary extracts
        # those values of constant tensors that are known at graph
        # construction time.
        consts = {}

        # Sometimes the constants are used as inputs to ops. This list
        # holds initializers that creates global constant tensors available
        # to be accessed by ops as inputs (as oppose to attributes which
        # is supplied by the `consts` map above).
        consts_proto = []

        node_tup = [(node.name, TensorflowNode(node))
                    for node in graph_def.node]

        for name, node in node_tup:

            if node.op == "Placeholder":
                # Tensorflow requires dtype to be known.
                # TODO: currently `dtype` is translated to `to`.
                onnx_type = node.attr["dtype"]
                shape = node.attr["shape"]
                input_proto = make_tensor_value_info(name, onnx_type, shape)
                inputs_proto.append(input_proto)
            elif node.op == "Const":
                const_dim = len(node.attr["value"].shape)

                consts[name] = node.attr["value"]
                raw_values = ([node.attr["value"].tolist()] if const_dim == 0
                              else node.attr["value"].flatten().tolist())
                if const_dim == 0:
                    values = [node.attr["value"]]
                else:
                    values = node.attr["value"]
                shape = np.array(values).shape
                consts_proto.append(
                    make_tensor(name=name,
                                data_type=node.attr["dtype"],
                                dims=shape,
                                vals=raw_values))
                input_proto = make_tensor_value_info(name, node.attr["dtype"],
                                                     shape)
                inputs_proto.append(input_proto)
            else:
                splitted_op_name = node.op.split(".")
                op_domain = "" if len(splitted_op_name) == 1 else ".".join(
                    splitted_op_name[:-1])
                op_name = splitted_op_name[-1]

                handler_name = "handle_" + op_name_to_lower(op_name)

                # TODO per domain frontend_tf_opset_version?
                versions = frontend_tf_opset_version[op_name_to_lower(op_name)]

                opset_dict = {}
                onnx_domain = defs.ONNX_DOMAIN
                for domain, version in opset:
                    if domain == "ai.onnx":
                        domain = ""
                    opset_dict[domain] = version
                    defs.ONNX_DOMAIN = domain
                    assert isinstance(
                        version, int
                    ) and (version <= defs.onnx_opset_version()) and (
                        version >= 0
                    ), "Opset should be an int less than or equal to {}, but {}: {}".format(
                        defs.onnx_opset_version(), type(version), version)
                    defs.ONNX_DOMAIN = onnx_domain

                opset_ver = opset_dict[op_domain]
                if opset_ver == 0:
                    version = max(versions)
                else:
                    versions = sorted(versions + [opset_ver])
                    version = versions[max(
                        [i
                         for i, v in enumerate(versions) if v == opset_ver]) -
                                       1]

                camel_domain = "".join(w.title() for w in op_domain.split("."))
                frontend_ver = "frontend_v{}".format(version)
                frontend_class_name = "{}TensorflowFrontend".format(
                    camel_domain)
                frontend_module = cls.frontend_version_cache.setdefault(
                    frontend_ver,
                    importlib.import_module("onnx_tf.frontends." +
                                            frontend_ver))
                if hasattr(frontend_module, frontend_class_name):
                    frontend = getattr(frontend_module, frontend_class_name)
                else:
                    assert NotImplementedError, \
                      "{} for domain {} is not implemented".format(frontend_ver, op_domain)

                # Check if specialized handler exists.
                if hasattr(frontend, handler_name):
                    method_to_call = getattr(frontend, handler_name)
                    node = method_to_call(node,
                                          consts=consts,
                                          node_dict=dict(node_tup))
                    if isinstance(node, list):
                        ops_proto.extend(node)
                    else:
                        ops_proto.append(node)
                elif node.op in TF_OP_STR_TO_ONNX_OP.keys():
                    # Remove tensorflow-specific attrs that are not
                    # needed/allowed in ONNX.
                    attr = cls.DEFAULT_TF_ATTR_PER_OP.get(node.op, {})
                    filtered_attr = dict(
                        filter(lambda pair: pair[0] not in TF_ATTR_TO_REMOVE,
                               node.attr.items()))
                    node_output = name
                    ops_proto.append(
                        make_node(TF_OP_STR_TO_ONNX_OP[node.op],
                                  node.inputs, [node_output],
                                  name=name,
                                  **filtered_attr))
                else:
                    raise NotImplementedError(
                        "{} op is not implemented.".format(node.op))

        output = TensorflowNode(output)
        # making output proto
        # TODO: deal with multi-output case.
        # TODO: default to BOOL, cf.
        # https://github.com/tensorflow/tensorflow/issues/14769
        output_onnx_type = output.attr.get("T", TensorProto.BOOL)
        output_proto = []
        for i in range(len(output.attr["_output_shapes"])):
            output_name = output.name + ":{}".format(
                i) if i > 0 else output.name
            output_proto.append(
                make_tensor_value_info(output_name, output_onnx_type,
                                       output.attr["_output_shapes"][i]))

        inputs = list(
            chain.from_iterable(map(lambda p: list(p.input), ops_proto)))

        # Remove proto in inputs_proto and consts_proto if proto is not used as input in ONNX
        inputs_proto = list(filter(lambda x: x.name in inputs, inputs_proto))
        consts_proto = list(filter(lambda x: x.name in inputs, consts_proto))

        return make_graph(ops_proto, name, inputs_proto, output_proto,
                          consts_proto)
Exemplo n.º 5
0
  def tensorflow_graph_to_onnx_graph(cls, graph_def, output, name="graph"):
    """Function that converts a tensorflow graph to an onnx graph.

    Args:
        graph_def: Tensorflow Graph Proto object.
        output: A Tensorflow NodeDef object specifying which node
          to be taken as output of the ONNX graph.
        name: The name of the output ONNX Graph.

    Returns:
        The equivalent ONNX Graph Proto object.

    """

    # This list holds the protobuf objects of type ValueInfoProto
    # representing the input to the converted ONNX graph.
    inputs_proto = []

    # This list holds the protobuf objects of type NodeProto
    # representing the ops in the converted ONNX graph.
    ops_proto = []

    # This dictionary contains a map from the name of the constant
    # op to the array of values it holds.
    consts = {}

    for node in graph_def.node:
      node = TensorflowNode(node)
      if node.op == "Placeholder":
        # Tensorflow requires dtype to be known.
        # TODO: currently `dtype` is translated to `to`.
        onnx_type = node.attr["to"]
        shape = node.attr["shape"]
        input_proto = make_tensor_value_info(node.name,
                                             onnx_type,
                                             shape)
        inputs_proto.append(input_proto)
      if node.op == "Const":
        consts[node.name] = node.attr["value"]
      elif node.op in TF_OP_STR_TO_ONNX_OP.keys():
        # Remove tensorflow-specific attrs that are not
        # needed/allowed in ONNX.
        attr_to_remove = ["_output_shapes", "T"]
        node.attr = dict(filter(lambda pair: pair[0]
                                not in attr_to_remove, node.attr.items()))

        node_output = node.name
        ops_proto.append(make_node(TF_OP_STR_TO_ONNX_OP[node.op],
                                   node.inputs,
                                   [node_output],
                                   name=node.name,
                                   **node.attr))
      else:
        handler_name = "handle_" + op_name_to_lower(node.op)

        # Check if specialized handler exists.
        if handler_name in dir(cls):
          method_to_call = getattr(cls, handler_name)
          ops_proto.append(method_to_call(node, consts))

    output = TensorflowNode(output)
    # making output proto
    # TODO: deal with multi-output case.
    # TODO: default to BOOL, cf.
    # https://github.com/tensorflow/tensorflow/issues/14769
    output_onnx_type = output.attr.get("T", TensorProto.BOOL)
    output_proto = make_tensor_value_info(output.name,
                                          output_onnx_type,
                                          output.attr["_output_shapes"][0])

    return make_graph(ops_proto,
                      name,
                      inputs_proto,
                      [output_proto])
Exemplo n.º 6
0
def main():
    backend_opset_dict = {}
    frontend_opset_dict = {}
    frontend_tf_opset_dict = {}

    for schema in defs.get_all_schemas():
        op_name = schema.name
        backend_opset_dict[op_name] = []
        frontend_opset_dict[op_name] = []

    version = 1
    while True:
        try:
            backend = (importlib.import_module(
                'backends.backend_v{}'.format(version)).TensorflowBackend)
            frontend = (importlib.import_module(
                'frontends.frontend_v{}'.format(version)).TensorflowFrontend)
        except:
            break

        # Register all tf ops in ONNX_TO_HANDLER
        tf_op_names = []
        onnx_to_handler = frontend.ONNX_TO_HANDLER.get(
            'frontend_v{}'.format(version), {})
        # for handler in frontend.ONNX_TO_HANDLER.values():
        for handler in onnx_to_handler.values():
            if isinstance(handler, list):
                tf_op_names.extend(list(map(op_name_to_lower, handler)))
            else:
                tf_op_names.append(op_name_to_lower(handler))

        for schema in defs.get_all_schemas():
            op_name = schema.name
            lower_op_name = op_name_to_lower(op_name)
            has_backend_handler = hasattr(backend, 'handle_' + lower_op_name)
            # Record only one version for trivial ops
            if has_backend_handler or (version == 1 and lower_op_name
                                       in ONNX_OP_TO_TF_OP.keys()):
                backend_opset_dict[op_name].append(version)

            # Register once if onnx op in ONNX_OP_TO_TF_OP_STR
            if version == 1 and schema.name in ONNX_OP_TO_TF_OP_STR and \
                ONNX_OP_TO_TF_OP_STR[schema.name] not in tf_op_names:
                tf_op_names.append(
                    op_name_to_lower(ONNX_OP_TO_TF_OP_STR[schema.name]))
                frontend_opset_dict[op_name].append(version)
            # Register if onnx op in ONNX_TO_HANDLER
            elif op_name in onnx_to_handler:
                frontend_opset_dict[op_name].append(version)
        for tf_op_name in tf_op_names:
            frontend_tf_opset_dict.setdefault(str(tf_op_name),
                                              []).append(version)

        version += 1

    with open('opset_version.py', 'w') as version_file:
        pp = pprint.PrettyPrinter(indent=4)
        version_file.write("backend_opset_version = {\n " +
                           pp.pformat(backend_opset_dict)[1:-1] + "\n}\n\n")
        version_file.write("frontend_opset_version = {\n " +
                           pp.pformat(frontend_opset_dict)[1:-1] + "\n}\n\n")
        version_file.write("frontend_tf_opset_version = {\n " +
                           pp.pformat(frontend_tf_opset_dict)[1:-1] + "\n}\n")