def restore(self, trial: Trial) -> None: """Restores training state from a given model checkpoint. Args: trial: The trial to be restored. Raises: RuntimeError: This error is raised if no runner is found. AbortTrialExecution: This error is raised if the trial is ineligible for restoration, given the Tune input arguments. """ checkpoint = trial.checkpoint if checkpoint.dir_or_data is None: return if trial.runner is None: raise RuntimeError( "Trial {}: Unable to restore - no runner found.".format(trial)) checkpoint_dir = checkpoint.dir_or_data node_ip = checkpoint.node_ip if checkpoint.storage_mode == CheckpointStorage.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. with self._change_working_directory(trial): trial.runner.restore_from_object.remote(checkpoint_dir) else: logger.debug("Trial %s: Attempting restore from %s", trial, checkpoint_dir) if (trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint or not os.path.exists(checkpoint_dir)): # If using cloud checkpointing, trial will get cp from cloud. # If not syncing to driver, assume it has access to the cp # on the local fs. with self._change_working_directory(trial): remote = trial.runner.restore.remote( checkpoint_dir, node_ip) elif trial.sync_on_checkpoint: # This provides FT backwards compatibility in the # case where no cloud checkpoints are provided. logger.debug("Trial %s: Reading checkpoint into memory", trial) checkpoint_path = TrainableUtil.find_checkpoint_dir( checkpoint_dir) obj = Checkpoint.from_directory(checkpoint_path).to_bytes() with self._change_working_directory(trial): remote = trial.runner.restore_from_object.remote(obj) else: raise _AbortTrialExecution( "Pass in `sync_on_checkpoint=True` for driver-based trial" "restoration. Pass in an `upload_dir` for remote " "storage-based restoration") self._futures[remote] = (_ExecutorEventType.RESTORING_RESULT, trial) trial.restoring_from = checkpoint
def _setup_remote_runner(self, trial): trial.init_logdir() # We checkpoint metadata here to try mitigating logdir duplication self._trials_to_cache.add(trial) logger_creator = partial(noop_logger_creator, logdir=trial.logdir) if len(self._cached_actor_pg) > 0: assert self._reuse_actors existing_runner, pg = self._cached_actor_pg.popleft() logger.debug(f"Trial {trial}: Reusing cached runner " f"{existing_runner}") trial.set_runner(existing_runner) if pg: self._pg_manager.assign_cached_pg(pg, trial) if not self.reset_trial(trial, trial.config, trial.experiment_tag, logger_creator): raise _AbortTrialExecution( "Trainable runner reuse requires reset_config() to be " "implemented and return True.") return existing_runner trainable_cls = trial.get_trainable_cls() if not trainable_cls: raise _AbortTrialExecution( f"Invalid trainable: {trial.trainable_name}. If you passed " f"a string, make sure the trainable was registered before.") _actor_cls = _class_cache.get(trainable_cls) if not self._pg_manager.has_ready(trial): return None full_actor_class = self._pg_manager.get_full_actor_cls( trial, _actor_cls) # Clear the Trial's location (to be updated later on result) # since we don't know where the remote runner is placed. trial.set_location(_Location()) logger.debug("Trial %s: Setting up new remote runner.", trial) # Logging for trials is handled centrally by TrialRunner, so # configure the remote runner to use a noop-logger. trial_config = copy.deepcopy(trial.config) trial_config[TRIAL_INFO] = _TrialInfo(trial) stdout_file, stderr_file = trial.log_to_file trial_config[STDOUT_FILE] = stdout_file trial_config[STDERR_FILE] = stderr_file kwargs = { "config": trial_config, "logger_creator": logger_creator, } if trial.uses_cloud_checkpointing: # We keep these kwargs separate for backwards compatibility # with trainables that don't provide these keyword arguments kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir kwargs["custom_syncer"] = trial.custom_syncer # Throw a meaningful error if trainable does not use the # new API sig = inspect.signature(trial.get_trainable_cls()) try: sig.bind_partial(**kwargs) except Exception as e: raise RuntimeError( "Your trainable class does not accept a " "`remote_checkpoint_dir` or `custom_syncer` argument " "in its constructor, but you've passed a " "`upload_dir` to your SyncConfig. Without accepting " "these parameters and passing them to the base trainable " "constructor in the init call, cloud checkpointing is " "effectively disabled. To resolve this issue, add the " "parameters to your trainable class constructor or " "disable cloud checkpointing by setting `upload_dir=None`." ) from e with self._change_working_directory(trial): return full_actor_class.remote(**kwargs)