def get_all_frontend_handlers(opset_dict): """ Get a dict of all frontend handler classes. e.g. {'domain': {'Abs': Abs handler class}, ...}, }. :param opset_dict: A dict of opset. e.g. {'domain': version, ...} :return: Dict. """ handlers = {} for handler in FrontendHandler.__subclasses__(): handler.check_cls() domain = handler.DOMAIN version = opset_dict[domain] handler.VERSION = version since_version = 1 if handler.ONNX_OP and defs.has(handler.ONNX_OP, domain=handler.DOMAIN): since_version = defs.get_schema( handler.ONNX_OP, domain=handler.DOMAIN, max_inclusive_version=version).since_version else: warnings.warn("Unknown op {} in domain `{}`. " "Can't check specification by ONNX. " "Please set should_check flag to False " "when call make_node method in handler.".format( handler.ONNX_OP or "Undefined", handler.DOMAIN or "ai.onnx")) handler.SINCE_VERSION = since_version for tf_op in handler.TF_OP: handlers.setdefault(domain, {})[tf_op] = handler return handlers
def get_frontend_coverage(): """ Get frontend coverage for document. :return: dict of frontend coverages onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...} tf_coverage: e.g. {'domain': {'TF_OP': [versions], ...}, ...} experimental_op: e.g. {'ONNX_OP'...} """ tf_coverage = {} onnx_coverage = {} experimental_op = set() for handler in FrontendHandler.__subclasses__(): handler.check_cls() versions = handler.get_versions() domain = handler.DOMAIN for tf_op in handler.TF_OP: _update_coverage(tf_coverage, domain, tf_op, versions) if handler.ONNX_OP: onnx_op = handler.ONNX_OP if getattr(handler, "EXPERIMENTAL", False): experimental_op.add(handler.ONNX_OP) _update_coverage(onnx_coverage, domain, onnx_op, versions) return dict( onnx_coverage=onnx_coverage, tf_coverage=tf_coverage, experimental_op=experimental_op)
def get_frontend_coverage(): """ Get frontend coverage for document. :return: onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...} tf_coverage: e.g. {'domain': {'TF_OP': [versions], ...}, ...} """ tf_coverage = {} onnx_coverage = {} for handler in FrontendHandler.__subclasses__(): handler.check_cls() versions = handler.get_versions() domain = handler.DOMAIN for tf_op in handler.TF_OP: _update_coverage(tf_coverage, domain, tf_op, versions) if handler.ONNX_OP: _update_coverage(onnx_coverage, domain, handler.ONNX_OP, versions) return onnx_coverage, tf_coverage
def tensorflow_graph_to_onnx_graph(cls, graph_def, output, opset=((defs.ONNX_DOMAIN, defs.onnx_opset_version()), ), name="graph", ignore_unimplemented=False): """Converts a Tensorflow Graph Proto to an ONNX graph This function converts a Tensorflow Graph proto to an equivalent representation of ONNX graph. :param graph_def: Tensorflow Graph Proto object. :param output: List of Tensorflow NodeDef object specifying which nodes to be taken as outputs of the ONNX graph. :param opset: Opset, which should be ((str domain: int version number),). :param name: The name of the output ONNX Graph. :param ignore_unimplemented: Convert to ONNX model and ignore all the operators that are not currently supported by onnx-tensorflow. This is an experimental feature. By enabling this feature, the graph would not be guaranteed to match the ONNX specifications. :returns: The equivalent ONNX Graph Proto object. """ onnx_graph = OnnxGraph(name) exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented opset_dict = {} for domain, version in opset: if domain == "ai.onnx": domain = defs.ONNX_DOMAIN opset_dict[domain] = version handlers = get_all_frontend_handlers(opset_dict) node_tup = [(node.name, TensorflowNode(node)) for node in graph_def.node] for name, node in node_tup: if node.op_type == "Placeholder": onnx_graph.add_input_proto(node) elif node.op_type == "Const": onnx_graph.add_const(node) onnx_graph.add_const_proto(node) onnx_graph.add_input_proto(node) else: onnx_graph.add_value_info_proto(node) handler = handlers.get(node.domain, {}).get(node.op_type, None) node_proto = None if handler: node_proto = handler.handle( node, consts=onnx_graph.consts, node_dict=dict(node_tup), data_type_cast_map=onnx_graph.data_type_cast_map) else: exception.OP_UNIMPLEMENTED_EXCEPT( node.op_type, domain=None if node.domain in handlers else node.domain) if node_proto is None: node_proto = FrontendHandler.make_node_from_tf_node( node, op_type=node.op_type, should_check=False) onnx_graph.add_node_proto(node_proto) for o in output: output_node = TensorflowNode(o) onnx_graph.add_output_proto(output_node) return onnx_graph.make_graph_proto()
def tensorflow_graph_to_onnx_graph(cls, tf_graph, opset=((defs.ONNX_DOMAIN, defs.onnx_opset_version()), ), ignore_unimplemented=False): """Converts a TensorflowGraph to an ONNX graph This function converts a TensorflowGraph to an equivalent representation of ONNX graph. :param tf_graph: TensorflowGraph object. :param opset: Opset, which should be ((str domain: int version number),). :param ignore_unimplemented: Convert to ONNX model and ignore all the operators that are not currently supported by onnx-tensorflow. This is an experimental feature. By enabling this feature, the graph would not be guaranteed to match the ONNX specifications. :returns: The equivalent ONNX Graph Proto object. """ onnx_graph = OnnxGraph(tf_graph.graph_name) exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented training_ops_to_remove = ["RandomShuffleQueueV2"] opset_dict = {} for domain, version in opset: if domain == "ai.onnx": domain = defs.ONNX_DOMAIN opset_dict[domain] = version handlers = get_all_frontend_handlers(opset_dict) node_tup = [(n.name, n) for n in tf_graph.nodes] for name, node in node_tup: if node.op_type == "Placeholder": onnx_graph.add_input_proto(node) elif node.op_type == "Const": onnx_graph.add_const(node) onnx_graph.add_const_proto(node) onnx_graph.add_input_proto(node) elif node.op_type in training_ops_to_remove: logger.info( "A training op with name {} type {} has been removed.". format(node.name, node.op_type)) elif node.op_type == "QueueDequeueManyV2": num_output = len(node.attr["_output_shapes"]) for index, shape, onnx_type in zip( range(num_output), node.attr["_output_shapes"], node.attr["component_types"]): onnx_graph.add_input_proto_explicit(node.name + ":" + str(index), shape, onnx_dtype=onnx_type) else: onnx_graph.add_value_info_proto(node) handler = handlers.get(node.domain, {}).get(node.op_type, None) node_proto = None if handler: node_proto = handler.handle( node, consts=onnx_graph.consts, node_dict=dict(node_tup), data_type_cast_map=onnx_graph.data_type_cast_map) else: exception.OP_UNIMPLEMENTED_EXCEPT( node.op_type, domain=None if node.domain in handlers else node.domain) if node_proto is None: node_proto = FrontendHandler.make_node_from_tf_node( node, op_type=node.op_type, should_check=False) onnx_graph.add_node_proto(node_proto) for o in tf_graph.outputs: output_node = tf_graph.get_node_by_name(o) onnx_graph.add_output_proto(output_node) return onnx_graph.make_graph_proto()