def load(path, compile=True): # pylint: disable=redefined-builtin """Loads Keras objects from a SavedModel. Any Keras layer or model saved to the SavedModel will be loaded back as Keras objects. Other objects are loaded as regular trackable objects (same as `tf.saved_model.load`). Currently, Keras saving/loading only retains the Keras object's weights, losses, and call function. The loaded model can be re-compiled, but the original optimizer, compiled loss functions, and metrics are not retained. This is temporary, and `model.save` will soon be able to serialize compiled models. Args: path: Path to SavedModel. compile: If true, compile the model after loading it. Returns: Object loaded from SavedModel. """ # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. # TODO(kathywu): Add code to load from objects that contain all endpoints model = tf_load.load_internal(path, loader_cls=KerasObjectLoader) if isinstance(model, RevivedModel) and compile: # TODO(kathywu): Use compiled objects from SavedModel, instead of # creating new objects from the training config. if model._training_config is not None: # pylint: disable=protected-access model.compile(**saving_utils.compile_args_from_training_config( model._training_config)) # pylint: disable=protected-access return model
def __init__(self, saved_model_dir): """Init method for SavedModelLoader. Args: saved_model_dir: A SavedModel directory providing a transform graph. The MetaGraphDef and signature are selected from the SavedModel using keys defined in `../constants.py` ('transform' and 'transform_signature', respectively). """ if tf.version.VERSION < '2.5': self._imported = load.load_internal(saved_model_dir, loader_cls=_Loader) if isinstance(self._imported, dict): self._imported = self._imported['root'] else: # TODO(b/160294509): Stop using tf.compat.v2 when TF1.15 support is # dropped. self._imported = tf.compat.v2.saved_model.load(saved_model_dir) self.load_v2_in_compat = (constants.TRANSFORM_SIGNATURE in self._imported.signatures) if self.load_v2_in_compat: self._wrapped = self._imported.signatures[ constants.TRANSFORM_SIGNATURE] self._func_graph = self._wrapped.graph self._structured_inputs = self._get_input_signature_from_v1_saved_model( saved_model_dir) self._structured_outputs = self._wrapped.structured_outputs else: # TODO(b/160550490): Remove local import. from tensorflow_transform import tf2_utils # pylint: disable=g-import-not-at-top # Since `input_signature` was specified when exporting the tf function to # transform_fn is now a ConcreteFunction, but was a tf.function. We need # to handle both to maintain backward compatiblity. If it's a tf.function, # since `input_signature` was specified when exporting the tf function to # `SavedModel`, there should be exactly one concrete function present on # loading the `SavedModel`. if hasattr(self._imported.transform_fn, 'concrete_functions'): concrete_functions = self._imported.transform_fn.concrete_functions assert len(concrete_functions) == 1, concrete_functions self._wrapped = concrete_functions[0] else: self._wrapped = self._imported.transform_fn self._func_graph = self._wrapped.graph self._structured_inputs = ( tf2_utils.get_structured_inputs_from_func_graph( self._func_graph)) self._structured_outputs = tf.nest.pack_sequence_as( self._func_graph.structured_outputs, self._func_graph.outputs, expand_composites=True) self._output_to_inputs_map = (self._get_output_to_inputs_map( self._structured_outputs)) saved_transform_io._maybe_register_addon_ops() # pylint: disable=protected-access
def load(path, compile=True, options=None): # pylint: disable=redefined-builtin """Loads Keras objects from a SavedModel. Any Keras layer or model saved to the SavedModel will be loaded back as Keras objects. Other objects are loaded as regular trackable objects (same as `tf.saved_model.load`). Currently, Keras saving/loading only retains the Keras object's weights, losses, and call function. The loaded model can be re-compiled, but the original optimizer, compiled loss functions, and metrics are not retained. This is temporary, and `model.save` will soon be able to serialize compiled models. Args: path: Path to SavedModel. compile: If true, compile the model after loading it. options: Optional `tf.saved_model.LoadOptions` object that specifies options for loading from SavedModel. Returns: Object loaded from SavedModel. """ # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. # TODO(kathywu): Add code to load from objects that contain all endpoints model = tf_load.load_internal(path, options=options, loader_cls=KerasObjectLoader) # pylint: disable=protected-access if isinstance(model, training_lib.Model) and compile: # TODO(kathywu): Use compiled objects from SavedModel, instead of # creating new objects from the training config. training_config = model._serialized_attributes['metadata'].get( 'training_config', None) if training_config is not None: model.compile(**saving_utils.compile_args_from_training_config( training_config)) saving_utils.try_build_compiled_arguments(model) else: logging.warning( 'No training configuration found in save file, so the ' 'model was *not* compiled. Compile it manually.') # pylint: enable=protected-access # Force variables and resources to initialize. if not context.executing_eagerly(): sess = backend.get_session() # Variables are initialized by this call. sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)) return model
# import tensorflow as tf from tensorflow.python.keras.saving.saved_model.load import KerasObjectLoader from tensorflow.python.saved_model.load import load_internal from tensorflow.python.keras.saving.saved_model.load import RevivedModel from tensorflow.python.keras.saving import saving_utils from tensorflow.python.saved_model import loader_impl model_path = 'output/saved_model/cls/1599723701' loader_impl.parse_saved_model(model_path) model = load_internal(model_path, tags=['serve'], loader_cls=KerasObjectLoader) if not isinstance(model, RevivedModel): raise RuntimeError("Can not load model") if model._training_config is None: raise RuntimeError("Model _training_config is None") model.compile( **saving_utils.compile_args_from_training_config(model._training_config)) test_data = [[], [], [], []] model.predict(test_data)