def step(self): """Runs one step of the trial event loop. Callers should typically run this method repeatedly in a loop. They may inspect or modify the runner's state in between calls to step(). """ if self.is_finished(): raise TuneError("Called step when all trials finished?") with warn_if_slow("on_step_begin"): self.trial_executor.on_step_begin(self) next_trial = self._get_next_trial() # blocking if next_trial is not None: with warn_if_slow("start_trial"): self.trial_executor.start_trial(next_trial) elif self.trial_executor.get_running_trials(): self._process_events() # blocking else: self.trial_executor.on_no_available_trials(self) try: with warn_if_slow("experiment_checkpoint"): self.checkpoint() except Exception: logger.exception("Trial Runner checkpointing failed.") self._iteration += 1 if self._server: with warn_if_slow("server"): self._process_requests() if self.is_finished(): self._server.shutdown() with warn_if_slow("on_step_end"): self.trial_executor.on_step_end(self)
def _process_events(self): failed_trial = self.trial_executor.get_next_failed_trial() if failed_trial: error_msg = ( "{} (IP: {}) detected as stale. This is likely because the " "node was lost").format(failed_trial, failed_trial.node_ip) logger.info(error_msg) with warn_if_slow("process_failed_trial"): self._process_trial_failure(failed_trial, error_msg=error_msg) else: # TODO(ujvl): Consider combining get_next_available_trial and # fetch_result functionality so that we don't timeout on fetch. trial = self.trial_executor.get_next_available_trial() # blocking if trial.is_restoring: with warn_if_slow("process_trial_restore"): self._process_trial_restore(trial) elif trial.is_saving: with warn_if_slow("process_trial_save") as profile: self._process_trial_save(trial) if profile.too_slow and trial.sync_on_checkpoint: # TODO(ujvl): Suggest using DurableTrainable once # API has converged. logger.warning( "Consider turning off forced head-worker trial " "checkpoint syncs by setting sync_on_checkpoint=False" ". Note that this may result in faulty trial " "restoration if a failure occurs while the checkpoint " "is being synced from the worker to the head node.") else: with warn_if_slow("process_trial"): self._process_trial(trial)
def _process_trial(self, trial): """Processes a trial result. Fetches the trial's latest result and makes a scheduling decision regarding its next action. If a checkpoint is taken, the decided action is cached and acted on only after the checkpoint is later processed (see `_process_trial_save`). Otherwise the decision is acted on immediately. If multiple results are received (e.g. because of buffering), all results are processed and the final action is determined. STOP takes precedence over PAUSE, which takes precedence over CONTINUE. Args: trial (Trial): Trial with a result ready to be processed. """ try: results = self.trial_executor.fetch_result(trial) with warn_if_slow( "process_trial_results", message="Processing trial results took {duration:.3f} s, " "which may be a performance bottleneck. Please consider " "reporting results less frequently to Ray Tune."): for i, result in enumerate(results): with warn_if_slow("process_trial_result"): decision = self._process_trial_result(trial, result) if decision is None: # If we didn't get a decision, this means a # non-training future (e.g. a save) was scheduled. # We do not allow processing more results then. if i < len(results) - 1: raise RuntimeError( f"Trial {trial} has a non-training future " f"scheduled but {len(results)-i} results " f"left to process. This should never " f"happen - please file an issue at " f"https://github.com/ray-project/ray/issues") elif decision == TrialScheduler.STOP: # If the decision is to stop the trial, # ignore all results that came after that. break except Exception: error_msg = "Trial %s: Error processing event." % trial if self._fail_fast == TrialRunner.RAISE: logger.error(error_msg) raise else: logger.exception(error_msg) self._process_trial_failure(trial, traceback.format_exc())
def reset_trial(self, trial, new_config, new_experiment_tag): """Tries to invoke `Trainable.reset_config()` to reset trial. Args: trial (Trial): Trial to be reset. new_config (dict): New configuration for Trial trainable. new_experiment_tag (str): New experiment name for trial. Returns: True if `reset_config` is successful else False. """ trial.experiment_tag = new_experiment_tag trial.config = new_config trainable = trial.runner with self._change_working_directory(trial): with warn_if_slow("reset_config"): try: reset_val = ray.get( trainable.reset_config.remote(new_config), DEFAULT_GET_TIMEOUT) except RayTimeoutError: logger.exception("Trial %s: reset_config timed out.", trial) return False return reset_val
def reset_trial(self, trial, new_config, new_experiment_tag, logger_creator=None): """Tries to invoke `Trainable.reset()` to reset trial. Args: trial (Trial): Trial to be reset. new_config (dict): New configuration for Trial trainable. new_experiment_tag (str): New experiment name for trial. logger_creator (Callable[[Dict], Logger]): A function that instantiates a logger on the actor process. Returns: True if `reset_config` is successful else False. """ trial.experiment_tag = new_experiment_tag trial.config = new_config trainable = trial.runner with self._change_working_directory(trial): with warn_if_slow("reset"): try: reset_val = ray.get( trainable.reset.remote(new_config, logger_creator), timeout=DEFAULT_GET_TIMEOUT) except GetTimeoutError: logger.exception("Trial %s: reset timed out.", trial) return False return reset_val
def checkpoint(self, force=False): """Saves execution state to `self._local_checkpoint_dir`. Overwrites the current session checkpoint, which starts when self is instantiated. Throttle depends on self._checkpoint_period. Also automatically saves the search algorithm to the local checkpoint dir. Args: force (bool): Forces a checkpoint despite checkpoint_period. """ with warn_if_slow( "experiment_checkpoint", message="Checkpointing the experiment state took " "{duration:.3f} s, which may be a performance " "bottleneck. Please ensure the " "`TUNE_GLOBAL_CHECKPOINT_S` environment variable is " "something significantly higher than this duration " "to ensure compute time is mostly spent on the main " "training loop.", disable=self._checkpoint_manager.auto_checkpoint_enabled): self._checkpoint_manager.checkpoint( checkpoint_file=self.checkpoint_file, trial_runner=self, trial_executor=self.trial_executor, search_alg=self._search_alg, force=force)
def _process_events(self, timeout: Optional[float] = None): with warn_if_slow("get_next_failed_trial"): failed_trial = self.trial_executor.get_next_failed_trial() if failed_trial: error_msg = ( "{} (IP: {}) detected as stale. This is likely because the " "node was lost").format(failed_trial, failed_trial.node_ip) logger.info(error_msg) with warn_if_slow("process_failed_trial"): self._process_trial_failure(failed_trial, error_msg=error_msg) else: # TODO(ujvl): Consider combining get_next_available_trial and # fetch_result functionality so that we don't timeout on fetch. trial = self.trial_executor.get_next_available_trial( timeout=timeout) # blocking if not trial: return if trial.is_restoring: with warn_if_slow("process_trial_restore"): self._process_trial_restore(trial) with warn_if_slow("callbacks.on_trial_restore"): self._callbacks.on_trial_restore( iteration=self._iteration, trials=self._trials, trial=trial) elif trial.is_saving: with warn_if_slow("process_trial_save") as _profile: self._process_trial_save(trial) with warn_if_slow("callbacks.on_trial_save"): self._callbacks.on_trial_save( iteration=self._iteration, trials=self._trials, trial=trial) if _profile.too_slow and trial.sync_on_checkpoint: # TODO(ujvl): Suggest using DurableTrainable once # API has converged. msg = ( "Consider turning off forced head-worker trial " "checkpoint syncs by setting sync_on_checkpoint=False" ". Note that this may result in faulty trial " "restoration if a failure occurs while the checkpoint " "is being synced from the worker to the head node.") if trial.location.hostname and (trial.location.hostname != get_node_ip_address()): if log_once("tune_head_worker_checkpoint"): logger.warning(msg) else: with warn_if_slow("process_trial"): self._process_trial(trial) # `self._queued_trial_decisions` now contains a final decision # based on all results if trial not in self._cached_trial_decisions: final_decision = self._queued_trial_decisions.pop( trial.trial_id, None) if final_decision: self._execute_action(trial, final_decision)
def _start_trial(trial: Trial) -> bool: """Helper function to start trial and call callbacks""" with warn_if_slow("start_trial"): if self.trial_executor.start_trial(trial): self._callbacks.on_trial_start(iteration=self._iteration, trials=self._trials, trial=trial) return True return False
def _process_events(self): failed_trial = self.trial_executor.get_next_failed_trial() if failed_trial: error_msg = ( "{} (IP: {}) detected as stale. This is likely because the " "node was lost").format(failed_trial, failed_trial.node_ip) logger.info(error_msg) with warn_if_slow("process_failed_trial"): self._process_trial_failure(failed_trial, error_msg=error_msg) else: # TODO(ujvl): Consider combining get_next_available_trial and # fetch_result functionality so that we don't timeout on fetch. trial = self.trial_executor.get_next_available_trial() # blocking if trial.is_restoring: with warn_if_slow("process_trial_restore"): self._process_trial_restore(trial) else: with warn_if_slow("process_trial"): self._process_trial(trial)
def _get_next_trial(self): """Replenishes queue. Blocks if all trials queued have finished, but search algorithm is still not finished. """ trials_done = all(trial.is_finished() for trial in self._trials) wait_for_trial = trials_done and not self._search_alg.is_finished() self._update_trial_queue(blocking=wait_for_trial) with warn_if_slow("choose_trial_to_run"): trial = self._scheduler_alg.choose_trial_to_run(self) return trial
def add_trial(self, trial): """Adds a new trial to this TrialRunner. Trials may be added at any time. Args: trial (Trial): Trial to queue. """ self._trials.append(trial) with warn_if_slow("scheduler.on_trial_add"): self._scheduler_alg.on_trial_add(self, trial) self.trial_executor.try_checkpoint_metadata(trial)
def _notify_trainable_of_new_resources_if_needed(self, trial: Trial): if trial.has_new_resources: trainable = trial.runner trial.has_new_resources = False with self._change_working_directory(trial): with warn_if_slow("update_resources"): try: ray.get(trainable._update_resources.remote( trial.placement_group_factory), timeout=DEFAULT_GET_TIMEOUT) except GetTimeoutError: logger.exception( "Trial %s: updating resources timed out.", trial)
def save(self, trial, storage=Checkpoint.PERSISTENT, result=None): """Saves the trial's state to a checkpoint. Args: trial (Trial): The state of this trial to be saved. storage (str): Where to store the checkpoint. Defaults to PERSISTENT. result (dict): The state of this trial as a dictionary to be saved. If result is None, the trial's last result will be used. Returns: Checkpoint future, or None if an Exception occurs. """ result = result or trial.last_result with self._change_working_directory(trial): if storage == Checkpoint.MEMORY: value = trial.runner.save_to_object.remote() checkpoint = Checkpoint(storage, value, result) else: with warn_if_slow("save_checkpoint_to_storage"): # TODO(ujvl): Make this asynchronous. value = ray.get(trial.runner.save.remote()) checkpoint = Checkpoint(storage, value, result) with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile: try: trial.on_checkpoint(checkpoint) except Exception: logger.exception("Trial %s: Error handling checkpoint %s", trial, checkpoint.value) return None if profile.too_slow and trial.sync_on_checkpoint: logger.warning( "Consider turning off forced head-worker trial checkpoint " "syncs by setting sync_on_checkpoint=False. Note that this " "might result in faulty trial restoration for some worker " "failure modes.") return checkpoint.value
def add_trial(self, trial): """Adds a new trial to this TrialRunner. Trials may be added at any time. Args: trial (Trial): Trial to queue. """ if trial.uses_placement_groups: self._max_pending_trials = TUNE_MAX_PENDING_TRIALS_PG self._trials.append(trial) with warn_if_slow("scheduler.on_trial_add"): self._scheduler_alg.on_trial_add(self, trial) self.trial_executor.try_checkpoint_metadata(trial)
def _get_next_trial(self): """Replenishes queue. Blocks if all trials queued have finished, but search algorithm is still not finished. """ trials_done = all(trial.is_finished() for trial in self._trials) wait_for_trial = trials_done and not self._search_alg.is_finished() # Only fetch a new trial if we have no pending trial if not any(trial.status == Trial.PENDING for trial in self._trials) \ or wait_for_trial: self._update_trial_queue(blocking=wait_for_trial) with warn_if_slow("choose_trial_to_run"): trial = self._scheduler_alg.choose_trial_to_run(self) logger.debug("Running trial {}".format(trial)) return trial
def fetch_result(self, trial): """Fetches one result of the running trials. Returns: Result of the most recent trial training run. """ trial_future = self._find_item(self._running, trial) if not trial_future: raise ValueError("Trial was not running.") self._running.pop(trial_future[0]) with warn_if_slow("fetch_result"): result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT) # For local mode if isinstance(result, _LocalWrapper): result = result.unwrap() return result
def reset_trial( self, trial: Trial, new_config: Dict, new_experiment_tag: str, logger_creator: Optional[Callable[[Dict], "ray.tune.Logger"]] = None, ) -> bool: """Tries to invoke `Trainable.reset()` to reset trial. Args: trial: Trial to be reset. new_config: New configuration for Trial trainable. new_experiment_tag: New experiment name for trial. logger_creator: Function that instantiates a logger on the actor process. Returns: True if `reset_config` is successful else False. """ trial.set_experiment_tag(new_experiment_tag) trial.set_config(new_config) trainable = trial.runner # Pass magic variables extra_config = copy.deepcopy(new_config) extra_config[TRIAL_INFO] = TrialInfo(trial) stdout_file, stderr_file = trial.log_to_file extra_config[STDOUT_FILE] = stdout_file extra_config[STDERR_FILE] = stderr_file with self._change_working_directory(trial): with warn_if_slow("reset"): try: reset_val = ray.get( trainable.reset.remote(extra_config, logger_creator), timeout=DEFAULT_GET_TIMEOUT, ) except GetTimeoutError: logger.exception("Trial %s: reset timed out.", trial) return False return reset_val
def _requeue_trial(self, trial): """Notification to TrialScheduler and requeue trial. This does not notify the SearchAlgorithm because the function evaluation is still in progress. """ self._scheduler_alg.on_trial_error(self, trial) self.trial_executor.set_status(trial, Trial.PENDING) # TODO(rliaw): Right now, this pushes the trial to the end of queue # because restoration can be expensive. However, this is not # ideal since it just hides the issue - a better fix would # be to use an actor table to detect the IP of the Trainable # and rsync the files there. # See https://github.com/ray-project/ray/issues/5168 self._trials.pop(self._trials.index(trial)) self._trials.append(trial) with warn_if_slow("scheduler.on_trial_add"): self._scheduler_alg.on_trial_add(self, trial)
def _process_trial_result(self, trial, result): result.update(trial_id=trial.trial_id) is_duplicate = RESULT_DUPLICATE in result force_checkpoint = result.get(SHOULD_CHECKPOINT, False) # TrialScheduler and SearchAlgorithm still receive a # notification because there may be special handling for # the `on_trial_complete` hook. if is_duplicate: logger.debug("Trial finished without logging 'done'.") result = trial.last_result result.update(done=True) self._total_time += result.get(TIME_THIS_ITER_S, 0) flat_result = flatten_dict(result) self._validate_result_metrics(flat_result) if self._stopper(trial.trial_id, result) or trial.should_stop(flat_result): result.update(done=True) # Hook into scheduler self._scheduler_alg.on_trial_complete(self, trial, flat_result) self._search_alg.on_trial_complete(trial.trial_id, result=flat_result) # If this is not a duplicate result, the callbacks should # be informed about the result. if not is_duplicate: with warn_if_slow("callbacks.on_trial_result"): self._callbacks.on_trial_result(iteration=self._iteration, trials=self._trials, trial=trial, result=result.copy()) self._callbacks.on_trial_complete(iteration=self._iteration, trials=self._trials, trial=trial) decision = TrialScheduler.STOP else: with warn_if_slow("scheduler.on_trial_result"): decision = self._scheduler_alg.on_trial_result( self, trial, flat_result) if decision == TrialScheduler.STOP: result.update(done=True) with warn_if_slow("search_alg.on_trial_result"): self._search_alg.on_trial_result(trial.trial_id, flat_result) with warn_if_slow("callbacks.on_trial_result"): self._callbacks.on_trial_result(iteration=self._iteration, trials=self._trials, trial=trial, result=result.copy()) if decision == TrialScheduler.STOP: with warn_if_slow("search_alg.on_trial_complete"): self._search_alg.on_trial_complete(trial.trial_id, result=flat_result) with warn_if_slow("callbacks.on_trial_complete"): self._callbacks.on_trial_complete( iteration=self._iteration, trials=self._trials, trial=trial) if not is_duplicate: trial.update_last_result( result, terminate=(decision == TrialScheduler.STOP)) # Checkpoints to disk. This should be checked even if # the scheduler decision is STOP or PAUSE. Note that # PAUSE only checkpoints to memory and does not update # the global checkpoint state. self._checkpoint_trial_if_needed(trial, force=force_checkpoint) if trial.is_saving: # Cache decision to execute on after the save is processed. # This prevents changing the trial's state or kicking off # another training step prematurely. self._cached_trial_decisions[trial.trial_id] = decision return None else: self._queue_decision(trial, decision) return decision
def _process_trial(self, trial): """Processes a trial result.""" try: result = self.trial_executor.fetch_result(trial) is_duplicate = RESULT_DUPLICATE in result # TrialScheduler and SearchAlgorithm still receive a # notification because there may be special handling for # the `on_trial_complete` hook. if is_duplicate: logger.debug("Trial finished without logging 'done'.") result = trial.last_result result.update(done=True) self._total_time += result.get(TIME_THIS_ITER_S, 0) flat_result = flatten_dict(result) if self._stopper(trial.trial_id, result) or trial.should_stop(flat_result): # Hook into scheduler self._scheduler_alg.on_trial_complete(self, trial, flat_result) self._search_alg.on_trial_complete(trial.trial_id, result=flat_result) decision = TrialScheduler.STOP else: with warn_if_slow("scheduler.on_trial_result"): decision = self._scheduler_alg.on_trial_result( self, trial, flat_result) with warn_if_slow("search_alg.on_trial_result"): self._search_alg.on_trial_result(trial.trial_id, flat_result) if decision == TrialScheduler.STOP: with warn_if_slow("search_alg.on_trial_complete"): self._search_alg.on_trial_complete( trial.trial_id, result=flat_result, early_terminated=True) if not is_duplicate: trial.update_last_result( result, terminate=(decision == TrialScheduler.STOP)) # Checkpoints to disk. This should be checked even if # the scheduler decision is STOP or PAUSE. Note that # PAUSE only checkpoints to memory and does not update # the global checkpoint state. self._checkpoint_trial_if_needed(trial, force=result.get( SHOULD_CHECKPOINT, False)) if decision == TrialScheduler.CONTINUE: self.trial_executor.continue_training(trial) elif decision == TrialScheduler.PAUSE: self.trial_executor.pause_trial(trial) elif decision == TrialScheduler.STOP: self.trial_executor.export_trial_if_needed(trial) self.trial_executor.stop_trial(trial) else: assert False, "Invalid scheduling decision: {}".format( decision) except Exception: logger.exception("Trial %s: Error processing event.", trial) self._process_trial_failure(trial, traceback.format_exc())
def _save(self, _=None): self.logger.debug( f"_save: {self._trial_info.trial_name}({self.iteration})") # All models are synchronized. Just save the state of first model with warn_if_slow("ImagenetExperiment.get_state.remote"): return ray.get(self.procs[0].get_state.remote())
def step(self): """Runs one step of the trial event loop. Callers should typically run this method repeatedly in a loop. They may inspect or modify the runner's state in between calls to step(). """ self._updated_queue = False if self.is_finished(): raise TuneError("Called step when all trials finished?") with warn_if_slow("on_step_begin"): self.trial_executor.on_step_begin(self) with warn_if_slow("callbacks.on_step_begin"): self._callbacks.on_step_begin(iteration=self._iteration, trials=self._trials) # This will contain the next trial to start next_trial = self._get_next_trial() # blocking # Create pending trials. If the queue was updated before, only # continue updating if this was successful (next_trial is not None) if not self._updated_queue or (self._updated_queue and next_trial): num_pending_trials = len( [t for t in self._trials if t.status == Trial.PENDING]) while num_pending_trials < self._max_pending_trials: if not self._update_trial_queue(blocking=False): break num_pending_trials += 1 # Update status of staged placement groups self.trial_executor.stage_and_update_status(self._trials) def _start_trial(trial: Trial) -> bool: """Helper function to start trial and call callbacks""" with warn_if_slow("start_trial"): if self.trial_executor.start_trial(trial): self._callbacks.on_trial_start(iteration=self._iteration, trials=self._trials, trial=trial) return True return False may_handle_events = True if next_trial is not None: if _start_trial(next_trial): may_handle_events = False elif next_trial.status != Trial.ERROR: # Only try to start another trial if previous trial startup # did not error (e.g. it just didn't start because its # placement group is not ready, yet). next_trial = self.trial_executor.get_staged_trial() if next_trial is not None: if _start_trial(next_trial): may_handle_events = False if may_handle_events: if self.trial_executor.get_running_trials(): timeout = None if self.trial_executor.in_staging_grace_period(): timeout = 0.1 self._process_events(timeout=timeout) # blocking else: self.trial_executor.on_no_available_trials(self) self._stop_experiment_if_needed() try: self.checkpoint() except Exception as e: logger.warning(f"Trial Runner checkpointing failed: {str(e)}") self._iteration += 1 if self._server: with warn_if_slow("server"): self._process_stop_requests() if self.is_finished(): self._server.shutdown() with warn_if_slow("on_step_end"): self.trial_executor.on_step_end(self) with warn_if_slow("callbacks.on_step_end"): self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials)
def _process_trial(self, trial): """Processes a trial result. Fetches the trial's latest result and makes a scheduling decision regarding its next action. If a checkpoint is taken, the decided action is cached and acted on only after the checkpoint is later processed (see `_process_trial_save`). Otherwise the decision is acted on immediately. Args: trial (Trial): Trial with a result ready to be processed. """ try: result = self.trial_executor.fetch_result(trial) is_duplicate = RESULT_DUPLICATE in result force_checkpoint = result.get(SHOULD_CHECKPOINT, False) # TrialScheduler and SearchAlgorithm still receive a # notification because there may be special handling for # the `on_trial_complete` hook. if is_duplicate: logger.debug("Trial finished without logging 'done'.") result = trial.last_result result.update(done=True) self._total_time += result.get(TIME_THIS_ITER_S, 0) flat_result = flatten_dict(result) if self._stopper(trial.trial_id, result) or trial.should_stop(flat_result): # Hook into scheduler self._scheduler_alg.on_trial_complete(self, trial, flat_result) self._search_alg.on_trial_complete( trial.trial_id, result=flat_result) decision = TrialScheduler.STOP else: with warn_if_slow("scheduler.on_trial_result"): decision = self._scheduler_alg.on_trial_result( self, trial, flat_result) with warn_if_slow("search_alg.on_trial_result"): self._search_alg.on_trial_result(trial.trial_id, flat_result) if decision == TrialScheduler.STOP: with warn_if_slow("search_alg.on_trial_complete"): self._search_alg.on_trial_complete( trial.trial_id, result=flat_result) if not is_duplicate: trial.update_last_result( result, terminate=(decision == TrialScheduler.STOP)) # Checkpoints to disk. This should be checked even if # the scheduler decision is STOP or PAUSE. Note that # PAUSE only checkpoints to memory and does not update # the global checkpoint state. self._checkpoint_trial_if_needed(trial, force=force_checkpoint) if trial.is_saving: # Cache decision to execute on after the save is processed. # This prevents changing the trial's state or kicking off # another training step prematurely. self._cached_trial_decisions[trial.trial_id] = decision else: self._execute_action(trial, decision) except Exception: logger.exception("Trial %s: Error processing event.", trial) self._process_trial_failure(trial, traceback.format_exc())
def step(self): """Runs one step of the trial event loop. Callers should typically run this method repeatedly in a loop. They may inspect or modify the runner's state in between calls to step(). """ if self.is_finished(): raise TuneError("Called step when all trials finished?") with warn_if_slow("on_step_begin"): self.trial_executor.on_step_begin(self) with warn_if_slow("callbacks.on_step_begin"): self._callbacks.on_step_begin(iteration=self._iteration, trials=self._trials) # This will contain the next trial to start next_trial = self._get_next_trial() # blocking # Create pending trials num_pending_trials = len( [t for t in self._trials if t.status == Trial.PENDING]) while num_pending_trials < self._max_pending_trials: if not self._update_trial_queue(blocking=False): break num_pending_trials += 1 # Update status of staged placement groups self.trial_executor.stage_and_update_status(self._trials) def _start_trial(trial: Trial) -> bool: """Helper function to start trial and call callbacks""" with warn_if_slow("start_trial"): if self.trial_executor.start_trial(trial): self._callbacks.on_trial_start(iteration=self._iteration, trials=self._trials, trial=trial) return True return False may_handle_events = True if next_trial is not None: if _start_trial(next_trial): may_handle_events = False else: next_trial = self.trial_executor.get_staged_trial() if next_trial is not None: if _start_trial(next_trial): may_handle_events = False if may_handle_events: if self.trial_executor.get_running_trials(): timeout = None if self.trial_executor.in_staging_grace_period(): timeout = 0.1 self._process_events(timeout=timeout) # blocking else: self.trial_executor.on_no_available_trials(self) self._stop_experiment_if_needed() try: with warn_if_slow( "experiment_checkpoint", message="Checkpointing the experiment state took " "{duration:.3f} s, which may be a performance " "bottleneck. Please ensure the " "`TUNE_GLOBAL_CHECKPOINT_S` environment variable is " "something significantly higher than this duration " "to ensure compute time is mostly spent on the main " "training loop."): self.checkpoint() except Exception as e: logger.warning(f"Trial Runner checkpointing failed: {str(e)}") self._iteration += 1 if self._server: with warn_if_slow("server"): self._process_stop_requests() if self.is_finished(): self._server.shutdown() with warn_if_slow("on_step_end"): self.trial_executor.on_step_end(self) with warn_if_slow("callbacks.on_step_end"): self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials)