Пример #1
0
    def _tune_run(self, config, resources_per_trial):
        """Wrapper to call ``tune.run``. Multiple estimators are generated when
        early stopping is possible, whereas a single estimator is
        generated when  early stopping is not possible.

        Args:
            config (dict): Configurations such as hyperparameters to run
                ``tune.run`` on.
            resources_per_trial (dict): Resources to use per trial within Ray.
                Accepted keys are `cpu`, `gpu` and custom resources, and values
                are integers specifying the number of each resource to use.

        Returns:
            analysis (`ExperimentAnalysis`): Object returned by
                `tune.run`.

        """
        trainable = _Trainable
        if self.pipeline_auto_early_stop and check_is_pipeline(
                self.estimator) and self.early_stopping:
            trainable = _PipelineTrainable

        if self.early_stopping is not None:
            config["estimator_ids"] = [
                ray.put(self.estimator) for _ in range(self.n_splits)
            ]
        else:
            config["estimator_ids"] = [ray.put(self.estimator)]

        stopper = MaximumIterationStopper(max_iter=self.max_iters)
        if self.stopper:
            stopper = CombinedStopper(stopper, self.stopper)

        run_args = dict(scheduler=self.early_stopping,
                        reuse_actors=True,
                        verbose=self.verbose,
                        stop=stopper,
                        config=config,
                        fail_fast="raise",
                        resources_per_trial=resources_per_trial,
                        local_dir=os.path.expanduser(self.local_dir),
                        loggers=self.loggers,
                        time_budget_s=self.time_budget_s)

        if isinstance(self.param_grid, list):
            run_args.update(
                dict(search_alg=ListSearcher(self.param_grid),
                     num_samples=self._list_grid_num_samples()))

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore",
                                    message="fail_fast='raise' "
                                    "detected.")
            analysis = tune.run(trainable, **run_args)
        return analysis
Пример #2
0
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 time_budget_s=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 loggers=None,
                 log_to_file=False,
                 sync_to_driver=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None):

        if loggers is not None:
            # Most users won't run into this as `tune.run()` does not pass
            # the argument anymore. However, we will want to inform users
            # if they instantiate their `Experiment` objects themselves.
            raise ValueError(
                "Passing `loggers` to an `Experiment` is deprecated. Use "
                "an `ExperimentLogger` callback instead, e.g. by passing the "
                "`Logger` classes to `tune.logger.LegacyExperimentLogger` and "
                "passing this as part of the `callback` parameter to "
                "`tune.run()`.")

        config = config or {}
        if callable(run) and detect_checkpoint_function(run):
            if checkpoint_at_end:
                raise ValueError("'checkpoint_at_end' cannot be used with a "
                                 "checkpointable function. You can specify "
                                 "and register checkpoints within "
                                 "your trainable function.")
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function.")
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier

        # If the name has been set explicitly, we don't want to create
        # dated directories. The same is true for string run identifiers.
        if int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0)) == 1 or name \
           or isinstance(run, str):
            self.dir_name = self.name
        else:
            self.dir_name = "{}_{}".format(self.name, date_str())

        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir,
                                                      self.dir_name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(self._stopper,
                                                TimeoutStopper(time_budget_s))
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run":
            self._run_identifier,
            "stop":
            stopping_criteria,
            "config":
            config,
            "resources_per_trial":
            resources_per_trial,
            "num_samples":
            num_samples,
            "local_dir":
            os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir":
            upload_dir,
            "remote_checkpoint_dir":
            self.remote_checkpoint_dir,
            "trial_name_creator":
            trial_name_creator,
            "trial_dirname_creator":
            trial_dirname_creator,
            "loggers":
            loggers,
            "log_to_file": (stdout_file, stderr_file),
            "sync_to_driver":
            sync_to_driver,
            "checkpoint_freq":
            checkpoint_freq,
            "checkpoint_at_end":
            checkpoint_at_end,
            "sync_on_checkpoint":
            sync_on_checkpoint,
            "keep_checkpoints_num":
            keep_checkpoints_num,
            "checkpoint_score_attr":
            checkpoint_score_attr,
            "export_formats":
            export_formats or [],
            "max_failures":
            max_failures,
            "restore":
            os.path.abspath(os.path.expanduser(restore)) if restore else None
        }
        self.spec = spec
Пример #3
0
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 time_budget_s=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 loggers=None,
                 log_to_file=False,
                 sync_to_driver=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None):

        config = config or {}
        if callable(run) and detect_checkpoint_function(run):
            if checkpoint_at_end:
                raise ValueError("'checkpoint_at_end' cannot be used with a "
                                 "checkpointable function. You can specify "
                                 "and register checkpoints within "
                                 "your trainable function.")
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function.")
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier
        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir, self.name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(self._stopper,
                                                TimeoutStopper(time_budget_s))
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        _raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run": self._run_identifier,
            "stop": stopping_criteria,
            "config": config,
            "resources_per_trial": resources_per_trial,
            "num_samples": num_samples,
            "local_dir": os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir": upload_dir,
            "remote_checkpoint_dir": self.remote_checkpoint_dir,
            "trial_name_creator": trial_name_creator,
            "trial_dirname_creator": trial_dirname_creator,
            "loggers": loggers,
            "log_to_file": (stdout_file, stderr_file),
            "sync_to_driver": sync_to_driver,
            "checkpoint_freq": checkpoint_freq,
            "checkpoint_at_end": checkpoint_at_end,
            "sync_on_checkpoint": sync_on_checkpoint,
            "keep_checkpoints_num": keep_checkpoints_num,
            "checkpoint_score_attr": checkpoint_score_attr,
            "export_formats": export_formats or [],
            "max_failures": max_failures,
            "restore": os.path.abspath(os.path.expanduser(restore))
            if restore else None
        }
        self.spec = spec
Пример #4
0
    def __init__(self,
                 name,
                 run,
                 stop=None,
                 time_budget_s=None,
                 config=None,
                 resources_per_trial=None,
                 num_samples=1,
                 local_dir=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 trial_dirname_creator=None,
                 log_to_file=False,
                 sync_to_driver=None,
                 sync_to_cloud=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 sync_on_checkpoint=True,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr=None,
                 export_formats=None,
                 max_failures=0,
                 restore=None):

        config = config or {}
        if callable(run) and not inspect.isclass(run) and \
                detect_checkpoint_function(run):
            if checkpoint_at_end:
                raise ValueError("'checkpoint_at_end' cannot be used with a "
                                 "checkpointable function. You can specify "
                                 "and register checkpoints within "
                                 "your trainable function.")
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function.")
        self._run_identifier = Experiment.register_if_needed(run)
        self.name = name or self._run_identifier

        # If the name has been set explicitly, we don't want to create
        # dated directories. The same is true for string run identifiers.
        if int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0)) == 1 or name \
           or isinstance(run, str):
            self.dir_name = self.name
        else:
            self.dir_name = "{}_{}".format(self.name, date_str())

        if upload_dir:
            self.remote_checkpoint_dir = os.path.join(upload_dir,
                                                      self.dir_name)
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, list):
            if any(not isinstance(s, Stopper) for s in stop):
                raise ValueError(
                    "If you pass a list as the `stop` argument to "
                    "`tune.run()`, each element must be an instance of "
                    "`tune.stopper.Stopper`.")
            self._stopper = CombinedStopper(*stop)
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif issubclass(type(stop), Stopper):
                self._stopper = stop
            else:
                raise ValueError("Provided stop object must be either a dict, "
                                 "a function, or a subclass of "
                                 "`ray.tune.Stopper`.")
        else:
            raise ValueError("Invalid stop criteria: {}. Must be a "
                             "callable or dict".format(stop))

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(self._stopper,
                                                TimeoutStopper(time_budget_s))
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        _raise_on_durable(self.is_durable_trainable, sync_to_driver,
                          upload_dir)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run": self._run_identifier,
            "stop": stopping_criteria,
            "config": config,
            "resources_per_trial": resources_per_trial,
            "num_samples": num_samples,
            "local_dir": os.path.abspath(
                os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR)),
            "upload_dir": upload_dir,
            "remote_checkpoint_dir": self.remote_checkpoint_dir,
            "trial_name_creator": trial_name_creator,
            "trial_dirname_creator": trial_dirname_creator,
            "log_to_file": (stdout_file, stderr_file),
            "sync_to_driver": sync_to_driver,
            "sync_to_cloud": sync_to_cloud,
            "checkpoint_freq": checkpoint_freq,
            "checkpoint_at_end": checkpoint_at_end,
            "sync_on_checkpoint": sync_on_checkpoint,
            "keep_checkpoints_num": keep_checkpoints_num,
            "checkpoint_score_attr": checkpoint_score_attr,
            "export_formats": export_formats or [],
            "max_failures": max_failures,
            "restore": os.path.abspath(os.path.expanduser(restore))
            if restore else None
        }
        self.spec = spec
Пример #5
0
    def __init__(
        self,
        name,
        run,
        stop=None,
        time_budget_s=None,
        config=None,
        resources_per_trial=None,
        num_samples=1,
        local_dir=None,
        _experiment_checkpoint_dir: Optional[str] = None,
        sync_config=None,
        trial_name_creator=None,
        trial_dirname_creator=None,
        log_to_file=False,
        checkpoint_freq=0,
        checkpoint_at_end=False,
        keep_checkpoints_num=None,
        checkpoint_score_attr=None,
        export_formats=None,
        max_failures=0,
        restore=None,
    ):

        local_dir = _get_local_dir_with_expand_user(local_dir)
        # `_experiment_checkpoint_dir` is for internal use only for better
        # support of Tuner API.
        # If set, it should be a subpath under `local_dir`. Also deduce `dir_name`.
        self._experiment_checkpoint_dir = _experiment_checkpoint_dir
        if _experiment_checkpoint_dir:
            experiment_checkpoint_dir_path = Path(_experiment_checkpoint_dir)
            local_dir_path = Path(local_dir)
            assert local_dir_path in experiment_checkpoint_dir_path.parents
            # `dir_name` is set by `_experiment_checkpoint_dir` indirectly.
            self.dir_name = os.path.relpath(_experiment_checkpoint_dir, local_dir)

        config = config or {}
        sync_config = sync_config or SyncConfig()
        if (
            callable(run)
            and not inspect.isclass(run)
            and detect_checkpoint_function(run)
        ):
            if checkpoint_at_end:
                raise ValueError(
                    "'checkpoint_at_end' cannot be used with a "
                    "checkpointable function. You can specify "
                    "and register checkpoints within "
                    "your trainable function."
                )
            if checkpoint_freq:
                raise ValueError(
                    "'checkpoint_freq' cannot be used with a "
                    "checkpointable function. You can specify checkpoints "
                    "within your trainable function."
                )
        try:
            self._run_identifier = Experiment.register_if_needed(run)
        except grpc.RpcError as e:
            if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
                raise TuneError(
                    f"The Trainable/training function is too large for grpc resource "
                    f"limit. Check that its definition is not implicitly capturing a "
                    f"large array or other object in scope. "
                    f"Tip: use tune.with_parameters() to put large objects "
                    f"in the Ray object store. \n"
                    f"Original exception: {traceback.format_exc()}"
                )
            else:
                raise e

        self.name = name or self._run_identifier

        if not _experiment_checkpoint_dir:
            self.dir_name = _get_dir_name(run, name, self.name)

        assert self.dir_name

        if sync_config.upload_dir:
            self.remote_checkpoint_dir = os.path.join(
                sync_config.upload_dir, self.dir_name
            )
        else:
            self.remote_checkpoint_dir = None

        self._stopper = None
        stopping_criteria = {}
        if not stop:
            pass
        elif isinstance(stop, list):
            bad_stoppers = [s for s in stop if not isinstance(s, Stopper)]
            if bad_stoppers:
                stopper_types = [type(s) for s in stop]
                raise ValueError(
                    "If you pass a list as the `stop` argument to "
                    "`tune.run()`, each element must be an instance of "
                    f"`tune.stopper.Stopper`. Got {stopper_types}."
                )
            self._stopper = CombinedStopper(*stop)
        elif isinstance(stop, dict):
            stopping_criteria = stop
        elif callable(stop):
            if FunctionStopper.is_valid_function(stop):
                self._stopper = FunctionStopper(stop)
            elif isinstance(stop, Stopper):
                self._stopper = stop
            else:
                raise ValueError(
                    "Provided stop object must be either a dict, "
                    "a function, or a subclass of "
                    f"`ray.tune.Stopper`. Got {type(stop)}."
                )
        else:
            raise ValueError(
                f"Invalid stop criteria: {stop}. Must be a "
                f"callable or dict. Got {type(stop)}."
            )

        if time_budget_s:
            if self._stopper:
                self._stopper = CombinedStopper(
                    self._stopper, TimeoutStopper(time_budget_s)
                )
            else:
                self._stopper = TimeoutStopper(time_budget_s)

        stdout_file, stderr_file = _validate_log_to_file(log_to_file)

        spec = {
            "run": self._run_identifier,
            "stop": stopping_criteria,
            "time_budget_s": time_budget_s,
            "config": config,
            "resources_per_trial": resources_per_trial,
            "num_samples": num_samples,
            "local_dir": local_dir,
            "sync_config": sync_config,
            "remote_checkpoint_dir": self.remote_checkpoint_dir,
            "trial_name_creator": trial_name_creator,
            "trial_dirname_creator": trial_dirname_creator,
            "log_to_file": (stdout_file, stderr_file),
            "checkpoint_freq": checkpoint_freq,
            "checkpoint_at_end": checkpoint_at_end,
            "keep_checkpoints_num": keep_checkpoints_num,
            "checkpoint_score_attr": checkpoint_score_attr,
            "export_formats": export_formats or [],
            "max_failures": max_failures,
            "restore": os.path.abspath(os.path.expanduser(restore))
            if restore
            else None,
        }
        self.spec = spec
Пример #6
0
    def _tune_run(self, config, resources_per_trial):
        """Wrapper to call ``tune.run``. Multiple estimators are generated when
        early stopping is possible, whereas a single estimator is
        generated when early stopping is not possible.

        Args:
            config (dict): Configurations such as hyperparameters to run
            ``tune.run`` on.
            resources_per_trial (dict): Resources to use per trial within Ray.
                Accepted keys are `cpu`, `gpu` and custom resources, and values
                are integers specifying the number of each resource to use.

        Returns:
            analysis (`ExperimentAnalysis`): Object returned by
                `tune.run`.

        """
        if self.seed is not None:
            random.seed(self.seed)
            np.random.seed(self.seed)

        trainable = _Trainable
        if self.pipeline_auto_early_stop and check_is_pipeline(
                self.estimator) and self.early_stopping:
            trainable = _PipelineTrainable

        max_iter = self.max_iters
        if self.early_stopping is not None:
            config["estimator_list"] = [
                clone(self.estimator) for _ in range(self.n_splits)
            ]
            if hasattr(self.early_stopping, "_max_t_attr"):
                # we want to delegate stopping to schedulers which
                # support it, but we want it to stop eventually, just in case
                # the solution is to make the stop condition very big
                max_iter = self.max_iters * 10
        else:
            config["estimator_list"] = [self.estimator]

        stopper = MaximumIterationStopper(max_iter=max_iter)
        if self.stopper:
            stopper = CombinedStopper(stopper, self.stopper)

        run_args = dict(scheduler=self.early_stopping,
                        reuse_actors=True,
                        verbose=self.verbose,
                        stop=stopper,
                        num_samples=self.n_trials,
                        config=config,
                        fail_fast="raise",
                        resources_per_trial=resources_per_trial,
                        local_dir=os.path.expanduser(self.local_dir),
                        loggers=self.loggers,
                        time_budget_s=self.time_budget_s)

        if self.search_optimization == "random":
            if isinstance(self.param_distributions, list):
                search_algo = RandomListSearcher(self.param_distributions)
            else:
                search_algo = BasicVariantGenerator()
            run_args["search_alg"] = search_algo
        else:
            search_space = None
            override_search_space = True
            if self._is_param_distributions_all_tune_domains():
                run_args["config"].update(self.param_distributions)
                override_search_space = False

            search_kwargs = self.search_kwargs.copy()
            search_kwargs.update(metric=self._metric_name, mode="max")

            if self.search_optimization == "bayesian":
                from ray.tune.suggest.skopt import SkOptSearch
                if override_search_space:
                    search_space = self.param_distributions
                search_algo = SkOptSearch(space=search_space, **search_kwargs)
                run_args["search_alg"] = search_algo

            elif self.search_optimization == "bohb":
                from ray.tune.suggest.bohb import TuneBOHB
                if override_search_space:
                    search_space = self._get_bohb_config_space()
                if self.seed:
                    warnings.warn("'seed' is not implemented for BOHB.")
                search_algo = TuneBOHB(space=search_space, **search_kwargs)
                # search_algo = TuneBOHB(
                #     space=search_space, seed=self.seed, **search_kwargs)
                run_args["search_alg"] = search_algo

            elif self.search_optimization == "optuna":
                from ray.tune.suggest.optuna import OptunaSearch
                from optuna.samplers import TPESampler
                sampler = TPESampler(seed=self.seed)
                if override_search_space:
                    search_space = self._get_optuna_params()
                search_algo = OptunaSearch(space=search_space,
                                           sampler=sampler,
                                           **search_kwargs)
                run_args["search_alg"] = search_algo

            elif self.search_optimization == "hyperopt":
                from ray.tune.suggest.hyperopt import HyperOptSearch
                if override_search_space:
                    search_space = self._get_hyperopt_params()
                search_algo = HyperOptSearch(space=search_space,
                                             random_state_seed=self.seed,
                                             **search_kwargs)
                run_args["search_alg"] = search_algo

            else:
                # This should not happen as we validate the input before
                # this method. Still, just to be sure, raise an error here.
                raise ValueError(
                    f"Invalid search optimizer: {self.search_optimization}")

        if isinstance(self.n_jobs, int) and self.n_jobs > 0 \
           and not self.search_optimization == "random":
            search_algo = ConcurrencyLimiter(search_algo,
                                             max_concurrent=self.n_jobs)
            run_args["search_alg"] = search_algo

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore",
                                    message="fail_fast='raise' "
                                    "detected.")
            analysis = tune.run(trainable, **run_args)
        return analysis