def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool: if not self.uses_cloud_checkpointing: return False rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir( self.logdir, checkpoint_path ) external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir) local_dir = os.path.join(self.logdir, rel_checkpoint_dir) if self.custom_syncer: # Only keep for backwards compatibility self.custom_syncer.sync_down(remote_dir=external_uri, local_dir=local_dir) self.custom_syncer.wait_or_retry() return True checkpoint = Checkpoint.from_uri(external_uri) retry_fn( lambda: checkpoint.to_directory(local_dir), subprocess.CalledProcessError, num_retries=3, sleep_time=1, ) return True
def restore(self, checkpoint_path): """Restores training state from a given model checkpoint. These checkpoints are returned from calls to save(). Subclasses should override ``load_checkpoint()`` instead to restore state. This method restores additional metadata saved with the checkpoint. `checkpoint_path` should match with the return from ``save()``. `checkpoint_path` can be `~/ray_results/exp/MyTrainable_abc/ checkpoint_00000/checkpoint`. Or, `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`. `self.logdir` should generally be corresponding to `checkpoint_path`, for example, `~/ray_results/exp/MyTrainable_abc`. `self.remote_checkpoint_dir` in this case, is something like, `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc` """ if self.uses_cloud_checkpointing: rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir( self.logdir, checkpoint_path) self.storage_client.sync_down( os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir), os.path.join(self.logdir, rel_checkpoint_dir), ) self.storage_client.wait_or_retry() # Ensure TrialCheckpoints are converted if isinstance(checkpoint_path, TrialCheckpoint): checkpoint_path = checkpoint_path.local_path with open(checkpoint_path + ".tune_metadata", "rb") as f: metadata = pickle.load(f) self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"] self._time_total = metadata["time_total"] self._episodes_total = metadata["episodes_total"] saved_as_dict = metadata["saved_as_dict"] if saved_as_dict: with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) checkpoint_dict.update(tune_checkpoint_path=checkpoint_path) self.load_checkpoint(checkpoint_dict) else: self.load_checkpoint(checkpoint_path) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True logger.info("Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } logger.info("Current state after restoring: %s", state)
def test_find_rel_checkpoint_dir(checkpoint_path, logdir): assert (TrainableUtil.find_rel_checkpoint_dir( logdir, checkpoint_path) == "checkpoint0/")
def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None): """Restores training state from a given model checkpoint. These checkpoints are returned from calls to save(). Subclasses should override ``load_checkpoint()`` instead to restore state. This method restores additional metadata saved with the checkpoint. `checkpoint_path` should match with the return from ``save()``. `checkpoint_path` can be `~/ray_results/exp/MyTrainable_abc/ checkpoint_00000/checkpoint`. Or, `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`. `self.logdir` should generally be corresponding to `checkpoint_path`, for example, `~/ray_results/exp/MyTrainable_abc`. `self.remote_checkpoint_dir` in this case, is something like, `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc` Args: checkpoint_path: Path to restore checkpoint from. If this path does not exist on the local node, it will be fetched from external (cloud) storage if available, or restored from a remote node. checkpoint_node_ip: If given, try to restore checkpoint from this node if it doesn't exist locally or on cloud storage. """ # Ensure TrialCheckpoints are converted if isinstance(checkpoint_path, TrialCheckpoint): checkpoint_path = checkpoint_path.local_path if self.uses_cloud_checkpointing: rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir( self.logdir, checkpoint_path ) external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir) local_dir = os.path.join(self.logdir, rel_checkpoint_dir) if self.storage_client: # Only keep for backwards compatibility self.storage_client.sync_down(external_uri, local_dir) self.storage_client.wait_or_retry() else: checkpoint = Checkpoint.from_uri(external_uri) retry_fn( lambda: checkpoint.to_directory(local_dir), subprocess.CalledProcessError, num_retries=3, sleep_time=1, ) elif ( # If a checkpoint source IP is given checkpoint_node_ip # And the checkpoint does not currently exist on the local node and not os.path.exists(checkpoint_node_ip) # And the source IP is different to the current IP and checkpoint_node_ip != ray.util.get_node_ip_address() ): checkpoint = get_checkpoint_from_remote_node( checkpoint_path, checkpoint_node_ip ) if checkpoint: checkpoint.to_directory(checkpoint_path) with open(checkpoint_path + ".tune_metadata", "rb") as f: metadata = pickle.load(f) self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"] self._time_total = metadata["time_total"] self._episodes_total = metadata["episodes_total"] saved_as_dict = metadata["saved_as_dict"] if saved_as_dict: with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) checkpoint_dict.update(tune_checkpoint_path=checkpoint_path) self.load_checkpoint(checkpoint_dict) else: self.load_checkpoint(checkpoint_path) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True logger.info( "Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path ) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } logger.info("Current state after restoring: %s", state)