def test_remove_input(self): model_proto = self.sample_net() nodes = model_proto.node g = Graph(nodes, output_shapes={}, dtypes={}) n4 = g.get_node_by_name("n4") g.remove_input(n4, n4.input[1]) result = onnx_to_graphviz(g) expected = 'digraph { n1 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] n4 [op_type=Add] ' \ 'n5 [op_type=Abs] n6 [op_type=Identity] input -> n1 n1:0 -> n2 n1:0 -> n3 n2:0 -> n4 ' \ 'n4:0 -> n5 n5:0 -> n6 }' self.assertEqual(expected, result)
def graphs_from_tf(tf_graph, input_names, output_names, shape_override=None, const_node_values=None, ignore_default=None, use_default=None): """make tf2onnx internal subgraphs from the tensorflow subgraphs""" if shape_override is None: shape_override = {} ordered_func = resolve_functions(tf_graph) subgraphs = [] for func in ordered_func: f_inputs_names = [t.name for t in func.inputs] f_output_names = [t.name for t in func.outputs] outputs_to_values, _ = compute_const_folding_using_tf( func, const_node_values, output_names) onnx_nodes, _, _, output_shapes, dtypes, _ = \ tensorflow_to_onnx(func, shape_override, const_node_values, ignore_default, use_default) fg = Graph(onnx_nodes, output_shapes, dtypes, input_names=f_inputs_names, output_names=f_output_names, is_subgraph=True, graph_name=func.name) fold_constants_using_tf(fg, outputs_to_values) subgraphs.append(fg) is_func = is_function(tf_graph) if not is_func: tf_graph = infer_shape(tf_graph, shape_override) outputs_to_values, _ = compute_const_folding_using_tf( tf_graph, const_node_values, output_names) onnx_nodes, _, _, output_shapes, dtypes, _ = \ tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default) utils.check_io(input_names, output_names, output_shapes.keys()) main_g = Graph(onnx_nodes, output_shapes, dtypes, input_names=input_names, output_names=output_names) fold_constants_using_tf(main_g, outputs_to_values) return main_g, subgraphs
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph continue_on_error: if an op can't be processed (aka there is no mapping), continue verbose: print summary stats target: list of workarounds applied to help certain platforms opset: the opset to be used (int, default is latest) custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters Return: onnx graph """ def topological_sort(ops): if not continue_on_error: g.topological_sort(ops) else: try: g.topological_sort(ops) except: # if we continue on error, ignore graph cycles so we can report all missing ops pass if target is None: target = DEFAULT_TARGET onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(tf_graph) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset) ops = g.get_nodes() # rewrite graph rewriters = [rewrite_transpose, rewrite_flatten, rewrite_random_uniform, rewrite_random_normal, rewrite_dropout] if custom_rewriter is not None: rewriters.extend(custom_rewriter) for rewrite in rewriters: ops = rewrite(g, ops) g.set_nodes(ops) topological_sort(g.get_nodes()) if custom_op_handlers is None: custom_op_handlers = {} mapped_op, unmapped_op = tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers) topological_sort(g.get_nodes()) g.update_proto() if verbose: print("tensorflow ops: {}".format(op_cnt)) print("tensorflow attr: {}".format(attr_cnt)) print("onnx mapped: {}".format(mapped_op)) print("onnx unmapped: {}".format(unmapped_op)) return g
def test_insert_node1(self): model_proto = self.sample_net() nodes = model_proto.node g = Graph(nodes, output_shapes={}, dtypes={}) n2 = g.get_node_by_name("n2") n7 = g.insert_new_node_on_input(n2, "Abs", "n1:0", name="n7") ops = g.get_nodes() ops.append(n7) g.topological_sort(ops) result = onnx_to_graphviz(g) expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \ 'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \ 'input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }' self.assertEqual(expected, result)
def graphs_from_tflite(tflite_path, input_names=None, output_names=None): """ Given the path to a tflite model, returns a tuple (main_graph, subgraphs) of graph.py Graph objects inputs/outputs will be taken from main graph in model if not overridden """ tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model( tflite_path) main_g = None subgraphs = [] for i, tfl_graph in enumerate(tflite_graphs): is_main_g = i == len(tflite_graphs) - 1 prefix = '' if is_main_g else tfl_graph.Name().decode() + '_' tensor_shapes_from_interpreter = None if is_main_g: tensor_shapes_from_interpreter = tensor_shapes onnx_nodes, _, _, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \ parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter) g_inputs = f_inputs g_outputs = f_outputs if is_main_g: # Override IO in main graph utils.check_io(input_names, output_names, output_shapes.keys()) if input_names is not None: g_inputs = input_names if output_names is not None: g_outputs = output_names g = Graph(onnx_nodes, output_shapes, dtypes, input_names=g_inputs, output_names=g_outputs, is_subgraph=not is_main_g, graph_name=graph_name) if is_main_g: main_g = g else: subgraphs.append(g) return main_g, subgraphs
def load_graph(fname): with open(fname, "rb") as f: data = f.read() model_proto = onnx.ModelProto() model_proto.ParseFromString(data) onnx_nodes = model_proto.graph.node output_names = [] # some pytorch model had empty names - make one up for node in onnx_nodes: if not node.name: node.name = tf2onnx.utils.make_name("was_empty") g = Graph(onnx_nodes, output_shapes={}, dtypes={}, output_names=output_names) for i in model_proto.graph.initializer: v = numpy_helper.to_array(i) name = i.name g.initializers[name] = i dtype = i.data_type g.set_dtype(name, dtype) g.set_shape(name, v.shape) for i in model_proto.graph.input: name = i.name if name in g.initializers: # ignore if it is not a model input continue shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1 for j in i.type.tensor_type.shape.dim] dtype = i.type.tensor_type.elem_type g.set_dtype(name, dtype) g.set_shape(name, shape) g.add_graph_input(name, dtype, shape) for i in model_proto.graph.output: name = i.name shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1 for j in i.type.tensor_type.shape.dim] dtype = i.type.tensor_type.elem_type g.set_dtype(name, dtype) g.set_shape(name, shape) output_names.append(name) # TODO: this is a hack in case a output name does not follow tensorflow convention for node in g.get_nodes(): for name in node.output: g._nodes_by_name[name] = node # pylint: disable=protected-access return g, model_proto.producer_name
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None, extra_opset=None, shape_override=None, inputs_as_nchw=None, input_names=None, output_names=None, is_subgraph=False): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph continue_on_error: if an op can't be processed (aka there is no mapping), continue verbose: print summary stats (deprecated) target: list of workarounds applied to help certain platforms opset: the opset to be used (int, default is latest) custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nchw to nchw input_names: list of input node names in graph, input name format as node_name:port_id output_names: list of output node names in graph, output name format as node_name:port_id Return: onnx graph """ if verbose: logger.warning("Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead.") del verbose opset = utils.find_opset(opset) if not is_subgraph: logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s", get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6]) logger.info("Using opset <onnx, %s>", opset) if opset > schemas.get_max_supported_opset_version(): logger.warning("Currently installed onnx package %s is too low to support opset %s, " "please upgrade onnx package to avoid potential conversion issue.", utils.get_onnx_version(), opset) is_func = is_function(tf_graph) if not is_func: tf_graph = infer_shape(tf_graph, shape_override) if shape_override is None: shape_override = {} if inputs_as_nchw is None: inputs_as_nchw = [] if target is None: target = constants.DEFAULT_TARGET onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = tensorflow_to_onnx(tf_graph, shape_override) if not is_subgraph: # make tf2onnx internal subgraphs from the tensorflow subgraphs ordered_func = resolve_functions(tf_graph) for func in ordered_func: f_inputs_names = [t.name for t in func.inputs] f_output_names = [t.name for t in func.outputs] fg = process_tf_graph(func, continue_on_error, False, target, opset, custom_op_handlers, custom_rewriter, extra_opset, shape_override, inputs_as_nchw, f_inputs_names, f_output_names, is_subgraph=True) fg.graph_name = func.name fg.func_inputs = f_inputs_names set_function(func.name, fg) io_to_check = [] if input_names: io_to_check.extend(input_names) if output_names: io_to_check.extend(output_names) if io_to_check: # check output existence in case user passed in wrong output ids non_exists = set(io_to_check) - set(output_shapes.keys()) if non_exists: logger.error("\nFailed to convert: inputs/outputs specified do not exist, make sure your passed" "in format: input/output_node_name:port_id. Problematical inputs/outputs are: %s \n", non_exists) raise ValueError("Inputs/Outputs Not Found") g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, output_names, is_subgraph=is_subgraph) # create ops mapping for the desired opsets ops_mapping = handler.tf_op.create_mapping(g.opset, g.extra_opset) # apply custom ops on top of the assembled opset. We can either complement the opset # or override existing ops with a custom op. if custom_op_handlers is not None: # below is a bit tricky since there are a few api's: # 1. the future way we want custom ops to be registered with the @tf_op decorator. THose handlers will be # registered via the decorator on load of the module ... nothing is required here. # 2. the old custom op api: a dictionary of {name: (func, args[]) # We deal with this by using a compat_handler that wraps to old handler with a new style handler. # This is tempoary to give people give to move to the new api and after tf2onnx-1.5 we want to remove this custom_opset = {} for k, v in custom_op_handlers.items(): # FIXME: remove this after tf2onnx-1.5 def compat_handler(ctx, node, **kwargs): # wrap old handler name = node.name args = kwargs["args"] func = kwargs["func"] return func(ctx, node, name, args) args = v[1] kwargs = {"func": v[0]} if args: onnx_op = args[0] kwargs["onnx_op"] = onnx_op args = args[1:] kwargs["args"] = args new_handler = handler.tf_op(k, domain=constants.TENSORFLOW_OPSET.domain, kwargs=kwargs) new_handler.register_compat_handler(compat_handler, 1) custom_opset[k] = (compat_handler, kwargs) ops_mapping.update(custom_opset) if inputs_as_nchw: transpose_inputs(g, inputs_as_nchw) # pre-processing graph rewrites # bi-directional re-writer should be placed after single directional re-writer rewriters = [rewrite_transpose, rewrite_flatten, rewrite_gemm, rewrite_random_uniform, rewrite_random_uniform_fold_const, rewrite_random_normal, rewrite_dropout, rewrite_eye, rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad, rewrite_single_direction_lstm, rewrite_bi_direction_lstm, rewrite_single_direction_gru, rewrite_bi_direction_gru, rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond, rewrite_biasadd_with_conv2d, ] if custom_rewriter is not None: rewriters.extend(custom_rewriter) run_rewriters(g, rewriters, continue_on_error) # some nodes may already copied into inner Graph, so remove them from main Graph. g.delete_unused_nodes(output_names) topological_sort(g, continue_on_error) mapped_op, unmapped_op, exceptions = tensorflow_onnx_mapping(g, ops_mapping) if unmapped_op: logger.error("Unsupported ops: %s", unmapped_op) if exceptions and not continue_on_error: raise exceptions[0] # post-processing rewriters late_rewriters = [] if constants.TARGET_RS5 in target: late_rewriters.append(rewrite_incomplete_type_support_rs5) if constants.TARGET_RS6 in target: late_rewriters.append(rewrite_incomplete_type_support_rs6) if late_rewriters: run_rewriters(g, late_rewriters, continue_on_error) # onnx requires topological sorting topological_sort(g, continue_on_error) g.update_proto() logger.verbose( "Summay Stats:\n" "\ttensorflow ops: {}\n" "\ttensorflow attr: {}\n" "\tonnx mapped: {}\n" "\tonnx unmapped: {}".format(op_cnt, attr_cnt, mapped_op, unmapped_op)) return g
def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=None, shape_override=None, ignore_default=None, use_default=None): """Creates an onnx graph from the provided tfjs nodes""" if shape_override is None: shape_override = {} onnx_nodes = [] output_shapes = {} tf_dtypes = {} op_info = {} graph_name = 'tfjs_model' func_name = None def update_shapes(new_shapes): if isinstance(new_shapes, dict): new_shapes = new_shapes.items() for k, v in new_shapes: output_shapes[k] = shape_override.get(k, v) if func is not None: tf_dtypes, fn_input_shapes, graph_inputs, graph_outputs, func_name = read_tfjs_function( func) update_shapes(fn_input_shapes) graph_name = func_name for inp in graph_inputs: onnx_nodes.append( helper.make_node("Placeholder", [], outputs=[inp], name=inp)) if graph_inputs is None: placeholder_ops = [ "Placeholder", "PlaceholderWithDefault", "PlaceholderV2" ] graph_inputs = [ n['name'] + ':0' for n in nodes if n['op'] in placeholder_ops ] for node in nodes: if node['op'] == "NextIteration": # NextIteration nodes can violate the topological sort with cyclic dependencies, so we do them first. node_name = node['name'] output_name = node_name + ':0' output_shapes[output_name] = None tf_dtypes[output_name] = read_tfjs_attr(node['attr']['T'], tf_dtypes=True) op_info[node_name] = (node['op'], { 'dtype': tf_dtypes[output_name] }, [tf_dtypes[output_name]]) for node in nodes: op_type = node['op'] node_name = node['name'] if op_type == "Const": np_arr = weights[node_name] out_name = node_name + ':0' tf_dtype = read_tfjs_attr(node['attr']['dtype'], tf_dtypes=True) onnx_dtype = tf_utils.map_tf_dtype(tf_dtype) # The dtype of a Const in tfjs can differ from that of the weight used to get its value np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype) onnx_tensor = numpy_helper.from_array(np_arr.astype(np_dtype), out_name) onnx_node = helper.make_node("Const", [], outputs=[out_name], name=node_name, value=onnx_tensor) onnx_nodes.append(onnx_node) output_shapes[out_name] = shape_override.get( out_name, list(np_arr.shape)) tf_dtypes[out_name] = tf_dtype op_info[node_name] = (op_type, {'dtype': tf_dtypes[out_name]}, []) continue tf_attr = {} onnx_attr = {} fix_string_attr(node) node_def = tfjs_node_to_tf_node_def(node) for k, v in node.get('attr', {}).items(): tf_attr[k] = read_tfjs_attr(v, tf_dtypes=True) if k in tf_utils.TF_IGNORED_NODE_ATTRS: continue if k == 'DstT': k = 'to' onnx_attr[k] = read_tfjs_attr(v) if op_type == "FusedDepthwiseConv2dNative": # This op isn't in tensorflow but can be converted to a TF op op_type = "_FusedDepthwiseConv2dNative" err_msg = "explicit_paddings for supported for _FusedDepthwiseConv2dNative" utils.make_sure(len(tf_attr['explicit_paddings']) == 0, err_msg) del tf_attr['explicit_paddings'] del onnx_attr['explicit_paddings'] del node_def.attr['explicit_paddings'] node_def.op = op_type input_names = [ inp for inp in node.get('input', []) if not inp.startswith('^') ] input_names = [ resolve_output(inp, op_info, func_name) for inp in input_names ] inp_dtypes = [tf_dtypes[inp] for inp in input_names] inp_shapes = [output_shapes[inp] for inp in input_names] inp_consts = [weights.get(inp.split(':')[0]) for inp in input_names] out_dtypes = get_output_dtypes(op_type, tf_attr, inp_dtypes) out_shapes = get_output_shapes(node_def, inp_dtypes, inp_shapes, inp_consts) op_info[node_name] = (op_type, tf_attr, inp_dtypes) output_names = [ node_name + ":" + str(i) for i in range(len(out_dtypes)) ] tf_dtypes.update(zip(output_names, out_dtypes)) update_shapes(zip(output_names, out_shapes)) if op_type == "PlaceholderWithDefault": remove = False if ignore_default and node_name in ignore_default: op_type = 'Placeholder' input_names = [] elif use_default and node_name in use_default: remove = True elif node_name.endswith('keras_learning_phase'): logger.warning( "Removing optional input %s that appears to be a keras learning phase parameter. " "Use --ignore_default to force this into an input.", node_name) remove = True if remove: op_type = 'Identity' graph_inputs = [ inp for inp in graph_inputs if inp != node_name + ":0" ] onnx_node = helper.make_node(op_type, input_names, output_names, name=node_name, **onnx_attr) onnx_nodes.append(onnx_node) dtypes = {k: tf_utils.map_tf_dtype(v) for k, v in tf_dtypes.items()} if graph_outputs is None: output_to_node = { out: node.name for node in onnx_nodes for out in node.output } node_to_outputs = {node.name: list(node.output) for node in onnx_nodes} used_nodes = set(output_to_node[out] for node in onnx_nodes for out in node.input) unused_nodes = [ node for node in onnx_nodes if node.name not in used_nodes ] graph_outputs = [ out for node in unused_nodes for out in node_to_outputs[node.name] ] graph_outputs_mapped = [ resolve_output(out, op_info, func_name) for out in graph_outputs ] g = Graph(onnx_nodes, output_shapes, dtypes, input_names=graph_inputs, output_names=graph_outputs_mapped, is_subgraph=func is not None, graph_name=graph_name) g.rename_tensors(dict(zip(graph_outputs_mapped, graph_outputs))) return g
def process_tf_graph(graph, continue_on_error=False, verbose=False, target=None, opset=0): """Convert tensorflow graph to onnx graph.""" if target is None: target = DEFAULT_TARGET onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(graph) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset) ops = g.get_nodes() # rewrites for rewrite in [rewrite_flatten, rewrite_random_uniform, rewrite_random_normal, rewrite_dropout, rewrite_transpose]: ops = rewrite(g, ops) g.set_nodes(ops) g.topological_sort(g.get_nodes()) mapped_op, unmapped_op = tensorflow_onnx_mapping(g, continue_on_error) g.topological_sort(g.get_nodes()) g.update_proto() if verbose: print("tensorflow ops: {}".format(op_cnt)) print("tensorflow attr: {}".format(attr_cnt)) print("onnx mapped: {}".format(mapped_op)) print("onnx unmapped: {}".format(unmapped_op)) return g
def rewrite(self): log.debug("enter custom rnn late rewriter") nodes = self.g.get_nodes() nodes_to_remove = [] for scan_node in nodes: if scan_node.type != "Scan": continue log.debug("late write for scan node %s", scan_node.name) num_scan_inputs = scan_node.get_attr("num_scan_inputs").i if not BodyGraphDict.has_body_graph_info(scan_node.name): continue body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name) onnx_nodes, _ = LoopRewriterBase.find_subgraph( body_graph_meta, self.g) nodes_to_remove.extend(onnx_nodes) log.debug("start creating body graph for scan node %s ", scan_node.name) body_graph_initializers = {} const_nodes = [ n for n in onnx_nodes if n.type in ("Const", "ConstV2") ] for n in const_nodes: # when set nodes, Const should be removed, they need be replaced as initializers. body_graph_initializers[n.output[0]] = self.g.initializers[ n.output[0]] onnx_nodes.remove(n) onnx_nodes = set(onnx_nodes) ops = [] for op in onnx_nodes: onnx_op = op.op ops.append(onnx_op) body_g = Graph(ops, output_shapes=self.g._output_shapes, dtypes=self.g._dtypes) body_g._initializers = body_graph_initializers log.debug("start preparing body graph inputs nodes") temp_nodes = body_g.get_nodes() i = 0 input_count = len(body_graph_meta.input_ids) for input_name, init_input_id in zip( body_graph_meta.input_ids, body_graph_meta.initial_input_ids): shape = body_g.get_shape(input_name) dtype = body_g.get_dtype(input_name) if shape is None: shape = self.g.get_shape(init_input_id) if i >= input_count - num_scan_inputs: loop_input_shape = list(shape)[2:] # delete [1, time,] else: loop_input_shape = list(shape) else: loop_input_shape = list(shape) onnx_input_shape = utils.make_onnx_shape(loop_input_shape) val = helper.make_tensor_value_info(input_name, dtype, onnx_input_shape) body_g.add_model_input(input_name, val) i += 1 log.debug("start preparing body graph outputs nodes") new_output_names = [] for o in body_graph_meta.output_ids: # insert identity node, since sometimes we need output same output_id as state_output # and scan_out, but ONNX don't allow the same output_id appeared more than once as # output node. identity_name = utils.make_name("Identity") identity_output = utils.port_name(identity_name) node = Node( helper.make_node("Identity", [o], [identity_output], name=identity_name), body_g) body_g.set_dtype(identity_output, body_g.get_dtype(o)) body_g.copy_shape(o, identity_output) new_output_names.append(identity_output) temp_nodes.append(node) body_g.set_nodes(temp_nodes) body_g.topological_sort(body_g.get_nodes()) log.debug("start make graph based on body graph nodes") body_g.output_names = new_output_names graph = body_g.make_graph("scan body graph") scan_node.set_attr("body", graph) # remove nodes in body graph from g for n in set(nodes_to_remove): if n in nodes: nodes.remove(n) elif self.g.is_initializer(n.output[0]): del self.g.initializers[n.output[0]] else: raise ValueError("error when removing nodes") return nodes
def _create_empty_graph(self, inputs, shapes, dtypes): graph = Graph([], target=self.config.target, opset=self.config.opset) for inp, shape, dtype in zip(inputs, shapes, dtypes): graph.add_graph_input(inp, dtype, shape) return graph
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None, extra_opset=None, shape_override=None, inputs_as_nchw=None, input_names=None, output_names=None, ignore_default=None, use_default=None, is_subgraph=False, const_node_values=None, tensors_to_rename=None, initialized_tables=None, tflite_path=None, dequantize=False): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph continue_on_error: if an op can't be processed (aka there is no mapping), continue verbose: print summary stats (deprecated) target: list of workarounds applied to help certain platforms opset: the opset to be used (int, default is latest) custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nchw to nhwc input_names: list of input node names in graph, input name format as node_name:port_id. Optional. output_names: list of output node names in graph, format is node_name:port_id. Optional for tflite. ignore_default: list of node names of PlaceholderWithDefault ops to change into Placeholder ops use_default: list of node names of PlaceholderWithDefault ops to change into Identity ops using the default const_node_values: a dict returned by compress_graph_def mapping node names to tensor values tensors_to_rename: an optional dict (string->string) mapping tensor names to new names initialized_tables: mapping from table shared_names to tuple of keys and values of table tflite_path: Path to a tflite file to convert. If used, pass None to tf_graph Return: onnx graph """ # NOTE: process_parsed_graph and Graph are always given tensors post-rename. # process_tf_graph (this function) gets tensors pre-rename. if verbose: logger.warning( "Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead." ) del verbose opset = utils.find_opset(opset) if not is_subgraph: logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s", get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6]) logger.info("Using opset <onnx, %s>", opset) if opset > schemas.get_max_supported_opset_version(): logger.warning( "Currently installed onnx package %s is too low to support opset %s, " "please upgrade onnx package to avoid potential conversion issue.", utils.get_onnx_version(), opset) if shape_override is None: shape_override = {} if inputs_as_nchw is None: inputs_as_nchw = [] if target is None: target = constants.DEFAULT_TARGET def check_io(input_names, output_names, output_shapes): io_to_check = [] if input_names: io_to_check.extend(input_names) if output_names: io_to_check.extend(output_names) if io_to_check: # check output existence in case user passed in wrong output ids non_exists = set(io_to_check) - set(output_shapes.keys()) if non_exists: logger.error( "\nFailed to convert: inputs/outputs specified do not exist, make sure your passed" "in format: input/output_node_name:port_id. Problematic inputs/outputs are: %s \n", non_exists) raise ValueError("Inputs/Outputs Not Found") def rename_tensors_in_dict(d): if tensors_to_rename is None: return d return {tensors_to_rename.get(k, k): v for k, v in d.items()} def rename_tensors_in_list(tensors): if tensors_to_rename is None or tensors is None: return tensors return [tensors_to_rename.get(t, t) for t in tensors] def rename_tensors_in_nodes(onnx_nodes): if tensors_to_rename is None: return for n in onnx_nodes: n.input[:] = rename_tensors_in_list(n.input) n.output[:] = rename_tensors_in_list(n.output) if tflite_path is not None: tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model( tflite_path) main_g = None inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) for i, tfl_graph in enumerate(tflite_graphs): is_main_g = i == len(tflite_graphs) - 1 prefix = '' if is_main_g else tfl_graph.Name().decode() + '_' tensor_shapes_from_interpreter = None if is_main_g: tensor_shapes_from_interpreter = tensor_shapes onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \ parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter) g_inputs = f_inputs g_outputs = f_outputs if is_main_g: # Override IO in main graph check_io(input_names, output_names, output_shapes) if input_names is not None: g_inputs = input_names if output_names is not None: g_outputs = output_names rename_tensors_in_nodes(onnx_nodes) g_inputs = rename_tensors_in_list(g_inputs) g_outputs = rename_tensors_in_list(g_outputs) output_shapes = rename_tensors_in_dict(output_shapes) dtypes = rename_tensors_in_dict(dtypes) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, g_inputs, g_outputs, is_subgraph) fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, g_outputs, {}, {}, {}, op_cnt, attr_cnt, is_tflite=True, dequantize=dequantize) fg.graph_name = graph_name if is_main_g: main_g = fg else: set_function(graph_name, fg) return main_g is_func = is_function(tf_graph) if not is_func: tf_graph = infer_shape(tf_graph, shape_override) outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf( tf_graph, const_node_values, output_names) onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \ tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default) if not is_subgraph: # make tf2onnx internal subgraphs from the tensorflow subgraphs ordered_func = resolve_functions(tf_graph) for func in ordered_func: f_inputs_names = [t.name for t in func.inputs] f_output_names = [t.name for t in func.outputs] fg = process_tf_graph(func, continue_on_error, False, target, opset, custom_op_handlers, custom_rewriter, extra_opset, shape_override, inputs_as_nchw, f_inputs_names, f_output_names, is_subgraph=True, const_node_values=const_node_values, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables) fg.graph_name = func.name set_function(func.name, fg) check_io(input_names, output_names, output_shapes) if not is_subgraph: rename_tensors_in_nodes(onnx_nodes) input_names = rename_tensors_in_list(input_names) output_names = rename_tensors_in_list(output_names) output_shapes = rename_tensors_in_dict(output_shapes) dtypes = rename_tensors_in_dict(dtypes) inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, input_names, output_names, is_subgraph) g = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt) return g