def from_graph_def(graph_def, name=None, input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, large_model=False, tensors_to_rename=None, output_path=None): """Returns a ONNX model_proto for a tensorflow graphdef. Args: graph_def: the graphdef we want to convert input_names: list of input names output_names: list of output names name: A name for the graph opset: the opset to be used for the ONNX model, default is the latest custom_ops: if a model contains ops not recognized by onnx runtime, you can tag these ops with a custom op domain so that the runtime can still open the model. Type is a dictionary `{op name: domain}`. target: list of workarounds applied to help certain platforms custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nchw to nhwc large_model: use the ONNX external tensor storage format output_path: save model to output_path Returns: An ONNX model_proto and an external_tensor_storage dict. """ if not input_names: raise ValueError("input_names needs to be provided") if not output_names: raise ValueError("output_names needs to be provided") if not name: name = "unknown" initialized_tables = None with tf.device("/cpu:0"): with tf.Graph().as_default() as tf_graph: with tf_loader.tf_session(graph=tf_graph) as sess: tf.import_graph_def(graph_def, name='') frozen_graph = tf_loader.freeze_session(sess, input_names=input_names, output_names=output_names) input_names = tf_loader.inputs_without_resource(sess, input_names) frozen_graph = tf_loader.tf_optimize(input_names, output_names, graph_def) model_proto, external_tensor_storage = _convert_common( frozen_graph, name=name, continue_on_error=True, target=target, opset=opset, custom_op_handlers=custom_ops, extra_opset=extra_opset, shape_override=shape_override, input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, large_model=large_model, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables, output_path=output_path) return model_proto, external_tensor_storage
def load(self, model_path: Union[str, Path], **_) -> Model: if isinstance(model_path, Path): model_path = model_path.as_posix() get_model = load_from_file(model_path, "model", GET_MODEL_FN_NAME) get_serving_input_receiver_fn = load_from_file(model_path, "model", GET_SERVING_INPUT_RECEIVER_FN) if get_model is None: raise RuntimeError(f"Could not find {GET_MODEL_FN_NAME} in {model_path}") if get_serving_input_receiver_fn is None: raise RuntimeError(f"Could not find {GET_SERVING_INPUT_RECEIVER_FN} in {model_path}") model_args = filter_fn_args(self._model_args, fn=get_model) serving_input_receiver_args = filter_fn_args(self._model_args, fn=get_serving_input_receiver_fn) session_config = create_session_config(allow_growth=True) tf.compat.v1.reset_default_graph() with tf.compat.v1.Session(config=session_config) as sess: estimator = get_model(**model_args) serving_input_receiver_fn = get_serving_input_receiver_fn(**serving_input_receiver_args) input_receiver = serving_input_receiver_fn() estimator_spec = estimator.model_fn( features=input_receiver.features, labels=None, mode=tf.estimator.ModeKeys.PREDICT, config=estimator.config, ) input_tensors_dict = input_receiver.receiver_tensors output_tensors_dict = estimator_spec.predictions inputs_dict = {k: tensor2tensor_spec(tensor) for k, tensor in input_tensors_dict.items()} outputs_dict = {k: tensor2tensor_spec(tensor) for k, tensor in output_tensors_dict.items()} input_tensor_names = [t.name for t in inputs_dict.values()] output_tensor_names = [t.name for t in outputs_dict.values()] graph_saver = estimator_spec.scaffold.saver or tf.compat.v1.train.Saver(sharded=True) graph_saver.restore(sess, estimator.latest_checkpoint()) input_tensor_names = inputs_without_resource(sess, input_tensor_names) frozen_graph = freeze_session(sess, input_names=input_tensor_names, output_names=output_tensor_names) input_tensor_names = remove_redundant_inputs(frozen_graph, input_tensor_names) tf.compat.v1.reset_default_graph() with tf.compat.v1.Session(config=estimator.config.session_config): frozen_graph = tf_optimize(input_tensor_names, output_tensor_names, frozen_graph) tf.compat.v1.reset_default_graph() precision = _infer_model_precision(frozen_graph, inputs_dict, outputs_dict) return Model(frozen_graph, precision, inputs_dict, outputs_dict)