def test_parse_tflite_graph(self): def func(a, b, c): alpha = tf.constant(1.1, dtype=tf.float32) beta = tf.constant(2.3, dtype=tf.float32) mul1 = tf.multiply(alpha, tf.matmul(a, b)) mul2 = tf.multiply(beta, c) x_ = mul1 + mul2 return tf.identity(x_, name="output") inp_shapes = [[2, 3], [3, 1], [2, 1]] inp_dtypes = [tf.float32, tf.float32, tf.float32] names = ['a', 'b', 'c'] names_with_port = ['a:0', 'b:0', 'c:0'] output_names = ['output'] output_names_with_port = ['output:0'] input_tensors = [tf.TensorSpec(shape=s, dtype=d, name=n) for s, d, n in zip(inp_shapes, inp_dtypes, names)] concrete_func = tf.function(func, input_signature=tuple(input_tensors)) concrete_func = concrete_func.get_concrete_function() graph_def = from_function(concrete_func, input_names=names_with_port, output_names=output_names_with_port) with tf_session() as sess: tf.import_graph_def(graph_def, name='') sess_inputs = [sess.graph.get_tensor_by_name(k) for k in names_with_port] sess_outputs = [sess.graph.get_tensor_by_name(n) for n in output_names_with_port] converter = tf.compat.v1.lite.TFLiteConverter.from_session(sess, sess_inputs, sess_outputs) tflite_model = converter.convert() tflite_path = os.path.join(self.test_data_directory, self._testMethodName + ".tflite") dir_name = os.path.dirname(tflite_path) tflite_model = converter.convert() os.makedirs(dir_name, exist_ok=True) with open(tflite_path, 'wb') as f: f.write(tflite_model) tflite_graphs, opcodes_map, model, tensor_shapes = read_tflite_model(tflite_path) self.assertEqual(1, len(tflite_graphs)) onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, _ = \ parse_tflite_graph(tflite_graphs[0], opcodes_map, model, tensor_shapes_override=tensor_shapes) self.assertEqual(2, op_cnt['MUL']) self.assertEqual(1, op_cnt['ADD']) self.assertEqual(1, op_cnt['FULLY_CONNECTED']) self.assertEqual(1, attr_cnt['WeightsFormat']) self.assertEqual(names, inputs) self.assertEqual(output_names, outputs) for name, shape, dtype in zip(names, inp_shapes, inp_dtypes): self.assertEqual(shape, output_shapes[name]) self.assertEqual(dtype, dtypes[name]) self.assertTrue(len(onnx_nodes) >= 4)
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None, extra_opset=None, shape_override=None, inputs_as_nchw=None, input_names=None, output_names=None, ignore_default=None, use_default=None, is_subgraph=False, const_node_values=None, tensors_to_rename=None, initialized_tables=None, tflite_path=None, dequantize=False): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph continue_on_error: if an op can't be processed (aka there is no mapping), continue verbose: print summary stats (deprecated) target: list of workarounds applied to help certain platforms opset: the opset to be used (int, default is latest) custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nchw to nhwc input_names: list of input node names in graph, input name format as node_name:port_id. Optional. output_names: list of output node names in graph, format is node_name:port_id. Optional for tflite. ignore_default: list of node names of PlaceholderWithDefault ops to change into Placeholder ops use_default: list of node names of PlaceholderWithDefault ops to change into Identity ops using the default const_node_values: a dict returned by compress_graph_def mapping node names to tensor values tensors_to_rename: an optional dict (string->string) mapping tensor names to new names initialized_tables: mapping from table shared_names to tuple of keys and values of table tflite_path: Path to a tflite file to convert. If used, pass None to tf_graph Return: onnx graph """ # NOTE: process_parsed_graph and Graph are always given tensors post-rename. # process_tf_graph (this function) gets tensors pre-rename. if verbose: logger.warning( "Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead." ) del verbose opset = utils.find_opset(opset) if not is_subgraph: logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s", get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6]) logger.info("Using opset <onnx, %s>", opset) if opset > schemas.get_max_supported_opset_version(): logger.warning( "Currently installed onnx package %s is too low to support opset %s, " "please upgrade onnx package to avoid potential conversion issue.", utils.get_onnx_version(), opset) if shape_override is None: shape_override = {} if inputs_as_nchw is None: inputs_as_nchw = [] if target is None: target = constants.DEFAULT_TARGET def check_io(input_names, output_names, output_shapes): io_to_check = [] if input_names: io_to_check.extend(input_names) if output_names: io_to_check.extend(output_names) if io_to_check: # check output existence in case user passed in wrong output ids non_exists = set(io_to_check) - set(output_shapes.keys()) if non_exists: logger.error( "\nFailed to convert: inputs/outputs specified do not exist, make sure your passed" "in format: input/output_node_name:port_id. Problematic inputs/outputs are: %s \n", non_exists) raise ValueError("Inputs/Outputs Not Found") def rename_tensors_in_dict(d): if tensors_to_rename is None: return d return {tensors_to_rename.get(k, k): v for k, v in d.items()} def rename_tensors_in_list(tensors): if tensors_to_rename is None or tensors is None: return tensors return [tensors_to_rename.get(t, t) for t in tensors] def rename_tensors_in_nodes(onnx_nodes): if tensors_to_rename is None: return for n in onnx_nodes: n.input[:] = rename_tensors_in_list(n.input) n.output[:] = rename_tensors_in_list(n.output) if tflite_path is not None: tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model( tflite_path) main_g = None inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) for i, tfl_graph in enumerate(tflite_graphs): is_main_g = i == len(tflite_graphs) - 1 prefix = '' if is_main_g else tfl_graph.Name().decode() + '_' tensor_shapes_from_interpreter = None if is_main_g: tensor_shapes_from_interpreter = tensor_shapes onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \ parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter) g_inputs = f_inputs g_outputs = f_outputs if is_main_g: # Override IO in main graph check_io(input_names, output_names, output_shapes) if input_names is not None: g_inputs = input_names if output_names is not None: g_outputs = output_names rename_tensors_in_nodes(onnx_nodes) g_inputs = rename_tensors_in_list(g_inputs) g_outputs = rename_tensors_in_list(g_outputs) output_shapes = rename_tensors_in_dict(output_shapes) dtypes = rename_tensors_in_dict(dtypes) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, g_inputs, g_outputs, is_subgraph) fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, g_outputs, {}, {}, {}, op_cnt, attr_cnt, is_tflite=True, dequantize=dequantize) fg.graph_name = graph_name if is_main_g: main_g = fg else: set_function(graph_name, fg) return main_g is_func = is_function(tf_graph) if not is_func: tf_graph = infer_shape(tf_graph, shape_override) outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf( tf_graph, const_node_values, output_names) onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \ tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default) if not is_subgraph: # make tf2onnx internal subgraphs from the tensorflow subgraphs ordered_func = resolve_functions(tf_graph) for func in ordered_func: f_inputs_names = [t.name for t in func.inputs] f_output_names = [t.name for t in func.outputs] fg = process_tf_graph(func, continue_on_error, False, target, opset, custom_op_handlers, custom_rewriter, extra_opset, shape_override, inputs_as_nchw, f_inputs_names, f_output_names, is_subgraph=True, const_node_values=const_node_values, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables) fg.graph_name = func.name set_function(func.name, fg) check_io(input_names, output_names, output_shapes) if not is_subgraph: rename_tensors_in_nodes(onnx_nodes) input_names = rename_tensors_in_list(input_names) output_names = rename_tensors_in_list(output_names) output_shapes = rename_tensors_in_dict(output_shapes) dtypes = rename_tensors_in_dict(dtypes) inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, input_names, output_names, is_subgraph) g = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt) return g