def basicSetup(self): ray.init(num_cpus=4, num_gpus=1) port = get_valid_port() self.runner = TrialRunner(server_port=port) runner = self.runner kwargs = { "stopping_criterion": { "training_iteration": 3 }, "resources": Resources(cpu=1, gpu=1), } trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] for t in trials: runner.add_trial(t) client = TuneClient("localhost", port) return runner, client
def set_status(self, trial: Trial, status: str) -> None: """Sets status and checkpoints metadata if needed. Only checkpoints metadata if trial status is a terminal condition. PENDING, PAUSED, and RUNNING switches have checkpoints taken care of in the TrialRunner. Args: trial: Trial to checkpoint. status: Status to set trial to. """ if trial.status == status: logger.debug("Trial %s: Status %s unchanged.", trial, trial.status) else: logger.debug("Trial %s: Changing status from %s to %s.", trial, trial.status, status) trial.set_status(status) if status in [Trial.TERMINATED, Trial.ERROR]: self._trials_to_cache.add(trial)
def reset_trial( self, trial: Trial, new_config: Dict, new_experiment_tag: str, logger_creator: Optional[Callable[[Dict], "ray.tune.Logger"]] = None, ) -> bool: """Tries to invoke `Trainable.reset()` to reset trial. Args: trial: Trial to be reset. new_config: New configuration for Trial trainable. new_experiment_tag: New experiment name for trial. logger_creator: Function that instantiates a logger on the actor process. Returns: True if `reset_config` is successful else False. """ trial.set_experiment_tag(new_experiment_tag) trial.set_config(new_config) trainable = trial.runner # Pass magic variables extra_config = copy.deepcopy(new_config) extra_config[TRIAL_INFO] = _TrialInfo(trial) stdout_file, stderr_file = trial.log_to_file extra_config[STDOUT_FILE] = stdout_file extra_config[STDERR_FILE] = stderr_file with self._change_working_directory(trial): with warn_if_slow("reset"): try: reset_val = ray.get( trainable.reset.remote(extra_config, logger_creator), timeout=DEFAULT_GET_TIMEOUT, ) except GetTimeoutError: logger.exception("Trial %s: reset timed out.", trial) return False return reset_val
def restore(self, trial: Trial) -> None: """Restores training state from a given model checkpoint. Args: trial: The trial to be restored. Raises: RuntimeError: This error is raised if no runner is found. AbortTrialExecution: This error is raised if the trial is ineligible for restoration, given the Tune input arguments. """ checkpoint = trial.checkpoint if checkpoint.dir_or_data is None: return if trial.runner is None: raise RuntimeError( "Trial {}: Unable to restore - no runner found.".format(trial)) checkpoint_dir = checkpoint.dir_or_data node_ip = checkpoint.node_ip if checkpoint.storage_mode == CheckpointStorage.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. with self._change_working_directory(trial): trial.runner.restore_from_object.remote(checkpoint_dir) else: logger.debug("Trial %s: Attempting restore from %s", trial, checkpoint_dir) if (trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint or not os.path.exists(checkpoint_dir)): # If using cloud checkpointing, trial will get cp from cloud. # If not syncing to driver, assume it has access to the cp # on the local fs. with self._change_working_directory(trial): remote = trial.runner.restore.remote( checkpoint_dir, node_ip) elif trial.sync_on_checkpoint: # This provides FT backwards compatibility in the # case where no cloud checkpoints are provided. logger.debug("Trial %s: Reading checkpoint into memory", trial) checkpoint_path = TrainableUtil.find_checkpoint_dir( checkpoint_dir) obj = Checkpoint.from_directory(checkpoint_path).to_bytes() with self._change_working_directory(trial): remote = trial.runner.restore_from_object.remote(obj) else: raise _AbortTrialExecution( "Pass in `sync_on_checkpoint=True` for driver-based trial" "restoration. Pass in an `upload_dir` for remote " "storage-based restoration") self._futures[remote] = (_ExecutorEventType.RESTORING_RESULT, trial) trial.restoring_from = checkpoint
def save( self, trial: Trial, storage: CheckpointStorage = CheckpointStorage.PERSISTENT, result: Optional[Dict] = None, ) -> _TrackedCheckpoint: """Saves the trial's state to a checkpoint asynchronously. Args: trial: The trial to be saved. storage: Where to store the checkpoint. Defaults to PERSISTENT. result: The state of this trial as a dictionary to be saved. If result is None, the trial's last result will be used. Returns: Checkpoint object, or None if an Exception occurs. """ logger.debug(f"saving trial {trial}") result = result or trial.last_result with self._change_working_directory(trial): if storage == CheckpointStorage.MEMORY: value = trial.runner.save_to_object.remote() checkpoint = _TrackedCheckpoint(dir_or_data=value, storage_mode=storage, metrics=result) trial.on_checkpoint(checkpoint) else: value = trial.runner.save.remote() checkpoint = _TrackedCheckpoint(dir_or_data=value, storage_mode=storage, metrics=result) trial.saving_to = checkpoint self._futures[value] = (_ExecutorEventType.SAVING_RESULT, trial) return checkpoint
def _start_trial(self, trial: Trial) -> bool: """Starts trial and restores last result if trial was paused. Args: trial: The trial to start. Returns: True if trial was started successfully, False otherwise. See `RayTrialExecutor.restore` for possible errors raised. """ self.set_status(trial, Trial.PENDING) runner = self._setup_remote_runner(trial) if not runner: return False trial.set_runner(runner) self.restore(trial) self.set_status(trial, Trial.RUNNING) self._staged_trials.discard(trial) if not trial.is_restoring: self._train(trial) return True
def testGetTrialsWithFunction(self): runner, client = self.basicSetup() test_trial = Trial( "__fake", trial_id="function_trial", stopping_criterion={"training_iteration": 3}, config={"callbacks": { "on_episode_start": lambda x: None }}, ) runner.add_trial(test_trial) for i in range(3): runner.step() all_trials = client.get_all_trials()["trials"] self.assertEqual(len(all_trials), 3) client.get_trial("function_trial") runner.step() self.assertEqual(len(all_trials), 3)
def _stop_trial( self, trial: Trial, error: bool = False, exc: Optional[Union[TuneError, RayTaskError]] = None, ): """Stops this trial. Stops this trial, releasing all allocating resources. If stopping the trial fails, the run will be marked as terminated in error, but no exception will be thrown. Args: error: Whether to mark this trial as terminated in error. exc: Optional exception. """ self.set_status(trial, Trial.ERROR if error or exc else Trial.TERMINATED) self._trial_just_finished = True trial.set_location(_Location()) try: trial.write_error_log(exc=exc) if hasattr(trial, "runner") and trial.runner: if (not error and self._reuse_actors and (len(self._cached_actor_pg) < (self._cached_actor_pg.maxlen or float("inf")))): logger.debug("Reusing actor for %s", trial.runner) # Move PG into cache (disassociate from trial) pg = self._pg_manager.cache_trial_pg(trial) if pg: # True if a placement group was replaced self._cached_actor_pg.append((trial.runner, pg)) should_destroy_actor = False else: # False if no placement group was replaced. This should # only be the case if there are no more trials with # this placement group factory to run logger.debug( f"Could not cache actor of trial {trial} for " "reuse, as there are no pending trials " "requiring its resources.") should_destroy_actor = True else: should_destroy_actor = True if should_destroy_actor: logger.debug("Trial %s: Destroying actor.", trial) with self._change_working_directory(trial): future = trial.runner.stop.remote() pg = self._pg_manager.remove_from_in_use(trial) self._futures[future] = (_ExecutorEventType.STOP_RESULT, pg) if self._trial_cleanup: # force trial cleanup within a deadline self._trial_cleanup.add(future) self._staged_trials.discard(trial) except Exception: logger.exception("Trial %s: Error stopping runner.", trial) self.set_status(trial, Trial.ERROR) finally: trial.set_runner(None)