def _load_layer(self, node_id, identifier, metadata): """Load a single layer from a SavedUserObject proto.""" metadata = json_utils.decode(metadata) # If node was already created if node_id in self.loaded_nodes: node, setter = self.loaded_nodes[node_id] # Revive setter requires the object to have a `_serialized_attributes` # property. Add it here. _maybe_add_serialized_attributes(node, metadata) config = metadata.get('config') if _is_graph_network(node) and generic_utils.validate_config(config): child_nodes = self._get_child_layer_node_ids(node_id) self.model_layer_dependencies[node_id] = (node, child_nodes) if not child_nodes: self._models_to_reconstruct.append(node_id) return node, setter # Detect whether this object can be revived from the config. If not, then # revive from the SavedModel instead. obj, setter = self._revive_from_config(identifier, metadata, node_id) if obj is None: obj, setter = revive_custom_object(identifier, metadata) # Add an attribute that stores the extra functions/objects saved in the # SavedModel. Most of these functions/objects are ignored, but some are # used later in the loading process (e.g. the list of regularization # losses, or the training config of compiled models). _maybe_add_serialized_attributes(obj, metadata) return obj, setter
def _init_from_metadata(cls, metadata): """Create revived layer from metadata stored in the SavedModel proto.""" init_args = dict( name=metadata['name'], trainable=metadata['trainable']) if metadata.get('dtype') is not None: init_args['dtype'] = metadata['dtype'] if metadata.get('batch_input_shape') is not None: init_args['batch_input_shape'] = metadata['batch_input_shape'] revived_obj = cls(**init_args) with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access revived_obj._expects_training_arg = metadata['expects_training_arg'] config = metadata.get('config') if generic_utils.validate_config(config): revived_obj._config = config if metadata.get('input_spec') is not None: revived_obj.input_spec = recursively_deserialize_keras_object( metadata['input_spec'], module_objects={'InputSpec': input_spec.InputSpec}) if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( metadata['activity_regularizer']) if metadata.get('_is_feature_layer') is not None: revived_obj._is_feature_layer = metadata['_is_feature_layer'] if metadata.get('stateful') is not None: revived_obj.stateful = metadata['stateful'] # pylint:enable=protected-access return revived_obj, _revive_setter
def _revive_graph_network(self, metadata, node_id): """Revives a graph network from config.""" # Determine whether the metadata contains information for reviving a # functional or Sequential model. config = metadata.get('config') if not generic_utils.validate_config(config): return None class_name = tf.compat.as_str(metadata['class_name']) if generic_utils.get_registered_object(class_name) is not None: return None model_is_functional_or_sequential = ( metadata.get('is_graph_network', False) or class_name == 'Sequential' or class_name == 'Functional') if not model_is_functional_or_sequential: return None # Revive functional and sequential models as blank model objects for now ( # must be initialized to enable setattr tracking and attribute caching). # Reconstruction of the network is deferred until all of the model's layers # have been revived. if class_name == 'Sequential': model = models_lib.Sequential(name=config['name']) else: model = models_lib.Functional( inputs=[], outputs=[], name=config['name']) # Record this model and its layers. This will later be used to reconstruct # the model. layers = self._get_child_layer_node_ids(node_id) self.model_layer_dependencies[node_id] = (model, layers) if not layers: self._models_to_reconstruct.append(node_id) return model
def _revive_layer_or_model_from_config(self, metadata, node_id): """Revives a layer/custom model from config; returns None if infeasible.""" # Check that the following requirements are met for reviving from config: # 1. Object can be deserialized from config. # 2. If the object needs to be built, then the build input shape can be # found. class_name = metadata.get('class_name') config = metadata.get('config') shared_object_id = metadata.get('shared_object_id') must_restore_from_config = metadata.get('must_restore_from_config') if not generic_utils.validate_config(config): return None try: obj = layers_module.deserialize( generic_utils.serialize_keras_class_and_config( class_name, config, shared_object_id=shared_object_id)) except ValueError: if must_restore_from_config: raise RuntimeError( 'Unable to restore a layer of class {cls}. Layers of ' 'class {cls} require that the class be provided to ' 'the model loading code, either by registering the ' 'class using @keras.utils.register_keras_serializable ' 'on the class def and including that file in your ' 'program, or by passing the class in a ' 'keras.utils.CustomObjectScope that wraps this load ' 'call.'.format(cls=class_name)) else: return None # Use the dtype, name, and trainable status. Often times these are not # specified in custom configs, so retrieve their values from the metadata. # pylint: disable=protected-access obj._name = metadata['name'] if metadata.get('trainable') is not None: obj.trainable = metadata['trainable'] if metadata.get('dtype') is not None: obj._set_dtype_policy(metadata['dtype']) if metadata.get('stateful') is not None: obj.stateful = metadata['stateful'] # Restore model save spec for subclassed models. (layers do not store a # SaveSpec) if isinstance(obj, training_lib.Model): save_spec = metadata.get('save_spec') if save_spec is not None: obj._set_save_spec(save_spec) # pylint: enable=protected-access build_input_shape = metadata.get('build_input_shape') built = self._try_build_layer(obj, node_id, build_input_shape) if not built: # If the layer cannot be built, revive a custom layer instead. return None return obj
def _revive_metric_from_config(self, metadata): """Revives a metric object using the config saved in the metadata.""" class_name = tf.compat.as_str(metadata['class_name']) config = metadata.get('config') if not generic_utils.validate_config(config): return None try: obj = metrics.deserialize( generic_utils.serialize_keras_class_and_config(class_name, config)) except ValueError: return None build_input_shape = metadata.get('build_input_shape') if build_input_shape is not None and hasattr(obj, '_build'): obj._build(build_input_shape) # pylint: disable=protected-access return obj
def _init_from_metadata(cls, metadata): """Create revived network from metadata stored in the SavedModel proto.""" revived_obj = cls(name=metadata['name']) # Store attributes revived from SerializedAttributes in a un-tracked # dictionary. The attributes are the ones listed in CommonEndpoints or # "keras_api" for keras-specific attributes. with trackable.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access revived_obj._expects_training_arg = metadata['expects_training_arg'] config = metadata.get('config') if generic_utils.validate_config(config): revived_obj._config = config if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( metadata['activity_regularizer']) # pylint:enable=protected-access return revived_obj, _revive_setter # pylint:disable=protected-access
def _revive_layer_or_model_from_config(self, metadata, node_id): """Revives a layer/custom model from config; returns None if infeasible.""" # Check that the following requirements are met for reviving from config: # 1. Object can be deserialized from config. # 2. If the object needs to be built, then the build input shape can be # found. class_name = metadata.get('class_name') config = metadata.get('config') shared_object_id = metadata.get('shared_object_id') must_restore_from_config = metadata.get('must_restore_from_config') if not generic_utils.validate_config(config): return None try: obj = layers_module.deserialize( generic_utils.serialize_keras_class_and_config( class_name, config, shared_object_id=shared_object_id)) except (TypeError, KeyError) as e: # A name conflict has occurred. The `class_name` is in the Keras native # framework; however, the value in the framework is different from the # user's class definition which confuses the KerasObjectLoader. builtin_layer = layers_module.get_builtin_layer(class_name) if builtin_layer: raise RuntimeError( f'Unable to restore object of class \'{class_name}\' likely due to ' f'name conflict with built-in Keras class \'{builtin_layer}\'. To ' 'override the built-in Keras definition of the object, decorate ' 'your class with `@keras.utils.register_keras_serializable` and ' 'include that file in your program, or pass your class in a ' '`keras.utils.CustomObjectScope` that wraps this load call.') from e else: raise except ValueError as e: if must_restore_from_config: raise e else: return None # Use the dtype, name, and trainable status. Often times these are not # specified in custom configs, so retrieve their values from the metadata. # pylint: disable=protected-access obj._name = metadata['name'] if metadata.get('trainable') is not None: obj.trainable = metadata['trainable'] if metadata.get('dtype') is not None: obj._set_dtype_policy(metadata['dtype']) if metadata.get('stateful') is not None: obj.stateful = metadata['stateful'] # Restore model save spec for subclassed models. (layers do not store a # SaveSpec) if isinstance(obj, training_lib.Model): full_save_spec = metadata.get('full_save_spec') if full_save_spec is not None: args_spec, kwargs_spec = full_save_spec inputs_spec = args_spec.pop(0) obj._set_save_spec(inputs_spec, args_spec, kwargs_spec) # pylint: enable=protected-access build_input_shape = metadata.get('build_input_shape') built = self._try_build_layer(obj, node_id, build_input_shape) if not built: # If the layer cannot be built, revive a custom layer instead. return None return obj