def from_function(func, input_names, output_names, large_model=False): if large_model: return convert_variables_to_constants_large_model(func) try: if get_tf_version() < LooseVersion("2.2"): frozen_func = convert_variables_to_constants_v2( func, lower_control_flow=False) else: frozen_func = convert_variables_to_constants_v2( func, lower_control_flow=False, aggressive_inlining=True) except ValueError as e: if "incompatible with expected resource" in str(e): frozen_func = convert_variables_to_constants_large_model(func) logger.warning( "TF freezing failed. Attempting to fix freezing errors.") graph_def = fix_freezing_errors(frozen_func) else: raise e else: graph_def = frozen_func.graph.as_graph_def(add_shapes=True) # output_names = [i.name for i in frozen_func.outputs] with tf.Graph().as_default() as tf_graph: with tf_session(graph=tf_graph) as sess: tf.import_graph_def(graph_def, name='') input_names = inputs_without_resource(sess, input_names) graph_def = tf_optimize(input_names, output_names, graph_def) return graph_def
def __init__(self): self.platform = sys.platform self.tf_version = tf_utils.get_tf_version() self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", constants.PREFERRED_OPSET)) self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',') self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime") self.backend_version = self._get_backend_version() self.log_level = logging.WARNING self.temp_dir = utils.get_temp_directory()
def tf_reload_graph(tf_graph): """Invoke tensorflow cpp shape inference by reloading graph_def.""" # invoke c api if tf version is below 1.8 if get_tf_version() < LooseVersion("1.8"): logger.debug("On TF < 1.8, graph is constructed by python API, " "which doesn't invoke shape inference, please set " "TF_C_API_GRAPH_CONSTRUCTION=1 to enable it") graph_def = tf_graph.as_graph_def(add_shapes=True) with tf.Graph().as_default() as inferred_graph: tf.import_graph_def(graph_def, name="") return inferred_graph
def __init__(self): self.platform = sys.platform self.tf_version = tf_utils.get_tf_version() self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", constants.PREFERRED_OPSET)) self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',') self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime") self.skip_tflite_tests = os.environ.get("TF2ONNX_SKIP_TFLITE_TESTS", "FALSE").upper() == "TRUE" self.skip_tf_tests = os.environ.get("TF2ONNX_SKIP_TF_TESTS", "FALSE").upper() == "TRUE" self.run_tfl_consistency_test = os.environ.get("TF2ONNX_RUN_TFL_CONSISTENCY_TEST", "FALSE").upper() == "TRUE" self.backend_version = self._get_backend_version() self.log_level = logging.WARNING self.temp_dir = utils.get_temp_directory()
def from_function(func, input_names, output_names, large_model=False): if large_model: return convert_variables_to_constants_large_model(func) if get_tf_version() < LooseVersion("2.2"): frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False) else: frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False, aggressive_inlining=True) graph_def = frozen_func.graph.as_graph_def(add_shapes=True) # output_names = [i.name for i in frozen_func.outputs] tf_reset_default_graph() with tf_session() as sess: tf.import_graph_def(graph_def, name='') input_names = inputs_without_resource(sess, input_names) graph_def = tf_optimize(input_names, output_names, graph_def) return graph_def
def infer_shape(tf_graph, shape_override): """Infer shape for TF graph with shape_override set first.""" if shape_override: logger.info("Apply shape override:") for name, shape in shape_override.items(): logger.info("\tSet %s shape to %s", name, shape) tf_graph.get_tensor_by_name(name).set_shape(shape) tf_graph = tf_reload_graph(tf_graph) tf_graph = infer_shape_for_graph(tf_graph) op_outputs_with_none_shape = check_shape_for_tf_graph(tf_graph) if op_outputs_with_none_shape: if get_tf_version() > LooseVersion("1.5.0"): for op, outs in op_outputs_with_none_shape.items(): logger.warning("Cannot infer shape for %s: %s", op, ",".join(outs)) tf_graph = infer_shape_for_graph_legacy(tf_graph) return tf_graph
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None, extra_opset=None, shape_override=None, inputs_as_nchw=None, input_names=None, output_names=None, is_subgraph=False): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph continue_on_error: if an op can't be processed (aka there is no mapping), continue verbose: print summary stats (deprecated) target: list of workarounds applied to help certain platforms opset: the opset to be used (int, default is latest) custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nchw to nchw input_names: list of input node names in graph, input name format as node_name:port_id output_names: list of output node names in graph, output name format as node_name:port_id Return: onnx graph """ if verbose: logger.warning("Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead.") del verbose opset = utils.find_opset(opset) if not is_subgraph: logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s", get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6]) logger.info("Using opset <onnx, %s>", opset) if opset > schemas.get_max_supported_opset_version(): logger.warning("Currently installed onnx package %s is too low to support opset %s, " "please upgrade onnx package to avoid potential conversion issue.", utils.get_onnx_version(), opset) is_func = is_function(tf_graph) if not is_func: tf_graph = infer_shape(tf_graph, shape_override) if shape_override is None: shape_override = {} if inputs_as_nchw is None: inputs_as_nchw = [] if target is None: target = constants.DEFAULT_TARGET onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = tensorflow_to_onnx(tf_graph, shape_override) if not is_subgraph: # make tf2onnx internal subgraphs from the tensorflow subgraphs ordered_func = resolve_functions(tf_graph) for func in ordered_func: f_inputs_names = [t.name for t in func.inputs] f_output_names = [t.name for t in func.outputs] fg = process_tf_graph(func, continue_on_error, False, target, opset, custom_op_handlers, custom_rewriter, extra_opset, shape_override, inputs_as_nchw, f_inputs_names, f_output_names, is_subgraph=True) fg.graph_name = func.name fg.func_inputs = f_inputs_names set_function(func.name, fg) io_to_check = [] if input_names: io_to_check.extend(input_names) if output_names: io_to_check.extend(output_names) if io_to_check: # check output existence in case user passed in wrong output ids non_exists = set(io_to_check) - set(output_shapes.keys()) if non_exists: logger.error("\nFailed to convert: inputs/outputs specified do not exist, make sure your passed" "in format: input/output_node_name:port_id. Problematical inputs/outputs are: %s \n", non_exists) raise ValueError("Inputs/Outputs Not Found") g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, output_names, is_subgraph=is_subgraph) # create ops mapping for the desired opsets ops_mapping = handler.tf_op.create_mapping(g.opset, g.extra_opset) # apply custom ops on top of the assembled opset. We can either complement the opset # or override existing ops with a custom op. if custom_op_handlers is not None: # below is a bit tricky since there are a few api's: # 1. the future way we want custom ops to be registered with the @tf_op decorator. THose handlers will be # registered via the decorator on load of the module ... nothing is required here. # 2. the old custom op api: a dictionary of {name: (func, args[]) # We deal with this by using a compat_handler that wraps to old handler with a new style handler. # This is tempoary to give people give to move to the new api and after tf2onnx-1.5 we want to remove this custom_opset = {} for k, v in custom_op_handlers.items(): # FIXME: remove this after tf2onnx-1.5 def compat_handler(ctx, node, **kwargs): # wrap old handler name = node.name args = kwargs["args"] func = kwargs["func"] return func(ctx, node, name, args) args = v[1] kwargs = {"func": v[0]} if args: onnx_op = args[0] kwargs["onnx_op"] = onnx_op args = args[1:] kwargs["args"] = args new_handler = handler.tf_op(k, domain=constants.TENSORFLOW_OPSET.domain, kwargs=kwargs) new_handler.register_compat_handler(compat_handler, 1) custom_opset[k] = (compat_handler, kwargs) ops_mapping.update(custom_opset) if inputs_as_nchw: transpose_inputs(g, inputs_as_nchw) # pre-processing graph rewrites # bi-directional re-writer should be placed after single directional re-writer rewriters = [rewrite_transpose, rewrite_flatten, rewrite_gemm, rewrite_random_uniform, rewrite_random_uniform_fold_const, rewrite_random_normal, rewrite_dropout, rewrite_eye, rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad, rewrite_single_direction_lstm, rewrite_bi_direction_lstm, rewrite_single_direction_gru, rewrite_bi_direction_gru, rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond, rewrite_biasadd_with_conv2d, ] if custom_rewriter is not None: rewriters.extend(custom_rewriter) run_rewriters(g, rewriters, continue_on_error) # some nodes may already copied into inner Graph, so remove them from main Graph. g.delete_unused_nodes(output_names) topological_sort(g, continue_on_error) mapped_op, unmapped_op, exceptions = tensorflow_onnx_mapping(g, ops_mapping) if unmapped_op: logger.error("Unsupported ops: %s", unmapped_op) if exceptions and not continue_on_error: raise exceptions[0] # post-processing rewriters late_rewriters = [] if constants.TARGET_RS5 in target: late_rewriters.append(rewrite_incomplete_type_support_rs5) if constants.TARGET_RS6 in target: late_rewriters.append(rewrite_incomplete_type_support_rs6) if late_rewriters: run_rewriters(g, late_rewriters, continue_on_error) # onnx requires topological sorting topological_sort(g, continue_on_error) g.update_proto() logger.verbose( "Summay Stats:\n" "\ttensorflow ops: {}\n" "\ttensorflow attr: {}\n" "\tonnx mapped: {}\n" "\tonnx unmapped: {}".format(op_cnt, attr_cnt, mapped_op, unmapped_op)) return g
def main(): global PERFITER args = get_args() logging.basicConfig(level=logging.get_verbosity_level(args.verbose)) if args.debug: utils.set_debug_mode(True) Test.cache_dir = args.cache Test.target = args.target tests = load_tests_from_yaml(args.config) if args.list: logger.info(sorted(tests.keys())) return 0 if args.tests: test_keys = args.tests.split(",") else: test_keys = list(tests.keys()) failed = 0 count = 0 PERFITER = args.perfiter for test in test_keys: logger.info("===================================") t = tests[test] if args.tests is None: if t.disabled and not args.include_disabled: logger.info("Skip %s: disabled", test) continue condition, reason = t.check_opset_constraints( args.opset, args.extra_opset) if not condition: logger.info("Skip %s: %s", test, reason) continue if t.tf_min_version: if tf_utils.get_tf_version() < LooseVersion( str(t.tf_min_version)): logger.info("Skip %s: %s %s", test, "Min TF version needed:", t.tf_min_version) continue count += 1 try: logger.info("Running %s", test) ret = t.run_test(test, backend=args.backend, onnx_file=args.onnx_file, opset=args.opset, extra_opset=args.extra_opset, perf=args.perf, fold_const=args.fold_const) except Exception: logger.error("Failed to run %s", test, exc_info=1) ret = None finally: if not utils.is_debug_mode(): utils.delete_directory(TEMP_DIR) if not ret: failed += 1 logger.info("===================================") logger.info("RESULT: %s failed of %s, backend=%s", failed, count, args.backend) if args.perf: with open(args.perf, "w") as f: f.write("test,tensorflow,onnx\n") for test in test_keys: t = tests[test] if t.perf: # Report perf in ms per inference f.write("{},{},{}\n".format( test, t.tf_runtime * 1000 / PERFITER, t.onnx_runtime * 1000 / PERFITER)) return failed
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None, extra_opset=None, shape_override=None, inputs_as_nchw=None, input_names=None, output_names=None, ignore_default=None, use_default=None, is_subgraph=False, const_node_values=None, tensors_to_rename=None, initialized_tables=None, tflite_path=None, dequantize=False): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph continue_on_error: if an op can't be processed (aka there is no mapping), continue verbose: print summary stats (deprecated) target: list of workarounds applied to help certain platforms opset: the opset to be used (int, default is latest) custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nchw to nhwc input_names: list of input node names in graph, input name format as node_name:port_id. Optional. output_names: list of output node names in graph, format is node_name:port_id. Optional for tflite. ignore_default: list of node names of PlaceholderWithDefault ops to change into Placeholder ops use_default: list of node names of PlaceholderWithDefault ops to change into Identity ops using the default const_node_values: a dict returned by compress_graph_def mapping node names to tensor values tensors_to_rename: an optional dict (string->string) mapping tensor names to new names initialized_tables: mapping from table shared_names to tuple of keys and values of table tflite_path: Path to a tflite file to convert. If used, pass None to tf_graph Return: onnx graph """ # NOTE: process_parsed_graph and Graph are always given tensors post-rename. # process_tf_graph (this function) gets tensors pre-rename. if verbose: logger.warning( "Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead." ) del verbose opset = utils.find_opset(opset) if not is_subgraph: logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s", get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6]) logger.info("Using opset <onnx, %s>", opset) if opset > schemas.get_max_supported_opset_version(): logger.warning( "Currently installed onnx package %s is too low to support opset %s, " "please upgrade onnx package to avoid potential conversion issue.", utils.get_onnx_version(), opset) if shape_override is None: shape_override = {} if inputs_as_nchw is None: inputs_as_nchw = [] if target is None: target = constants.DEFAULT_TARGET def check_io(input_names, output_names, output_shapes): io_to_check = [] if input_names: io_to_check.extend(input_names) if output_names: io_to_check.extend(output_names) if io_to_check: # check output existence in case user passed in wrong output ids non_exists = set(io_to_check) - set(output_shapes.keys()) if non_exists: logger.error( "\nFailed to convert: inputs/outputs specified do not exist, make sure your passed" "in format: input/output_node_name:port_id. Problematic inputs/outputs are: %s \n", non_exists) raise ValueError("Inputs/Outputs Not Found") def rename_tensors_in_dict(d): if tensors_to_rename is None: return d return {tensors_to_rename.get(k, k): v for k, v in d.items()} def rename_tensors_in_list(tensors): if tensors_to_rename is None or tensors is None: return tensors return [tensors_to_rename.get(t, t) for t in tensors] def rename_tensors_in_nodes(onnx_nodes): if tensors_to_rename is None: return for n in onnx_nodes: n.input[:] = rename_tensors_in_list(n.input) n.output[:] = rename_tensors_in_list(n.output) if tflite_path is not None: tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model( tflite_path) main_g = None inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) for i, tfl_graph in enumerate(tflite_graphs): is_main_g = i == len(tflite_graphs) - 1 prefix = '' if is_main_g else tfl_graph.Name().decode() + '_' tensor_shapes_from_interpreter = None if is_main_g: tensor_shapes_from_interpreter = tensor_shapes onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \ parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter) g_inputs = f_inputs g_outputs = f_outputs if is_main_g: # Override IO in main graph check_io(input_names, output_names, output_shapes) if input_names is not None: g_inputs = input_names if output_names is not None: g_outputs = output_names rename_tensors_in_nodes(onnx_nodes) g_inputs = rename_tensors_in_list(g_inputs) g_outputs = rename_tensors_in_list(g_outputs) output_shapes = rename_tensors_in_dict(output_shapes) dtypes = rename_tensors_in_dict(dtypes) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, g_inputs, g_outputs, is_subgraph) fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, g_outputs, {}, {}, {}, op_cnt, attr_cnt, is_tflite=True, dequantize=dequantize) fg.graph_name = graph_name if is_main_g: main_g = fg else: set_function(graph_name, fg) return main_g is_func = is_function(tf_graph) if not is_func: tf_graph = infer_shape(tf_graph, shape_override) outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf( tf_graph, const_node_values, output_names) onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \ tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default) if not is_subgraph: # make tf2onnx internal subgraphs from the tensorflow subgraphs ordered_func = resolve_functions(tf_graph) for func in ordered_func: f_inputs_names = [t.name for t in func.inputs] f_output_names = [t.name for t in func.outputs] fg = process_tf_graph(func, continue_on_error, False, target, opset, custom_op_handlers, custom_rewriter, extra_opset, shape_override, inputs_as_nchw, f_inputs_names, f_output_names, is_subgraph=True, const_node_values=const_node_values, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables) fg.graph_name = func.name set_function(func.name, fg) check_io(input_names, output_names, output_shapes) if not is_subgraph: rename_tensors_in_nodes(onnx_nodes) input_names = rename_tensors_in_list(input_names) output_names = rename_tensors_in_list(output_names) output_shapes = rename_tensors_in_dict(output_shapes) dtypes = rename_tensors_in_dict(dtypes) inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, input_names, output_names, is_subgraph) g = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt) return g
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None, extra_opset=None, shape_override=None, inputs_as_nchw=None, input_names=None, output_names=None, ignore_default=None, use_default=None, is_subgraph=False, const_node_values=None, tensors_to_rename=None, initialized_tables=None, tflite_path=None, dequantize=False, tfjs_path=None): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph continue_on_error: if an op can't be processed (aka there is no mapping), continue verbose: print summary stats (deprecated) target: list of workarounds applied to help certain platforms opset: the opset to be used (int, default is latest) custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nchw to nhwc input_names: list of input node names in graph, input name format as node_name:port_id. Optional. output_names: list of output node names in graph, format is node_name:port_id. Optional for tflite. ignore_default: list of node names of PlaceholderWithDefault ops to change into Placeholder ops use_default: list of node names of PlaceholderWithDefault ops to change into Identity ops using the default const_node_values: a dict returned by compress_graph_def mapping node names to tensor values tensors_to_rename: an optional dict (string->string) mapping tensor names to new names initialized_tables: mapping from table shared_names to tuple of keys and values of table tflite_path: Path to a tflite file to convert. If used, pass None to tf_graph Return: onnx graph """ # NOTE: process_parsed_graph and Graph are always given tensors post-rename. # process_tf_graph (this function) gets tensors pre-rename. if verbose: logger.warning( "Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead." ) del verbose opset = utils.find_opset(opset) logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s", get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6]) logger.info("Using opset <onnx, %s>", opset) if opset > schemas.get_max_supported_opset_version(): logger.warning( "Currently installed onnx package %s is too low to support opset %s, " "please upgrade onnx package to avoid potential conversion issue.", utils.get_onnx_version(), opset) clear_functions() if inputs_as_nchw is None: inputs_as_nchw = [] is_tflite = False if tflite_path is not None: main_g, subgraphs = graphs_from_tflite(tflite_path, input_names, output_names) is_tflite = True elif tfjs_path is not None: main_g, subgraphs = graphs_from_tfjs(tfjs_path, input_names, output_names, shape_override, ignore_default, use_default) else: main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values, ignore_default, use_default) for g in [main_g] + subgraphs: g.set_config(target, opset, extra_opset) g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, initialized_tables, tensors_to_rename, is_tflite, dequantize) return g