Ejemplo n.º 1
0
    def restore(self, trial, checkpoint=None):
        """Restores training state from a given model checkpoint.

        This will also sync the trial results to a new location
        if restoring on a different node.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            RayTimeoutError: This error is raised if a remote call to the
                runner times out.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial.checkpoint
        if checkpoint.value is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        value = checkpoint.value
        if checkpoint.storage == Checkpoint.MEMORY:
            assert not isinstance(value, Checkpoint), type(value)
            trial.runner.restore_from_object.remote(value)
        else:
            logger.info("Trial %s: Attempting restore from %s", trial, value)
            with warn_if_slow("get_current_ip"):
                worker_ip = ray.get(trial.runner.current_ip.remote(),
                                    DEFAULT_GET_TIMEOUT)
            with warn_if_slow("sync_to_new_location"):
                trial.sync_logger_to_new_location(worker_ip)
            with warn_if_slow("restore_from_disk"):
                # TODO(ujvl): Take blocking restores out of the control loop.
                ray.get(trial.runner.restore.remote(value))
        trial.last_result = checkpoint.result
Ejemplo n.º 2
0
    def save(self, trial, storage=Checkpoint.DISK, result=None):
        """Saves the trial's state to a checkpoint."""
        result = result or trial.last_result

        if storage == Checkpoint.MEMORY:
            value = trial.runner.save_to_object.remote()
            checkpoint = Checkpoint(storage, value, result)
        else:
            with warn_if_slow("save_checkpoint_to_disk"):
                value = ray.get(trial.runner.save.remote())
                checkpoint = Checkpoint(storage, value, result)

        with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
            try:
                trial.on_checkpoint(checkpoint)
            except Exception:
                logger.exception("Trial %s: Error handling checkpoint %s",
                                 trial, checkpoint.value)
                return None
        if profile.too_slow and trial.sync_on_checkpoint:
            logger.warning(
                "Consider turning off forced head-worker trial checkpoint "
                "syncs by setting sync_on_checkpoint=False. Note that this "
                "might result in faulty trial restoration for some worker "
                "failure modes.")
        return checkpoint.value
Ejemplo n.º 3
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        if self.is_finished():
            raise TuneError("Called step when all trials finished?")
        with warn_if_slow("on_step_begin"):
            self.trial_executor.on_step_begin(self)
        next_trial = self._get_next_trial()  # blocking
        if next_trial is not None:
            with warn_if_slow("start_trial"):
                self.trial_executor.start_trial(next_trial)
        elif self.trial_executor.get_running_trials():
            self._process_events()  # blocking
        else:
            self.trial_executor.on_no_available_trials(self)

        try:
            with warn_if_slow("experiment_checkpoint"):
                self.checkpoint()
        except Exception:
            logger.exception("Trial Runner checkpointing failed.")
        self._iteration += 1

        if self._server:
            with warn_if_slow("server"):
                self._process_requests()

            if self.is_finished():
                self._server.shutdown()
        with warn_if_slow("on_step_end"):
            self.trial_executor.on_step_end(self)
Ejemplo n.º 4
0
    def _process_trial(self, trial):
        try:
            result = self.trial_executor.fetch_result(trial)

            is_duplicate = RESULT_DUPLICATE in result
            # TrialScheduler and SearchAlgorithm still receive a
            # notification because there may be special handling for
            # the `on_trial_complete` hook.
            if is_duplicate:
                logger.debug("Trial finished without logging 'done'.")
                result = trial.last_result
                result.update(done=True)

            self._total_time += result.get(TIME_THIS_ITER_S, 0)

            flat_result = flatten_dict(result)
            if trial.should_stop(flat_result):
                # Hook into scheduler
                self._scheduler_alg.on_trial_complete(self, trial, flat_result)
                self._search_alg.on_trial_complete(
                    trial.trial_id, result=flat_result)
                decision = TrialScheduler.STOP
            else:
                with warn_if_slow("scheduler.on_trial_result"):
                    decision = self._scheduler_alg.on_trial_result(
                        self, trial, flat_result)
                with warn_if_slow("search_alg.on_trial_result"):
                    self._search_alg.on_trial_result(trial.trial_id,
                                                     flat_result)
                if decision == TrialScheduler.STOP:
                    with warn_if_slow("search_alg.on_trial_complete"):
                        self._search_alg.on_trial_complete(
                            trial.trial_id,
                            result=flat_result,
                            early_terminated=True)

            if not is_duplicate:
                trial.update_last_result(
                    result, terminate=(decision == TrialScheduler.STOP))

            # Checkpoints to disk. This should be checked even if
            # the scheduler decision is STOP or PAUSE. Note that
            # PAUSE only checkpoints to memory and does not update
            # the global checkpoint state.
            self._checkpoint_trial_if_needed(
                trial, force=result.get(SHOULD_CHECKPOINT, False))

            if decision == TrialScheduler.CONTINUE:
                self.trial_executor.continue_training(trial)
            elif decision == TrialScheduler.PAUSE:
                self.trial_executor.pause_trial(trial)
            elif decision == TrialScheduler.STOP:
                self.trial_executor.export_trial_if_needed(trial)
                self.trial_executor.stop_trial(trial)
            else:
                assert False, "Invalid scheduling decision: {}".format(
                    decision)
        except Exception:
            logger.exception("Error processing event.")
            self._process_trial_failure(trial, traceback.format_exc())
Ejemplo n.º 5
0
    def _process_trial(self, trial):
        try:
            result = self.trial_executor.fetch_result(trial)
            self._total_time += result[TIME_THIS_ITER_S]

            if trial.should_stop(result):
                # Hook into scheduler
                self._scheduler_alg.on_trial_complete(self, trial, result)
                self._search_alg.on_trial_complete(
                    trial.trial_id, result=result)
                decision = TrialScheduler.STOP
            else:
                with warn_if_slow("scheduler.on_trial_result"):
                    decision = self._scheduler_alg.on_trial_result(
                        self, trial, result)
                with warn_if_slow("search_alg.on_trial_result"):
                    self._search_alg.on_trial_result(trial.trial_id, result)
                if decision == TrialScheduler.STOP:
                    with warn_if_slow("search_alg.on_trial_complete"):
                        self._search_alg.on_trial_complete(
                            trial.trial_id, early_terminated=True)

            # __duplicate__ is a magic keyword used internally to
            # avoid double-logging results when using the Function API.
            # TrialScheduler and SearchAlgorithm still receive a
            # notification because there may be special handling for
            # the `on_trial_complete` hook.
            if "__duplicate__" not in result:
                trial.update_last_result(
                    result, terminate=(decision == TrialScheduler.STOP))

            # Checkpoints to disk. This should be checked even if
            # the scheduler decision is STOP or PAUSE. Note that
            # PAUSE only checkpoints to memory and does not update
            # the global checkpoint state.
            self._checkpoint_trial_if_needed(trial)

            if decision == TrialScheduler.CONTINUE:
                self.trial_executor.continue_training(trial)
            elif decision == TrialScheduler.PAUSE:
                self.trial_executor.pause_trial(trial)
            elif decision == TrialScheduler.STOP:
                self.trial_executor.export_trial_if_needed(trial)
                self.trial_executor.stop_trial(trial)
            else:
                assert False, "Invalid scheduling decision: {}".format(
                    decision)
        except Exception:
            logger.exception("Error processing event.")
            error_msg = traceback.format_exc()
            if trial.status == Trial.RUNNING:
                if trial.should_recover():
                    self._try_recover(trial, error_msg)
                else:
                    self._scheduler_alg.on_trial_error(self, trial)
                    self._search_alg.on_trial_complete(
                        trial.trial_id, error=True)
                    self.trial_executor.stop_trial(
                        trial, error=True, error_msg=error_msg)
Ejemplo n.º 6
0
    def _process_trial(self, trial):
        try:
            result = self.trial_executor.fetch_result(trial)
            self._total_time += result[TIME_THIS_ITER_S]

            if trial.should_stop(result):
                # Hook into scheduler
                self._scheduler_alg.on_trial_complete(self, trial, result)
                self._search_alg.on_trial_complete(trial.trial_id,
                                                   result=result)
                decision = TrialScheduler.STOP

            else:
                with warn_if_slow("scheduler.on_trial_result"):
                    decision = self._scheduler_alg.on_trial_result(
                        self, trial, result)
                with warn_if_slow("search_alg.on_trial_result"):
                    self._search_alg.on_trial_result(trial.trial_id, result)
                if decision == TrialScheduler.STOP:
                    with warn_if_slow("search_alg.on_trial_complete"):
                        self._search_alg.on_trial_complete(
                            trial.trial_id, early_terminated=True)
            trial.update_last_result(
                result, terminate=(decision == TrialScheduler.STOP))

            # Checkpoints to disk. This should be checked even if
            # the scheduler decision is STOP or PAUSE. Note that
            # PAUSE only checkpoints to memory and does not update
            # the global checkpoint state.
            self._checkpoint_trial_if_needed(trial)

            if decision == TrialScheduler.CONTINUE:
                self.trial_executor.continue_training(trial)
            elif decision == TrialScheduler.PAUSE:
                self.trial_executor.pause_trial(trial)
            elif decision == TrialScheduler.STOP:
                self.trial_executor.export_trial_if_needed(trial)
                self.trial_executor.stop_trial(trial)
            else:
                assert False, "Invalid scheduling decision: {}".format(
                    decision)
        except Exception:
            logger.exception("Error processing event.")
            error_msg = traceback.format_exc()
            if trial.status == Trial.RUNNING:
                if trial.should_recover():
                    self._try_recover(trial, error_msg)
                else:
                    self._scheduler_alg.on_trial_error(self, trial)
                    self._search_alg.on_trial_complete(trial.trial_id,
                                                       error=True)
                    self.trial_executor.stop_trial(trial,
                                                   error=True,
                                                   error_msg=error_msg)
Ejemplo n.º 7
0
    def _process_trial(self, trial):
        try:
            result = self.trial_executor.fetch_result(trial)
            self._total_time += result[TIME_THIS_ITER_S]

            if trial.should_stop(result):
                # Hook into scheduler
                self._scheduler_alg.on_trial_complete(self, trial, result)
                self._search_alg.on_trial_complete(
                    trial.trial_id, result=result)
                decision = TrialScheduler.STOP

            else:
                with warn_if_slow("scheduler.on_trial_result"):
                    decision = self._scheduler_alg.on_trial_result(
                        self, trial, result)
                with warn_if_slow("search_alg.on_trial_result"):
                    self._search_alg.on_trial_result(trial.trial_id, result)
                if decision == TrialScheduler.STOP:
                    with warn_if_slow("search_alg.on_trial_complete"):
                        self._search_alg.on_trial_complete(
                            trial.trial_id, early_terminated=True)
            trial.update_last_result(
                result, terminate=(decision == TrialScheduler.STOP))

            # Checkpoints to disk. This should be checked even if
            # the scheduler decision is STOP or PAUSE. Note that
            # PAUSE only checkpoints to memory and does not update
            # the global checkpoint state.
            self._checkpoint_trial_if_needed(trial)

            if decision == TrialScheduler.CONTINUE:
                self.trial_executor.continue_training(trial)
            elif decision == TrialScheduler.PAUSE:
                self.trial_executor.pause_trial(trial)
            elif decision == TrialScheduler.STOP:
                self.trial_executor.export_trial_if_needed(trial)
                self.trial_executor.stop_trial(trial)
            else:
                assert False, "Invalid scheduling decision: {}".format(
                    decision)
        except Exception:
            logger.exception("Error processing event.")
            error_msg = traceback.format_exc()
            if trial.status == Trial.RUNNING:
                if trial.should_recover():
                    self._try_recover(trial, error_msg)
                else:
                    self._scheduler_alg.on_trial_error(self, trial)
                    self._search_alg.on_trial_complete(
                        trial.trial_id, error=True)
                    self.trial_executor.stop_trial(
                        trial, error=True, error_msg=error_msg)
Ejemplo n.º 8
0
 def _process_events(self):
     failed_trial = self.trial_executor.get_next_failed_trial()
     if failed_trial:
         with warn_if_slow("process_failed_trial"):
             self._process_trial_failure(
                 failed_trial,
                 error_msg="{} (ip: {}) detected as stale. This is likely"
                 "because the node was lost".format(failed_trial,
                                                    failed_trial.node_ip))
     else:
         trial = self.trial_executor.get_next_available_trial()  # blocking
         with warn_if_slow("process_trial"):
             self._process_trial(trial)
Ejemplo n.º 9
0
 def _process_events(self):
     failed_trial = self.trial_executor.get_next_failed_trial()
     if failed_trial:
         error_msg = (
             "{} (IP: {}) detected as stale. This is likely because the "
             "node was lost").format(failed_trial, failed_trial.node_ip)
         logger.info(error_msg)
         with warn_if_slow("process_failed_trial"):
             self._process_trial_failure(failed_trial, error_msg=error_msg)
     else:
         trial = self.trial_executor.get_next_available_trial()  # blocking
         with warn_if_slow("process_trial"):
             self._process_trial(trial)
Ejemplo n.º 10
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        if self.is_finished():
            raise TuneError("Called step when all trials finished?")
        with warn_if_slow("on_step_begin"):
            self.trial_executor.on_step_begin()
        next_trial = self._get_next_trial()  # blocking
        if next_trial is not None:
            with warn_if_slow("start_trial"):
                self.trial_executor.start_trial(next_trial)
        elif self.trial_executor.get_running_trials():
            self._process_events()  # blocking
        else:
            for trial in self._trials:
                if trial.status == Trial.PENDING:
                    if not self.has_resources(trial.resources):
                        raise TuneError(
                            ("Insufficient cluster resources to launch trial: "
                             "trial requested {} but the cluster has only {}. "
                             "Pass `queue_trials=True` in "
                             "ray.tune.run_experiments() or on the command "
                             "line to queue trials until the cluster scales "
                             "up. {}").format(
                                 trial.resources.summary_string(),
                                 self.trial_executor.resource_string(),
                                 trial._get_trainable_cls().resource_help(
                                     trial.config)))
                elif trial.status == Trial.PAUSED:
                    raise TuneError(
                        "There are paused trials, but no more pending "
                        "trials with sufficient resources.")

        try:
            with warn_if_slow("experiment_checkpoint"):
                self.checkpoint()
        except Exception:
            logger.exception("Trial Runner checkpointing failed.")
        self._iteration += 1

        if self._server:
            with warn_if_slow("server"):
                self._process_requests()

            if self.is_finished():
                self._server.shutdown()
        with warn_if_slow("on_step_end"):
            self.trial_executor.on_step_end()
Ejemplo n.º 11
0
    def step(self):
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().
        """
        if self.is_finished():
            raise TuneError("Called step when all trials finished?")
        with warn_if_slow("on_step_begin"):
            self.trial_executor.on_step_begin()
        next_trial = self._get_next_trial()  # blocking
        if next_trial is not None:
            with warn_if_slow("start_trial"):
                self.trial_executor.start_trial(next_trial)
        elif self.trial_executor.get_running_trials():
            self._process_events()  # blocking
        else:
            for trial in self._trials:
                if trial.status == Trial.PENDING:
                    if not self.has_resources(trial.resources):
                        raise TuneError(
                            ("Insufficient cluster resources to launch trial: "
                             "trial requested {} but the cluster has only {}. "
                             "Pass `queue_trials=True` in "
                             "ray.tune.run() or on the command "
                             "line to queue trials until the cluster scales "
                             "up. {}").format(
                                 trial.resources.summary_string(),
                                 self.trial_executor.resource_string(),
                                 trial._get_trainable_cls().resource_help(
                                     trial.config)))
                elif trial.status == Trial.PAUSED:
                    raise TuneError(
                        "There are paused trials, but no more pending "
                        "trials with sufficient resources.")

        try:
            with warn_if_slow("experiment_checkpoint"):
                self.checkpoint()
        except Exception:
            logger.exception("Trial Runner checkpointing failed.")
        self._iteration += 1

        if self._server:
            with warn_if_slow("server"):
                self._process_requests()

            if self.is_finished():
                self._server.shutdown()
        with warn_if_slow("on_step_end"):
            self.trial_executor.on_step_end()
Ejemplo n.º 12
0
    def save(self, trial, storage=Checkpoint.DISK):
        """Saves the trial's state to a checkpoint."""
        trial._checkpoint.storage = storage
        trial._checkpoint.last_result = trial.last_result
        if storage == Checkpoint.MEMORY:
            trial._checkpoint.value = trial.runner.save_to_object.remote()
        else:
            # Keeps only highest performing checkpoints if enabled
            if trial.keep_checkpoints_num:
                try:
                    last_attr_val = trial.last_result[
                        trial.checkpoint_score_attr]
                    if (trial.compare_checkpoints(last_attr_val)
                            and not math.isnan(last_attr_val)):
                        trial.best_checkpoint_attr_value = last_attr_val
                        self._checkpoint_and_erase(trial)
                except KeyError:
                    logger.warning(
                        "Result dict has no key: {}. keep"
                        "_checkpoints_num flag will not work".format(
                            trial.checkpoint_score_attr))
            else:
                with warn_if_slow("save_to_disk"):
                    trial._checkpoint.value = ray.get(
                        trial.runner.save.remote())

        return trial._checkpoint.value
Ejemplo n.º 13
0
    def reset_trial(self, trial, new_config, new_experiment_tag):
        """Tries to invoke `Trainable.reset_config()` to reset trial.

        Args:
            trial (Trial): Trial to be reset.
            new_config (dict): New configuration for Trial
                trainable.
            new_experiment_tag (str): New experiment name
                for trial.

        Returns:
            True if `reset_config` is successful else False.
        """
        trial.experiment_tag = new_experiment_tag
        trial.config = new_config
        trainable = trial.runner
        with warn_if_slow("reset_config"):
            try:
                reset_val = ray.get(
                    trainable.reset_config.remote(new_config),
                    DEFAULT_GET_TIMEOUT)
            except RayTimeoutError:
                logger.exception("Trial %s: reset_config timed out.")
                return False
        return reset_val
Ejemplo n.º 14
0
    def restore(self, trial, checkpoint=None):
        """Restores training state from a given model checkpoint.

        This will also sync the trial results to a new location
        if restoring on a different node.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial._checkpoint
        if checkpoint is None or checkpoint.value is None:
            return True
        if trial.runner is None:
            logger.error("Unable to restore - no runner.")
            self.set_status(trial, Trial.ERROR)
            return False
        try:
            value = checkpoint.value
            if checkpoint.storage == Checkpoint.MEMORY:
                assert type(value) != Checkpoint, type(value)
                trial.runner.restore_from_object.remote(value)
            else:
                worker_ip = ray.get(trial.runner.current_ip.remote())
                trial.sync_logger_to_new_location(worker_ip)
                with warn_if_slow("restore_from_disk"):
                    ray.get(trial.runner.restore.remote(value))
            trial.last_result = checkpoint.last_result
            return True
        except Exception:
            logger.exception("Error restoring runner for Trial %s.", trial)
            self.set_status(trial, Trial.ERROR)
            return False
Ejemplo n.º 15
0
    def restore(self, trial, checkpoint=None):
        """Restores training state from a given model checkpoint.

        This will also sync the trial results to a new location
        if restoring on a different node.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial.checkpoint
        if checkpoint is None or checkpoint.value is None:
            return True
        if trial.runner is None:
            logger.error(
                "Trial %s: Unable to restore - no runner. "
                "Setting status to ERROR.", trial)
            self.set_status(trial, Trial.ERROR)
            return False
        try:
            value = checkpoint.value
            if checkpoint.storage == Checkpoint.MEMORY:
                assert type(value) != Checkpoint, type(value)
                trial.runner.restore_from_object.remote(value)
            else:
                logger.info("Trial %s: Attempting restoration from %s", trial,
                            checkpoint.value)
                with warn_if_slow("get_current_ip"):
                    worker_ip = ray.get(trial.runner.current_ip.remote(),
                                        DEFAULT_GET_TIMEOUT)
                with warn_if_slow("sync_to_new_location"):
                    trial.sync_logger_to_new_location(worker_ip)
                with warn_if_slow("restore_from_disk"):
                    ray.get(trial.runner.restore.remote(value),
                            DEFAULT_GET_TIMEOUT)
        except RayTimeoutError:
            logger.exception(
                "Trial %s: Unable to restore - runner task timed "
                "out. Setting status to ERROR", trial)
            self.set_status(trial, Trial.ERROR)
            return False
        except Exception:
            logger.exception(
                "Trial %s: Unable to restore. Setting status to ERROR", trial)
            self.set_status(trial, Trial.ERROR)
            return False

        trial.last_result = checkpoint.result
        return True
Ejemplo n.º 16
0
    def _requeue_trial(self, trial):
        """Notification to TrialScheduler and requeue trial.

        This does not notify the SearchAlgorithm because the function
        evaluation is still in progress.
        """
        self._scheduler_alg.on_trial_error(self, trial)
        self.trial_executor.set_status(trial, Trial.PENDING)
        with warn_if_slow("scheduler.on_trial_add"):
            self._scheduler_alg.on_trial_add(self, trial)
Ejemplo n.º 17
0
 def save(self, trial, storage=Checkpoint.DISK):
     """Saves the trial's state to a checkpoint."""
     trial._checkpoint.storage = storage
     trial._checkpoint.last_result = trial.last_result
     if storage == Checkpoint.MEMORY:
         trial._checkpoint.value = trial.runner.save_to_object.remote()
     else:
         with warn_if_slow("save_to_disk"):
             trial._checkpoint.value = ray.get(trial.runner.save.remote())
     return trial._checkpoint.value
Ejemplo n.º 18
0
 def _process_events(self):
     failed_trial = self.trial_executor.get_next_failed_trial()
     if failed_trial:
         error_msg = (
             "{} (IP: {}) detected as stale. This is likely because the "
             "node was lost").format(failed_trial, failed_trial.node_ip)
         logger.info(error_msg)
         with warn_if_slow("process_failed_trial"):
             self._process_trial_failure(failed_trial, error_msg=error_msg)
     else:
         # TODO(ujvl): Consider combining get_next_available_trial and
         #  fetch_result functionality so that we don't timeout on fetch.
         trial = self.trial_executor.get_next_available_trial()  # blocking
         if trial.is_restoring:
             with warn_if_slow("process_trial_restore"):
                 self._process_trial_restore(trial)
         else:
             with warn_if_slow("process_trial"):
                 self._process_trial(trial)
Ejemplo n.º 19
0
    def _requeue_trial(self, trial):
        """Notification to TrialScheduler and requeue trial.

        This does not notify the SearchAlgorithm because the function
        evaluation is still in progress.
        """
        self._scheduler_alg.on_trial_error(self, trial)
        self.trial_executor.set_status(trial, Trial.PENDING)
        with warn_if_slow("scheduler.on_trial_add"):
            self._scheduler_alg.on_trial_add(self, trial)
Ejemplo n.º 20
0
    def _get_next_trial(self):
        """Replenishes queue.

        Blocks if all trials queued have finished, but search algorithm is
        still not finished.
        """
        trials_done = all(trial.is_finished() for trial in self._trials)
        wait_for_trial = trials_done and not self._search_alg.is_finished()
        self._update_trial_queue(blocking=wait_for_trial)
        with warn_if_slow("choose_trial_to_run"):
            trial = self._scheduler_alg.choose_trial_to_run(self)
        return trial
Ejemplo n.º 21
0
    def _get_next_trial(self):
        """Replenishes queue.

        Blocks if all trials queued have finished, but search algorithm is
        still not finished.
        """
        trials_done = all(trial.is_finished() for trial in self._trials)
        wait_for_trial = trials_done and not self._search_alg.is_finished()
        self._update_trial_queue(blocking=wait_for_trial)
        with warn_if_slow("choose_trial_to_run"):
            trial = self._scheduler_alg.choose_trial_to_run(self)
        return trial
Ejemplo n.º 22
0
    def add_trial(self, trial):
        """Adds a new trial to this TrialRunner.

        Trials may be added at any time.

        Args:
            trial (Trial): Trial to queue.
        """
        trial.set_verbose(self._verbose)
        self._trials.append(trial)
        with warn_if_slow("scheduler.on_trial_add"):
            self._scheduler_alg.on_trial_add(self, trial)
        self.trial_executor.try_checkpoint_metadata(trial)
Ejemplo n.º 23
0
    def add_trial(self, trial):
        """Adds a new trial to this TrialRunner.

        Trials may be added at any time.

        Args:
            trial (Trial): Trial to queue.
        """
        trial.set_verbose(self._verbose)
        self._trials.append(trial)
        with warn_if_slow("scheduler.on_trial_add"):
            self._scheduler_alg.on_trial_add(self, trial)
        self.trial_executor.try_checkpoint_metadata(trial)
Ejemplo n.º 24
0
    def restore(self, trial, checkpoint=None):
        """Restores training state from a given model checkpoint.

        This will also sync the trial results to a new location
        if restoring on a different node.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial._checkpoint
        if checkpoint is None or checkpoint.value is None:
            return True
        if trial.runner is None:
            logger.error("Unable to restore - no runner.")
            self.set_status(trial, Trial.ERROR)
            return False
        try:
            value = checkpoint.value
            if checkpoint.storage == Checkpoint.MEMORY:
                assert type(value) != Checkpoint, type(value)
                trial.runner.restore_from_object.remote(value)
            else:
                # TODO: Somehow, the call to get the current IP on the
                # remote actor can be very slow - a better fix would
                # be to use an actor table to detect the IP of the Trainable
                # and rsync the files there.
                # See https://github.com/ray-project/ray/issues/5168
                with warn_if_slow("get_current_ip"):
                    worker_ip = ray.get(trial.runner.current_ip.remote())
                with warn_if_slow("sync_to_new_location"):
                    trial.sync_logger_to_new_location(worker_ip)
                with warn_if_slow("restore_from_disk"):
                    ray.get(trial.runner.restore.remote(value))
            trial.last_result = checkpoint.last_result
            return True
        except Exception:
            logger.exception("Error restoring runner for Trial %s.", trial)
            self.set_status(trial, Trial.ERROR)
            return False
Ejemplo n.º 25
0
    def _checkpoint_and_erase(self, trial):
        """Checkpoints the model and erases old checkpoints
            if needed.
        Parameters
        ----------
            trial : trial to save
        """

        with warn_if_slow("save_to_disk"):
            trial._checkpoint.value = ray.get(trial.runner.save.remote())

        if len(trial.history) >= trial.keep_checkpoints_num:
            ray.get(trial.runner.delete_checkpoint.remote(trial.history[-1]))
            trial.history.pop()

        trial.history.insert(0, trial._checkpoint.value)
Ejemplo n.º 26
0
    def fetch_result(self, trial):
        """Fetches one result of the running trials.

        Returns:
            Result of the most recent trial training run."""
        trial_future = self._find_item(self._running, trial)
        if not trial_future:
            raise ValueError("Trial was not running.")
        self._running.pop(trial_future[0])
        with warn_if_slow("fetch_result"):
            result = ray.get(trial_future[0])

        # For local mode
        if isinstance(result, _LocalWrapper):
            result = result.unwrap()
        return result
Ejemplo n.º 27
0
    def reset_trial(self, trial, new_config, new_experiment_tag):
        """Tries to invoke `Trainable.reset_config()` to reset trial.

        Args:
            trial (Trial): Trial to be reset.
            new_config (dict): New configuration for Trial
                trainable.
            new_experiment_tag (str): New experiment name
                for trial.

        Returns:
            True if `reset_config` is successful else False.
        """
        trial.experiment_tag = new_experiment_tag
        trial.config = new_config
        trainable = trial.runner
        with warn_if_slow("reset_config"):
            reset_val = ray.get(trainable.reset_config.remote(new_config))
        return reset_val
Ejemplo n.º 28
0
    def _requeue_trial(self, trial):
        """Notification to TrialScheduler and requeue trial.

        This does not notify the SearchAlgorithm because the function
        evaluation is still in progress.

        """
        self._scheduler_alg.on_trial_error(self, trial)
        self.trial_executor.set_status(trial, Trial.PENDING)

        # TODO(rliaw): Right now, this pushes the trial to the end of queue
        # because restoration can be expensive. However, this is not
        # ideal since it just hides the issue - a better fix would
        # be to use an actor table to detect the IP of the Trainable
        # and rsync the files there.
        # See https://github.com/ray-project/ray/issues/5168
        self._trials.pop(self._trials.index(trial))
        self._trials.append(trial)

        with warn_if_slow("scheduler.on_trial_add"):
            self._scheduler_alg.on_trial_add(self, trial)
Ejemplo n.º 29
0
    def _checkpoint_and_erase(self, subdir, trial):
        """Checkpoints the model and erases old checkpoints
            if needed.

        Parameters
        ----------
            subdir string: either "" or "best"
            trial : trial to save
        """

        with warn_if_slow("save_to_disk"):
            trial._checkpoint.value, folder_path = ray.get(
                trial.runner.save_checkpoint_relative.remote(subdir))

        if trial.prefix[subdir]["limit"]:
            if len(trial.prefix[subdir]
                   ["history"]) == trial.prefix[subdir]["limit"]:
                ray.get(
                    trial.runner.delete_checkpoint.remote(
                        trial.prefix[subdir]["history"][-1]))
                trial.prefix[subdir]["history"].pop()
            trial.prefix[subdir]["history"].insert(0, folder_path)
Ejemplo n.º 30
0
 def _process_events(self):
     trial = self.trial_executor.get_next_available_trial()  # blocking
     with warn_if_slow("process_trial"):
         self._process_trial(trial)
Ejemplo n.º 31
0
 def _process_events(self):
     trial = self.trial_executor.get_next_available_trial()  # blocking
     with warn_if_slow("process_trial"):
         self._process_trial(trial)