# SPDX-License-Identifier: Apache-2.0 """ tf2onnx.graph_helper - class to help building graph, such as helping to make complex node """ import numpy as np from tf2onnxnightly import utils, logging # pylint: disable=missing-docstring logger = logging.getLogger(__name__) class GraphBuilder(object): """help to build graph""" def __init__(self, graph): self._g = graph @property def graph(self): return self._g def make_slice(self, kwargs, name=None, shapes=None, dtypes=None, return_node=False): """ slice changes its schema at opset 10: it treats some attributes as dynamic input so this function has to process inputs according to graph's opset version
def main(): args = get_args() logging.basicConfig(level=logging.get_verbosity_level(args.verbose)) if args.debug: utils.set_debug_mode(True) logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME) extra_opset = args.extra_opset or [] tflite_path = None custom_ops = {} initialized_tables = None tensors_to_rename = {} if args.custom_ops: using_tf_opset = False for op in args.custom_ops.split(","): if ":" in op: op, domain = op.split(":") else: # default custom ops for tensorflow-onnx are in the "tf" namespace using_tf_opset = True domain = constants.TENSORFLOW_OPSET.domain custom_ops[op] = (make_default_custom_op_handler(domain), []) if using_tf_opset: extra_opset.append(constants.TENSORFLOW_OPSET) if any(opset.domain == constants.CONTRIB_OPS_DOMAIN for opset in extra_opset): try: import tensorflow_text # pylint: disable=import-outside-toplevel except ModuleNotFoundError: logger.warning("tensorflow_text not installed. Model will fail to load if tensorflow_text ops are used.") # get the frozen tensorflow model from graphdef, checkpoint or saved_model. graph_def = None inputs = None outputs = None model_path = None if args.graphdef: graph_def, inputs, outputs = tf_loader.from_graphdef(args.graphdef, args.inputs, args.outputs) model_path = args.graphdef if args.checkpoint: graph_def, inputs, outputs = tf_loader.from_checkpoint(args.checkpoint, args.inputs, args.outputs) model_path = args.checkpoint if args.saved_model: graph_def, inputs, outputs, initialized_tables, tensors_to_rename = tf_loader.from_saved_model( args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function, args.large_model, return_initialized_tables=True, return_tensors_to_rename=True) model_path = args.saved_model if args.keras: graph_def, inputs, outputs = tf_loader.from_keras( args.keras, args.inputs, args.outputs) model_path = args.keras if args.tflite: tflite_path = args.tflite model_path = tflite_path if args.verbose: logger.info("inputs: %s", inputs) logger.info("outputs: %s", outputs) if args.rename_inputs: tensors_to_rename.update(zip(inputs, args.rename_inputs)) if args.rename_outputs: tensors_to_rename.update(zip(outputs, args.rename_outputs)) with tf.device("/cpu:0"): model_proto, _ = _convert_common( graph_def, name=model_path, continue_on_error=args.continue_on_error, target=args.target, opset=args.opset, custom_op_handlers=custom_ops, extra_opset=extra_opset, shape_override=args.shape_override, input_names=inputs, output_names=outputs, inputs_as_nchw=args.inputs_as_nchw, large_model=args.large_model, tensors_to_rename=tensors_to_rename, ignore_default=args.ignore_default, use_default=args.use_default, tflite_path=tflite_path, dequantize=args.dequantize, initialized_tables=initialized_tables, output_frozen_graph=args.output_frozen_graph, output_path=args.output) # write onnx graph logger.info("") logger.info("Successfully converted TensorFlow model %s to ONNX", model_path) logger.info("Model inputs: %s", [n.name for n in model_proto.graph.input]) logger.info("Model outputs: %s", [n.name for n in model_proto.graph.output]) if args.output: if args.large_model: logger.info("Zipped ONNX model is saved at %s. Unzip before opening in onnxruntime.", args.output) else: logger.info("ONNX model is saved at %s", args.output) else: logger.info("To export ONNX model to file, please run with `--output` option")
import tensorflow.contrib.rnn # pylint: disable=unused-import except: # pylint: disable=bare-except # not needed for tf-2.0 pass try: import tensorflow_text # pylint: disable=unused-import except ModuleNotFoundError: pass from tf2onnxnightly import tf_loader, logging, optimizer, utils, tf_utils, constants from tf2onnxnightly.tfonnx import process_tf_graph from tf2onnxnightly.tf_loader import tf_session, tf_reset_default_graph from tf2onnxnightly.graph import ExternalTensorStorage logger = logging.getLogger("run_pretrained") TEMP_DIR = os.path.join(utils.get_temp_directory(), "run_pretrained") PERFITER = 1000 def get_img(shape, path, dtype, should_scale=True): """Get image as input.""" resize_to = shape[1:3] path = os.path.join(os.path.dirname(os.path.abspath(__file__)), path) img = PIL.Image.open(path) img = img.resize(resize_to, PIL.Image.ANTIALIAS) img_np = np.array(img).astype(dtype) img_np = np.stack([img_np] * shape[0], axis=0).reshape(shape) if should_scale: img_np = img_np / 255