def _python_properties_internal(self): """Returns dictionary of all python properties.""" # TODO(kathywu): Add support for metrics serialization. # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once # the python config serialization has caught up. metadata = dict( class_name=generic_utils.get_registered_name(type(self.obj)), name=self.obj.name, trainable=self.obj.trainable, expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access batch_input_shape=getattr(self.obj, '_batch_input_shape', None), stateful=self.obj.stateful, must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access ) metadata.update(get_config(self.obj)) if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property setter. metadata['input_spec'] = tf.nest.map_structure( lambda x: generic_utils.serialize_keras_object(x) if x else None, self.obj.input_spec) if (self.obj.activity_regularizer is not None and hasattr(self.obj.activity_regularizer, 'get_config')): metadata[ 'activity_regularizer'] = generic_utils.serialize_keras_object( self.obj.activity_regularizer) if self.obj._build_input_shape is not None: # pylint: disable=protected-access metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access return metadata
def model_metadata(model, include_optimizer=True, require_config=True): """Returns a dictionary containing the model metadata.""" from keras import ( __version__ as keras_version, ) # pylint: disable=g-import-not-at-top from keras.optimizers.optimizer_v2 import ( optimizer_v2, ) # pylint: disable=g-import-not-at-top model_config = {"class_name": model.__class__.__name__} try: model_config["config"] = model.get_config() except NotImplementedError as e: if require_config: raise e metadata = dict( keras_version=str(keras_version), backend=backend.backend(), model_config=model_config, ) if model.optimizer and include_optimizer: if isinstance(model.optimizer, optimizer_v1.TFOptimizer): logging.warning( "TensorFlow optimizers do not " "make it possible to access " "optimizer attributes or optimizer state " "after instantiation. " "As a result, we cannot save the optimizer " "as part of the model save file. " "You will have to compile your model again after loading it. " "Prefer using a Keras optimizer instead " "(see keras.io/optimizers)." ) elif model._compile_was_called: # pylint: disable=protected-access training_config = model._get_compile_args( user_metrics=False ) # pylint: disable=protected-access training_config.pop("optimizer", None) # Handled separately. metadata["training_config"] = _serialize_nested_config( training_config ) if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer): raise NotImplementedError( "Optimizers loaded from a SavedModel cannot be saved. " "If you are calling `model.save` or `tf.keras.models.save_model`, " "please set the `include_optimizer` option to `False`. For " "`tf.saved_model.save`, delete the optimizer from the model." ) else: optimizer_config = { "class_name": generic_utils.get_registered_name( model.optimizer.__class__ ), "config": model.optimizer.get_config(), } metadata["training_config"]["optimizer_config"] = optimizer_config return metadata
def _python_properties_internal(self): metadata = dict(class_name=generic_utils.get_registered_name( type(self.obj)), name=self.obj.name, dtype=self.obj.dtype) metadata.update(layer_serialization.get_serialized(self.obj)) if self.obj._build_input_shape is not None: # pylint: disable=protected-access metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access return metadata
def _python_properties_internal(self): metadata = dict( class_name=generic_utils.get_registered_name(type(self.obj)), name=self.obj.name, dtype=self.obj.dtype, ) metadata.update(layer_serialization.get_serialized(self.obj)) if self.obj._build_input_shape is not None: metadata["build_input_shape"] = self.obj._build_input_shape return metadata
def model_metadata(model, include_optimizer=True, require_config=True): """Returns a dictionary containing the model metadata.""" from keras import __version__ as keras_version # pylint: disable=g-import-not-at-top from keras.optimizer_v2 import optimizer_v2 # pylint: disable=g-import-not-at-top model_config = {'class_name': model.__class__.__name__} try: model_config['config'] = model.get_config() except NotImplementedError as e: if require_config: raise e metadata = dict(keras_version=str(keras_version), backend=K.backend(), model_config=model_config) if model.optimizer and include_optimizer: if isinstance(model.optimizer, optimizer_v1.TFOptimizer): logging.warning( 'TensorFlow optimizers do not ' 'make it possible to access ' 'optimizer attributes or optimizer state ' 'after instantiation. ' 'As a result, we cannot save the optimizer ' 'as part of the model save file. ' 'You will have to compile your model again after loading it. ' 'Prefer using a Keras optimizer instead ' '(see keras.io/optimizers).') elif model._compile_was_called: # pylint: disable=protected-access training_config = model._get_compile_args(user_metrics=False) # pylint: disable=protected-access training_config.pop('optimizer', None) # Handled separately. metadata['training_config'] = _serialize_nested_config( training_config) if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer): raise NotImplementedError( 'As of now, Optimizers loaded from SavedModel cannot be saved. ' 'If you\'re calling `model.save` or `tf.keras.models.save_model`,' ' please set the `include_optimizer` option to `False`. For ' '`tf.saved_model.save`, delete the optimizer from the model.' ) else: optimizer_config = { 'class_name': generic_utils.get_registered_name( model.optimizer.__class__), 'config': model.optimizer.get_config() } metadata['training_config']['optimizer_config'] = optimizer_config return metadata
def _get_object_registered_name(obj): if isinstance(obj, types.FunctionType): return generic_utils.get_registered_name(obj) else: return generic_utils.get_registered_name(obj.__class__)