示例#1
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        if self.is_finished():
            raise TuneError("Called step when all trials finished?")
        with warn_if_slow("on_step_begin"):
            self.trial_executor.on_step_begin(self)
        next_trial = self._get_next_trial()  # blocking
        if next_trial is not None:
            with warn_if_slow("start_trial"):
                self.trial_executor.start_trial(next_trial)
        elif self.trial_executor.get_running_trials():
            self._process_events()  # blocking
        else:
            self.trial_executor.on_no_available_trials(self)

        try:
            with warn_if_slow("experiment_checkpoint"):
                self.checkpoint()
        except Exception:
            logger.exception("Trial Runner checkpointing failed.")
        self._iteration += 1

        if self._server:
            with warn_if_slow("server"):
                self._process_requests()

            if self.is_finished():
                self._server.shutdown()
        with warn_if_slow("on_step_end"):
            self.trial_executor.on_step_end(self)
示例#2
0
 def _process_events(self):
     failed_trial = self.trial_executor.get_next_failed_trial()
     if failed_trial:
         error_msg = (
             "{} (IP: {}) detected as stale. This is likely because the "
             "node was lost").format(failed_trial, failed_trial.node_ip)
         logger.info(error_msg)
         with warn_if_slow("process_failed_trial"):
             self._process_trial_failure(failed_trial, error_msg=error_msg)
     else:
         # TODO(ujvl): Consider combining get_next_available_trial and
         #  fetch_result functionality so that we don't timeout on fetch.
         trial = self.trial_executor.get_next_available_trial()  # blocking
         if trial.is_restoring:
             with warn_if_slow("process_trial_restore"):
                 self._process_trial_restore(trial)
         elif trial.is_saving:
             with warn_if_slow("process_trial_save") as profile:
                 self._process_trial_save(trial)
             if profile.too_slow and trial.sync_on_checkpoint:
                 # TODO(ujvl): Suggest using DurableTrainable once
                 #  API has converged.
                 logger.warning(
                     "Consider turning off forced head-worker trial "
                     "checkpoint syncs by setting sync_on_checkpoint=False"
                     ". Note that this may result in faulty trial "
                     "restoration if a failure occurs while the checkpoint "
                     "is being synced from the worker to the head node.")
         else:
             with warn_if_slow("process_trial"):
                 self._process_trial(trial)
示例#3
0
    def _process_trial(self, trial):
        """Processes a trial result.

        Fetches the trial's latest result and makes a scheduling decision
        regarding its next action. If a checkpoint is taken, the decided
        action is cached and acted on only after the checkpoint is later
        processed (see `_process_trial_save`). Otherwise the decision is
        acted on immediately.

        If multiple results are received (e.g. because of buffering), all
        results are processed and the final action is determined. STOP
        takes precedence over PAUSE, which takes precedence over CONTINUE.

        Args:
            trial (Trial): Trial with a result ready to be processed.
        """
        try:
            results = self.trial_executor.fetch_result(trial)
            with warn_if_slow(
                    "process_trial_results",
                    message="Processing trial results took {duration:.3f} s, "
                    "which may be a performance bottleneck. Please consider "
                    "reporting results less frequently to Ray Tune."):
                for i, result in enumerate(results):
                    with warn_if_slow("process_trial_result"):
                        decision = self._process_trial_result(trial, result)
                    if decision is None:
                        # If we didn't get a decision, this means a
                        # non-training future (e.g. a save) was scheduled.
                        # We do not allow processing more results then.
                        if i < len(results) - 1:
                            raise RuntimeError(
                                f"Trial {trial} has a non-training future "
                                f"scheduled but {len(results)-i} results "
                                f"left to process. This should never "
                                f"happen - please file an issue at "
                                f"https://github.com/ray-project/ray/issues")
                    elif decision == TrialScheduler.STOP:
                        # If the decision is to stop the trial,
                        # ignore all results that came after that.
                        break
        except Exception:
            error_msg = "Trial %s: Error processing event." % trial
            if self._fail_fast == TrialRunner.RAISE:
                logger.error(error_msg)
                raise
            else:
                logger.exception(error_msg)
            self._process_trial_failure(trial, traceback.format_exc())
示例#4
0
    def reset_trial(self, trial, new_config, new_experiment_tag):
        """Tries to invoke `Trainable.reset_config()` to reset trial.

        Args:
            trial (Trial): Trial to be reset.
            new_config (dict): New configuration for Trial
                trainable.
            new_experiment_tag (str): New experiment name
                for trial.

        Returns:
            True if `reset_config` is successful else False.
        """
        trial.experiment_tag = new_experiment_tag
        trial.config = new_config
        trainable = trial.runner
        with self._change_working_directory(trial):
            with warn_if_slow("reset_config"):
                try:
                    reset_val = ray.get(
                        trainable.reset_config.remote(new_config),
                        DEFAULT_GET_TIMEOUT)
                except RayTimeoutError:
                    logger.exception("Trial %s: reset_config timed out.",
                                     trial)
                    return False
        return reset_val
示例#5
0
    def reset_trial(self,
                    trial,
                    new_config,
                    new_experiment_tag,
                    logger_creator=None):
        """Tries to invoke `Trainable.reset()` to reset trial.

        Args:
            trial (Trial): Trial to be reset.
            new_config (dict): New configuration for Trial trainable.
            new_experiment_tag (str): New experiment name for trial.
            logger_creator (Callable[[Dict], Logger]): A function that
                instantiates a logger on the actor process.

        Returns:
            True if `reset_config` is successful else False.
        """
        trial.experiment_tag = new_experiment_tag
        trial.config = new_config
        trainable = trial.runner
        with self._change_working_directory(trial):
            with warn_if_slow("reset"):
                try:
                    reset_val = ray.get(
                        trainable.reset.remote(new_config, logger_creator),
                        timeout=DEFAULT_GET_TIMEOUT)
                except GetTimeoutError:
                    logger.exception("Trial %s: reset timed out.", trial)
                    return False
        return reset_val
示例#6
0
    def checkpoint(self, force=False):
        """Saves execution state to `self._local_checkpoint_dir`.

        Overwrites the current session checkpoint, which starts when self
        is instantiated. Throttle depends on self._checkpoint_period.

        Also automatically saves the search algorithm to the local
        checkpoint dir.

        Args:
            force (bool): Forces a checkpoint despite checkpoint_period.
        """
        with warn_if_slow(
                "experiment_checkpoint",
                message="Checkpointing the experiment state took "
                "{duration:.3f} s, which may be a performance "
                "bottleneck. Please ensure the "
                "`TUNE_GLOBAL_CHECKPOINT_S` environment variable is "
                "something significantly higher than this duration "
                "to ensure compute time is mostly spent on the main "
                "training loop.",
                disable=self._checkpoint_manager.auto_checkpoint_enabled):

            self._checkpoint_manager.checkpoint(
                checkpoint_file=self.checkpoint_file,
                trial_runner=self,
                trial_executor=self.trial_executor,
                search_alg=self._search_alg,
                force=force)
示例#7
0
    def _process_events(self, timeout: Optional[float] = None):
        with warn_if_slow("get_next_failed_trial"):
            failed_trial = self.trial_executor.get_next_failed_trial()
        if failed_trial:
            error_msg = (
                "{} (IP: {}) detected as stale. This is likely because the "
                "node was lost").format(failed_trial, failed_trial.node_ip)
            logger.info(error_msg)
            with warn_if_slow("process_failed_trial"):
                self._process_trial_failure(failed_trial, error_msg=error_msg)
        else:
            # TODO(ujvl): Consider combining get_next_available_trial and
            #  fetch_result functionality so that we don't timeout on fetch.
            trial = self.trial_executor.get_next_available_trial(
                timeout=timeout)  # blocking
            if not trial:
                return
            if trial.is_restoring:
                with warn_if_slow("process_trial_restore"):
                    self._process_trial_restore(trial)
                with warn_if_slow("callbacks.on_trial_restore"):
                    self._callbacks.on_trial_restore(
                        iteration=self._iteration,
                        trials=self._trials,
                        trial=trial)
            elif trial.is_saving:
                with warn_if_slow("process_trial_save") as _profile:
                    self._process_trial_save(trial)
                with warn_if_slow("callbacks.on_trial_save"):
                    self._callbacks.on_trial_save(
                        iteration=self._iteration,
                        trials=self._trials,
                        trial=trial)
                if _profile.too_slow and trial.sync_on_checkpoint:
                    # TODO(ujvl): Suggest using DurableTrainable once
                    #  API has converged.

                    msg = (
                        "Consider turning off forced head-worker trial "
                        "checkpoint syncs by setting sync_on_checkpoint=False"
                        ". Note that this may result in faulty trial "
                        "restoration if a failure occurs while the checkpoint "
                        "is being synced from the worker to the head node.")

                    if trial.location.hostname and (trial.location.hostname !=
                                                    get_node_ip_address()):
                        if log_once("tune_head_worker_checkpoint"):
                            logger.warning(msg)

            else:
                with warn_if_slow("process_trial"):
                    self._process_trial(trial)

            # `self._queued_trial_decisions` now contains a final decision
            # based on all results
            if trial not in self._cached_trial_decisions:
                final_decision = self._queued_trial_decisions.pop(
                    trial.trial_id, None)
                if final_decision:
                    self._execute_action(trial, final_decision)
示例#8
0
 def _start_trial(trial: Trial) -> bool:
     """Helper function to start trial and call callbacks"""
     with warn_if_slow("start_trial"):
         if self.trial_executor.start_trial(trial):
             self._callbacks.on_trial_start(iteration=self._iteration,
                                            trials=self._trials,
                                            trial=trial)
             return True
         return False
 def _process_events(self):
     failed_trial = self.trial_executor.get_next_failed_trial()
     if failed_trial:
         error_msg = (
             "{} (IP: {}) detected as stale. This is likely because the "
             "node was lost").format(failed_trial, failed_trial.node_ip)
         logger.info(error_msg)
         with warn_if_slow("process_failed_trial"):
             self._process_trial_failure(failed_trial, error_msg=error_msg)
     else:
         # TODO(ujvl): Consider combining get_next_available_trial and
         #  fetch_result functionality so that we don't timeout on fetch.
         trial = self.trial_executor.get_next_available_trial()  # blocking
         if trial.is_restoring:
             with warn_if_slow("process_trial_restore"):
                 self._process_trial_restore(trial)
         else:
             with warn_if_slow("process_trial"):
                 self._process_trial(trial)
示例#10
0
    def _get_next_trial(self):
        """Replenishes queue.

        Blocks if all trials queued have finished, but search algorithm is
        still not finished.
        """
        trials_done = all(trial.is_finished() for trial in self._trials)
        wait_for_trial = trials_done and not self._search_alg.is_finished()
        self._update_trial_queue(blocking=wait_for_trial)
        with warn_if_slow("choose_trial_to_run"):
            trial = self._scheduler_alg.choose_trial_to_run(self)
        return trial
示例#11
0
    def add_trial(self, trial):
        """Adds a new trial to this TrialRunner.

        Trials may be added at any time.

        Args:
            trial (Trial): Trial to queue.
        """
        self._trials.append(trial)
        with warn_if_slow("scheduler.on_trial_add"):
            self._scheduler_alg.on_trial_add(self, trial)
        self.trial_executor.try_checkpoint_metadata(trial)
示例#12
0
 def _notify_trainable_of_new_resources_if_needed(self, trial: Trial):
     if trial.has_new_resources:
         trainable = trial.runner
         trial.has_new_resources = False
         with self._change_working_directory(trial):
             with warn_if_slow("update_resources"):
                 try:
                     ray.get(trainable._update_resources.remote(
                         trial.placement_group_factory),
                             timeout=DEFAULT_GET_TIMEOUT)
                 except GetTimeoutError:
                     logger.exception(
                         "Trial %s: updating resources timed out.", trial)
示例#13
0
    def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
        """Saves the trial's state to a checkpoint.

        Args:
            trial (Trial): The state of this trial to be saved.
            storage (str): Where to store the checkpoint. Defaults to
                PERSISTENT.
            result (dict): The state of this trial as a dictionary to be saved.
                If result is None, the trial's last result will be used.

        Returns:
             Checkpoint future, or None if an Exception occurs.
        """
        result = result or trial.last_result

        with self._change_working_directory(trial):
            if storage == Checkpoint.MEMORY:
                value = trial.runner.save_to_object.remote()
                checkpoint = Checkpoint(storage, value, result)
            else:
                with warn_if_slow("save_checkpoint_to_storage"):
                    # TODO(ujvl): Make this asynchronous.
                    value = ray.get(trial.runner.save.remote())
                    checkpoint = Checkpoint(storage, value, result)
        with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
            try:
                trial.on_checkpoint(checkpoint)
            except Exception:
                logger.exception("Trial %s: Error handling checkpoint %s",
                                 trial, checkpoint.value)
                return None
        if profile.too_slow and trial.sync_on_checkpoint:
            logger.warning(
                "Consider turning off forced head-worker trial checkpoint "
                "syncs by setting sync_on_checkpoint=False. Note that this "
                "might result in faulty trial restoration for some worker "
                "failure modes.")
        return checkpoint.value
示例#14
0
    def add_trial(self, trial):
        """Adds a new trial to this TrialRunner.

        Trials may be added at any time.

        Args:
            trial (Trial): Trial to queue.
        """
        if trial.uses_placement_groups:
            self._max_pending_trials = TUNE_MAX_PENDING_TRIALS_PG

        self._trials.append(trial)
        with warn_if_slow("scheduler.on_trial_add"):
            self._scheduler_alg.on_trial_add(self, trial)
        self.trial_executor.try_checkpoint_metadata(trial)
示例#15
0
    def _get_next_trial(self):
        """Replenishes queue.

        Blocks if all trials queued have finished, but search algorithm is
        still not finished.
        """
        trials_done = all(trial.is_finished() for trial in self._trials)
        wait_for_trial = trials_done and not self._search_alg.is_finished()
        # Only fetch a new trial if we have no pending trial
        if not any(trial.status == Trial.PENDING for trial in self._trials) \
           or wait_for_trial:
            self._update_trial_queue(blocking=wait_for_trial)
        with warn_if_slow("choose_trial_to_run"):
            trial = self._scheduler_alg.choose_trial_to_run(self)
            logger.debug("Running trial {}".format(trial))
        return trial
示例#16
0
    def fetch_result(self, trial):
        """Fetches one result of the running trials.

        Returns:
            Result of the most recent trial training run.
        """
        trial_future = self._find_item(self._running, trial)
        if not trial_future:
            raise ValueError("Trial was not running.")
        self._running.pop(trial_future[0])
        with warn_if_slow("fetch_result"):
            result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)

        # For local mode
        if isinstance(result, _LocalWrapper):
            result = result.unwrap()
        return result
示例#17
0
    def reset_trial(
        self,
        trial: Trial,
        new_config: Dict,
        new_experiment_tag: str,
        logger_creator: Optional[Callable[[Dict], "ray.tune.Logger"]] = None,
    ) -> bool:
        """Tries to invoke `Trainable.reset()` to reset trial.

        Args:
            trial: Trial to be reset.
            new_config: New configuration for Trial trainable.
            new_experiment_tag: New experiment name for trial.
            logger_creator: Function that instantiates a logger on the
                actor process.

        Returns:
            True if `reset_config` is successful else False.
        """
        trial.set_experiment_tag(new_experiment_tag)
        trial.set_config(new_config)
        trainable = trial.runner

        # Pass magic variables
        extra_config = copy.deepcopy(new_config)
        extra_config[TRIAL_INFO] = TrialInfo(trial)

        stdout_file, stderr_file = trial.log_to_file
        extra_config[STDOUT_FILE] = stdout_file
        extra_config[STDERR_FILE] = stderr_file

        with self._change_working_directory(trial):
            with warn_if_slow("reset"):
                try:
                    reset_val = ray.get(
                        trainable.reset.remote(extra_config, logger_creator),
                        timeout=DEFAULT_GET_TIMEOUT,
                    )
                except GetTimeoutError:
                    logger.exception("Trial %s: reset timed out.", trial)
                    return False
        return reset_val
示例#18
0
    def _requeue_trial(self, trial):
        """Notification to TrialScheduler and requeue trial.

        This does not notify the SearchAlgorithm because the function
        evaluation is still in progress.

        """
        self._scheduler_alg.on_trial_error(self, trial)
        self.trial_executor.set_status(trial, Trial.PENDING)

        # TODO(rliaw): Right now, this pushes the trial to the end of queue
        # because restoration can be expensive. However, this is not
        # ideal since it just hides the issue - a better fix would
        # be to use an actor table to detect the IP of the Trainable
        # and rsync the files there.
        # See https://github.com/ray-project/ray/issues/5168
        self._trials.pop(self._trials.index(trial))
        self._trials.append(trial)

        with warn_if_slow("scheduler.on_trial_add"):
            self._scheduler_alg.on_trial_add(self, trial)
示例#19
0
    def _process_trial_result(self, trial, result):
        result.update(trial_id=trial.trial_id)
        is_duplicate = RESULT_DUPLICATE in result
        force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
        # TrialScheduler and SearchAlgorithm still receive a
        # notification because there may be special handling for
        # the `on_trial_complete` hook.
        if is_duplicate:
            logger.debug("Trial finished without logging 'done'.")
            result = trial.last_result
            result.update(done=True)

        self._total_time += result.get(TIME_THIS_ITER_S, 0)

        flat_result = flatten_dict(result)
        self._validate_result_metrics(flat_result)

        if self._stopper(trial.trial_id,
                         result) or trial.should_stop(flat_result):
            result.update(done=True)

            # Hook into scheduler
            self._scheduler_alg.on_trial_complete(self, trial, flat_result)
            self._search_alg.on_trial_complete(trial.trial_id,
                                               result=flat_result)

            # If this is not a duplicate result, the callbacks should
            # be informed about the result.
            if not is_duplicate:
                with warn_if_slow("callbacks.on_trial_result"):
                    self._callbacks.on_trial_result(iteration=self._iteration,
                                                    trials=self._trials,
                                                    trial=trial,
                                                    result=result.copy())

            self._callbacks.on_trial_complete(iteration=self._iteration,
                                              trials=self._trials,
                                              trial=trial)
            decision = TrialScheduler.STOP
        else:
            with warn_if_slow("scheduler.on_trial_result"):
                decision = self._scheduler_alg.on_trial_result(
                    self, trial, flat_result)
            if decision == TrialScheduler.STOP:
                result.update(done=True)
            with warn_if_slow("search_alg.on_trial_result"):
                self._search_alg.on_trial_result(trial.trial_id, flat_result)
            with warn_if_slow("callbacks.on_trial_result"):
                self._callbacks.on_trial_result(iteration=self._iteration,
                                                trials=self._trials,
                                                trial=trial,
                                                result=result.copy())
            if decision == TrialScheduler.STOP:
                with warn_if_slow("search_alg.on_trial_complete"):
                    self._search_alg.on_trial_complete(trial.trial_id,
                                                       result=flat_result)
                with warn_if_slow("callbacks.on_trial_complete"):
                    self._callbacks.on_trial_complete(
                        iteration=self._iteration,
                        trials=self._trials,
                        trial=trial)

        if not is_duplicate:
            trial.update_last_result(
                result, terminate=(decision == TrialScheduler.STOP))

        # Checkpoints to disk. This should be checked even if
        # the scheduler decision is STOP or PAUSE. Note that
        # PAUSE only checkpoints to memory and does not update
        # the global checkpoint state.
        self._checkpoint_trial_if_needed(trial, force=force_checkpoint)

        if trial.is_saving:
            # Cache decision to execute on after the save is processed.
            # This prevents changing the trial's state or kicking off
            # another training step prematurely.
            self._cached_trial_decisions[trial.trial_id] = decision
            return None
        else:
            self._queue_decision(trial, decision)
            return decision
    def _process_trial(self, trial):
        """Processes a trial result."""
        try:
            result = self.trial_executor.fetch_result(trial)

            is_duplicate = RESULT_DUPLICATE in result
            # TrialScheduler and SearchAlgorithm still receive a
            # notification because there may be special handling for
            # the `on_trial_complete` hook.
            if is_duplicate:
                logger.debug("Trial finished without logging 'done'.")
                result = trial.last_result
                result.update(done=True)

            self._total_time += result.get(TIME_THIS_ITER_S, 0)

            flat_result = flatten_dict(result)
            if self._stopper(trial.trial_id,
                             result) or trial.should_stop(flat_result):
                # Hook into scheduler
                self._scheduler_alg.on_trial_complete(self, trial, flat_result)
                self._search_alg.on_trial_complete(trial.trial_id,
                                                   result=flat_result)
                decision = TrialScheduler.STOP
            else:
                with warn_if_slow("scheduler.on_trial_result"):
                    decision = self._scheduler_alg.on_trial_result(
                        self, trial, flat_result)
                with warn_if_slow("search_alg.on_trial_result"):
                    self._search_alg.on_trial_result(trial.trial_id,
                                                     flat_result)
                if decision == TrialScheduler.STOP:
                    with warn_if_slow("search_alg.on_trial_complete"):
                        self._search_alg.on_trial_complete(
                            trial.trial_id,
                            result=flat_result,
                            early_terminated=True)

            if not is_duplicate:
                trial.update_last_result(
                    result, terminate=(decision == TrialScheduler.STOP))

            # Checkpoints to disk. This should be checked even if
            # the scheduler decision is STOP or PAUSE. Note that
            # PAUSE only checkpoints to memory and does not update
            # the global checkpoint state.
            self._checkpoint_trial_if_needed(trial,
                                             force=result.get(
                                                 SHOULD_CHECKPOINT, False))

            if decision == TrialScheduler.CONTINUE:
                self.trial_executor.continue_training(trial)
            elif decision == TrialScheduler.PAUSE:
                self.trial_executor.pause_trial(trial)
            elif decision == TrialScheduler.STOP:
                self.trial_executor.export_trial_if_needed(trial)
                self.trial_executor.stop_trial(trial)
            else:
                assert False, "Invalid scheduling decision: {}".format(
                    decision)
        except Exception:
            logger.exception("Trial %s: Error processing event.", trial)
            self._process_trial_failure(trial, traceback.format_exc())
示例#21
0
 def _save(self, _=None):
     self.logger.debug(
         f"_save: {self._trial_info.trial_name}({self.iteration})")
     # All models are synchronized. Just save the state of first model
     with warn_if_slow("ImagenetExperiment.get_state.remote"):
         return ray.get(self.procs[0].get_state.remote())
示例#22
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        self._updated_queue = False

        if self.is_finished():
            raise TuneError("Called step when all trials finished?")
        with warn_if_slow("on_step_begin"):
            self.trial_executor.on_step_begin(self)
        with warn_if_slow("callbacks.on_step_begin"):
            self._callbacks.on_step_begin(iteration=self._iteration,
                                          trials=self._trials)

        # This will contain the next trial to start
        next_trial = self._get_next_trial()  # blocking

        # Create pending trials. If the queue was updated before, only
        # continue updating if this was successful (next_trial is not None)
        if not self._updated_queue or (self._updated_queue and next_trial):
            num_pending_trials = len(
                [t for t in self._trials if t.status == Trial.PENDING])
            while num_pending_trials < self._max_pending_trials:
                if not self._update_trial_queue(blocking=False):
                    break
                num_pending_trials += 1

        # Update status of staged placement groups
        self.trial_executor.stage_and_update_status(self._trials)

        def _start_trial(trial: Trial) -> bool:
            """Helper function to start trial and call callbacks"""
            with warn_if_slow("start_trial"):
                if self.trial_executor.start_trial(trial):
                    self._callbacks.on_trial_start(iteration=self._iteration,
                                                   trials=self._trials,
                                                   trial=trial)
                    return True
                return False

        may_handle_events = True
        if next_trial is not None:
            if _start_trial(next_trial):
                may_handle_events = False
            elif next_trial.status != Trial.ERROR:
                # Only try to start another trial if previous trial startup
                # did not error (e.g. it just didn't start because its
                # placement group is not ready, yet).
                next_trial = self.trial_executor.get_staged_trial()
                if next_trial is not None:
                    if _start_trial(next_trial):
                        may_handle_events = False

        if may_handle_events:
            if self.trial_executor.get_running_trials():
                timeout = None
                if self.trial_executor.in_staging_grace_period():
                    timeout = 0.1
                self._process_events(timeout=timeout)  # blocking
            else:
                self.trial_executor.on_no_available_trials(self)

        self._stop_experiment_if_needed()

        try:
            self.checkpoint()
        except Exception as e:
            logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
        self._iteration += 1

        if self._server:
            with warn_if_slow("server"):
                self._process_stop_requests()

            if self.is_finished():
                self._server.shutdown()
        with warn_if_slow("on_step_end"):
            self.trial_executor.on_step_end(self)
        with warn_if_slow("callbacks.on_step_end"):
            self._callbacks.on_step_end(iteration=self._iteration,
                                        trials=self._trials)
示例#23
0
    def _process_trial(self, trial):
        """Processes a trial result.

        Fetches the trial's latest result and makes a scheduling decision
        regarding its next action. If a checkpoint is taken, the decided
        action is cached and acted on only after the checkpoint is later
        processed (see `_process_trial_save`). Otherwise the decision is
        acted on immediately.

        Args:
            trial (Trial): Trial with a result ready to be processed.
        """
        try:
            result = self.trial_executor.fetch_result(trial)

            is_duplicate = RESULT_DUPLICATE in result
            force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
            # TrialScheduler and SearchAlgorithm still receive a
            # notification because there may be special handling for
            # the `on_trial_complete` hook.
            if is_duplicate:
                logger.debug("Trial finished without logging 'done'.")
                result = trial.last_result
                result.update(done=True)

            self._total_time += result.get(TIME_THIS_ITER_S, 0)

            flat_result = flatten_dict(result)
            if self._stopper(trial.trial_id,
                             result) or trial.should_stop(flat_result):
                # Hook into scheduler
                self._scheduler_alg.on_trial_complete(self, trial, flat_result)
                self._search_alg.on_trial_complete(
                    trial.trial_id, result=flat_result)
                decision = TrialScheduler.STOP
            else:
                with warn_if_slow("scheduler.on_trial_result"):
                    decision = self._scheduler_alg.on_trial_result(
                        self, trial, flat_result)
                with warn_if_slow("search_alg.on_trial_result"):
                    self._search_alg.on_trial_result(trial.trial_id,
                                                     flat_result)
                if decision == TrialScheduler.STOP:
                    with warn_if_slow("search_alg.on_trial_complete"):
                        self._search_alg.on_trial_complete(
                            trial.trial_id, result=flat_result)

            if not is_duplicate:
                trial.update_last_result(
                    result, terminate=(decision == TrialScheduler.STOP))

            # Checkpoints to disk. This should be checked even if
            # the scheduler decision is STOP or PAUSE. Note that
            # PAUSE only checkpoints to memory and does not update
            # the global checkpoint state.
            self._checkpoint_trial_if_needed(trial, force=force_checkpoint)

            if trial.is_saving:
                # Cache decision to execute on after the save is processed.
                # This prevents changing the trial's state or kicking off
                # another training step prematurely.
                self._cached_trial_decisions[trial.trial_id] = decision
            else:
                self._execute_action(trial, decision)
        except Exception:
            logger.exception("Trial %s: Error processing event.", trial)
            self._process_trial_failure(trial, traceback.format_exc())
示例#24
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        if self.is_finished():
            raise TuneError("Called step when all trials finished?")
        with warn_if_slow("on_step_begin"):
            self.trial_executor.on_step_begin(self)
        with warn_if_slow("callbacks.on_step_begin"):
            self._callbacks.on_step_begin(iteration=self._iteration,
                                          trials=self._trials)

        # This will contain the next trial to start
        next_trial = self._get_next_trial()  # blocking

        # Create pending trials
        num_pending_trials = len(
            [t for t in self._trials if t.status == Trial.PENDING])
        while num_pending_trials < self._max_pending_trials:
            if not self._update_trial_queue(blocking=False):
                break
            num_pending_trials += 1

        # Update status of staged placement groups
        self.trial_executor.stage_and_update_status(self._trials)

        def _start_trial(trial: Trial) -> bool:
            """Helper function to start trial and call callbacks"""
            with warn_if_slow("start_trial"):
                if self.trial_executor.start_trial(trial):
                    self._callbacks.on_trial_start(iteration=self._iteration,
                                                   trials=self._trials,
                                                   trial=trial)
                    return True
                return False

        may_handle_events = True
        if next_trial is not None:
            if _start_trial(next_trial):
                may_handle_events = False
            else:
                next_trial = self.trial_executor.get_staged_trial()
                if next_trial is not None:
                    if _start_trial(next_trial):
                        may_handle_events = False

        if may_handle_events:
            if self.trial_executor.get_running_trials():
                timeout = None
                if self.trial_executor.in_staging_grace_period():
                    timeout = 0.1
                self._process_events(timeout=timeout)  # blocking
            else:
                self.trial_executor.on_no_available_trials(self)

        self._stop_experiment_if_needed()

        try:
            with warn_if_slow(
                    "experiment_checkpoint",
                    message="Checkpointing the experiment state took "
                    "{duration:.3f} s, which may be a performance "
                    "bottleneck. Please ensure the "
                    "`TUNE_GLOBAL_CHECKPOINT_S` environment variable is "
                    "something significantly higher than this duration "
                    "to ensure compute time is mostly spent on the main "
                    "training loop."):
                self.checkpoint()
        except Exception as e:
            logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
        self._iteration += 1

        if self._server:
            with warn_if_slow("server"):
                self._process_stop_requests()

            if self.is_finished():
                self._server.shutdown()
        with warn_if_slow("on_step_end"):
            self.trial_executor.on_step_end(self)
        with warn_if_slow("callbacks.on_step_end"):
            self._callbacks.on_step_end(iteration=self._iteration,
                                        trials=self._trials)