Esempio n. 1
0
    def test_attach_trial_ttl_seconds(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )
        params, idx = ax_client.attach_trial(
            parameters={"x": 0.0, "y": 1.0}, ttl_seconds=1
        )
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running)
        time.sleep(1)  # Wait for TTL to elapse.
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed)
        # Also make sure we can no longer complete the trial as it is failed.
        with self.assertRaisesRegex(
            ValueError, ".* has been marked FAILED, so it no longer expects data."
        ):
            ax_client.complete_trial(trial_index=idx, raw_data=5)

        params2, idx2 = ax_client.attach_trial(
            parameters={"x": 0.0, "y": 1.0}, ttl_seconds=1
        )
        ax_client.complete_trial(trial_index=idx2, raw_data=5)
        self.assertEqual(ax_client.get_best_parameters()[0], params2)
        self.assertEqual(
            ax_client.get_trial_parameters(trial_index=idx2), {"x": 0, "y": 1}
        )
Esempio n. 2
0
 def test_attach_trial_and_get_trial_parameters(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
     )
     params, idx = ax_client.attach_trial(parameters={"x": 0.0, "y": 1.0})
     ax_client.complete_trial(trial_index=idx, raw_data=5)
     self.assertEqual(ax_client.get_best_parameters()[0], params)
     self.assertEqual(ax_client.get_trial_parameters(trial_index=idx), {
         "x": 0,
         "y": 1
     })
     with self.assertRaises(ValueError):
         ax_client.get_trial_parameters(
             trial_index=10)  # No trial #10 in experiment.
     with self.assertRaisesRegex(ValueError, ".* is of type"):
         ax_client.attach_trial({"x": 1, "y": 2})
Esempio n. 3
0
 def test_find_last_trial_with_parameterization(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
         objective_name="a",
     )
     params, trial_idx = ax_client.get_next_trial()
     found_trial_idx = ax_client._find_last_trial_with_parameterization(
         parameterization=params)
     self.assertEqual(found_trial_idx, trial_idx)
     # Check that it's indeed the _last_ trial with params that is found.
     _, new_trial_idx = ax_client.attach_trial(parameters=params)
     found_trial_idx = ax_client._find_last_trial_with_parameterization(
         parameterization=params)
     self.assertEqual(found_trial_idx, new_trial_idx)
     with self.assertRaisesRegex(ValueError, "No .* matches"):
         found_trial_idx = ax_client._find_last_trial_with_parameterization(
             parameterization={k: v + 1.0
                               for k, v in params.items()})
Esempio n. 4
0
 def test_attach_trial_numpy(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     params, idx = ax_client.attach_trial(parameters={"x1": 0, "x2": 1})
     ax_client.complete_trial(trial_index=idx, raw_data=np.int32(5))
     self.assertEqual(ax_client.get_best_parameters()[0], params)
Esempio n. 5
0
class AxSearch(Searcher):
    """Uses `Ax <https://ax.dev/>`_ to optimize hyperparameters.

    Ax is a platform for understanding, managing, deploying, and
    automating adaptive experiments. Ax provides an easy to use
    interface with BoTorch, a flexible, modern library for Bayesian
    optimization in PyTorch. More information can be found in https://ax.dev/.

    To use this search algorithm, you must install Ax and sqlalchemy:

    .. code-block:: bash

        $ pip install ax-platform sqlalchemy

    Parameters:
        space (list[dict]): Parameters in the experiment search space.
            Required elements in the dictionaries are: "name" (name of
            this parameter, string), "type" (type of the parameter: "range",
            "fixed", or "choice", string), "bounds" for range parameters
            (list of two values, lower bound first), "values" for choice
            parameters (list of values), and "value" for fixed parameters
            (single value).
        metric (str): Name of the metric used as objective in this
            experiment. This metric must be present in `raw_data` argument
            to `log_data`. This metric must also be present in the dict
            reported/returned by the Trainable. If None but a mode was passed,
            the `ray.tune.result.DEFAULT_METRIC` will be used per default.
        mode (str): One of {min, max}. Determines whether objective is
            minimizing or maximizing the metric attribute. Defaults to "max".
        points_to_evaluate (list): Initial parameter suggestions to be run
            first. This is for when you already have some good parameters
            you want to run first to help the algorithm make better suggestions
            for future parameters. Needs to be a list of dicts containing the
            configurations.
        parameter_constraints (list[str]): Parameter constraints, such as
            "x3 >= x4" or "x3 + x4 >= 2".
        outcome_constraints (list[str]): Outcome constraints of form
            "metric_name >= bound", like "m1 <= 3."
        ax_client (AxClient): Optional AxClient instance. If this is set, do
            not pass any values to these parameters: `space`, `metric`,
            `parameter_constraints`, `outcome_constraints`.
        use_early_stopped_trials: Deprecated.
        max_concurrent (int): Deprecated.

    Tune automatically converts search spaces to Ax's format:

    .. code-block:: python

        from ray import tune
        from ray.tune.suggest.ax import AxSearch

        config = {
            "x1": tune.uniform(0.0, 1.0),
            "x2": tune.uniform(0.0, 1.0)
        }

        def easy_objective(config):
            for i in range(100):
                intermediate_result = config["x1"] + config["x2"] * i
                tune.report(score=intermediate_result)

        ax_search = AxSearch(metric="score")
        tune.run(
            config=config,
            easy_objective,
            search_alg=ax_search)

    If you would like to pass the search space manually, the code would
    look like this:

    .. code-block:: python

        from ray import tune
        from ray.tune.suggest.ax import AxSearch

        parameters = [
            {"name": "x1", "type": "range", "bounds": [0.0, 1.0]},
            {"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
        ]

        def easy_objective(config):
            for i in range(100):
                intermediate_result = config["x1"] + config["x2"] * i
                tune.report(score=intermediate_result)

        ax_search = AxSearch(space=parameters, metric="score")
        tune.run(easy_objective, search_alg=ax_search)

    """
    def __init__(self,
                 space: Optional[Union[Dict, List[Dict]]] = None,
                 metric: Optional[str] = None,
                 mode: Optional[str] = None,
                 points_to_evaluate: Optional[List[Dict]] = None,
                 parameter_constraints: Optional[List] = None,
                 outcome_constraints: Optional[List] = None,
                 ax_client: Optional[AxClient] = None,
                 use_early_stopped_trials: Optional[bool] = None,
                 max_concurrent: Optional[int] = None):
        assert ax is not None, """Ax must be installed!
            You can install AxSearch with the command:
            `pip install ax-platform sqlalchemy`."""
        if mode:
            assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."

        super(AxSearch,
              self).__init__(metric=metric,
                             mode=mode,
                             max_concurrent=max_concurrent,
                             use_early_stopped_trials=use_early_stopped_trials)

        self._ax = ax_client

        if isinstance(space, dict) and space:
            resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
            if domain_vars or grid_vars:
                logger.warning(
                    UNRESOLVED_SEARCH_SPACE.format(par="space",
                                                   cls=type(self)))
                space = self.convert_search_space(space)

        self._space = space
        self._parameter_constraints = parameter_constraints
        self._outcome_constraints = outcome_constraints

        self._points_to_evaluate = copy.deepcopy(points_to_evaluate)

        self.max_concurrent = max_concurrent

        self._objective_name = metric
        self._parameters = []
        self._live_trial_mapping = {}

        if self._ax or self._space:
            self._setup_experiment()

    def _setup_experiment(self):
        if self._metric is None and self._mode:
            # If only a mode was passed, use anonymous metric
            self._metric = DEFAULT_METRIC

        if not self._ax:
            self._ax = AxClient()

        try:
            exp = self._ax.experiment
            has_experiment = True
        except ValueError:
            has_experiment = False

        if not has_experiment:
            if not self._space:
                raise ValueError(
                    "You have to create an Ax experiment by calling "
                    "`AxClient.create_experiment()`, or you should pass an "
                    "Ax search space as the `space` parameter to `AxSearch`, "
                    "or pass a `config` dict to `tune.run()`.")
            self._ax.create_experiment(
                parameters=self._space,
                objective_name=self._metric,
                parameter_constraints=self._parameter_constraints,
                outcome_constraints=self._outcome_constraints,
                minimize=self._mode != "max")
        else:
            if any([
                    self._space, self._parameter_constraints,
                    self._outcome_constraints
            ]):
                raise ValueError(
                    "If you create the Ax experiment yourself, do not pass "
                    "values for these parameters to `AxSearch`: {}.".format([
                        "space", "parameter_constraints", "outcome_constraints"
                    ]))

        exp = self._ax.experiment
        self._objective_name = exp.optimization_config.objective.metric.name
        self._parameters = list(exp.parameters)

        if self._ax._enforce_sequential_optimization:
            logger.warning("Detected sequential enforcement. Be sure to use "
                           "a ConcurrencyLimiter.")

    def set_search_properties(self, metric: Optional[str], mode: Optional[str],
                              config: Dict):
        if self._ax:
            return False
        space = self.convert_search_space(config)
        self._space = space
        if metric:
            self._metric = metric
        if mode:
            self._mode = mode

        self._setup_experiment()
        return True

    def suggest(self, trial_id: str) -> Optional[Dict]:
        if not self._ax:
            raise RuntimeError(
                UNDEFINED_SEARCH_SPACE.format(cls=self.__class__.__name__,
                                              space="space"))

        if not self._metric or not self._mode:
            raise RuntimeError(
                UNDEFINED_METRIC_MODE.format(cls=self.__class__.__name__,
                                             metric=self._metric,
                                             mode=self._mode))

        if self.max_concurrent:
            if len(self._live_trial_mapping) >= self.max_concurrent:
                return None

        if self._points_to_evaluate:
            config = self._points_to_evaluate.pop(0)
            parameters, trial_index = self._ax.attach_trial(config)
        else:
            parameters, trial_index = self._ax.get_next_trial()

        self._live_trial_mapping[trial_id] = trial_index
        return unflatten_dict(parameters)

    def on_trial_complete(self, trial_id, result=None, error=False):
        """Notification for the completion of trial.

        Data of form key value dictionary of metric names and values.
        """
        if result:
            self._process_result(trial_id, result)
        self._live_trial_mapping.pop(trial_id)

    def _process_result(self, trial_id, result):
        ax_trial_index = self._live_trial_mapping[trial_id]
        metric_dict = {
            self._objective_name: (result[self._objective_name], 0.0)
        }
        outcome_names = [
            oc.metric.name for oc in
            self._ax.experiment.optimization_config.outcome_constraints
        ]
        metric_dict.update({on: (result[on], 0.0) for on in outcome_names})
        self._ax.complete_trial(trial_index=ax_trial_index,
                                raw_data=metric_dict)

    @staticmethod
    def convert_search_space(spec: Dict):
        spec = flatten_dict(spec, prevent_delimiter=True)
        resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)

        if grid_vars:
            raise ValueError(
                "Grid search parameters cannot be automatically converted "
                "to an Ax search space.")

        def resolve_value(par, domain):
            sampler = domain.get_sampler()
            if isinstance(sampler, Quantized):
                logger.warning("AxSearch does not support quantization. "
                               "Dropped quantization.")
                sampler = sampler.sampler

            if isinstance(domain, Float):
                if isinstance(sampler, LogUniform):
                    return {
                        "name": par,
                        "type": "range",
                        "bounds": [domain.lower, domain.upper],
                        "value_type": "float",
                        "log_scale": True
                    }
                elif isinstance(sampler, Uniform):
                    return {
                        "name": par,
                        "type": "range",
                        "bounds": [domain.lower, domain.upper],
                        "value_type": "float",
                        "log_scale": False
                    }
            elif isinstance(domain, Integer):
                if isinstance(sampler, LogUniform):
                    return {
                        "name": par,
                        "type": "range",
                        "bounds": [domain.lower, domain.upper],
                        "value_type": "int",
                        "log_scale": True
                    }
                elif isinstance(sampler, Uniform):
                    return {
                        "name": par,
                        "type": "range",
                        "bounds": [domain.lower, domain.upper],
                        "value_type": "int",
                        "log_scale": False
                    }
            elif isinstance(domain, Categorical):
                if isinstance(sampler, Uniform):
                    return {
                        "name": par,
                        "type": "choice",
                        "values": domain.categories
                    }

            raise ValueError("AxSearch does not support parameters of type "
                             "`{}` with samplers of type `{}`".format(
                                 type(domain).__name__,
                                 type(domain.sampler).__name__))

        # Fixed vars
        fixed_values = [{
            "name": "/".join(path),
            "type": "fixed",
            "value": val
        } for path, val in resolved_vars]

        # Parameter name is e.g. "a/b/c" for nested dicts
        resolved_values = [
            resolve_value("/".join(path), domain)
            for path, domain in domain_vars
        ]

        return fixed_values + resolved_values
Esempio n. 6
0
class AxSearchJob(AutoSearchJob):
    """Job for hyperparameter search using [ax](https://ax.dev/)."""
    def __init__(self, config: Config, dataset, parent_job=None):
        super().__init__(config, dataset, parent_job)
        self.num_trials = self.config.get("ax_search.num_trials")
        self.num_sobol_trials = self.config.get("ax_search.num_sobol_trials")
        self.ax_client: AxClient = None

        if self.__class__ == AxSearchJob:
            for f in Job.job_created_hooks:
                f(self)

    # Overridden such that instances of search job can be pickled to workers
    def __getstate__(self):
        state = super(AxSearchJob, self).__getstate__()
        del state["ax_client"]
        return state

    def _prepare(self):
        super()._prepare()
        if self.num_sobol_trials > 0:
            # BEGIN: from /ax/service/utils/dispatch.py
            generation_strategy = GenerationStrategy(
                name="Sobol+GPEI",
                steps=[
                    GenerationStep(
                        model=Models.SOBOL,
                        num_trials=self.num_sobol_trials,
                        min_trials_observed=ceil(self.num_sobol_trials / 2),
                        enforce_num_trials=True,
                        model_kwargs={
                            "seed": self.config.get("ax_search.sobol_seed")
                        },
                    ),
                    GenerationStep(model=Models.GPEI,
                                   num_trials=-1,
                                   max_parallelism=3),
                ],
            )
            # END: from /ax/service/utils/dispatch.py

            self.ax_client = AxClient(generation_strategy=generation_strategy)
            choose_generation_strategy_kwargs = dict()
        else:
            self.ax_client = AxClient()
            # set random_seed that will be used by auto created sobol search from ax
            # note that here the argument is called "random_seed" not "seed"
            choose_generation_strategy_kwargs = {
                "random_seed": self.config.get("ax_search.sobol_seed")
            }
        self.ax_client.create_experiment(
            name=self.job_id,
            parameters=self.config.get("ax_search.parameters"),
            objective_name="metric_value",
            minimize=not self.config.get("valid.metric_max"),
            parameter_constraints=self.config.get(
                "ax_search.parameter_constraints"),
            choose_generation_strategy_kwargs=choose_generation_strategy_kwargs,
        )
        self.config.log("ax search initialized with {}".format(
            self.ax_client.generation_strategy))

        # Make sure sobol models are resumed correctly
        if self.ax_client.generation_strategy._curr.model == Models.SOBOL:

            self.ax_client.generation_strategy._set_current_model(
                experiment=self.ax_client.experiment, data=None)

            # Regenerate and drop SOBOL arms already generated. Since we fixed the seed,
            # we will skip exactly the arms already generated in the job being resumed.
            num_generated = len(self.parameters)
            if num_generated > 0:
                num_sobol_generated = min(
                    self.ax_client.generation_strategy._curr.num_trials,
                    num_generated)
                for i in range(num_sobol_generated):
                    generator_run = self.ax_client.generation_strategy.gen(
                        experiment=self.ax_client.experiment)
                    # self.config.log("Skipped parameters: {}".format(generator_run.arms))
                self.config.log(
                    "Skipped {} of {} Sobol trials due to prior data.".format(
                        num_sobol_generated,
                        self.ax_client.generation_strategy._curr.num_trials,
                    ))

    def register_trial(self, parameters=None):
        trial_id = None
        try:
            if parameters is None:
                parameters, trial_id = self.ax_client.get_next_trial()
            else:
                _, trial_id = self.ax_client.attach_trial(parameters)
        except Exception as e:
            self.config.log(
                "Cannot generate trial parameters. Will try again after a " +
                "running trial has completed. message was: {}".format(e))
        return parameters, trial_id

    def register_trial_result(self, trial_id, parameters, trace_entry):
        if trace_entry is None:
            self.ax_client.log_trial_failure(trial_index=trial_id)
        else:
            self.ax_client.complete_trial(trial_index=trial_id,
                                          raw_data=trace_entry["metric_value"])

    def get_best_parameters(self):
        best_parameters, values = self.ax_client.get_best_parameters()
        return best_parameters, float(values[0]["metric_value"])
Esempio n. 7
0
class ParamSpec:
  '''Parameter spec to generate trials.

  This class uses the Adaptive Experimentation Platform (ax.dev)
  under the hood to run BayesOpt for parameter tuning. It relies
  on an SQLite based storage backend for generation and result
  reporting. Each trial is assigned an ID for lookup later.
  '''
  def __init__(self, name: str, db_path: str,
               parameters: Optional[List[TParameterization]] = None,
               objective_name: Optional[str] = None):
    if not os.path.isfile(db_path):
      init_engine_and_session_factory(url=f'sqlite:///{db_path}')
      create_all_tables(get_engine())

    self.name = name
    self.ax = AxClient(enforce_sequential_optimization=False,
                       verbose_logging=False,
                       db_settings=DBSettings(url=f'sqlite:///{db_path}'))

    if self.ax._experiment is None:
      try:
        self.ax.create_experiment(name=name, parameters=parameters,
                                  objective_name=objective_name)
      except ValueError:
        self.ax.load_experiment_from_database(name)

  @property
  def experiment(self):
    return self.ax.experiment

  def generate_trials(self, n: int = 1) \
    -> Generator[Tuple[TParameterization, int], None, None]:
    for _ in range(n):
      yield self.ax.get_next_trial()

  def manual_trials(self, param_lists: Dict[str, List]) \
    -> Generator[Tuple[TParameterization, int], None, None]:
    for trial in exhaust_params(param_lists):
      yield self.ax.attach_trial(trial)

  def complete_trials(self, results: List[dict]):
    '''Complete trials, asynchronous update.

    The input should be the following format.
      [
        {
          "id": <int>,
          "metrics": {
            <str>: <float>
          }
        }
      ]
    '''
    for r in results:
      idx = r.get('id')
      raw_data = {k: (v, 0.0) for k, v in r.get('metrics').items()}
      self.ax.complete_trial(trial_index=idx, raw_data=raw_data)

  def get_trials(self) \
    -> Generator[Tuple[TParameterization, int], None, None]:
    # NOTE(sanyam): Assumes regular Trial with single arm.
    for idx, trial in self.ax.experiment.trials.items():
      for _, arm in trial.arms_by_name.items():
        yield arm.parameters, idx