def testFindCheckpointDir(self): checkpoint_path = os.path.join(self.checkpoint_dir, "my/nested/chkpt") os.makedirs(checkpoint_path) found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) self.assertEquals(self.checkpoint_dir, found_dir) with self.assertRaises(FileNotFoundError): parent = os.path.dirname(found_dir) TrainableUtil.find_checkpoint_dir(parent)
def save_checkpoint(self, checkpoint): if isinstance(checkpoint, str): try: TrainableUtil.find_checkpoint_dir(checkpoint) except FileNotFoundError: logger.error("Checkpoint must be created with path given from " "make_checkpoint_dir.") raise self._last_checkpoint = checkpoint self._fresh_checkpoint = True
def save_all_states_remote(self, trial_state): """ Save all of AdaptDL's job state and return it as an in-memory object.""" checkpoint = save_all_states() parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) checkpoint_path = TrainableUtil.process_checkpoint(checkpoint, parent_dir, trial_state) checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_path) # Done with the directory, remove shutil.rmtree(checkpoint_path) return checkpoint_obj
def get_trial_checkpoints_paths(self, trial, metric=TRAINING_ITERATION): """Gets paths and metrics of all persistent checkpoints of a trial. Args: trial (Trial): The log directory of a trial, or a trial instance. metric (str): key for trial info to return, e.g. "mean_accuracy". "training_iteration" is used by default. Returns: List of [path, metric] for all persistent checkpoints of the trial. """ if isinstance(trial, str): trial_dir = os.path.expanduser(trial) # Get checkpoints from logdir. chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir) # Join with trial dataframe to get metrics. trial_df = self.trial_dataframes[trial_dir] path_metric_df = chkpt_df.merge( trial_df, on="training_iteration", how="inner") return path_metric_df[["chkpt_path", metric]].values.tolist() elif isinstance(trial, Trial): checkpoints = trial.checkpoint_manager.best_checkpoints() return [[c.value, c.result[metric]] for c in checkpoints] else: raise ValueError("trial should be a string or a Trial instance.")
def get_trial_checkpoints_paths(self, trial, metric=TRAINING_ITERATION): """Returns a list of [path, metric] lists for all disk checkpoints of a trial. Arguments: trial(Trial): The log directory of a trial, or a trial instance. metric (str): key for trial info to return, e.g. "mean_accuracy". "training_iteration" is used by default. """ if isinstance(trial, str): trial_dir = os.path.expanduser(trial) # get checkpoints from logdir chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir) # join with trial dataframe to get metrics trial_df = self.trial_dataframes[trial_dir] path_metric_df = chkpt_df.merge( trial_df, on="training_iteration", how="inner") return path_metric_df[["chkpt_path", metric]].values.tolist() elif isinstance(trial, Trial): checkpoints = trial.checkpoint_manager.best_checkpoints() # TODO(ujvl): Remove condition once the checkpoint manager is # modified to only track PERSISTENT checkpoints. return [[c.value, c.result[metric]] for c in checkpoints if c.storage == Checkpoint.PERSISTENT] else: raise ValueError("trial should be a string or a Trial instance.")
def save_checkpoint(self, checkpoint_dir: str = ""): if checkpoint_dir: raise ValueError( "Checkpoint dir should not be used with function API.") checkpoint = self._status_reporter.get_checkpoint() if not checkpoint: # We drop a marker here to indicate that the checkpoint is empty checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir) parent_dir = checkpoint elif isinstance(checkpoint, dict): return checkpoint elif isinstance(checkpoint, str): parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) # When the trainable is restored, a temporary checkpoint # is created. However, when saved, it should become permanent. # Ideally, there are no save calls upon a temporary # checkpoint, but certain schedulers might. if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir): parent_dir = FuncCheckpointUtil.create_perm_checkpoint( checkpoint_dir=parent_dir, logdir=self.logdir, step=self.training_iteration, ) else: raise ValueError("Provided checkpoint was expected to have " "type (str, dict). Got {}.".format( type(checkpoint))) return parent_dir
def save_all_states(): """ Invokes `save_state` on all `State` objects for which `State.skip` is True. This function can be used to trigger a global checkpoint and save every `State` in the current job. """ if from_ray(): from ray.tune.trainable import TrainableUtil checkpoint_dir = TrainableUtil.make_checkpoint_dir("/tmp", index=None, override=True) else: checkpoint_dir = checkpoint_path() for state in _STATES_TO_NAMES: save_state(state, checkpoint_dir) # Prevent corrupting original state files in case the process got killed # during state file writing. if replica_rank() == 0 and checkpoint_dir is not None: tmp_ckpt_dir = _get_tmp_ckpt_dir(checkpoint_dir) ckpt_dir = os.path.join(checkpoint_dir, f"{CKPT_DIR_PREFIX}{num_restarts()}") os.rename(tmp_ckpt_dir, ckpt_dir) # atomic, rename(src, dst) for dir_name in os.listdir(checkpoint_dir): dir_path = os.path.join(checkpoint_dir, dir_name) if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir: shutil.rmtree(dir_path) return checkpoint_dir
def mk_temp_checkpoint_dir(logdir): """Indicate that the checkpoint is only for restoration.""" temporary_checkpoint_dir = TrainableUtil.make_checkpoint_dir( logdir, index="tmp" + uuid.uuid4().hex[:6], override=True ) open(os.path.join(temporary_checkpoint_dir, TEMP_MARKER), "a").close() return temporary_checkpoint_dir
def pause(self, trial_runner): """ Pause the AdaptDLTrial with a checkpoint. We try to remove the PG attached to this trial""" assert self.runner is not None checkpoint_obj = ray.get( self.runner.save_all_states.remote(self.runner.get_state.remote())) # Serialize to disk temp_checkpoint_dir = (FuncCheckpointUtil.mk_temp_checkpoint_dir( self.logdir)) checkpoint_path = TrainableUtil.create_from_pickle( checkpoint_obj, temp_checkpoint_dir) # Trial will be restored from the checkpoint_path when it's resumed self.restore_path = checkpoint_path # Clear the allocation. This is a hack to clear the PG associated with # the trial. We assign a temporary PG which will get replaced with a # real PG once we resume the trial. This is needed because Tune likes # to keep the PGs around even for PAUSED trials. self.placement_group_factory = PlacementGroupFactory([{"CPU": 0.001}]) # This forces Tune to garbage-collect uneeded PGs which can then be # reused trial_runner.trial_executor._pg_manager.\ reconcile_placement_groups([self]) logger.debug(f"PAUSING {self} w/ checkpoint at {checkpoint_path}")
def set_checkpoint(self, checkpoint, is_new=True): """Sets the checkpoint to be returned upon get_checkpoint. If this is a "new" checkpoint, it will notify Tune (via has_new_checkpoint). Otherwise, it will NOT notify Tune. """ if isinstance(checkpoint, str): try: TrainableUtil.find_checkpoint_dir(checkpoint) except FileNotFoundError: logger.error("Checkpoint must be created with path given from " "make_checkpoint_dir.") raise self._last_checkpoint = checkpoint if is_new: self._fresh_checkpoint = True
def mk_null_checkpoint_dir(logdir): """Indicate that the given checkpoint doesn't have state.""" checkpoint_dir = TrainableUtil.make_checkpoint_dir(logdir, index=-1, override=True) open(os.path.join(checkpoint_dir, NULL_MARKER), "a").close() return checkpoint_dir
def _restore(self, trial, checkpoint=None, block=False) -> Optional[RunningJob]: """Restores training state from a given model checkpoint. Args: trial (Trial): The trial to be restored. checkpoint (Checkpoint): The checkpoint to restore from. If None, the most recent PERSISTENT checkpoint is used. Defaults to None. block (bool): Whether or not to block on restore before returning. Raises: RuntimeError: This error is raised if no runner is found. AbortTrialExecution: This error is raised if the trial is ineligible for restoration, given the Tune input arguments. """ if checkpoint is None or checkpoint.value is None: checkpoint = trial.checkpoint if checkpoint.value is None: return if trial.runner is None: raise RuntimeError( "Trial {}: Unable to restore - no runner found.".format(trial)) value = checkpoint.value if checkpoint.storage == Checkpoint.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. with _change_working_directory(trial): trial.runner.restore_from_object.remote(value) else: logger.debug("Trial %s: Attempting restore from %s", trial, value) if issubclass(trial.get_trainable_cls(), DurableTrainable): with _change_working_directory(trial): remote = trial.runner.restore.remote(value) elif trial.sync_on_checkpoint: # This provides FT backwards compatibility in the # case where a DurableTrainable is not provided. logger.warning("Trial %s: Reading checkpoint into memory.", trial) data_dict = TrainableUtil.pickle_checkpoint(value) with _change_working_directory(trial): remote = trial.runner.restore_from_object.remote(data_dict) else: raise AbortTrialExecution( "Pass in `sync_on_checkpoint=True` for driver-based trial" "restoration. Pass in an `upload_dir` and a Trainable " "extending `DurableTrainable` for remote storage-based " "restoration") if block: ray.get(remote) else: trial.restoring_from = checkpoint running_job = RunningJob(trial, remote) self.jobs_running[remote] = running_job return running_job
def save_to_object(self): checkpoint_path = self.save() data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) out = io.BytesIO() if len(data_dict) > 10e6: # getting pretty large logger.info("Checkpoint size is {} bytes".format(len(data_dict))) out.write(data_dict) return out.getvalue()
def restore_from_object(self, obj): self.temp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir( self.logdir ) checkpoint_path = TrainableUtil.create_from_pickle( obj, self.temp_checkpoint_dir ) self.restore(checkpoint_path)
def save(self, checkpoint_path=None) -> str: if checkpoint_path: raise ValueError("Checkpoint path should not be used with function API.") checkpoint = self._status_reporter.get_checkpoint() state = self.get_state() if not checkpoint: state.update(iteration=0, timesteps_total=0, episodes_total=0) # We drop a marker here to indicate that the checkpoint is empty checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir) parent_dir = checkpoint elif isinstance(checkpoint, dict): parent_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index=self.training_iteration ) elif isinstance(checkpoint, str): parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) # When the trainable is restored, a temporary checkpoint # is created. However, when saved, it should become permanent. # Ideally, there are no save calls upon a temporary # checkpoint, but certain schedulers might. if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir): relative_path = os.path.relpath(checkpoint, parent_dir) parent_dir = FuncCheckpointUtil.create_perm_checkpoint( checkpoint_dir=parent_dir, logdir=self.logdir, step=self.training_iteration, ) checkpoint = os.path.abspath(os.path.join(parent_dir, relative_path)) else: raise ValueError( "Provided checkpoint was expected to have " "type (str, dict). Got {}.".format(type(checkpoint)) ) checkpoint_path = TrainableUtil.process_checkpoint( checkpoint, parent_dir, state ) self._postprocess_checkpoint(checkpoint_path) self._maybe_save_to_cloud(parent_dir) return checkpoint_path
def delete_checkpoint(self, checkpoint_path): """Deletes checkpoint from both local and remote storage. Args: checkpoint_path (str): Local path to checkpoint. """ super(DurableTrainable, self).delete_checkpoint(checkpoint_path) local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path) self.storage_client.delete(self._storage_path(local_dirpath))
def save(self, checkpoint_path=None) -> str: if checkpoint_path: raise ValueError("Checkpoint path should not be used with function API.") checkpoint_path = self.save_checkpoint() parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) self._maybe_save_to_cloud(parent_dir) return checkpoint_path
def save(self, checkpoint_path=None): if checkpoint_path: raise ValueError( "Checkpoint path should not be used with function API.") checkpoint = self._status_reporter.get_checkpoint() state = self.get_state() if not checkpoint: state.update(iteration=0, timesteps_total=0, episodes_total=0) parent_dir = self.create_default_checkpoint_dir() elif isinstance(checkpoint, dict): parent_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index=self.training_iteration) else: parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) checkpoint_path = TrainableUtil.process_checkpoint( checkpoint, parent_dir, state) return checkpoint_path
def restore_from_object(self, obj): if self.default_checkpoint_dir is not None and os.exists( self.default_checkpoint_dir): shutil.rmtree(self.default_checkpoint_dir) logger.debug("Clearing default checkpoint: %s", self.default_checkpoint_dir) checkpoint_dir = self.create_default_checkpoint_dir() checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir) self.restore(checkpoint_path)
def testConvertTempToPermanent(self): checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir) new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint( checkpoint_dir, self.logdir, step=4) assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir( new_checkpoint_dir) assert os.path.exists(new_checkpoint_dir) assert not FuncCheckpointUtil.is_temp_checkpoint_dir( new_checkpoint_dir) tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir( self.logdir) assert tmp_checkpoint_dir != new_checkpoint_dir
def create_perm_checkpoint(checkpoint_dir, logdir, step): """Copies temporary checkpoint to a permanent checkpoint directory.""" checkpoint_dir = os.path.abspath(checkpoint_dir) temporary_marker = os.path.join(checkpoint_dir, TEMP_MARKER) assert os.path.exists(temporary_marker), ( "Should not be calling this method on a permanent checkpoint.") os.remove(temporary_marker) perm_checkpoint_dir = TrainableUtil.make_checkpoint_dir( logdir, index=step, override=True) shutil.rmtree(perm_checkpoint_dir) shutil.copytree(checkpoint_dir, perm_checkpoint_dir) assert not os.path.exists( os.path.join(perm_checkpoint_dir, TEMP_MARKER)) return perm_checkpoint_dir
def delete_checkpoint(self, checkpoint_path): """Deletes checkpoint from both local and remote storage. Args: checkpoint_path (str): Local path to checkpoint. """ try: local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path) except FileNotFoundError: logger.warning( "Trial %s: checkpoint path not found during " "garbage collection. See issue #6697.", self.trial_id) else: self.storage_client.delete(self._storage_path(local_dirpath)) super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
def testPickleCheckpoint(self): for i in range(5): path = os.path.join(self.checkpoint_dir, str(i)) with open(path, "w") as f: f.write(str(i)) checkpoint_path = os.path.join(self.checkpoint_dir, "0") data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) loaded = pickle.loads(data_dict) checkpoint_name = os.path.basename(checkpoint_path) self.assertEqual(loaded["checkpoint_name"], checkpoint_name) for i in range(5): path = os.path.join(self.checkpoint_dir, str(i)) self.assertEquals(loaded["data"][str(i)], open(path, "rb").read())
def delete(checkpoint): """Requests checkpoint deletion asynchronously. Args: checkpoint (Checkpoint): Checkpoint to delete. """ if checkpoint.storage == Checkpoint.PERSISTENT and checkpoint.value: logger.debug("Trial %s: Deleting checkpoint %s", trial_id, checkpoint.value) checkpoint_path = checkpoint.value # Delete local copy, if any exists. if os.path.exists(checkpoint_path): try: checkpoint_dir = TrainableUtil.find_checkpoint_dir( checkpoint_path) shutil.rmtree(checkpoint_dir) except FileNotFoundError: logger.warning("Checkpoint dir not found during deletion.") # TODO(ujvl): Batch remote deletes. runner.delete_checkpoint.remote(checkpoint.value)
def restore(self, trial, checkpoint=None): """Restores training state from a given model checkpoint. Raises: RuntimeError: This error is raised if no runner is found. AbortTrialExecution: This error is raised if the trial is ineligible for restoration, given the Tune input arguments. """ if checkpoint is None or checkpoint.value is None: checkpoint = trial.checkpoint if checkpoint.value is None: return if trial.runner is None: raise RuntimeError( "Trial {}: Unable to restore - no runner found.".format(trial)) value = checkpoint.value if checkpoint.storage == Checkpoint.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. trial.runner.restore_from_object.remote(value) else: logger.debug("Trial %s: Attempting restore from %s", trial, value) if issubclass(trial.get_trainable_cls(), DurableTrainable): remote = trial.runner.restore.remote(value) elif trial.sync_on_checkpoint: # This provides FT backwards compatibility in the # case where a DurableTrainable is not provided. logger.warning("Trial %s: Reading checkpoint into memory.", trial) data_dict = TrainableUtil.pickle_checkpoint(value) remote = trial.runner.restore_from_object.remote(data_dict) else: raise AbortTrialExecution( "Pass in `sync_on_checkpoint=True` for driver-based trial" "restoration. Pass in an `upload_dir` and a Trainable " "extending `DurableTrainable` for remote storage-based " "restoration") self._running[remote] = trial trial.restoring_from = checkpoint
def save_to_object(self): checkpoint_path = self.save() obj = TrainableUtil.checkpoint_to_object(checkpoint_path) return obj
def make_checkpoint_dir(self, step): checkpoint_dir = TrainableUtil.make_checkpoint_dir(self.logdir, index=step) logger.debug("Making checkpoint dir at %s", checkpoint_dir) return checkpoint_dir
def load_checkpoint(self, checkpoint_dir: str): checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir) x_id = ray.put(checkpoint_obj) return self.executor.execute(lambda w: w.restore_from_object(x_id))
def save_checkpoint(self, checkpoint_dir: str) -> str: # TODO: optimize if colocated save_obj = self.executor.execute_single(lambda w: w.save_to_object()) checkpoint_path = TrainableUtil.create_from_pickle( save_obj, checkpoint_dir) return checkpoint_path
def load_checkpoint(self, checkpoint_dir): checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir) return ray.get( w.restore_from_object.remote(checkpoint_obj) for w in self.workers)