コード例 #1
0
    def __init__(self, obj, name, call_with_mapped_captures=None):
        self._trackable = obj
        self._call_with_mapped_captures = call_with_mapped_captures

        save_fn = obj._serialize_to_tensors  # pylint: disable=protected-access

        if (call_with_mapped_captures
                and isinstance(save_fn, core.ConcreteFunction)):
            tensor_dict = call_with_mapped_captures(save_fn, [])
        else:
            tensor_dict = save_fn()

        specs = []
        self._local_names = []
        self._prefix = saveable_compat.get_saveable_name(self._trackable) or ""
        for tensor_name, maybe_tensor in tensor_dict.items():
            self._local_names.append(tensor_name)
            spec_name = name + trackable_utils.escape_local_name(tensor_name)

            if not isinstance(maybe_tensor, dict):
                maybe_tensor = {"": maybe_tensor}

            # Create separate specs for each slice spec.
            for slice_spec, tensor in maybe_tensor.items():
                specs.append(
                    saveable_object.SaveSpec(tensor, slice_spec, spec_name))
        super(TrackableSaveable, self).__init__(obj, specs, name)
コード例 #2
0
    def _create_serialize_to_tensor_saveable(self, saveable_factories):
        """Creates a saveable using the _serialize_to_tensor method."""
        # Extract the saveable name from the checkpoint key. This will be used as
        # the cache key or the name to pass to the saveable factory.
        suffix = saveable_compat.get_saveable_name(self.trackable) or ""
        saveable_name = _extract_saveable_name(
            self.object_proto.attributes[0].checkpoint_key) + suffix

        # Try to find the cached saveable (only in graph mode).
        if not context.executing_eagerly():
            existing_op = self._checkpoint.restore_ops_by_name.get(
                saveable_name, None)
            if existing_op is not None:
                return existing_op, {}

            saveables_cache = self._checkpoint.saveables_cache.setdefault(
                self.trackable, {})
            if saveable_name in saveables_cache:
                return [], {saveable_name: saveables_cache[saveable_name]}

        saveable = saveable_factories[
            trackable_utils.SERIALIZE_TO_TENSORS_NAME](name=saveable_name)
        if not context.executing_eagerly():
            saveables_cache[saveable_name] = saveable
        return [], {saveable_name: saveable}
コード例 #3
0
def extract_saveable_name(trackable, checkpoint_key):
  if saveable_compat.get_saveable_name(trackable) is not None:
    # If there is a legacy saveable name, the saveable name is the checkpoint
    # key.
    return checkpoint_key
  # Substring the checkpoint key to the end of the ".ATTRIBUTES/" (len=12)
  return checkpoint_key[:checkpoint_key.index(
      trackable_utils.OBJECT_ATTRIBUTES_NAME) + 12]
コード例 #4
0
ファイル: save_util_v1.py プロジェクト: wwjiang007/tensorflow
def get_checkpoint_factories_and_keys(object_names, object_map=None):
    """Gets a map of saveable factories and corresponding checkpoint keys.

  Args:
    object_names: a dictionary that maps `Trackable` objects to auto-generated
      string names.
    object_map: a dictionary mapping `Trackable` to copied `Trackable` objects.
      The copied objects are generated from `Trackable._map_resources()` which
      copies the object into another graph. Generally only resource objects
      (e.g. Variables, Tables) will be in this map.

  Returns:
    A tuple of (
      Dictionary mapping trackable -> list of _CheckpointFactoryData,
      Dictionary mapping registered saver name -> {object name -> trackable})
  """
    checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
    unmapped_registered_savers = collections.defaultdict(dict)
    for trackable, object_name in object_names.items():
        # object_to_save is only used to retrieve the saving functionality. For keys
        # and other data, use the original `trackable`.
        object_to_save = util.get_mapped_trackable(trackable, object_map)

        saver_name = registration.get_registered_saver_name(object_to_save)
        if saver_name:
            # Add the original trackable instead of `object_to_save` to the returned
            # dict because the original is needed for writing the object proto.
            unmapped_registered_savers[saver_name][object_name] = trackable
        else:
            checkpoint_factory_map[trackable] = []
            for name, saveable_factory in (
                    saveable_object_util.saveable_objects_from_trackable(
                        object_to_save).items()):  # pylint: disable=protected-access
                # Retrieve the legacy saveable name (for compatibility purposes during
                # SaveableObject deprecation)

                key_suffix = saveable_compat.get_saveable_name(
                    object_to_save) or name
                checkpoint_key = trackable_utils.checkpoint_key(
                    object_name, key_suffix)

                if not saveable_compat.force_checkpoint_conversion_enabled():
                    # Make sure the set the name as the legacy saveable name if there
                    # is one (only when checkpoint conversion is diabled)
                    name = key_suffix

                checkpoint_factory_map[trackable].append(
                    _CheckpointFactoryData(factory=saveable_factory,
                                           name=name,
                                           checkpoint_key=checkpoint_key))
    return checkpoint_factory_map, unmapped_registered_savers
コード例 #5
0
def trace_save_and_restore(obj):
    """Traces `Trackable` serialize- and restore-from-tensors functions.

  Args:
    obj: A `Trackable` object.

  Returns:
    A concrete Function.
  """
    legacy_name = saveable_compat.get_saveable_name(obj)

    obj_save_fn = obj._serialize_to_tensors  # pylint: disable=protected-access
    obj_restore_fn = obj._restore_from_tensors  # pylint: disable=protected-access

    if isinstance(obj_save_fn, defun.ConcreteFunction):
        concrete_save = obj_save_fn
    else:

        @def_function.function
        def save_fn():
            tensor_dict = obj_save_fn()
            if legacy_name:
                # If there is a legacy decorator, append the name to the keys.
                return {
                    f"{legacy_name}{key}": value
                    for key, value in tensor_dict.items()
                }
            return tensor_dict

        concrete_save = save_fn.get_concrete_function()

    if isinstance(obj_restore_fn, defun.ConcreteFunction):
        concrete_restore = obj_restore_fn
    else:

        @def_function.function
        def restore_fn(restored_tensors):
            if legacy_name:
                # Do the opposite operation of save_fn()
                restored_tensors = {
                    key[len(legacy_name):]: value
                    for key, value in restored_tensors.items()
                }
            obj_restore_fn(restored_tensors)

        concrete_restore = restore_fn.get_concrete_function(
            concrete_save.structured_outputs)

    return concrete_save, concrete_restore
コード例 #6
0
    def test_multiple_specs_single_saveable(self):
        class MyTrackable(base.Trackable):
            def __init__(self):
                self.a = variables.Variable(35.0)
                self.b = variables.Variable(40.0)

            def _gather_saveables_for_checkpoint(self):
                return {"foo": lambda name: _MultiSpecSaveable(self, name)}

        t = MyTrackable()
        converter = saveable_object_util.SaveableCompatibilityConverter(t)
        serialized_tensors = converter._serialize_to_tensors()

        self.assertLen(serialized_tensors, 2)
        self.assertEqual(35, self.evaluate(serialized_tensors["foo-a"]))
        self.assertEqual(40, self.evaluate(serialized_tensors["foo-b"]))
        converter._restore_from_tensors({"foo-a": 5., "foo-b": 6.})
        self.assertEqual(5, self.evaluate(t.a))
        self.assertEqual(6, self.evaluate(t.b))

        # Make sure that the legacy saveable name has been applied.
        self.assertEqual("foo", saveable_compat.get_saveable_name(converter))
コード例 #7
0
def _add_attributes_to_object_graph_for_saveable_objects(
        checkpoint_factory_map, object_graph_proto, node_ids, object_map,
        call_with_mapped_captures, saveables_cache):
    """Create SaveableObjects and corresponding SerializedTensor protos."""
    named_saveable_objects = []
    if saveables_cache is None:
        # No SaveableObject caching. Either we're executing eagerly, or building a
        # static save which is specialized to the current Python state.
        feed_additions = None
    else:
        # If we are caching SaveableObjects, we need to build up a feed_dict with
        # functions computing volatile Python state to be saved with the
        # checkpoint.
        feed_additions = {}
    for trackable, factory_data_list in checkpoint_factory_map.items():
        object_proto = object_graph_proto.nodes[node_ids[trackable]]
        object_to_save = _get_mapped_trackable(trackable, object_map)
        if saveables_cache is not None:
            cached_attributes = saveables_cache.setdefault(object_to_save, {})
        else:
            cached_attributes = None

        for factory_data in factory_data_list:
            name = factory_data.name
            key = factory_data.checkpoint_key
            saveable_factory = factory_data.factory

            # See if we can skip saving this checkpoint key.
            saveables = cached_attributes.get(
                name) if cached_attributes else None
            if saveables is not None:
                for saveable in saveables:
                    if key not in saveable.name:
                        # The checkpoint key for this SaveableObject is different. We
                        # need to re-create it.
                        saveables = None
                        del cached_attributes[name]
                        break

            if saveables is None:
                if callable(saveable_factory):
                    maybe_saveable = saveable_object_util.create_saveable_object(
                        saveable_factory, key, call_with_mapped_captures)
                else:
                    maybe_saveable = saveable_factory
                if isinstance(maybe_saveable,
                              saveable_object_lib.SaveableObject):
                    saveables = (maybe_saveable, )
                else:
                    saveables = tuple(
                        saveable_object_util.saveable_objects_for_op(
                            op=maybe_saveable, name=key))
                for saveable in saveables:
                    if key not in saveable.name:
                        raise AssertionError(
                            f"The object {trackable} produced a SaveableObject with name "
                            f"'{saveable.name}' for attribute '{name}'. Expected a name"
                            f" containing '{key}'.")
                if cached_attributes is not None:
                    cached_attributes[name] = saveables

            for saveable in saveables:
                if isinstance(saveable, python_state.PythonStateSaveable):
                    if feed_additions is None:
                        assert saveables_cache is None
                        # If we're not caching saveables, then we're either executing
                        # eagerly or building a static save/restore (e.g. for a
                        # SavedModel). In either case, we should embed the current Python
                        # state in the graph rather than relying on a feed dict.
                        saveable = saveable.freeze()
                    else:
                        saveable_feed_dict = saveable.feed_dict_additions()
                        for new_feed_key in saveable_feed_dict.keys():
                            if new_feed_key in feed_additions:
                                raise AssertionError(
                                    f"The object {trackable} tried to feed a value for the "
                                    f"Tensor {new_feed_key} when saving, but another object "
                                    "is already feeding a value.")
                        feed_additions.update(saveable_feed_dict)
                named_saveable_objects.append(saveable)

            # Update the object proto.
            # For updated Trackables that override serialize_to_tensors, add an
            # attribute for each tensor that is serialized.
            # For Trackables that have SaveableObjects or a legacy saveable name,
            # add a single attribute to the proto.
            if (isinstance(saveables[0],
                           saveable_object_util.TrackableSaveable) and
                    saveable_compat.get_saveable_name(object_to_save) is None):
                for local_name, local_key in (
                        saveables[0].get_proto_names_and_checkpoint_keys()):
                    object_proto.attributes.add(
                        name=local_name,
                        checkpoint_key=local_key,
                        full_name=_get_full_name(object_to_save))
            else:
                object_proto.attributes.add(
                    name=name,
                    checkpoint_key=key,
                    full_name=_get_full_name(object_to_save))

    return named_saveable_objects, feed_additions