def to_air_checkpoint(self) -> Optional[Checkpoint]: from ray.tune.trainable.util import TrainableUtil checkpoint_data = self.dir_or_data if not checkpoint_data: return None if isinstance(checkpoint_data, ray.ObjectRef): checkpoint_data = ray.get(checkpoint_data) if isinstance(checkpoint_data, str): checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data) checkpoint = Checkpoint.from_directory(checkpoint_dir) elif isinstance(checkpoint_data, bytes): with tempfile.TemporaryDirectory() as tmpdir: TrainableUtil.create_from_pickle(checkpoint_data, tmpdir) # Double wrap in checkpoint so we hold the data in memory and # can remove the temp directory checkpoint = Checkpoint.from_dict( Checkpoint.from_directory(tmpdir).to_dict()) elif isinstance(checkpoint_data, dict): checkpoint = Checkpoint.from_dict(checkpoint_data) else: raise RuntimeError( f"Unknown checkpoint data type: {type(checkpoint_data)}") return checkpoint
def save(self, checkpoint_dir: Optional[str] = None) -> str: """Saves the current model state to a checkpoint. Subclasses should override ``save_checkpoint()`` instead to save state. This method dumps additional metadata alongside the saved path. If a remote checkpoint dir is given, this will also sync up to remote storage. Args: checkpoint_dir: Optional dir to place the checkpoint. Returns: str: path that points to xxx.pkl file. Note the return path should match up with what is expected of `restore()`. """ checkpoint_dir = TrainableUtil.make_checkpoint_dir( checkpoint_dir or self.logdir, index=self.iteration) checkpoint_dict_or_path = self.save_checkpoint(checkpoint_dir) trainable_state = self.get_state() checkpoint_path = TrainableUtil.process_checkpoint( checkpoint_dict_or_path, parent_dir=checkpoint_dir, trainable_state=trainable_state, ) # Maybe sync to cloud self._maybe_save_to_cloud(checkpoint_dir) return checkpoint_path
def testFindCheckpointDir(self): checkpoint_path = os.path.join(self.checkpoint_dir, "0/my/nested/chkpt") os.makedirs(checkpoint_path) found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) self.assertEqual(self.checkpoint_dir, found_dir) with self.assertRaises(FileNotFoundError): parent = os.path.dirname(found_dir) TrainableUtil.find_checkpoint_dir(parent)
def __call__(self, checkpoint: _TrackedCheckpoint): """Requests checkpoint deletion asynchronously. Args: checkpoint: Checkpoint to delete. """ if not self.runner: return if (checkpoint.storage_mode == CheckpointStorage.PERSISTENT and checkpoint.dir_or_data): checkpoint_path = checkpoint.dir_or_data logger.debug("Trial %s: Deleting checkpoint %s", self.trial_id, checkpoint_path) # TODO(ujvl): Batch remote deletes. # We first delete the remote checkpoint. If it is on the same # node as the driver, it will also remove the local copy. ray.get(self.runner.delete_checkpoint.remote(checkpoint_path)) # Delete local copy, if any exists. if os.path.exists(checkpoint_path): try: checkpoint_dir = TrainableUtil.find_checkpoint_dir( checkpoint_path) shutil.rmtree(checkpoint_dir) except FileNotFoundError: logger.debug( "Local checkpoint dir not found during deletion.")
def _create_checkpoint_dir(self, checkpoint_dir: Optional[str] = None ) -> Optional[str]: # Create checkpoint_xxxxx directory and drop checkpoint marker checkpoint_dir = TrainableUtil.make_checkpoint_dir( checkpoint_dir or self.logdir, index=self.iteration) return checkpoint_dir
def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool: if not self.uses_cloud_checkpointing: return False rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir( self.logdir, checkpoint_path) external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir) local_dir = os.path.join(self.logdir, rel_checkpoint_dir) if self.custom_syncer: # Only keep for backwards compatibility self.custom_syncer.sync_down(remote_dir=external_uri, local_dir=local_dir) self.custom_syncer.wait_or_retry() return True checkpoint = Checkpoint.from_uri(external_uri) retry_fn( lambda: checkpoint.to_directory(local_dir), subprocess.CalledProcessError, num_retries=3, sleep_time=1, ) return True
def setUp(self): self.checkpoint_dir = os.path.join( ray._private.utils.get_user_temp_dir(), "tune", "MyTrainable123" ) self.checkpoint_dir = TrainableUtil.make_checkpoint_dir( self.checkpoint_dir, "0" )
def save_checkpoint(self, tmp_checkpoint_dir: str = ""): checkpoint_path = super().save_checkpoint() parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) preprocessor = self._merged_config.get("preprocessor", None) if parent_dir and preprocessor: save_preprocessor_to_dir(preprocessor, parent_dir) return checkpoint_path
def restore_from_object(self, obj): """Restores training state from a checkpoint object. These checkpoints are returned from calls to save_to_object(). """ tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir) checkpoint_path = TrainableUtil.create_from_pickle(obj, tmpdir) self.restore(checkpoint_path) shutil.rmtree(tmpdir)
def restore(self, trial: Trial) -> None: """Restores training state from a given model checkpoint. Args: trial: The trial to be restored. Raises: RuntimeError: This error is raised if no runner is found. AbortTrialExecution: This error is raised if the trial is ineligible for restoration, given the Tune input arguments. """ checkpoint = trial.checkpoint if checkpoint.dir_or_data is None: return if trial.runner is None: raise RuntimeError( "Trial {}: Unable to restore - no runner found.".format(trial)) checkpoint_dir = checkpoint.dir_or_data node_ip = checkpoint.node_ip if checkpoint.storage_mode == CheckpointStorage.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. with self._change_working_directory(trial): trial.runner.restore_from_object.remote(checkpoint_dir) else: logger.debug("Trial %s: Attempting restore from %s", trial, checkpoint_dir) if (trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint or not os.path.exists(checkpoint_dir)): # If using cloud checkpointing, trial will get cp from cloud. # If not syncing to driver, assume it has access to the cp # on the local fs. with self._change_working_directory(trial): remote = trial.runner.restore.remote( checkpoint_dir, node_ip) elif trial.sync_on_checkpoint: # This provides FT backwards compatibility in the # case where no cloud checkpoints are provided. logger.debug("Trial %s: Reading checkpoint into memory", trial) checkpoint_path = TrainableUtil.find_checkpoint_dir( checkpoint_dir) obj = Checkpoint.from_directory(checkpoint_path).to_bytes() with self._change_working_directory(trial): remote = trial.runner.restore_from_object.remote(obj) else: raise _AbortTrialExecution( "Pass in `sync_on_checkpoint=True` for driver-based trial" "restoration. Pass in an `upload_dir` for remote " "storage-based restoration") self._futures[remote] = (_ExecutorEventType.RESTORING_RESULT, trial) trial.restoring_from = checkpoint
def testConvertTempToPermanent(self): checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir) new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint( checkpoint_dir, self.logdir, step=4) assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir( new_checkpoint_dir) assert os.path.exists(new_checkpoint_dir) assert not FuncCheckpointUtil.is_temp_checkpoint_dir( new_checkpoint_dir) tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir( self.logdir) assert tmp_checkpoint_dir != new_checkpoint_dir
def save_to_object(self): """Saves the current model state to a Python object. It also saves to disk but does not return the checkpoint path. Returns: Object holding checkpoint data. """ tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir) checkpoint_path = self.save(tmpdir) # Save all files in subtree and delete the tmpdir. obj = TrainableUtil.checkpoint_to_object(checkpoint_path) shutil.rmtree(tmpdir) return obj
def write_checkpoint(trial: Trial, index: int): checkpoint_dir = TrainableUtil.make_checkpoint_dir(trial.logdir, index=index) result = {"training_iteration": index} with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f: json.dump(result, f) tune_cp = _TrackedCheckpoint( dir_or_data=checkpoint_dir, storage_mode=CheckpointStorage.PERSISTENT, metrics=result, ) trial.saving_to = tune_cp return checkpoint_dir
def create_checkpoint(preprocessor: Optional[Preprocessor] = None, config: Optional[dict] = None) -> Checkpoint: rl_trainer = RLTrainer( algorithm=_DummyAlgo, config=config or {}, preprocessor=preprocessor, ) rl_trainable_cls = rl_trainer.as_trainable() rl_trainable = rl_trainable_cls() with tempfile.TemporaryDirectory() as checkpoint_dir: checkpoint_file = rl_trainable.save(checkpoint_dir) checkpoint_path = TrainableUtil.find_checkpoint_dir(checkpoint_file) checkpoint_data = Checkpoint.from_directory(checkpoint_path).to_dict() return Checkpoint.from_dict(checkpoint_data)
def testPickleCheckpoint(self): for i in range(5): path = os.path.join(self.checkpoint_dir, str(i)) with open(path, "w") as f: f.write(str(i)) checkpoint_path = os.path.join(self.checkpoint_dir, "0") data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) loaded = cloudpickle.loads(data_dict) checkpoint_name = os.path.basename(checkpoint_path) self.assertEqual(loaded["checkpoint_name"], checkpoint_name) for i in range(5): path = os.path.join(self.checkpoint_dir, str(i)) self.assertEqual(loaded["data"][str(i)], open(path, "rb").read())
def delete_checkpoint(self, checkpoint_path: Union[str, Checkpoint]): """Deletes local copy of checkpoint. Args: checkpoint_path: Path to checkpoint. """ # Ensure Checkpoints are converted if isinstance(checkpoint_path, Checkpoint) and checkpoint_path._local_path: checkpoint_path = checkpoint_path._local_path try: checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) except FileNotFoundError: # The checkpoint won't exist locally if the # trial was rescheduled to another worker. logger.debug( f"Local checkpoint not found during garbage collection: " f"{self.trial_id} - {checkpoint_path}") return else: if self.uses_cloud_checkpointing: if self.custom_syncer: # Keep for backwards compatibility self.custom_syncer.delete( self._storage_path(checkpoint_dir)) self.custom_syncer.wait_or_retry() else: checkpoint_uri = self._storage_path(checkpoint_dir) retry_fn( lambda: delete_external_checkpoint(checkpoint_uri), subprocess.CalledProcessError, num_retries=3, sleep_time=1, ) if os.path.exists(checkpoint_dir): shutil.rmtree(checkpoint_dir)
def to_air_checkpoint(self) -> Optional[Checkpoint]: from ray.tune.trainable.util import TrainableUtil checkpoint_data = self.dir_or_data if not checkpoint_data: return None if isinstance(checkpoint_data, ray.ObjectRef): checkpoint_data = ray.get(checkpoint_data) if isinstance(checkpoint_data, str): try: checkpoint_dir = TrainableUtil.find_checkpoint_dir( checkpoint_data) except FileNotFoundError: if log_once("checkpoint_not_available"): logger.error( f"The requested checkpoint is not available on this node, " f"most likely because you are using Ray client or disabled " f"checkpoint synchronization. To avoid this, enable checkpoint " f"synchronization to cloud storage by specifying a " f"`SyncConfig`. The checkpoint may be available on a different " f"node - please check this location on worker nodes: " f"{checkpoint_data}") return None checkpoint = Checkpoint.from_directory(checkpoint_dir) elif isinstance(checkpoint_data, bytes): checkpoint = Checkpoint.from_bytes(checkpoint_data) elif isinstance(checkpoint_data, dict): checkpoint = Checkpoint.from_dict(checkpoint_data) else: raise RuntimeError( f"Unknown checkpoint data type: {type(checkpoint_data)}") return checkpoint
def get_trial_checkpoints_paths( self, trial: Trial, metric: Optional[str] = None) -> List[Tuple[str, Number]]: """Gets paths and metrics of all persistent checkpoints of a trial. Args: trial: The log directory of a trial, or a trial instance. metric: key for trial info to return, e.g. "mean_accuracy". "training_iteration" is used by default if no value was passed to ``self.default_metric``. Returns: List of [path, metric] for all persistent checkpoints of the trial. """ metric = metric or self.default_metric or TRAINING_ITERATION if isinstance(trial, str): trial_dir = os.path.expanduser(trial) # Get checkpoints from logdir. chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir) # Join with trial dataframe to get metrics. trial_df = self.trial_dataframes[trial_dir] path_metric_df = chkpt_df.merge(trial_df, on="training_iteration", how="inner") return path_metric_df[["chkpt_path", metric]].values.tolist() elif isinstance(trial, Trial): checkpoints = trial.get_trial_checkpoints() # Support metrics given as paths, e.g. # "info/learner/default_policy/policy_loss". return [(c.dir_or_data, unflattened_lookup(metric, c.metrics)) for c in checkpoints] else: raise ValueError("trial should be a string or a Trial instance.")
def save(self, checkpoint_dir: Optional[str] = None) -> str: """Saves the current model state to a checkpoint. Subclasses should override ``save_checkpoint()`` instead to save state. This method dumps additional metadata alongside the saved path. If a remote checkpoint dir is given, this will also sync up to remote storage. Args: checkpoint_dir: Optional dir to place the checkpoint. Returns: The given or created checkpoint directory. Note the return path should match up with what is expected of `restore()`. """ checkpoint_dir = self._create_checkpoint_dir( checkpoint_dir=checkpoint_dir) # User saves checkpoint checkpoint_dict_or_path = self.save_checkpoint(checkpoint_dir) if checkpoint_dict_or_path is None: # checkpoint_dict_or_path can only be None in class trainables. # In that case the default is to use the root checkpoint directory. assert checkpoint_dir checkpoint_dict_or_path = checkpoint_dir elif checkpoint_dir is None: # checkpoint_dir is only None in function trainables. In that case, # checkpoint_dict_or_path points to the already saved checkpoint dir. # This will be considered the root dir. assert isinstance(checkpoint_dict_or_path, str) checkpoint_dir = checkpoint_dict_or_path # Get trainable metadata metadata = self.get_state() if isinstance(checkpoint_dict_or_path, dict): metadata["relative_checkpoint_path"] = "" metadata["saved_as_dict"] = True Checkpoint.from_dict(checkpoint_dict_or_path).to_directory( checkpoint_dir) # Re-drop marker TrainableUtil.mark_as_checkpoint_dir(checkpoint_dir) else: # Make sure the checkpoint dir is contained if not checkpoint_dict_or_path.startswith(checkpoint_dir): raise ValueError( f"The returned checkpoint path must be within the given " f"checkpoint dir ({checkpoint_dir}): {checkpoint_dict_or_path}" ) # Get relative path to returned checkpoint relative_checkpoint_path = os.path.relpath(checkpoint_dict_or_path, checkpoint_dir) metadata["relative_checkpoint_path"] = relative_checkpoint_path metadata["saved_as_dict"] = False TrainableUtil.write_metadata(checkpoint_dir, metadata) # Maybe sync to cloud self._maybe_save_to_cloud(checkpoint_dir) return checkpoint_dir
def restore( self, checkpoint_path: Union[str, Checkpoint], checkpoint_node_ip: Optional[str] = None, ): """Restores training state from a given model checkpoint. These checkpoints are returned from calls to save(). Subclasses should override ``load_checkpoint()`` instead to restore state. This method restores additional metadata saved with the checkpoint. `checkpoint_path` should match with the return from ``save()``. `checkpoint_path` can be `~/ray_results/exp/MyTrainable_abc/ checkpoint_00000/checkpoint`. Or, `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`. `self.logdir` should generally be corresponding to `checkpoint_path`, for example, `~/ray_results/exp/MyTrainable_abc`. `self.remote_checkpoint_dir` in this case, is something like, `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc` Args: checkpoint_path: Path to restore checkpoint from. If this path does not exist on the local node, it will be fetched from external (cloud) storage if available, or restored from a remote node. checkpoint_node_ip: If given, try to restore checkpoint from this node if it doesn't exist locally or on cloud storage. """ # Ensure Checkpoints are converted if isinstance(checkpoint_path, Checkpoint): return self._restore_from_checkpoint_obj(checkpoint_path) if not self._maybe_load_from_cloud(checkpoint_path) and ( # If a checkpoint source IP is given checkpoint_node_ip # And the checkpoint does not currently exist on the local node and not os.path.exists(checkpoint_node_ip) # And the source IP is different to the current IP and checkpoint_node_ip != ray.util.get_node_ip_address()): checkpoint = get_checkpoint_from_remote_node( checkpoint_path, checkpoint_node_ip) if checkpoint: checkpoint.to_directory(checkpoint_path) if not os.path.exists(checkpoint_path): raise ValueError( f"Could not recover from checkpoint as it does not exist on local " f"disk and was not available on cloud storage or another Ray node. " f"Got checkpoint path: {checkpoint_path} and IP {checkpoint_node_ip}" ) checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) metadata = TrainableUtil.load_metadata(checkpoint_dir) if metadata["saved_as_dict"]: # If data was saved as a dict (e.g. from a class trainable), # also pass the dict to `load_checkpoint()`. checkpoint_dict = Checkpoint.from_directory( checkpoint_dir).to_dict() # If other files were added to the directory after converting from the # original dict (e.g. marker files), clean these up checkpoint_dict.pop(_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY, None) to_load = checkpoint_dict else: # Otherwise, pass the relative checkpoint path relative_checkpoint_path = metadata["relative_checkpoint_path"] to_load = os.path.join(checkpoint_dir, relative_checkpoint_path) # Set metadata self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"] self._time_total = metadata["time_total"] self._episodes_total = metadata["episodes_total"] # Actually load checkpoint self.load_checkpoint(to_load) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True logger.info("Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_dir) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } logger.info("Current state after restoring: %s", state)
def test_find_rel_checkpoint_dir(checkpoint_path, logdir): assert ( TrainableUtil.find_rel_checkpoint_dir(logdir, checkpoint_path) == "checkpoint0" )