def testFindCheckpointDir(self): checkpoint_path = os.path.join(self.checkpoint_dir, "my/nested/chkpt") os.makedirs(checkpoint_path) found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) self.assertEquals(self.checkpoint_dir, found_dir) with self.assertRaises(FileNotFoundError): parent = os.path.dirname(found_dir) TrainableUtil.find_checkpoint_dir(parent)
def save_checkpoint(self, checkpoint): if isinstance(checkpoint, str): try: TrainableUtil.find_checkpoint_dir(checkpoint) except FileNotFoundError: logger.error("Checkpoint must be created with path given from " "make_checkpoint_dir.") raise self._last_checkpoint = checkpoint self._fresh_checkpoint = True
def set_checkpoint(self, checkpoint, is_new=True): """Sets the checkpoint to be returned upon get_checkpoint. If this is a "new" checkpoint, it will notify Tune (via has_new_checkpoint). Otherwise, it will NOT notify Tune. """ if isinstance(checkpoint, str): try: TrainableUtil.find_checkpoint_dir(checkpoint) except FileNotFoundError: logger.error("Checkpoint must be created with path given from " "make_checkpoint_dir.") raise self._last_checkpoint = checkpoint if is_new: self._fresh_checkpoint = True
def save_checkpoint(self, checkpoint_dir: str = ""): if checkpoint_dir: raise ValueError( "Checkpoint dir should not be used with function API.") checkpoint = self._status_reporter.get_checkpoint() if not checkpoint: # We drop a marker here to indicate that the checkpoint is empty checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir) parent_dir = checkpoint elif isinstance(checkpoint, dict): return checkpoint elif isinstance(checkpoint, str): parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) # When the trainable is restored, a temporary checkpoint # is created. However, when saved, it should become permanent. # Ideally, there are no save calls upon a temporary # checkpoint, but certain schedulers might. if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir): parent_dir = FuncCheckpointUtil.create_perm_checkpoint( checkpoint_dir=parent_dir, logdir=self.logdir, step=self.training_iteration, ) else: raise ValueError("Provided checkpoint was expected to have " "type (str, dict). Got {}.".format( type(checkpoint))) return parent_dir
def delete_checkpoint(self, checkpoint_path): """Deletes checkpoint from both local and remote storage. Args: checkpoint_path (str): Local path to checkpoint. """ super(DurableTrainable, self).delete_checkpoint(checkpoint_path) local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path) self.storage_client.delete(self._storage_path(local_dirpath))
def save(self, checkpoint_path=None) -> str: if checkpoint_path: raise ValueError("Checkpoint path should not be used with function API.") checkpoint_path = self.save_checkpoint() parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) self._maybe_save_to_cloud(parent_dir) return checkpoint_path
def save_all_states_remote(self, trial_state): """ Save all of AdaptDL's job state and return it as an in-memory object.""" checkpoint = save_all_states() parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) checkpoint_path = TrainableUtil.process_checkpoint(checkpoint, parent_dir, trial_state) checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_path) # Done with the directory, remove shutil.rmtree(checkpoint_path) return checkpoint_obj
def testConvertTempToPermanent(self): checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir) new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint( checkpoint_dir, self.logdir, step=4) assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir( new_checkpoint_dir) assert os.path.exists(new_checkpoint_dir) assert not FuncCheckpointUtil.is_temp_checkpoint_dir( new_checkpoint_dir) tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir( self.logdir) assert tmp_checkpoint_dir != new_checkpoint_dir
def delete_checkpoint(self, checkpoint_path): """Deletes checkpoint from both local and remote storage. Args: checkpoint_path (str): Local path to checkpoint. """ try: local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path) except FileNotFoundError: logger.warning( "Trial %s: checkpoint path not found during " "garbage collection. See issue #6697.", self.trial_id) else: self.storage_client.delete(self._storage_path(local_dirpath)) super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
def save(self, checkpoint_path=None) -> str: if checkpoint_path: raise ValueError("Checkpoint path should not be used with function API.") checkpoint = self._status_reporter.get_checkpoint() state = self.get_state() if not checkpoint: state.update(iteration=0, timesteps_total=0, episodes_total=0) # We drop a marker here to indicate that the checkpoint is empty checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir) parent_dir = checkpoint elif isinstance(checkpoint, dict): parent_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index=self.training_iteration ) elif isinstance(checkpoint, str): parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) # When the trainable is restored, a temporary checkpoint # is created. However, when saved, it should become permanent. # Ideally, there are no save calls upon a temporary # checkpoint, but certain schedulers might. if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir): relative_path = os.path.relpath(checkpoint, parent_dir) parent_dir = FuncCheckpointUtil.create_perm_checkpoint( checkpoint_dir=parent_dir, logdir=self.logdir, step=self.training_iteration, ) checkpoint = os.path.abspath(os.path.join(parent_dir, relative_path)) else: raise ValueError( "Provided checkpoint was expected to have " "type (str, dict). Got {}.".format(type(checkpoint)) ) checkpoint_path = TrainableUtil.process_checkpoint( checkpoint, parent_dir, state ) self._postprocess_checkpoint(checkpoint_path) self._maybe_save_to_cloud(parent_dir) return checkpoint_path
def save(self, checkpoint_path=None): if checkpoint_path: raise ValueError( "Checkpoint path should not be used with function API.") checkpoint = self._status_reporter.get_checkpoint() state = self.get_state() if not checkpoint: state.update(iteration=0, timesteps_total=0, episodes_total=0) parent_dir = self.create_default_checkpoint_dir() elif isinstance(checkpoint, dict): parent_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index=self.training_iteration) else: parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) checkpoint_path = TrainableUtil.process_checkpoint( checkpoint, parent_dir, state) return checkpoint_path
def delete(checkpoint): """Requests checkpoint deletion asynchronously. Args: checkpoint (Checkpoint): Checkpoint to delete. """ if checkpoint.storage == Checkpoint.PERSISTENT and checkpoint.value: logger.debug("Trial %s: Deleting checkpoint %s", trial_id, checkpoint.value) checkpoint_path = checkpoint.value # Delete local copy, if any exists. if os.path.exists(checkpoint_path): try: checkpoint_dir = TrainableUtil.find_checkpoint_dir( checkpoint_path) shutil.rmtree(checkpoint_dir) except FileNotFoundError: logger.warning("Checkpoint dir not found during deletion.") # TODO(ujvl): Batch remote deletes. runner.delete_checkpoint.remote(checkpoint.value)