def del_tracking(self): """Removes tracked references that are only used when loading the model.""" # Now that the node object has been fully loaded, and the checkpoint has # been restored, the object no longer needs to track objects added from # SerializedAttributes. (Note that saving a training checkpoint still # functions correctly, because layers and variables are tracked separately # by the Layer object.) # TODO(kathywu): Instead of outright deleting these nodes (which would # make restoring from a different checkpoint tricky), mark them as extra # dependencies that are OK to overwrite. for node in self.loaded_nodes.values(): node = node[0] if not isinstance(node, base_layer.Layer): # Loaded nodes can contain other trackable objects created when # loading layers from the config, such as variables. continue for name in PUBLIC_ATTRIBUTES: delete_tracking(node, name) if isinstance(node, functional_lib.Functional): # Delete the temporary layer dependencies, which were used to restore # the checkpointed values. When the model is live, the user can delete # or add layers to the model at any time, so these layer dependencies # may be obsolete. dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access for name in dependencies: if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None: delete_tracking(node, name)
def __init__(self, *args, **kwargs): # Maps node id -> (node, revive setter function) # Nodes recreated from the config may generate other nodes. This list # records all nodes that were generated directly/indirectly from the config, # so that they do not get recreated multiple times. self._nodes_recreated_from_config = {} self._all_nodes_recreated_from_config = ( object_identity.ObjectIdentityWeakSet()) # Store all node ids that have already been traversed when tracking nodes # that were recreated from the config. self._traversed_nodes_from_config = [] # Maps model id -> (blank model obj, list of child layer or their node ids) # This tracks all layers in functional and sequential models. These models # are only reconstructed after all of their child layers have been created. self.model_layer_dependencies = {} self._models_to_reconstruct = [] super(KerasObjectLoader, self).__init__(*args, **kwargs) # Now that the node object has been fully loaded, and the checkpoint has # been restored, the object no longer needs to track objects added from # SerializedAttributes. (Note that saving a training checkpoint still # functions correctly, because layers and variables are tracked separately # by the Layer object.) # TODO(kathywu): Instead of outright deleting these nodes (which would # make restoring from a different checkpoint tricky), mark them as extra # dependencies that are OK to overwrite. for node in self._nodes: if not isinstance(node, base_layer.Layer): continue for name in PUBLIC_ATTRIBUTES: delete_tracking(node, name)
def _load_all(self): """Reconstruct the object graph from the SavedModel.""" # Load layer and model objects from either config or SavedModel. The objects # loaded from config may create variables / other objects during # initialization. These are recorded in `_nodes_recreated_from_config`. self._layer_nodes = self._load_layers() # Load all other nodes and functions. super(KerasObjectLoader, self)._load_all() # Finish setting up layers and models. See function docstring for more info. self._finalize_objects() # Now that the node object has been fully loaded, the object no longer needs # to track objects added from SerializedAttributes. (Note that saving a # training checkpoint still functions correctly, because layers and # variables are tracked separately by the Layer object.) # TODO(kathywu): Instead of outright deleting these nodes (which would # make restoring from a different checkpoint tricky), mark them as extra # dependencies that are OK to overwrite. for node in self._nodes: if not isinstance(node, base_layer.Layer): continue for name in PUBLIC_ATTRIBUTES: delete_tracking(node, name)
def _finalize(self): # pylint: disable=protected-access for node in self._nodes: if isinstance(node, RevivedLayer): if not isinstance(node, RevivedSequential): if hasattr(node.keras_api, 'call_and_return_conditional_losses'): node.call = utils.use_wrapped_call( node, node.keras_api.call_and_return_conditional_losses, return_method=True) node._init_call_fn_args() for node in self._nodes: if isinstance(node, RevivedModel): call_fn = node.keras_api.call_and_return_conditional_losses if call_fn.input_signature is None: inputs = infer_inputs_from_restored_call_function(call_fn) else: inputs = call_fn.input_signature[0] if isinstance(node, RevivedSequential): with trackable.no_automatic_dependency_tracking_scope( node): node._layers = [] for layer in node.keras_api.layers: node.add(layer) if not node.inputs: # Since this revived object is technically a subclassed model (even if # the original model is functional/sequential), inputs should be set. node._set_inputs(inputs) if isinstance(node, RevivedLayer): if hasattr(node.keras_api, 'layer_regularization_losses'): losses = getattr(node.keras_api, 'layer_regularization_losses', []) else: # Some earlier SavedModels may not have layer_regularization_losses # serialized separately. Fall back to using the regularization_losses # list if it does not exist. losses = node._serialized_attributes.get( 'regularization_losses', []) for loss in losses: node.add_loss(loss) # Use wrapped activity regularizer function if the layer's activity # regularizer wasn't created during initialization. if node.activity_regularizer is None: node.activity_regularizer = getattr( node.keras_api, 'activity_regularizer_fn', None) # Now that the node object has been fully loaded and restored from the, # checkpoint, the object no longer needs to track objects added from # SerializedAttributes. (Note that saving a training checkpoint still # functions correctly, because layers and variables are tracked # separately by the Layer object.) # TODO(kathywu): Instead of outright deleting these nodes (which would # make restoring from a different checkpoint tricky), mark them as extra # dependencies that are OK to overwrite. for name in PUBLIC_ATTRIBUTES: delete_tracking(node, name)
def del_tracking(self): """Removes tracked references that are only used when loading the model.""" # Now that the node object has been fully loaded, and the checkpoint has # been restored, the object no longer needs to track objects added from # SerializedAttributes. (Note that saving a training checkpoint still # functions correctly, because layers and variables are tracked separately # by the Layer object.) # TODO(kathywu): Instead of outright deleting these nodes (which would # make restoring from a different checkpoint tricky), mark them as extra # dependencies that are OK to overwrite. for node in self.loaded_nodes.values(): node = node[0] if not isinstance(node, base_layer.Layer): # Loaded nodes can contain other trackable objects created when # loading layers from the config, such as variables. continue for name in PUBLIC_ATTRIBUTES: delete_tracking(node, name)
def _finalize(self): # pylint: disable=protected-access # Set up call functions for all layers (skip this step for Sequential and # Functional models). for node in self._nodes: if isinstance(node, RevivedLayer): node.built = True is_graph_network = node._serialized_attributes['metadata'].get( 'is_graph_network', False) if not (isinstance(node, models_lib.Sequential) or is_graph_network): if hasattr(node.keras_api, 'call_and_return_conditional_losses'): node.call = utils.use_wrapped_call( node, node.keras_api.call_and_return_conditional_losses, return_method=True) node._init_call_fn_args() for node in self._nodes: if isinstance(node, RevivedNetwork): call_fn = node.keras_api.call_and_return_conditional_losses if call_fn.input_signature is None: inputs = infer_inputs_from_restored_call_function(call_fn) else: inputs = call_fn.input_signature[0] # Set model inputs and outputs. is_graph_network = node._serialized_attributes['metadata'].get( 'is_graph_network', False) if isinstance(node, models_lib.Sequential): with trackable.no_automatic_dependency_tracking_scope( node): node._layers = [] for layer in node.keras_api.layers: node.add(layer) elif is_graph_network: # Reconstruct functional model from the config and layers loaded # from the SavedModel. inputs, outputs, _ = network_lib.reconstruct_from_config( node.get_config(), created_layers={ layer.name: layer for layer in node.layers }) node._init_graph_network( inputs, outputs, name=node._serialized_attributes['metadata']['name']) # Set the metadata attributes once more, since _init_graph_network # resets these attributes. _set_network_attributes_from_metadata(node) else: # Model is subclassed. node._set_inputs(inputs) # Add unconditional losses. if isinstance(node, RevivedLayer): if hasattr(node.keras_api, 'layer_regularization_losses'): losses = getattr(node.keras_api, 'layer_regularization_losses', []) else: # Some earlier SavedModels may not have layer_regularization_losses # serialized separately. Fall back to using the regularization_losses # list if it does not exist. losses = node._serialized_attributes.get( 'regularization_losses', []) for loss in losses: node.add_loss(loss) # Use wrapped activity regularizer function if the layer's activity # regularizer wasn't created during initialization. if node.activity_regularizer is None: node.activity_regularizer = getattr( node.keras_api, 'activity_regularizer_fn', None) # Now that the node object has been fully loaded and restored from the, # checkpoint, the object no longer needs to track objects added from # SerializedAttributes. (Note that saving a training checkpoint still # functions correctly, because layers and variables are tracked # separately by the Layer object.) # TODO(kathywu): Instead of outright deleting these nodes (which would # make restoring from a different checkpoint tricky), mark them as extra # dependencies that are OK to overwrite. for name in PUBLIC_ATTRIBUTES: delete_tracking(node, name)