def handle_trivial(cls, node, input_dict): op_name_lowered = op_name_to_lower(node.op_type) attrs = dict([(x, node.attrs[x]) for x in node.attrs.keys()]) if op_name_lowered in cls.DEFAULT_ONNX_ATTR_PER_OP: default_attrs = cls.DEFAULT_ONNX_ATTR_PER_OP[op_name_lowered] default_attrs.update(attrs) attrs = default_attrs # Perform automatic attribute value translation. attrs = dict([(x, cls.attr_translator[x](cls, attrs[x]) \ if x in cls.attr_translator else attrs[x]) \ for x in attrs.keys()]) # Create an identity map from onnx attribute names to tf # attribute names. attr_map = dict([(x, x) for x in attrs.keys()]) # Modify the map accoridng to onnx_tf_attribute_map. attr_map = dict([(x, ONNX_ATTR_TO_TF_ATTR[x] \ if x in ONNX_ATTR_TO_TF_ATTR.keys() else x) \ for x in attr_map.keys()]) # TODO: Per op attribute name mapping has the final say. # Modify the map according to onnx_tf_per_op_attr_map attr_map = dict([ (x, ONNX_ATTR_TO_TF_ATTR_PER_OP[op_name_lowered][x] if op_name_lowered in ONNX_ATTR_TO_TF_ATTR_PER_OP and x in ONNX_ATTR_TO_TF_ATTR_PER_OP[op_name_lowered].keys() else attr_map[x]) for x in attr_map.keys() ]) # Substitute attribute names in attrs. attrs = dict([(attr_map[x], y) for (x, y) in attrs.items()]) # Remove the key according to onnx_tf_per_op_attr_remove attrs = { x: attrs[x] for x in attrs if not (op_name_lowered in ONNX_ATTR_TO_REMOVE_PER_OP and x in ONNX_ATTR_TO_REMOVE_PER_OP[op_name_lowered]) } inputs = [input_dict[name] for name in node.inputs] return [ONNX_OP_TO_TF_OP[op_name_to_lower(node.op_type)] \ (*inputs, **attrs)]
def _onnx_node_to_tensorflow_op(cls, node, input_dict, opset=0): """ Convert onnx node to tensorflow op. Args: node: Onnx node object. input_dict: Inputs dict of graph. opset: Opset version of the operator set. Default 0 means using latest version. Returns: Tensorflow op """ op_name_lowered = op_name_to_lower(node.op_type) handler_name = "handle_" + op_name_lowered # Check if specialized handler exists. versions = backend_opset_version[op_name_lowered] if opset == 0: version = max(versions) else: versions = sorted(versions + [opset]) version = versions[ max([i for i, v in enumerate(versions) if v == opset]) - 1] backend_ver = 'backend_v{}'.format(version) backend = cls.backend_version_cache.setdefault( backend_ver, importlib.import_module('onnx_tf.backends.' + backend_ver).TensorflowBackend) if hasattr(backend, handler_name): method_to_call = getattr(backend, handler_name) return method_to_call(node, input_dict) elif op_name_lowered in ONNX_OP_TO_TF_OP.keys(): return backend.handle_trivial(node, input_dict) else: raise NotImplementedError("{} op is not implemented.".format( node.op_type))
def tensorflow_graph_to_onnx_graph(cls, graph_def, output, name="graph"): """Function that converts a tensorflow graph to an onnx graph. Args: graph_def: Tensorflow Graph Proto object. output: A Tensorflow NodeDef object specifying which node to be taken as output of the ONNX graph. name: The name of the output ONNX Graph. Returns: The equivalent ONNX Graph Proto object. """ # This list holds the protobuf objects of type ValueInfoProto # representing the input to the converted ONNX graph. inputs_proto = [] # This list holds the protobuf objects of type NodeProto # representing the ops in the converted ONNX graph. ops_proto = [] # This dictionary contains a map from the name of the constant # op to the array of values it holds. This is useful because # tensorflow is less eager to know about input values at # graph construction time than ONNX. That is to say, some ONNX # attributes are input tensors in TF. This dictionary extracts # those values of constant tensors that are known at graph # construction time. consts = {} # Sometimes the constants are used as inputs to ops. This list # holds initializers that creates global constant tensors available # to be accessed by ops as inputs (as oppose to attributes which # is supplied by the `consts` map above). consts_proto = [] for node in graph_def.node: node = TensorflowNode(node) if node.op == "Placeholder": # Tensorflow requires dtype to be known. # TODO: currently `dtype` is translated to `to`. onnx_type = node.attr["dtype"] shape = node.attr["shape"] input_proto = make_tensor_value_info(node.name, onnx_type, shape) inputs_proto.append(input_proto) elif node.op == "Const": const_dim = len(node.attr["value"].shape) consts[node.name] = node.attr["value"] raw_values = ([node.attr["value"].tolist()] if const_dim == 0 else node.attr["value"].flatten().tolist()) if const_dim == 0: values = [node.attr["value"]] else: values = node.attr["value"] shape = np.array(values).shape consts_proto.append(make_tensor( name=node.name, data_type=node.attr["dtype"], dims=shape, vals=raw_values)) input_proto = make_tensor_value_info(node.name, node.attr["dtype"], shape) inputs_proto.append(input_proto) elif node.op in TF_OP_STR_TO_ONNX_OP.keys(): # Remove tensorflow-specific attrs that are not # needed/allowed in ONNX. attr_to_remove = ["_output_shapes", "T", "seed2", "Tidx"] node.attr = dict(filter(lambda pair: pair[0] not in attr_to_remove, node.attr.items())) node_output = node.name ops_proto.append(make_node(TF_OP_STR_TO_ONNX_OP[node.op], node.inputs, [node_output], name=node.name, **node.attr)) else: handler_name = "handle_" + op_name_to_lower(node.op) # Check if specialized handler exists. if handler_name in dir(cls): method_to_call = getattr(cls, handler_name) ops_proto.append(method_to_call(node, consts)) else: raise NotImplementedError("{} op is not implemented.".format(node.op)) output = TensorflowNode(output) # making output proto # TODO: deal with multi-output case. # TODO: default to BOOL, cf. # https://github.com/tensorflow/tensorflow/issues/14769 output_onnx_type = output.attr.get("T", TensorProto.BOOL) output_proto = make_tensor_value_info(output.name, output_onnx_type, output.attr["_output_shapes"][0]) return make_graph(ops_proto, name, inputs_proto, [output_proto], consts_proto)
def tensorflow_graph_to_onnx_graph(cls, graph_def, output, opset=(("", 0), ), name="graph"): """Converts a Tensorflow Graph Proto to an ONNX graph This function converts a Tensorflow Graph proto to an equivalent representation of ONNX graph. :param graph_def: Tensorflow Graph Proto object. :param output: A Tensorflow NodeDef object specifying which node to be taken as output of the ONNX graph. :param opset: Opset, which should be ((str domain: int version number),). :param name: The name of the output ONNX Graph. :returns: The equivalent ONNX Graph Proto object. """ # This list holds the protobuf objects of type ValueInfoProto # representing the input to the converted ONNX graph. inputs_proto = [] # This list holds the protobuf objects of type NodeProto # representing the ops in the converted ONNX graph. ops_proto = [] # This dictionary contains a map from the name of the constant # op to the array of values it holds. This is useful because # tensorflow is less eager to know about input values at # graph construction time than ONNX. That is to say, some ONNX # attributes are input tensors in TF. This dictionary extracts # those values of constant tensors that are known at graph # construction time. consts = {} # Sometimes the constants are used as inputs to ops. This list # holds initializers that creates global constant tensors available # to be accessed by ops as inputs (as oppose to attributes which # is supplied by the `consts` map above). consts_proto = [] node_tup = [(node.name, TensorflowNode(node)) for node in graph_def.node] for name, node in node_tup: if node.op == "Placeholder": # Tensorflow requires dtype to be known. # TODO: currently `dtype` is translated to `to`. onnx_type = node.attr["dtype"] shape = node.attr["shape"] input_proto = make_tensor_value_info(name, onnx_type, shape) inputs_proto.append(input_proto) elif node.op == "Const": const_dim = len(node.attr["value"].shape) consts[name] = node.attr["value"] raw_values = ([node.attr["value"].tolist()] if const_dim == 0 else node.attr["value"].flatten().tolist()) if const_dim == 0: values = [node.attr["value"]] else: values = node.attr["value"] shape = np.array(values).shape consts_proto.append( make_tensor(name=name, data_type=node.attr["dtype"], dims=shape, vals=raw_values)) input_proto = make_tensor_value_info(name, node.attr["dtype"], shape) inputs_proto.append(input_proto) else: splitted_op_name = node.op.split(".") op_domain = "" if len(splitted_op_name) == 1 else ".".join( splitted_op_name[:-1]) op_name = splitted_op_name[-1] handler_name = "handle_" + op_name_to_lower(op_name) # TODO per domain frontend_tf_opset_version? versions = frontend_tf_opset_version[op_name_to_lower(op_name)] opset_dict = {} onnx_domain = defs.ONNX_DOMAIN for domain, version in opset: if domain == "ai.onnx": domain = "" opset_dict[domain] = version defs.ONNX_DOMAIN = domain assert isinstance( version, int ) and (version <= defs.onnx_opset_version()) and ( version >= 0 ), "Opset should be an int less than or equal to {}, but {}: {}".format( defs.onnx_opset_version(), type(version), version) defs.ONNX_DOMAIN = onnx_domain opset_ver = opset_dict[op_domain] if opset_ver == 0: version = max(versions) else: versions = sorted(versions + [opset_ver]) version = versions[max( [i for i, v in enumerate(versions) if v == opset_ver]) - 1] camel_domain = "".join(w.title() for w in op_domain.split(".")) frontend_ver = "frontend_v{}".format(version) frontend_class_name = "{}TensorflowFrontend".format( camel_domain) frontend_module = cls.frontend_version_cache.setdefault( frontend_ver, importlib.import_module("onnx_tf.frontends." + frontend_ver)) if hasattr(frontend_module, frontend_class_name): frontend = getattr(frontend_module, frontend_class_name) else: assert NotImplementedError, \ "{} for domain {} is not implemented".format(frontend_ver, op_domain) # Check if specialized handler exists. if hasattr(frontend, handler_name): method_to_call = getattr(frontend, handler_name) node = method_to_call(node, consts=consts, node_dict=dict(node_tup)) if isinstance(node, list): ops_proto.extend(node) else: ops_proto.append(node) elif node.op in TF_OP_STR_TO_ONNX_OP.keys(): # Remove tensorflow-specific attrs that are not # needed/allowed in ONNX. attr = cls.DEFAULT_TF_ATTR_PER_OP.get(node.op, {}) filtered_attr = dict( filter(lambda pair: pair[0] not in TF_ATTR_TO_REMOVE, node.attr.items())) node_output = name ops_proto.append( make_node(TF_OP_STR_TO_ONNX_OP[node.op], node.inputs, [node_output], name=name, **filtered_attr)) else: raise NotImplementedError( "{} op is not implemented.".format(node.op)) output = TensorflowNode(output) # making output proto # TODO: deal with multi-output case. # TODO: default to BOOL, cf. # https://github.com/tensorflow/tensorflow/issues/14769 output_onnx_type = output.attr.get("T", TensorProto.BOOL) output_proto = [] for i in range(len(output.attr["_output_shapes"])): output_name = output.name + ":{}".format( i) if i > 0 else output.name output_proto.append( make_tensor_value_info(output_name, output_onnx_type, output.attr["_output_shapes"][i])) inputs = list( chain.from_iterable(map(lambda p: list(p.input), ops_proto))) # Remove proto in inputs_proto and consts_proto if proto is not used as input in ONNX inputs_proto = list(filter(lambda x: x.name in inputs, inputs_proto)) consts_proto = list(filter(lambda x: x.name in inputs, consts_proto)) return make_graph(ops_proto, name, inputs_proto, output_proto, consts_proto)
def tensorflow_graph_to_onnx_graph(cls, graph_def, output, name="graph"): """Function that converts a tensorflow graph to an onnx graph. Args: graph_def: Tensorflow Graph Proto object. output: A Tensorflow NodeDef object specifying which node to be taken as output of the ONNX graph. name: The name of the output ONNX Graph. Returns: The equivalent ONNX Graph Proto object. """ # This list holds the protobuf objects of type ValueInfoProto # representing the input to the converted ONNX graph. inputs_proto = [] # This list holds the protobuf objects of type NodeProto # representing the ops in the converted ONNX graph. ops_proto = [] # This dictionary contains a map from the name of the constant # op to the array of values it holds. consts = {} for node in graph_def.node: node = TensorflowNode(node) if node.op == "Placeholder": # Tensorflow requires dtype to be known. # TODO: currently `dtype` is translated to `to`. onnx_type = node.attr["to"] shape = node.attr["shape"] input_proto = make_tensor_value_info(node.name, onnx_type, shape) inputs_proto.append(input_proto) if node.op == "Const": consts[node.name] = node.attr["value"] elif node.op in TF_OP_STR_TO_ONNX_OP.keys(): # Remove tensorflow-specific attrs that are not # needed/allowed in ONNX. attr_to_remove = ["_output_shapes", "T"] node.attr = dict(filter(lambda pair: pair[0] not in attr_to_remove, node.attr.items())) node_output = node.name ops_proto.append(make_node(TF_OP_STR_TO_ONNX_OP[node.op], node.inputs, [node_output], name=node.name, **node.attr)) else: handler_name = "handle_" + op_name_to_lower(node.op) # Check if specialized handler exists. if handler_name in dir(cls): method_to_call = getattr(cls, handler_name) ops_proto.append(method_to_call(node, consts)) output = TensorflowNode(output) # making output proto # TODO: deal with multi-output case. # TODO: default to BOOL, cf. # https://github.com/tensorflow/tensorflow/issues/14769 output_onnx_type = output.attr.get("T", TensorProto.BOOL) output_proto = make_tensor_value_info(output.name, output_onnx_type, output.attr["_output_shapes"][0]) return make_graph(ops_proto, name, inputs_proto, [output_proto])
def main(): backend_opset_dict = {} frontend_opset_dict = {} frontend_tf_opset_dict = {} for schema in defs.get_all_schemas(): op_name = schema.name backend_opset_dict[op_name] = [] frontend_opset_dict[op_name] = [] version = 1 while True: try: backend = (importlib.import_module( 'backends.backend_v{}'.format(version)).TensorflowBackend) frontend = (importlib.import_module( 'frontends.frontend_v{}'.format(version)).TensorflowFrontend) except: break # Register all tf ops in ONNX_TO_HANDLER tf_op_names = [] onnx_to_handler = frontend.ONNX_TO_HANDLER.get( 'frontend_v{}'.format(version), {}) # for handler in frontend.ONNX_TO_HANDLER.values(): for handler in onnx_to_handler.values(): if isinstance(handler, list): tf_op_names.extend(list(map(op_name_to_lower, handler))) else: tf_op_names.append(op_name_to_lower(handler)) for schema in defs.get_all_schemas(): op_name = schema.name lower_op_name = op_name_to_lower(op_name) has_backend_handler = hasattr(backend, 'handle_' + lower_op_name) # Record only one version for trivial ops if has_backend_handler or (version == 1 and lower_op_name in ONNX_OP_TO_TF_OP.keys()): backend_opset_dict[op_name].append(version) # Register once if onnx op in ONNX_OP_TO_TF_OP_STR if version == 1 and schema.name in ONNX_OP_TO_TF_OP_STR and \ ONNX_OP_TO_TF_OP_STR[schema.name] not in tf_op_names: tf_op_names.append( op_name_to_lower(ONNX_OP_TO_TF_OP_STR[schema.name])) frontend_opset_dict[op_name].append(version) # Register if onnx op in ONNX_TO_HANDLER elif op_name in onnx_to_handler: frontend_opset_dict[op_name].append(version) for tf_op_name in tf_op_names: frontend_tf_opset_dict.setdefault(str(tf_op_name), []).append(version) version += 1 with open('opset_version.py', 'w') as version_file: pp = pprint.PrettyPrinter(indent=4) version_file.write("backend_opset_version = {\n " + pp.pformat(backend_opset_dict)[1:-1] + "\n}\n\n") version_file.write("frontend_opset_version = {\n " + pp.pformat(frontend_opset_dict)[1:-1] + "\n}\n\n") version_file.write("frontend_tf_opset_version = {\n " + pp.pformat(frontend_tf_opset_dict)[1:-1] + "\n}\n")