Example #1
0
class AsyncExecutor:
    "Async version of tune.run(...)"

    def __init__(self,
                 run_or_experiment,
                 name=None,
                 stop=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 loggers=None,
                 sync_to_cloud=None,
                 sync_to_driver=False,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 global_checkpoint_period=10,
                 export_formats=None,
                 max_failures=0,
                 fail_fast=True,
                 restore=None,
                 search_alg=None,
                 scheduler=None,
                 with_server=False,
                 server_port=TuneServer.DEFAULT_PORT,
                 verbose=0,
                 progress_reporter=None,
                 resume=False,
                 queue_trials=False,
                 reuse_actors=False,
                 trial_executor=None,
                 raise_on_failed_trial=True,
                 return_trials=False,
                 ray_auto_init=True,
                 shuffle=False):

        if loggers is None:
            loggers = [JsonLogger, CSVLogger]
        config = _transform_config(config)

        is_trainable = False
        try:
            if issubclass(run_or_experiment, Trainable):
                is_trainable = True
        except TypeError:
            pass

        if not is_trainable:
            run_or_experiment = wrap_function(run_or_experiment)

        self.trial_executor = trial_executor or RayTrialExecutor(
            queue_trials=queue_trials,
            reuse_actors=reuse_actors,
            ray_auto_init=ray_auto_init)

        experiments = [run_or_experiment]
        self.logger = logging.getLogger(__name__)

        for i, exp in enumerate(experiments):
            if not isinstance(exp, Experiment):
                run_identifier = Experiment.register_if_needed(exp)
                experiments[i] = Experiment(
                    name=name,
                    run=run_identifier,
                    stop=stop,
                    config=config,
                    resources_per_trial=resources_per_trial,
                    num_samples=num_samples,
                    local_dir=local_dir,
                    upload_dir=upload_dir,
                    sync_to_driver=sync_to_driver,
                    trial_name_creator=trial_name_creator,
                    loggers=loggers,
                    checkpoint_freq=checkpoint_freq,
                    checkpoint_at_end=checkpoint_at_end,
                    sync_on_checkpoint=sync_on_checkpoint,
                    keep_checkpoints_num=keep_checkpoints_num,
                    checkpoint_score_attr=checkpoint_score_attr,
                    export_formats=export_formats,
                    max_failures=max_failures,
                    restore=restore)

        if fail_fast and max_failures != 0:
            raise ValueError("max_failures must be 0 if fail_fast=True.")

        self.runner = TrialRunner(
            search_alg=search_alg or BasicVariantGenerator(shuffle=shuffle),
            scheduler=scheduler or FIFOScheduler(),
            local_checkpoint_dir=experiments[0].checkpoint_dir,
            remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
            sync_to_cloud=sync_to_cloud,
            stopper=experiments[0].stopper,
            checkpoint_period=global_checkpoint_period,
            resume=resume,
            launch_web_server=with_server,
            server_port=server_port,
            verbose=bool(verbose > 1),
            fail_fast=fail_fast,
            trial_executor=self.trial_executor)

        for exp in experiments:
            self.runner.add_experiment(exp)

        self._is_worker_stopped = threading.Event()
        self._worker_exc = None
        self._worker = threading.Thread(target=self.step_worker, daemon=True)
        self._worker.start()

        atexit.register(self.stop)

    def step_worker(self):
        while not self._is_worker_stopped.is_set(
        ) and not self.runner.is_finished():
            try:
                self.runner.step()  # blocking call!
            except Exception:
                self._is_worker_stopped.set()
                self._worker_exc = sys.exc_info()

    def stop(self, timeout=5):

        self.runner.request_stop_experiment()
        self._is_worker_stopped.set()

        # FORCE KILL, mute all the errors from the dying subprocesses
        for t in self.trial_executor.get_running_trials():
            try:  # TODO ?? ValueError: ray.kill() only supported for actors. Got: .
                ray.kill(t.runner)
            except Exception:
                pass
            self.trial_executor.stop_trial(t, True)
            time.sleep(0.5)  # wait for stdio sync

        self._worker.join(timeout=timeout)
        assert self._worker.is_alive() is False

    def get_trials(self):
        return self.runner.get_trials()

    def get_results(self):
        # Reraise from the worker thread
        if self._worker_exc:
            raise self._worker_exc[1].with_traceback(self._worker_exc[2])

        trials = self.runner.get_trials()
        try:
            self.runner.checkpoint(force=True)
        except Exception:
            self.logger.exception("Trial Runner checkpointing failed.")
        wait_for_sync()

        completed_results = []
        n_incompleted = 0
        for trial in trials:

            if len(trial.metric_analysis) > 0:
                score = trial.metric_analysis[SCORE_NAME]['last']
                it = trial.metric_analysis[TRAINING_STEP_NAME]['last']

                result = {
                    SCORE_NAME: score,
                    TRAINING_STEP_NAME: it,
                    'logdir': trial.logdir,
                    'config': trial.config,
                }
                completed_results.append(result)

            if trial.status != Trial.TERMINATED:
                n_incompleted += 1
                continue

        return completed_results, n_incompleted