def _restore_from_v1_saved_model( restored_function: function.ConcreteFunction, saved_model_dir: str ) -> Tuple[function.ConcreteFunction, Mapping[str, Any], Mapping[ str, common_types.TensorType]]: """Restores an exported TF1 compat SavedModel.""" saved_model = saved_model_loader.parse_saved_model(saved_model_dir) meta_graph_def = saved_model_loader.choose_meta_graph_def_and_raise( saved_model) signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE] # Re-register pyfuncs, if any. graph_def = pyfunc_helper.register_pyfuncs_from_saved_transform( restored_function.graph, meta_graph_def, loaded_in_tf2=True) if graph_def is None: return (restored_function, signature.inputs, restored_function.structured_outputs) inputs = [t.name for t in restored_function.graph.inputs] outputs = [t.name for t in restored_function.graph.outputs] wrapped = wrap_function.function_from_graph_def(graph_def, inputs, outputs) structured_outputs = ( tf.nest.pack_sequence_as( restored_function.structured_outputs, wrapped.outputs, expand_composites=True)) wrapped = wrapped.prune(wrapped.inputs, structured_outputs) return (wrapped, signature.inputs, wrapped.structured_outputs)
def _construct_concrete_function(func, output_graph_def, converted_input_indices): """Constructs a concrete function from the `output_graph_def`. Args: func: ConcreteFunction output_graph_def: GraphDef proto. converted_input_indices: Set of integers of input indices that were converted to constants. Returns: ConcreteFunction. """ # Create a ConcreteFunction from the new GraphDef. input_tensors = func.graph.internal_captures converted_inputs = set( [input_tensors[index] for index in converted_input_indices]) not_converted_inputs = set(func.inputs).difference(converted_inputs) not_converted_inputs_map = { tensor.name: tensor for tensor in not_converted_inputs } new_input_names = [tensor.name for tensor in not_converted_inputs] new_output_names = [tensor.name for tensor in func.outputs] new_func = wrap_function.function_from_graph_def(output_graph_def, new_input_names, new_output_names) # Manually propagate shape for input tensors where the shape is not correctly # propagated. Scalars shapes are lost when wrapping the function. for input_tensor in new_func.inputs: input_tensor.set_shape( not_converted_inputs_map[input_tensor.name].shape) return new_func
def __init__(self, model_pb_file, l2norm=True, scale=255.0, target=None, size=None, preserve_aspect_ratio=False, mean=None, std=None): """ Args: model_pb_file: tensorflow graph file. l2norm: whether use l2-norm or not. Default: True scale: the scale to divide from raw pixels. Default: 224 size: size vector for model input. preserve_aspect_ratio: preserve the aspect ratio or not. Defautl: False mean: mean vector. std: std vector. """ super(CustomModule, self).__init__() graph_def = tf.compat.v1.GraphDef() loaded = graph_def.ParseFromString(open(model_pb_file, "rb").read()) #self.model_func = wrap_frozen_graph(graph_def, inputs="input1:0", outputs="output1:0") self.model_func = \ wrap_function.function_from_graph_def(graph_def, inputs="input1:0", outputs="output1:0") self.l2norm = l2norm self.scale = scale self.target = target self.size = size self.preserve_aspect_ratio = preserve_aspect_ratio self.mean = mean self.std = std
def _convert_saved_model_v2(self): """Convert the input SavedModel in 2.0 format.""" assert context.executing_eagerly() self._saved_model = load.load(self._input_saved_model_dir, self._input_saved_model_tags) func = self._saved_model.signatures[ self._input_saved_model_signature_key] frozen_func = convert_to_constants.convert_variables_to_constants_v2( func) self._grappler_meta_graph_def = saver.export_meta_graph( graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) # Add a collection 'train_op' so that Grappler knows the outputs. fetch_collection = meta_graph_pb2.CollectionDef() for array in frozen_func.inputs + frozen_func.outputs: fetch_collection.node_list.value.append(array.name) self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom( fetch_collection) # Run TRT optimizer in Grappler to convert the graph. self._run_conversion() self._converted_func = wrap_function.function_from_graph_def( self._converted_graph_def, [tensor.name for tensor in frozen_func.inputs], [tensor.name for tensor in frozen_func.outputs])
def __enter__(self): tf.compat.v1.reset_default_graph() input_tensor_names = [spec.name for spec in self._model.inputs.values()] output_tensor_names = [spec.name for spec in self._model.outputs.values()] self._concrete_func = wrap_function.function_from_graph_def( self._model.handle, input_tensor_names, output_tensor_names ) self._concrete_func._signature = [ tf.TensorSpec(shape=spec.shape, dtype=spec.dtype, name=name) for name, spec in self._model.inputs.items() ] return self
def test_function_from_graph_def(self): @def_function.function def make_graph_def(x): return x + 1. original_func_graph = make_graph_def.get_concrete_function( tensor_spec.TensorSpec([None, 2], dtypes.float32)).graph graph_def = original_func_graph.as_graph_def() revived_function = wrap_function.function_from_graph_def( graph_def, inputs=original_func_graph.inputs[0].name, outputs=original_func_graph.outputs[0].name) self.assertEqual(2., revived_function(constant_op.constant(1.)).numpy())
def __init__(self, g1, scale1, target1, size1, mean1, std1, p1, g2, scale2, target2, size2, mean2, std2, p2, l2norm=True): super(CustomModule, self).__init__() graph_def1 = tf.compat.v1.GraphDef() graph_def1.ParseFromString(open(g1, "rb").read()) graph_def2 = tf.compat.v1.GraphDef() graph_def2.ParseFromString(open(g2, "rb").read()) self.model_func1 = wrap_function.function_from_graph_def( graph_def1, inputs="input1:0", outputs="output1:0") self.model_func2 = wrap_function.function_from_graph_def( graph_def2, inputs="input1:0", outputs="output1:0") self.l2norm = l2norm self.scale1 = scale1 self.target1 = target1 self.size1 = size1 self.mean1 = mean1 self.std1 = std1 self.p1 = p1 self.scale2 = scale2 self.target2 = target2 self.size2 = size2 self.mean2 = mean2 self.std2 = std2 self.p2 = p2
def convert(self): """Convert the input SavedModel in 2.0 format. Returns: The TF-TRT converted Function. """ assert not self._converted self._saved_model = load.load(self._input_saved_model_dir, self._input_saved_model_tags) func = self._saved_model.signatures[ self._input_saved_model_signature_key] frozen_func = convert_to_constants.convert_variables_to_constants_v2( func) grappler_meta_graph_def = saver.export_meta_graph( graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) # Add a collection 'train_op' so that Grappler knows the outputs. fetch_collection = meta_graph_pb2.CollectionDef() for array in frozen_func.inputs + frozen_func.outputs: fetch_collection.node_list.value.append(array.name) grappler_meta_graph_def.collection_def["train_op"].CopyFrom( fetch_collection) # Run TRT optimizer in Grappler to convert the graph. self._converted_graph_def = self._run_conversion( grappler_meta_graph_def) self._converted_func = wrap_function.function_from_graph_def( self._converted_graph_def, [tensor.name for tensor in frozen_func.inputs], [tensor.name for tensor in frozen_func.outputs]) # Reconstruct the output signatures using the ones from original model. self._converted_func.graph.structured_outputs = nest.pack_sequence_as( func.graph.structured_outputs, self._converted_func.graph.structured_outputs) self._converted = True # Wrap the converted ConcreteFunction in a Function so it can accept numpy # arrays as input. @def_function.function def wrapper_func(*args, **kwargs): return self._converted_func(*args, **kwargs) return wrapper_func
def __init__(self, group, l2norm=True): super(CustomModule, self).__init__() self.l2norm = l2norm self.scales = [] self.targets = [] self.sizes = [] self.means = [] self.stds = [] self.ps = [] self.ws = [] self.model_funcs = [] for item in group: g1, scale1, target1, size1, mean1, std1, p1, s1 = item graph_def1 = tf.compat.v1.GraphDef() graph_def1.ParseFromString(open(g1, "rb").read()) self.model_funcs.append(wrap_function.function_from_graph_def(graph_def1, inputs="input1:0", outputs="output1:0")) self.scales.append(scale1) self.targets.append(target1) self.sizes.append(size1) self.means.append(mean1) self.stds.append(std1) self.ps.append(p1) self.ws.append(s1)
def convert(self): """Convert the input SavedModel in 2.0 format. Returns: The TF-TRT converted Function. """ assert not self._converted self._saved_model = load.load(self._input_saved_model_dir, self._input_saved_model_tags) func = self._saved_model.signatures[self._input_saved_model_signature_key] frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) grappler_meta_graph_def = saver.export_meta_graph( graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) # Add a collection 'train_op' so that Grappler knows the outputs. fetch_collection = meta_graph_pb2.CollectionDef() for array in frozen_func.inputs + frozen_func.outputs: fetch_collection.node_list.value.append(array.name) grappler_meta_graph_def.collection_def["train_op"].CopyFrom( fetch_collection) # Run TRT optimizer in Grappler to convert the graph. self._converted_graph_def = self._run_conversion(grappler_meta_graph_def) self._converted_func = wrap_function.function_from_graph_def( self._converted_graph_def, [tensor.name for tensor in frozen_func.inputs], [tensor.name for tensor in frozen_func.outputs]) self._converted = True # Wrap the converted ConcreteFunction in a Function so it can accept numpy # arrays as input. @def_function.function def wrapper_func(*args, **kwargs): return self._converted_func(*args, **kwargs) return wrapper_func
def _convert_saved_model_v2(self): """Convert the input SavedModel in 2.0 format.""" assert context.executing_eagerly() self._saved_model = load.load(self._input_saved_model_dir, self._input_saved_model_tags) func = self._saved_model.signatures[self._input_saved_model_signature_key] frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) self._grappler_meta_graph_def = saver.export_meta_graph( graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) # Add a collection 'train_op' so that Grappler knows the outputs. fetch_collection = meta_graph_pb2.CollectionDef() for array in frozen_func.inputs + frozen_func.outputs: fetch_collection.node_list.value.append(array.name) self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom( fetch_collection) # Run TRT optimizer in Grappler to convert the graph. self._run_conversion() self._converted_func = wrap_function.function_from_graph_def( self._converted_graph_def, [tensor.name for tensor in frozen_func.inputs], [tensor.name for tensor in frozen_func.outputs])
def convert(self, calibration_input_fn=None): """Convert the input SavedModel in 2.0 format. Args: calibration_input_fn: a generator function that yields input data as a list or tuple, which will be used to execute the converted signature for calibration. All the returned input data should have the same shape. Example: ``` def input_fn(): yield input1, input2, input3 ``` Raises: ValueError: if the input combination is invalid. Returns: The TF-TRT converted Function. """ assert not self._converted if (self._need_calibration and not calibration_input_fn): raise ValueError("Should specify calibration_input_fn because INT8 " "calibration is needed") if (not self._need_calibration and calibration_input_fn): raise ValueError("Should not specify calibration_input_fn because INT8 " "calibration is not needed") self._saved_model = load.load(self._input_saved_model_dir, self._input_saved_model_tags) func = self._saved_model.signatures[self._input_saved_model_signature_key] frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) grappler_meta_graph_def = saver.export_meta_graph( graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) # Add a collection 'train_op' so that Grappler knows the outputs. fetch_collection = meta_graph_pb2.CollectionDef() for array in frozen_func.inputs + frozen_func.outputs: fetch_collection.node_list.value.append(array.name) grappler_meta_graph_def.collection_def["train_op"].CopyFrom( fetch_collection) # Run TRT optimizer in Grappler to convert the graph. self._converted_graph_def = self._run_conversion(grappler_meta_graph_def) self._converted_func = wrap_function.function_from_graph_def( self._converted_graph_def, [tensor.name for tensor in frozen_func.inputs], [tensor.name for tensor in frozen_func.outputs]) # Reconstruct the output signatures using the ones from original model. self._converted_func.graph.structured_outputs = nest.pack_sequence_as( func.graph.structured_outputs, self._converted_func.graph.structured_outputs) if self._need_calibration: for inp in calibration_input_fn(): self._converted_func(*map(ops.convert_to_tensor, inp)) def _save_calibration_table(node): calibration_table = gen_trt_ops.get_calibration_data_op( _get_canonical_engine_name(node.name)) node.attr["calibration_data"].s = calibration_table.numpy() self._for_each_trt_node(self._converted_graph_def, _save_calibration_table) # Rebuild the function since calibration has changed the graph. calibrated_func = wrap_function.function_from_graph_def( self._converted_graph_def, [tensor.name for tensor in self._converted_func.inputs], [tensor.name for tensor in self._converted_func.outputs]) calibrated_func.graph.structured_outputs = nest.pack_sequence_as( self._converted_func.graph.structured_outputs, calibrated_func.graph.structured_outputs) self._converted_func = calibrated_func self._converted = True
def convert_variables_to_constants_v2(func): """Replaces all the variables in a graph with constants of the same values. TensorFlow 2.0 function for converting all Variable ops into Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables. This function runs Grappler's function inlining optimization in order to return a single subgraph. The current implementation only works for graphs that do not contain any control flow or embedding related ops. Args: func: ConcreteFunction. Returns: ConcreteFunction containing a simplified version of the original. """ # TODO(nupurgarg): Replace ResourceGather with Gather. # TODO(nupurgarg): Change attr for Variables in control flow and functions. graph_def = _run_inline_graph_optimization(func) # Identify the ReadVariableOps. get_name = lambda name: name.split(":")[0] map_name_to_node = {get_name(node.name): node for node in graph_def.node} # TODO(b/125838789): Use `func.graph.captures`. # Get mapping from input name to variable value. tensor_data = {} map_name_to_handle = {} input_tensors = func.inputs[-len(func.captured_inputs):] for var in func.graph.variables: index = func.captured_inputs.index(var.handle) tensor_name = get_name(input_tensors[index].name) tensor_data[tensor_name] = var.numpy() map_name_to_handle[tensor_name] = var.handle # Get mapping from input name to value for non-variable placeholders. map_name_to_value = {} for name_tensor, value_tensor in zip(input_tensors, func.captured_inputs): tensor_name = get_name(name_tensor.name) if tensor_name not in map_name_to_handle: map_name_to_value[tensor_name] = value_tensor resource_identities = {} placeholders = {} converted_input_indices = set() for node in graph_def.node: if node.name in map_name_to_value: # Get the dtype and data for the Placeholders whose values are stored as # Tensors. This is the case for values that were originally Const ops. tensor = map_name_to_value[node.name] placeholders[node.name] = { "dtype": node.attr["dtype"], "data": tensor.numpy(), } converted_input_indices.add( func.captured_inputs.index(map_name_to_value[node.name])) if node.op == "ReadVariableOp": # Get name of Placeholder op associated with ReadVariableOp. There can be # an Identity in between the ReadVariableOp and Placeholder. Store the # Identity ops with the associated dtypes. input_name = get_name(node.input[0]) while map_name_to_node[input_name].op == "Identity": resource_identities[input_name] = node.attr["dtype"] input_name = get_name(map_name_to_node[input_name].input[0]) if map_name_to_node[input_name].op != "Placeholder": raise ValueError( "Cannot find the Placeholder op that is an input " "to the ReadVariableOp.") # Build a map of Placeholder ops that are inputs to ReadVariableOps to the # variable's dtype and data. placeholders[input_name] = { "dtype": node.attr["dtype"], "data": tensor_data[input_name], } converted_input_indices.add( func.captured_inputs.index(map_name_to_handle[input_name])) # Reconstruct the graph with constants in place of variables. output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in graph_def.node: output_node = output_graph_def.node.add() # Convert Placeholder ops to Const ops. if input_node.name in placeholders: dtype = placeholders[input_node.name]["dtype"] data = placeholders[input_node.name]["data"] output_node.op = "Const" output_node.name = input_node.name output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].tensor.CopyFrom( tensor_util.make_tensor_proto(data, dtype=dtype.type, shape=data.shape)) how_many_converted += 1 # Change the dtype for Identity ops that are inputs to ReadVariableOps. elif input_node.name in resource_identities: output_node.CopyFrom(input_node) output_node.attr["T"].CopyFrom( resource_identities[input_node.name]) # Convert ReadVariableOps into Identity ops. elif input_node.op == "ReadVariableOp": output_node.op = "Identity" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) else: output_node.CopyFrom(input_node) logging.info("Converted %d variables to const ops.", how_many_converted) # Create a ConcreteFunction from the new GraphDef. converted_inputs = set( [input_tensors[index] for index in converted_input_indices]) not_converted_inputs = set(func.inputs).difference(converted_inputs) not_converted_inputs_map = { tensor.name: tensor for tensor in not_converted_inputs } new_input_names = [tensor.name for tensor in not_converted_inputs] new_output_names = [tensor.name for tensor in func.outputs] new_func = wrap_function.function_from_graph_def(output_graph_def, new_input_names, new_output_names) # Manually propagate shape for input tensors where the shape is not correctly # propagated. Scalars shapes are lost when wrapping the function. for input_tensor in new_func.inputs: input_tensor.set_shape( not_converted_inputs_map[input_tensor.name].shape) return new_func
def convert_variables_to_constants_v2(func): """Replaces all the variables in a graph with constants of the same values. TensorFlow 2.0 function for converting all Variable ops into Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables. This function runs Grappler's function inlining optimization in order to return a single subgraph. The current implementation only works for graphs that do not contain any control flow or embedding related ops. Args: func: ConcreteFunction. Returns: ConcreteFunction containing a simplified version of the original. """ # TODO(nupurgarg): Replace ResourceGather with Gather. # TODO(nupurgarg): Change attr for Variables in control flow and functions. graph_def = _run_inline_graph_optimization(func) # Identify the ReadVariableOps. get_name = lambda name: name.split(":")[0] map_name_to_node = {get_name(node.name): node for node in graph_def.node} # TODO(b/125838789): Use `func.graph.captures`. # Get mapping from input name to variable value. tensor_data = {} map_name_to_handle = {} input_tensors = func.inputs[-len(func.captured_inputs):] for var in func.graph.variables: index = func.captured_inputs.index(var.handle) tensor_name = get_name(input_tensors[index].name) tensor_data[tensor_name] = var.numpy() map_name_to_handle[tensor_name] = var.handle # Get mapping from input name to value for non-variable placeholders. map_name_to_value = {} for name_tensor, value_tensor in zip(input_tensors, func.captured_inputs): tensor_name = get_name(name_tensor.name) if tensor_name not in map_name_to_handle: map_name_to_value[tensor_name] = value_tensor resource_identities = {} placeholders = {} converted_input_indices = set() for node in graph_def.node: if node.name in map_name_to_value: # Get the dtype and data for the Placeholders whose values are stored as # Tensors. This is the case for values that were originally Const ops. tensor = map_name_to_value[node.name] placeholders[node.name] = { "dtype": node.attr["dtype"], "data": tensor.numpy(), } converted_input_indices.add( func.captured_inputs.index(map_name_to_value[node.name])) if node.op == "ReadVariableOp": # Get name of Placeholder op associated with ReadVariableOp. There can be # an Identity in between the ReadVariableOp and Placeholder. Store the # Identity ops with the associated dtypes. input_name = get_name(node.input[0]) while map_name_to_node[input_name].op == "Identity": resource_identities[input_name] = node.attr["dtype"] input_name = get_name(map_name_to_node[input_name].input[0]) if map_name_to_node[input_name].op != "Placeholder": raise ValueError("Cannot find the Placeholder op that is an input " "to the ReadVariableOp.") # Build a map of Placeholder ops that are inputs to ReadVariableOps to the # variable's dtype and data. placeholders[input_name] = { "dtype": node.attr["dtype"], "data": tensor_data[input_name], } converted_input_indices.add( func.captured_inputs.index(map_name_to_handle[input_name])) # Reconstruct the graph with constants in place of variables. output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in graph_def.node: output_node = output_graph_def.node.add() # Convert Placeholder ops to Const ops. if input_node.name in placeholders: dtype = placeholders[input_node.name]["dtype"] data = placeholders[input_node.name]["data"] output_node.op = "Const" output_node.name = input_node.name output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].tensor.CopyFrom( tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape)) how_many_converted += 1 # Change the dtype for Identity ops that are inputs to ReadVariableOps. elif input_node.name in resource_identities: output_node.CopyFrom(input_node) output_node.attr["T"].CopyFrom(resource_identities[input_node.name]) # Convert ReadVariableOps into Identity ops. elif input_node.op == "ReadVariableOp": output_node.op = "Identity" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) else: output_node.CopyFrom(input_node) logging.info("Converted %d variables to const ops.", how_many_converted) # Create a ConcreteFunction from the new GraphDef. converted_inputs = set( [input_tensors[index] for index in converted_input_indices]) not_converted_inputs = set(func.inputs).difference(converted_inputs) not_converted_inputs_map = { tensor.name: tensor for tensor in not_converted_inputs } new_input_names = [tensor.name for tensor in not_converted_inputs] new_output_names = [tensor.name for tensor in func.outputs] new_func = wrap_function.function_from_graph_def(output_graph_def, new_input_names, new_output_names) # Manually propagate shape for input tensors where the shape is not correctly # propagated. Scalars shapes are lost when wrapping the function. for input_tensor in new_func.inputs: input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape) return new_func
def save(self, output_saved_model_dir): """Save the converted SavedModel. Args: output_saved_model_dir: directory to saved the converted SavedModel. """ assert self._converted # Serialize the TRT engines in the cache if any, and create trackable # resource to track them. engine_asset_dir = tempfile.mkdtemp() resource_map = {} def _serialize_and_track_engine(node): """Serialize TRT engines in the cache and track them.""" # Don't dump the same cache twice. canonical_engine_name = _get_canonical_engine_name(node.name) if canonical_engine_name in resource_map: return filename = os.path.join(engine_asset_dir, "trt-serialized-engine." + canonical_engine_name) if self._need_calibration: calibration_table = gen_trt_ops.get_calibration_data_op( canonical_engine_name) node.attr["calibration_data"].s = calibration_table.numpy() try: gen_trt_ops.serialize_trt_resource( resource_name=canonical_engine_name, filename=filename, delete_resource=True) except errors.NotFoundError: # If user haven't run the function to populate the engine, it's fine, # and we don't need to track any serialized TRT engines. return # TODO(laigd): add an option for the user to choose the device. resource_map[canonical_engine_name] = _TRTEngineResource( canonical_engine_name, filename, self._conversion_params.maximum_cached_engines) for node in self._converted_graph_def.node: if node.op == _TRT_ENGINE_OP_NAME: _serialize_and_track_engine(node) for func in self._converted_graph_def.library.function: for node in func.node_def: if node.op == _TRT_ENGINE_OP_NAME: _serialize_and_track_engine(node) self._saved_model.trt_engine_resources = resource_map # Rebuild the function since calibration may change the graph. func_to_save = wrap_function.function_from_graph_def( self._converted_graph_def, [tensor.name for tensor in self._converted_func.inputs], [tensor.name for tensor in self._converted_func.outputs]) func_to_save.graph.structured_outputs = nest.pack_sequence_as( self._converted_func.graph.structured_outputs, func_to_save.graph.structured_outputs) # Rewrite the signature map using the optimized ConcreteFunction. signatures = { key: value for key, value in self._saved_model.signatures.items() } signatures[self._input_saved_model_signature_key] = func_to_save save.save(self._saved_model, output_saved_model_dir, signatures)