def reset_trial( self, trial: Trial, new_config: Dict, new_experiment_tag: str, logger_creator: Optional[Callable[[Dict], "ray.tune.Logger"]] = None, ) -> bool: """Tries to invoke `Trainable.reset()` to reset trial. Args: trial: Trial to be reset. new_config: New configuration for Trial trainable. new_experiment_tag: New experiment name for trial. logger_creator: Function that instantiates a logger on the actor process. Returns: True if `reset_config` is successful else False. """ trial.set_experiment_tag(new_experiment_tag) trial.set_config(new_config) trainable = trial.runner # Pass magic variables extra_config = copy.deepcopy(new_config) extra_config[TRIAL_INFO] = _TrialInfo(trial) stdout_file, stderr_file = trial.log_to_file extra_config[STDOUT_FILE] = stdout_file extra_config[STDERR_FILE] = stderr_file with self._change_working_directory(trial): with warn_if_slow("reset"): try: reset_val = ray.get( trainable.reset.remote(extra_config, logger_creator), timeout=DEFAULT_GET_TIMEOUT, ) except GetTimeoutError: logger.exception("Trial %s: reset timed out.", trial) return False return reset_val
def _setup_remote_runner(self, trial): trial.init_logdir() # We checkpoint metadata here to try mitigating logdir duplication self._trials_to_cache.add(trial) logger_creator = partial(noop_logger_creator, logdir=trial.logdir) if len(self._cached_actor_pg) > 0: assert self._reuse_actors existing_runner, pg = self._cached_actor_pg.popleft() logger.debug(f"Trial {trial}: Reusing cached runner " f"{existing_runner}") trial.set_runner(existing_runner) if pg: self._pg_manager.assign_cached_pg(pg, trial) if not self.reset_trial(trial, trial.config, trial.experiment_tag, logger_creator): raise _AbortTrialExecution( "Trainable runner reuse requires reset_config() to be " "implemented and return True.") return existing_runner trainable_cls = trial.get_trainable_cls() if not trainable_cls: raise _AbortTrialExecution( f"Invalid trainable: {trial.trainable_name}. If you passed " f"a string, make sure the trainable was registered before.") _actor_cls = _class_cache.get(trainable_cls) if not self._pg_manager.has_ready(trial): return None full_actor_class = self._pg_manager.get_full_actor_cls( trial, _actor_cls) # Clear the Trial's location (to be updated later on result) # since we don't know where the remote runner is placed. trial.set_location(_Location()) logger.debug("Trial %s: Setting up new remote runner.", trial) # Logging for trials is handled centrally by TrialRunner, so # configure the remote runner to use a noop-logger. trial_config = copy.deepcopy(trial.config) trial_config[TRIAL_INFO] = _TrialInfo(trial) stdout_file, stderr_file = trial.log_to_file trial_config[STDOUT_FILE] = stdout_file trial_config[STDERR_FILE] = stderr_file kwargs = { "config": trial_config, "logger_creator": logger_creator, } if trial.uses_cloud_checkpointing: # We keep these kwargs separate for backwards compatibility # with trainables that don't provide these keyword arguments kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir kwargs["sync_function_tpl"] = trial.sync_function_tpl # Throw a meaningful error if trainable does not use the # new API sig = inspect.signature(trial.get_trainable_cls()) try: sig.bind_partial(**kwargs) except Exception as e: raise RuntimeError( "Your trainable class does not accept a " "`remote_checkpoint_dir` or `sync_function_tpl` argument " "in its constructor, but you've passed a " "`upload_dir` to your SyncConfig. Without accepting " "these parameters and passing them to the base trainable " "constructor in the init call, cloud checkpointing is " "effectively disabled. To resolve this issue, add the " "parameters to your trainable class constructor or " "disable cloud checkpointing by setting `upload_dir=None`." ) from e with self._change_working_directory(trial): return full_actor_class.remote(**kwargs)
def testWandbDecoratorConfig(self): config = {"par1": 4, "par2": 9.12345678} trial = Trial( config, 0, "trial_0", "trainable", PlacementGroupFactory([{"CPU": 1}]), "/tmp", ) trial_info = _TrialInfo(trial) @wandb_mixin def train_fn(config): return 1 train_fn.__mixins__ = (_MockWandbTrainableMixin,) config[TRIAL_INFO] = trial_info if WANDB_ENV_VAR in os.environ: del os.environ[WANDB_ENV_VAR] # Needs at least a project with self.assertRaises(ValueError): wrapped = wrap_function(train_fn)(config) # No API key config["wandb"] = {"project": "test_project"} with self.assertRaises(ValueError): wrapped = wrap_function(train_fn)(config) # API Key in config config["wandb"] = {"project": "test_project", "api_key": "1234"} wrapped = wrap_function(train_fn)(config) self.assertEqual(os.environ[WANDB_ENV_VAR], "1234") del os.environ[WANDB_ENV_VAR] # API Key file with tempfile.NamedTemporaryFile("wt") as fp: fp.write("5678") fp.flush() config["wandb"] = {"project": "test_project", "api_key_file": fp.name} wrapped = wrap_function(train_fn)(config) self.assertEqual(os.environ[WANDB_ENV_VAR], "5678") del os.environ[WANDB_ENV_VAR] # API Key in env os.environ[WANDB_ENV_VAR] = "9012" config["wandb"] = {"project": "test_project"} wrapped = wrap_function(train_fn)(config) # From now on, the API key is in the env variable. # Default configuration config["wandb"] = {"project": "test_project"} config[TRIAL_INFO] = trial_info wrapped = wrap_function(train_fn)(config) self.assertEqual(wrapped.wandb.kwargs["project"], "test_project") self.assertEqual(wrapped.wandb.kwargs["id"], trial.trial_id) self.assertEqual(wrapped.wandb.kwargs["name"], trial.trial_name)