def load_blobs_from_checkpoint(self, blob_names, epoch): """ Builds a Task that loads only the necessary blobs from a checkpoint of the given epoch. The necessary blobs are given in the blob_names argument. Args: blob_names: A list of strings. Each string is the name of a blob. epoch: The checkpoint epoch to load from. Returns: A Task which loads the specified blobs from the checkpoint of the given epoch. """ logger.info('Load from %s' % db_name(epoch, self._node_name, self._db_prefix)) with Task() as task: ops.Load([], blob_names, db=db_name(epoch, self._node_name, self._db_prefix), db_type=self._db_type, absolute_path=True, allow_incomplete=True) return task
def init(self, nodes=None, retrieve_from_epoch=None, path_prefix=None, path_type=None): """ Build a Task that will be run once after the job's `init_group` is run. This task will determine which blobs need to be checkpointed. If retrieve_from_epoch is not None, then the checkpoint metadata is retrieved from a previously saved checkpoint. """ assert nodes is None or len(nodes) == 1, ( 'CheckpointManager only supports single node.') with Task(outputs=[self._blob_names]) as task: if retrieve_from_epoch is None: ops.GetAllBlobNames([], self._blob_names, include_shared=False) else: full_db_name = db_name(retrieve_from_epoch, self._node_name, self._db_prefix, path_prefix) db_type = path_type or self._db_type logger.info("Initializing checkpoints from = %s" % full_db_name) ops.Load([], self._blob_names, db=full_db_name, db_type=db_type, absolute_path=True) self._names_output = task.outputs()[0] return task
def add_op(): ops.Load([], blob_names, db=self._current_db_name, db_type=self._db_type, absolute_path=True, allow_incomplete=True)
def add_op(): ops.Load( [], self.blob_list(), db=self._current_db_name, db_type=db_type, absolute_path=True, keep_device=True, )
def load(self, epoch): """ Build a Task that will be run by JobRunner when the job is to be resumed from a given epoch. This task will run a Load op that will load and deserialize all relevant blobs from a persistent storage. """ with Task() as task: ops.Load([], self.blob_list(), db=self._dbname(epoch), db_type=self._db_type, absolute_path=True) return task
def load(self, epoch, path_prefix=None): """ Build a Task that will be run by JobRunner when the job is to be resumed from a given epoch. This task will run a Load op that will load and deserialize all relevant blobs from a persistent storage. """ full_db_name = self._db_name(epoch, path_prefix) logger.info("Loading checkpoints from = %s" % full_db_name) with Task() as task: ops.Load([], self.blob_list(), db=full_db_name, db_type=self._db_type, absolute_path=True) return task