示例#1
0
    def _add_attributes_to_object_graph_for_saveable_objects(
            self, checkpoint_factory_map, object_graph_proto, node_ids,
            object_map, call_with_mapped_captures):
        """Create SaveableObjects and corresponding SerializedTensor protos."""
        named_saveable_objects = []
        if self._saveables_cache is None:
            # No SaveableObject caching. Either we're executing eagerly, or building a
            # static save which is specialized to the current Python state.
            feed_additions = None
        else:
            # If we are caching SaveableObjects, we need to build up a feed_dict with
            # functions computing volatile Python state to be saved with the
            # checkpoint.
            feed_additions = {}
        for trackable, factory_data_list in checkpoint_factory_map.items():
            object_proto = object_graph_proto.nodes[node_ids[trackable]]
            if self._saveables_cache is not None:
                object_to_save = _get_mapped_trackable(trackable, object_map)
                cached_attributes = self._saveables_cache.setdefault(
                    object_to_save, {})
            else:
                cached_attributes = None

            for factory_data in factory_data_list:
                attribute = object_proto.attributes.add()
                attribute.name = name = factory_data.name
                attribute.checkpoint_key = key = factory_data.checkpoint_key
                saveable_factory = factory_data.factory

                # See if we can skip saving this checkpoint key.
                saveables = cached_attributes.get(
                    name) if cached_attributes else None
                if saveables is not None:
                    for saveable in saveables:
                        if key not in saveable.name:
                            # The checkpoint key for this SaveableObject is different. We
                            # need to re-create it.
                            saveables = None
                            del cached_attributes[name]
                            break

                if saveables is None:
                    if callable(saveable_factory):
                        maybe_saveable = saveable_object_util.create_saveable_object(
                            saveable_factory, key, call_with_mapped_captures)
                    else:
                        maybe_saveable = saveable_factory
                    if isinstance(maybe_saveable,
                                  saveable_object_lib.SaveableObject):
                        saveables = (maybe_saveable, )
                    else:
                        # Figure out the name-based Saver's name for this variable. If it's
                        # already a SaveableObject we'd just get the checkpoint key back, so
                        # we leave full_name blank.
                        saver_dict = saveable_object_util.op_list_to_dict(
                            [maybe_saveable], convert_variable_to_tensor=False)
                        full_name, = saver_dict.keys()
                        saveables = tuple(
                            saveable_object_util.saveable_objects_for_op(
                                op=maybe_saveable, name=key))
                        for saveable in saveables:
                            saveable.full_name = full_name
                    for saveable in saveables:
                        if key not in saveable.name:
                            raise AssertionError(
                                f"The object {trackable} produced a SaveableObject with name "
                                f"'{saveable.name}' for attribute '{name}'. Expected a name"
                                f" containing '{key}'.")
                    if cached_attributes is not None:
                        cached_attributes[name] = saveables

                for saveable in saveables:
                    if hasattr(saveable, "full_name"):
                        attribute.full_name = saveable.full_name
                    if isinstance(saveable, base.PythonStateSaveable):
                        if feed_additions is None:
                            assert self._saveables_cache is None
                            # If we're not caching saveables, then we're either executing
                            # eagerly or building a static save/restore (e.g. for a
                            # SavedModel). In either case, we should embed the current Python
                            # state in the graph rather than relying on a feed dict.
                            saveable = saveable.freeze()
                        else:
                            saveable_feed_dict = saveable.feed_dict_additions()
                            for new_feed_key in saveable_feed_dict.keys():
                                if new_feed_key in feed_additions:
                                    raise AssertionError(
                                        f"The object {trackable} tried to feed a value for the "
                                        f"Tensor {new_feed_key} when saving, but another object "
                                        "is already feeding a value.")
                            feed_additions.update(saveable_feed_dict)
                    named_saveable_objects.append(saveable)

        return named_saveable_objects, feed_additions
示例#2
0
    def _add_attributes_to_object_graph(self, trackable_objects,
                                        object_graph_proto, node_ids,
                                        object_names, object_map,
                                        call_with_mapped_captures):
        """Create SaveableObjects and corresponding SerializedTensor protos."""
        named_saveable_objects = []
        if self._saveables_cache is None:
            # No SaveableObject caching. Either we're executing eagerly, or building a
            # static save which is specialized to the current Python state.
            feed_additions = None
        else:
            # If we are caching SaveableObjects, we need to build up a feed_dict with
            # functions computing volatile Python state to be saved with the
            # checkpoint.
            feed_additions = {}
        for checkpoint_id, (trackable, object_proto) in enumerate(
                zip(trackable_objects, object_graph_proto.nodes)):
            assert node_ids[trackable] == checkpoint_id
            object_name = object_names[trackable]
            if object_map is None:
                object_to_save = trackable
            else:
                object_to_save = object_map.get(trackable, trackable)
            if self._saveables_cache is not None:
                cached_attributes = self._saveables_cache.setdefault(
                    object_to_save, {})
            else:
                cached_attributes = None

            for name, saveable_factory in (
                    object_to_save._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
                attribute = object_proto.attributes.add()
                attribute.name = name
                attribute.checkpoint_key = "%s/%s/%s" % (
                    object_name, _OBJECT_ATTRIBUTES_NAME,
                    _escape_local_name(name))
                if cached_attributes is None:
                    saveables = None
                else:
                    saveables = cached_attributes.get(name, None)
                    if saveables is not None:
                        for saveable in saveables:
                            if attribute.checkpoint_key not in saveable.name:
                                # The checkpoint key for this SaveableObject is different. We
                                # need to re-create it.
                                saveables = None
                                del cached_attributes[name]
                                break
                if saveables is None:
                    if callable(saveable_factory):
                        maybe_saveable = saveable_object_util.create_saveable_object(
                            saveable_factory, attribute.checkpoint_key,
                            call_with_mapped_captures)
                    else:
                        maybe_saveable = saveable_factory
                    if isinstance(maybe_saveable,
                                  saveable_object_lib.SaveableObject):
                        saveables = (maybe_saveable, )
                    else:
                        # Figure out the name-based Saver's name for this variable. If it's
                        # already a SaveableObject we'd just get the checkpoint key back, so
                        # we leave full_name blank.
                        saver_dict = saveable_object_util.op_list_to_dict(
                            [maybe_saveable], convert_variable_to_tensor=False)
                        full_name, = saver_dict.keys()
                        saveables = tuple(
                            saveable_object_util.saveable_objects_for_op(
                                op=maybe_saveable,
                                name=attribute.checkpoint_key))
                        for saveable in saveables:
                            saveable.full_name = full_name
                    for saveable in saveables:
                        if attribute.checkpoint_key not in saveable.name:
                            raise AssertionError((
                                "The object %s produced a SaveableObject with name '%s' for "
                                "attribute '%s'. Expected a name containing '%s'."
                            ) % (trackable, name, saveable.name,
                                 attribute.checkpoint_key))
                    if cached_attributes is not None:
                        cached_attributes[name] = saveables

                optional_restore = None
                for saveable in saveables:
                    if optional_restore is None:
                        optional_restore = saveable.optional_restore
                    else:
                        optional_restore = optional_restore and saveable.optional_restore

                    if hasattr(saveable, "full_name"):
                        attribute.full_name = saveable.full_name
                    if isinstance(saveable, base.PythonStateSaveable):
                        if feed_additions is None:
                            assert self._saveables_cache is None
                            # If we're not caching saveables, then we're either executing
                            # eagerly or building a static save/restore (e.g. for a
                            # SavedModel). In either case, we should embed the current Python
                            # state in the graph rather than relying on a feed dict.
                            saveable = saveable.freeze()
                        else:
                            saveable_feed_dict = saveable.feed_dict_additions()
                            for new_feed_key in saveable_feed_dict.keys():
                                if new_feed_key in feed_additions:
                                    raise AssertionError((
                                        "The object %s tried to feed a value for the Tensor %s "
                                        "when saving, but another object is already feeding a "
                                        "value.") % (trackable, new_feed_key))
                            feed_additions.update(saveable_feed_dict)
                    named_saveable_objects.append(saveable)
                if optional_restore is None:
                    optional_restore = False
                attribute.optional_restore = optional_restore

        return named_saveable_objects, feed_additions
示例#3
0
def _add_attributes_to_object_graph_for_saveable_objects(
        checkpoint_factory_map, object_graph_proto, node_ids, object_map,
        call_with_mapped_captures, saveables_cache):
    """Create SaveableObjects and corresponding SerializedTensor protos."""
    named_saveable_objects = []
    if saveables_cache is None:
        # No SaveableObject caching. Either we're executing eagerly, or building a
        # static save which is specialized to the current Python state.
        feed_additions = None
    else:
        # If we are caching SaveableObjects, we need to build up a feed_dict with
        # functions computing volatile Python state to be saved with the
        # checkpoint.
        feed_additions = {}
    for trackable, factory_data_list in checkpoint_factory_map.items():
        object_proto = object_graph_proto.nodes[node_ids[trackable]]
        object_to_save = _get_mapped_trackable(trackable, object_map)
        if saveables_cache is not None:
            cached_attributes = saveables_cache.setdefault(object_to_save, {})
        else:
            cached_attributes = None

        for factory_data in factory_data_list:
            name = factory_data.name
            key = factory_data.checkpoint_key
            saveable_factory = factory_data.factory

            # See if we can skip saving this checkpoint key.
            saveables = cached_attributes.get(
                name) if cached_attributes else None
            if saveables is not None:
                for saveable in saveables:
                    if key not in saveable.name:
                        # The checkpoint key for this SaveableObject is different. We
                        # need to re-create it.
                        saveables = None
                        del cached_attributes[name]
                        break

            if saveables is None:
                if callable(saveable_factory):
                    maybe_saveable = saveable_object_util.create_saveable_object(
                        saveable_factory, key, call_with_mapped_captures)
                else:
                    maybe_saveable = saveable_factory
                if isinstance(maybe_saveable,
                              saveable_object_lib.SaveableObject):
                    saveables = (maybe_saveable, )
                else:
                    saveables = tuple(
                        saveable_object_util.saveable_objects_for_op(
                            op=maybe_saveable, name=key))
                for saveable in saveables:
                    if key not in saveable.name:
                        raise AssertionError(
                            f"The object {trackable} produced a SaveableObject with name "
                            f"'{saveable.name}' for attribute '{name}'. Expected a name"
                            f" containing '{key}'.")
                if cached_attributes is not None:
                    cached_attributes[name] = saveables

            for saveable in saveables:
                if isinstance(saveable, python_state.PythonStateSaveable):
                    if feed_additions is None:
                        assert saveables_cache is None
                        # If we're not caching saveables, then we're either executing
                        # eagerly or building a static save/restore (e.g. for a
                        # SavedModel). In either case, we should embed the current Python
                        # state in the graph rather than relying on a feed dict.
                        saveable = saveable.freeze()
                    else:
                        saveable_feed_dict = saveable.feed_dict_additions()
                        for new_feed_key in saveable_feed_dict.keys():
                            if new_feed_key in feed_additions:
                                raise AssertionError(
                                    f"The object {trackable} tried to feed a value for the "
                                    f"Tensor {new_feed_key} when saving, but another object "
                                    "is already feeding a value.")
                        feed_additions.update(saveable_feed_dict)
                named_saveable_objects.append(saveable)

            # Update the object proto.
            # For updated Trackables that override serialize_to_tensors, add an
            # attribute for each tensor that is serialized.
            # For Trackables that have SaveableObjects or a legacy saveable name,
            # add a single attribute to the proto.
            if (isinstance(saveables[0],
                           saveable_object_util.TrackableSaveable) and
                    saveable_compat.get_saveable_name(object_to_save) is None):
                for local_name, local_key in (
                        saveables[0].get_proto_names_and_checkpoint_keys()):
                    object_proto.attributes.add(
                        name=local_name,
                        checkpoint_key=local_key,
                        full_name=_get_full_name(object_to_save))
            else:
                object_proto.attributes.add(
                    name=name,
                    checkpoint_key=key,
                    full_name=_get_full_name(object_to_save))

    return named_saveable_objects, feed_additions