예제 #1
0
    def load_blobs_from_checkpoint(self, blob_names, epoch):
        """
        Builds a Task that loads only the necessary blobs from a checkpoint of
        the given epoch. The necessary blobs are given in the blob_names
        argument.

        Args:
            blob_names: A list of strings. Each string is the name of a
                blob.
            epoch: The checkpoint epoch to load from.

        Returns:
            A Task which loads the specified blobs from the checkpoint of the
            given epoch.
        """
        logger.info('Load from %s' %
                    db_name(epoch, self._node_name, self._db_prefix))
        with Task() as task:
            ops.Load([],
                     blob_names,
                     db=db_name(epoch, self._node_name, self._db_prefix),
                     db_type=self._db_type,
                     absolute_path=True,
                     allow_incomplete=True)
        return task
예제 #2
0
    def init(self,
             nodes=None,
             retrieve_from_epoch=None,
             path_prefix=None,
             path_type=None):
        """
        Build a Task that will be run once after the job's `init_group` is run.
        This task will determine which blobs need to be checkpointed.
        If retrieve_from_epoch is not None, then the checkpoint metadata is
        retrieved from a previously saved checkpoint.
        """
        assert nodes is None or len(nodes) == 1, (
            'CheckpointManager only supports single node.')

        with Task(outputs=[self._blob_names]) as task:
            if retrieve_from_epoch is None:
                ops.GetAllBlobNames([], self._blob_names, include_shared=False)
            else:
                full_db_name = db_name(retrieve_from_epoch, self._node_name,
                                       self._db_prefix, path_prefix)
                db_type = path_type or self._db_type
                logger.info("Initializing checkpoints from = %s" %
                            full_db_name)
                ops.Load([],
                         self._blob_names,
                         db=full_db_name,
                         db_type=db_type,
                         absolute_path=True)
        self._names_output = task.outputs()[0]
        return task
예제 #3
0
 def add_op():
     ops.Load([],
              blob_names,
              db=self._current_db_name,
              db_type=self._db_type,
              absolute_path=True,
              allow_incomplete=True)
예제 #4
0
 def add_op():
     ops.Load(
         [],
         self.blob_list(),
         db=self._current_db_name,
         db_type=db_type,
         absolute_path=True,
         keep_device=True,
     )
예제 #5
0
 def load(self, epoch):
     """
     Build a Task that will be run by JobRunner when the job is to be
     resumed from a given epoch. This task will run a Load op that will
     load and deserialize all relevant blobs from a persistent storage.
     """
     with Task() as task:
         ops.Load([],
                  self.blob_list(),
                  db=self._dbname(epoch),
                  db_type=self._db_type,
                  absolute_path=True)
     return task
예제 #6
0
 def load(self, epoch, path_prefix=None):
     """
     Build a Task that will be run by JobRunner when the job is to be
     resumed from a given epoch. This task will run a Load op that will
     load and deserialize all relevant blobs from a persistent storage.
     """
     full_db_name = self._db_name(epoch, path_prefix)
     logger.info("Loading checkpoints from = %s" % full_db_name)
     with Task() as task:
         ops.Load([],
                  self.blob_list(),
                  db=full_db_name,
                  db_type=self._db_type,
                  absolute_path=True)
     return task