def testFindCheckpointDir(self): checkpoint_path = os.path.join(self.checkpoint_dir, "0/my/nested/chkpt") os.makedirs(checkpoint_path) found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) self.assertEqual(self.checkpoint_dir, found_dir) with self.assertRaises(FileNotFoundError): parent = os.path.dirname(found_dir) TrainableUtil.find_checkpoint_dir(parent)
def to_air_checkpoint(self) -> Optional[Checkpoint]: from ray.tune.trainable.util import TrainableUtil checkpoint_data = self.dir_or_data if not checkpoint_data: return None if isinstance(checkpoint_data, ray.ObjectRef): checkpoint_data = ray.get(checkpoint_data) if isinstance(checkpoint_data, str): checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data) checkpoint = Checkpoint.from_directory(checkpoint_dir) elif isinstance(checkpoint_data, bytes): with tempfile.TemporaryDirectory() as tmpdir: TrainableUtil.create_from_pickle(checkpoint_data, tmpdir) # Double wrap in checkpoint so we hold the data in memory and # can remove the temp directory checkpoint = Checkpoint.from_dict( Checkpoint.from_directory(tmpdir).to_dict()) elif isinstance(checkpoint_data, dict): checkpoint = Checkpoint.from_dict(checkpoint_data) else: raise RuntimeError( f"Unknown checkpoint data type: {type(checkpoint_data)}") return checkpoint
def __call__(self, checkpoint: _TrackedCheckpoint): """Requests checkpoint deletion asynchronously. Args: checkpoint: Checkpoint to delete. """ if not self.runner: return if (checkpoint.storage_mode == CheckpointStorage.PERSISTENT and checkpoint.dir_or_data): checkpoint_path = checkpoint.dir_or_data logger.debug("Trial %s: Deleting checkpoint %s", self.trial_id, checkpoint_path) # TODO(ujvl): Batch remote deletes. # We first delete the remote checkpoint. If it is on the same # node as the driver, it will also remove the local copy. ray.get(self.runner.delete_checkpoint.remote(checkpoint_path)) # Delete local copy, if any exists. if os.path.exists(checkpoint_path): try: checkpoint_dir = TrainableUtil.find_checkpoint_dir( checkpoint_path) shutil.rmtree(checkpoint_dir) except FileNotFoundError: logger.debug( "Local checkpoint dir not found during deletion.")
def save_checkpoint(self, tmp_checkpoint_dir: str = ""): checkpoint_path = super().save_checkpoint() parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) preprocessor = self._merged_config.get("preprocessor", None) if parent_dir and preprocessor: save_preprocessor_to_dir(preprocessor, parent_dir) return checkpoint_path
def restore(self, trial: Trial) -> None: """Restores training state from a given model checkpoint. Args: trial: The trial to be restored. Raises: RuntimeError: This error is raised if no runner is found. AbortTrialExecution: This error is raised if the trial is ineligible for restoration, given the Tune input arguments. """ checkpoint = trial.checkpoint if checkpoint.dir_or_data is None: return if trial.runner is None: raise RuntimeError( "Trial {}: Unable to restore - no runner found.".format(trial)) checkpoint_dir = checkpoint.dir_or_data node_ip = checkpoint.node_ip if checkpoint.storage_mode == CheckpointStorage.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. with self._change_working_directory(trial): trial.runner.restore_from_object.remote(checkpoint_dir) else: logger.debug("Trial %s: Attempting restore from %s", trial, checkpoint_dir) if (trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint or not os.path.exists(checkpoint_dir)): # If using cloud checkpointing, trial will get cp from cloud. # If not syncing to driver, assume it has access to the cp # on the local fs. with self._change_working_directory(trial): remote = trial.runner.restore.remote( checkpoint_dir, node_ip) elif trial.sync_on_checkpoint: # This provides FT backwards compatibility in the # case where no cloud checkpoints are provided. logger.debug("Trial %s: Reading checkpoint into memory", trial) checkpoint_path = TrainableUtil.find_checkpoint_dir( checkpoint_dir) obj = Checkpoint.from_directory(checkpoint_path).to_bytes() with self._change_working_directory(trial): remote = trial.runner.restore_from_object.remote(obj) else: raise _AbortTrialExecution( "Pass in `sync_on_checkpoint=True` for driver-based trial" "restoration. Pass in an `upload_dir` for remote " "storage-based restoration") self._futures[remote] = (_ExecutorEventType.RESTORING_RESULT, trial) trial.restoring_from = checkpoint
def testConvertTempToPermanent(self): checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir) new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint( checkpoint_dir, self.logdir, step=4) assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir( new_checkpoint_dir) assert os.path.exists(new_checkpoint_dir) assert not FuncCheckpointUtil.is_temp_checkpoint_dir( new_checkpoint_dir) tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir( self.logdir) assert tmp_checkpoint_dir != new_checkpoint_dir
def create_checkpoint(preprocessor: Optional[Preprocessor] = None, config: Optional[dict] = None) -> Checkpoint: rl_trainer = RLTrainer( algorithm=_DummyAlgo, config=config or {}, preprocessor=preprocessor, ) rl_trainable_cls = rl_trainer.as_trainable() rl_trainable = rl_trainable_cls() with tempfile.TemporaryDirectory() as checkpoint_dir: checkpoint_file = rl_trainable.save(checkpoint_dir) checkpoint_path = TrainableUtil.find_checkpoint_dir(checkpoint_file) checkpoint_data = Checkpoint.from_directory(checkpoint_path).to_dict() return Checkpoint.from_dict(checkpoint_data)
def delete_checkpoint(self, checkpoint_path: Union[str, Checkpoint]): """Deletes local copy of checkpoint. Args: checkpoint_path: Path to checkpoint. """ # Ensure Checkpoints are converted if isinstance(checkpoint_path, Checkpoint) and checkpoint_path._local_path: checkpoint_path = checkpoint_path._local_path try: checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) except FileNotFoundError: # The checkpoint won't exist locally if the # trial was rescheduled to another worker. logger.debug( f"Local checkpoint not found during garbage collection: " f"{self.trial_id} - {checkpoint_path}") return else: if self.uses_cloud_checkpointing: if self.custom_syncer: # Keep for backwards compatibility self.custom_syncer.delete( self._storage_path(checkpoint_dir)) self.custom_syncer.wait_or_retry() else: checkpoint_uri = self._storage_path(checkpoint_dir) retry_fn( lambda: delete_external_checkpoint(checkpoint_uri), subprocess.CalledProcessError, num_retries=3, sleep_time=1, ) if os.path.exists(checkpoint_dir): shutil.rmtree(checkpoint_dir)
def to_air_checkpoint(self) -> Optional[Checkpoint]: from ray.tune.trainable.util import TrainableUtil checkpoint_data = self.dir_or_data if not checkpoint_data: return None if isinstance(checkpoint_data, ray.ObjectRef): checkpoint_data = ray.get(checkpoint_data) if isinstance(checkpoint_data, str): try: checkpoint_dir = TrainableUtil.find_checkpoint_dir( checkpoint_data) except FileNotFoundError: if log_once("checkpoint_not_available"): logger.error( f"The requested checkpoint is not available on this node, " f"most likely because you are using Ray client or disabled " f"checkpoint synchronization. To avoid this, enable checkpoint " f"synchronization to cloud storage by specifying a " f"`SyncConfig`. The checkpoint may be available on a different " f"node - please check this location on worker nodes: " f"{checkpoint_data}") return None checkpoint = Checkpoint.from_directory(checkpoint_dir) elif isinstance(checkpoint_data, bytes): checkpoint = Checkpoint.from_bytes(checkpoint_data) elif isinstance(checkpoint_data, dict): checkpoint = Checkpoint.from_dict(checkpoint_data) else: raise RuntimeError( f"Unknown checkpoint data type: {type(checkpoint_data)}") return checkpoint
def restore( self, checkpoint_path: Union[str, Checkpoint], checkpoint_node_ip: Optional[str] = None, ): """Restores training state from a given model checkpoint. These checkpoints are returned from calls to save(). Subclasses should override ``load_checkpoint()`` instead to restore state. This method restores additional metadata saved with the checkpoint. `checkpoint_path` should match with the return from ``save()``. `checkpoint_path` can be `~/ray_results/exp/MyTrainable_abc/ checkpoint_00000/checkpoint`. Or, `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`. `self.logdir` should generally be corresponding to `checkpoint_path`, for example, `~/ray_results/exp/MyTrainable_abc`. `self.remote_checkpoint_dir` in this case, is something like, `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc` Args: checkpoint_path: Path to restore checkpoint from. If this path does not exist on the local node, it will be fetched from external (cloud) storage if available, or restored from a remote node. checkpoint_node_ip: If given, try to restore checkpoint from this node if it doesn't exist locally or on cloud storage. """ # Ensure Checkpoints are converted if isinstance(checkpoint_path, Checkpoint): return self._restore_from_checkpoint_obj(checkpoint_path) if not self._maybe_load_from_cloud(checkpoint_path) and ( # If a checkpoint source IP is given checkpoint_node_ip # And the checkpoint does not currently exist on the local node and not os.path.exists(checkpoint_node_ip) # And the source IP is different to the current IP and checkpoint_node_ip != ray.util.get_node_ip_address()): checkpoint = get_checkpoint_from_remote_node( checkpoint_path, checkpoint_node_ip) if checkpoint: checkpoint.to_directory(checkpoint_path) if not os.path.exists(checkpoint_path): raise ValueError( f"Could not recover from checkpoint as it does not exist on local " f"disk and was not available on cloud storage or another Ray node. " f"Got checkpoint path: {checkpoint_path} and IP {checkpoint_node_ip}" ) checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) metadata = TrainableUtil.load_metadata(checkpoint_dir) if metadata["saved_as_dict"]: # If data was saved as a dict (e.g. from a class trainable), # also pass the dict to `load_checkpoint()`. checkpoint_dict = Checkpoint.from_directory( checkpoint_dir).to_dict() # If other files were added to the directory after converting from the # original dict (e.g. marker files), clean these up checkpoint_dict.pop(_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY, None) to_load = checkpoint_dict else: # Otherwise, pass the relative checkpoint path relative_checkpoint_path = metadata["relative_checkpoint_path"] to_load = os.path.join(checkpoint_dir, relative_checkpoint_path) # Set metadata self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"] self._time_total = metadata["time_total"] self._episodes_total = metadata["episodes_total"] # Actually load checkpoint self.load_checkpoint(to_load) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True logger.info("Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_dir) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } logger.info("Current state after restoring: %s", state)