コード例 #1
0
ファイル: load.py プロジェクト: jforlines23/tensorflow
  def del_tracking(self):
    """Removes tracked references that are only used when loading the model."""
    # Now that the node object has been fully loaded, and the checkpoint has
    # been restored, the object no longer needs to track objects added from
    # SerializedAttributes. (Note that saving a training checkpoint still
    # functions correctly, because layers and variables are tracked separately
    # by the Layer object.)
    # TODO(kathywu): Instead of outright deleting these nodes (which would
    # make restoring from a different checkpoint tricky), mark them as extra
    # dependencies that are OK to overwrite.
    for node in self.loaded_nodes.values():
      node = node[0]
      if not isinstance(node, base_layer.Layer):
        # Loaded nodes can contain other trackable objects created when
        # loading layers from the config, such as variables.
        continue
      for name in PUBLIC_ATTRIBUTES:
        delete_tracking(node, name)

      if isinstance(node, functional_lib.Functional):
        # Delete the temporary layer dependencies, which were used to restore
        # the checkpointed values. When the model is live, the user can delete
        # or add layers to the model at any time, so these layer dependencies
        # may be obsolete.
        dependencies = list(node._self_unconditional_dependency_names)  # pylint: disable=protected-access
        for name in dependencies:
          if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
            delete_tracking(node, name)
コード例 #2
0
    def __init__(self, *args, **kwargs):
        # Maps node id -> (node, revive setter function)
        # Nodes recreated from the config may generate other nodes. This list
        # records all nodes that were generated directly/indirectly from the config,
        # so that they do not get recreated multiple times.
        self._nodes_recreated_from_config = {}
        self._all_nodes_recreated_from_config = (
            object_identity.ObjectIdentityWeakSet())
        # Store all node ids that have already been traversed when tracking nodes
        # that were recreated from the config.
        self._traversed_nodes_from_config = []

        # Maps model id -> (blank model obj, list of child layer or their node ids)
        # This tracks all layers in functional and sequential models. These models
        # are only reconstructed after all of their child layers have been created.
        self.model_layer_dependencies = {}
        self._models_to_reconstruct = []

        super(KerasObjectLoader, self).__init__(*args, **kwargs)

        # Now that the node object has been fully loaded, and the checkpoint has
        # been restored, the object no longer needs to track objects added from
        # SerializedAttributes. (Note that saving a training checkpoint still
        # functions correctly, because layers and variables are tracked separately
        # by the Layer object.)
        # TODO(kathywu): Instead of outright deleting these nodes (which would
        # make restoring from a different checkpoint tricky), mark them as extra
        # dependencies that are OK to overwrite.
        for node in self._nodes:
            if not isinstance(node, base_layer.Layer):
                continue
            for name in PUBLIC_ATTRIBUTES:
                delete_tracking(node, name)
コード例 #3
0
ファイル: load.py プロジェクト: jungmin-yoon1/tensorflow
    def _load_all(self):
        """Reconstruct the object graph from the SavedModel."""
        # Load layer and model objects from either config or SavedModel. The objects
        # loaded from config may create variables / other objects during
        # initialization. These are recorded in `_nodes_recreated_from_config`.
        self._layer_nodes = self._load_layers()

        # Load all other nodes and functions.
        super(KerasObjectLoader, self)._load_all()

        # Finish setting up layers and models. See function docstring for more info.
        self._finalize_objects()

        # Now that the node object has been fully loaded, the object no longer needs
        # to track objects added from SerializedAttributes. (Note that saving a
        # training checkpoint still functions correctly, because layers and
        # variables are tracked separately by the Layer object.)
        # TODO(kathywu): Instead of outright deleting these nodes (which would
        # make restoring from a different checkpoint tricky), mark them as extra
        # dependencies that are OK to overwrite.
        for node in self._nodes:
            if not isinstance(node, base_layer.Layer):
                continue
            for name in PUBLIC_ATTRIBUTES:
                delete_tracking(node, name)
コード例 #4
0
    def _finalize(self):
        # pylint: disable=protected-access
        for node in self._nodes:
            if isinstance(node, RevivedLayer):
                if not isinstance(node, RevivedSequential):
                    if hasattr(node.keras_api,
                               'call_and_return_conditional_losses'):
                        node.call = utils.use_wrapped_call(
                            node,
                            node.keras_api.call_and_return_conditional_losses,
                            return_method=True)
                        node._init_call_fn_args()

        for node in self._nodes:
            if isinstance(node, RevivedModel):
                call_fn = node.keras_api.call_and_return_conditional_losses
                if call_fn.input_signature is None:
                    inputs = infer_inputs_from_restored_call_function(call_fn)
                else:
                    inputs = call_fn.input_signature[0]
                if isinstance(node, RevivedSequential):
                    with trackable.no_automatic_dependency_tracking_scope(
                            node):
                        node._layers = []
                    for layer in node.keras_api.layers:
                        node.add(layer)

                if not node.inputs:
                    # Since this revived object is technically a subclassed model (even if
                    # the original model is functional/sequential), inputs should be set.
                    node._set_inputs(inputs)
            if isinstance(node, RevivedLayer):
                if hasattr(node.keras_api, 'layer_regularization_losses'):
                    losses = getattr(node.keras_api,
                                     'layer_regularization_losses', [])
                else:
                    # Some earlier SavedModels may not have layer_regularization_losses
                    # serialized separately. Fall back to using the regularization_losses
                    # list if it does not exist.
                    losses = node._serialized_attributes.get(
                        'regularization_losses', [])
                for loss in losses:
                    node.add_loss(loss)

                # Use wrapped activity regularizer function if the layer's activity
                # regularizer wasn't created during initialization.
                if node.activity_regularizer is None:
                    node.activity_regularizer = getattr(
                        node.keras_api, 'activity_regularizer_fn', None)

                # Now that the node object has been fully loaded and restored from the,
                # checkpoint, the object no longer needs to track objects added from
                # SerializedAttributes. (Note that saving a training checkpoint still
                # functions correctly, because layers and variables are tracked
                # separately by the Layer object.)
                # TODO(kathywu): Instead of outright deleting these nodes (which would
                # make restoring from a different checkpoint tricky), mark them as extra
                # dependencies that are OK to overwrite.
                for name in PUBLIC_ATTRIBUTES:
                    delete_tracking(node, name)
コード例 #5
0
ファイル: load.py プロジェクト: zheng568/tensorflow
 def del_tracking(self):
   """Removes tracked references that are only used when loading the model."""
   # Now that the node object has been fully loaded, and the checkpoint has
   # been restored, the object no longer needs to track objects added from
   # SerializedAttributes. (Note that saving a training checkpoint still
   # functions correctly, because layers and variables are tracked separately
   # by the Layer object.)
   # TODO(kathywu): Instead of outright deleting these nodes (which would
   # make restoring from a different checkpoint tricky), mark them as extra
   # dependencies that are OK to overwrite.
   for node in self.loaded_nodes.values():
     node = node[0]
     if not isinstance(node, base_layer.Layer):
       # Loaded nodes can contain other trackable objects created when
       # loading layers from the config, such as variables.
       continue
     for name in PUBLIC_ATTRIBUTES:
       delete_tracking(node, name)
コード例 #6
0
ファイル: load.py プロジェクト: romeokienzler/tensorflow-1
    def _finalize(self):
        # pylint: disable=protected-access

        # Set up call functions for all layers (skip this step for Sequential and
        # Functional models).
        for node in self._nodes:
            if isinstance(node, RevivedLayer):
                node.built = True
                is_graph_network = node._serialized_attributes['metadata'].get(
                    'is_graph_network', False)
                if not (isinstance(node, models_lib.Sequential)
                        or is_graph_network):
                    if hasattr(node.keras_api,
                               'call_and_return_conditional_losses'):
                        node.call = utils.use_wrapped_call(
                            node,
                            node.keras_api.call_and_return_conditional_losses,
                            return_method=True)
                        node._init_call_fn_args()

        for node in self._nodes:
            if isinstance(node, RevivedNetwork):
                call_fn = node.keras_api.call_and_return_conditional_losses
                if call_fn.input_signature is None:
                    inputs = infer_inputs_from_restored_call_function(call_fn)
                else:
                    inputs = call_fn.input_signature[0]

                # Set model inputs and outputs.
                is_graph_network = node._serialized_attributes['metadata'].get(
                    'is_graph_network', False)
                if isinstance(node, models_lib.Sequential):
                    with trackable.no_automatic_dependency_tracking_scope(
                            node):
                        node._layers = []
                    for layer in node.keras_api.layers:
                        node.add(layer)
                elif is_graph_network:
                    # Reconstruct functional model from the config and layers loaded
                    # from the SavedModel.
                    inputs, outputs, _ = network_lib.reconstruct_from_config(
                        node.get_config(),
                        created_layers={
                            layer.name: layer
                            for layer in node.layers
                        })
                    node._init_graph_network(
                        inputs,
                        outputs,
                        name=node._serialized_attributes['metadata']['name'])
                    # Set the metadata attributes once more, since _init_graph_network
                    # resets these attributes.
                    _set_network_attributes_from_metadata(node)
                else:  # Model is subclassed.
                    node._set_inputs(inputs)

            # Add unconditional losses.
            if isinstance(node, RevivedLayer):
                if hasattr(node.keras_api, 'layer_regularization_losses'):
                    losses = getattr(node.keras_api,
                                     'layer_regularization_losses', [])
                else:
                    # Some earlier SavedModels may not have layer_regularization_losses
                    # serialized separately. Fall back to using the regularization_losses
                    # list if it does not exist.
                    losses = node._serialized_attributes.get(
                        'regularization_losses', [])
                for loss in losses:
                    node.add_loss(loss)

                # Use wrapped activity regularizer function if the layer's activity
                # regularizer wasn't created during initialization.
                if node.activity_regularizer is None:
                    node.activity_regularizer = getattr(
                        node.keras_api, 'activity_regularizer_fn', None)

                # Now that the node object has been fully loaded and restored from the,
                # checkpoint, the object no longer needs to track objects added from
                # SerializedAttributes. (Note that saving a training checkpoint still
                # functions correctly, because layers and variables are tracked
                # separately by the Layer object.)
                # TODO(kathywu): Instead of outright deleting these nodes (which would
                # make restoring from a different checkpoint tricky), mark them as extra
                # dependencies that are OK to overwrite.
                for name in PUBLIC_ATTRIBUTES:
                    delete_tracking(node, name)