Exemplo n.º 1
0
 def test_log_failure(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
     )
     _, idx = ax_client.get_next_trial()
     ax_client.log_trial_failure(idx, metadata={"dummy": "test"})
     self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed)
     self.assertEqual(
         ax_client.experiment.trials.get(idx).run_metadata.get("dummy"),
         "test")
     with self.assertRaisesRegex(ValueError, ".* no longer expects"):
         ax_client.complete_trial(idx, {})
Exemplo n.º 2
0
 def test_log_failure(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     _, idx = ax_client.get_next_trial()
     ax_client.log_trial_failure(idx, metadata={"dummy": "test"})
     self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed)
     self.assertEqual(
         ax_client.experiment.trials.get(idx).run_metadata.get("dummy"), "test"
     )
Exemplo n.º 3
0
class AxSearchJob(AutoSearchJob):
    """Job for hyperparameter search using [ax](https://ax.dev/)."""
    def __init__(self, config: Config, dataset, parent_job=None):
        super().__init__(config, dataset, parent_job)
        self.num_trials = self.config.get("ax_search.num_trials")
        self.num_sobol_trials = self.config.get("ax_search.num_sobol_trials")
        self.ax_client: AxClient = None

        if self.__class__ == AxSearchJob:
            for f in Job.job_created_hooks:
                f(self)

    # Overridden such that instances of search job can be pickled to workers
    def __getstate__(self):
        state = super(AxSearchJob, self).__getstate__()
        del state["ax_client"]
        return state

    def _prepare(self):
        super()._prepare()
        if self.num_sobol_trials > 0:
            # BEGIN: from /ax/service/utils/dispatch.py
            generation_strategy = GenerationStrategy(
                name="Sobol+GPEI",
                steps=[
                    GenerationStep(
                        model=Models.SOBOL,
                        num_trials=self.num_sobol_trials,
                        min_trials_observed=ceil(self.num_sobol_trials / 2),
                        enforce_num_trials=True,
                        model_kwargs={
                            "seed": self.config.get("ax_search.sobol_seed")
                        },
                    ),
                    GenerationStep(model=Models.GPEI,
                                   num_trials=-1,
                                   max_parallelism=3),
                ],
            )
            # END: from /ax/service/utils/dispatch.py

            self.ax_client = AxClient(generation_strategy=generation_strategy)
            choose_generation_strategy_kwargs = dict()
        else:
            self.ax_client = AxClient()
            # set random_seed that will be used by auto created sobol search from ax
            # note that here the argument is called "random_seed" not "seed"
            choose_generation_strategy_kwargs = {
                "random_seed": self.config.get("ax_search.sobol_seed")
            }
        self.ax_client.create_experiment(
            name=self.job_id,
            parameters=self.config.get("ax_search.parameters"),
            objective_name="metric_value",
            minimize=not self.config.get("valid.metric_max"),
            parameter_constraints=self.config.get(
                "ax_search.parameter_constraints"),
            choose_generation_strategy_kwargs=choose_generation_strategy_kwargs,
        )
        self.config.log("ax search initialized with {}".format(
            self.ax_client.generation_strategy))

        # Make sure sobol models are resumed correctly
        if self.ax_client.generation_strategy._curr.model == Models.SOBOL:

            self.ax_client.generation_strategy._set_current_model(
                experiment=self.ax_client.experiment, data=None)

            # Regenerate and drop SOBOL arms already generated. Since we fixed the seed,
            # we will skip exactly the arms already generated in the job being resumed.
            num_generated = len(self.parameters)
            if num_generated > 0:
                num_sobol_generated = min(
                    self.ax_client.generation_strategy._curr.num_trials,
                    num_generated)
                for i in range(num_sobol_generated):
                    generator_run = self.ax_client.generation_strategy.gen(
                        experiment=self.ax_client.experiment)
                    # self.config.log("Skipped parameters: {}".format(generator_run.arms))
                self.config.log(
                    "Skipped {} of {} Sobol trials due to prior data.".format(
                        num_sobol_generated,
                        self.ax_client.generation_strategy._curr.num_trials,
                    ))

    def register_trial(self, parameters=None):
        trial_id = None
        try:
            if parameters is None:
                parameters, trial_id = self.ax_client.get_next_trial()
            else:
                _, trial_id = self.ax_client.attach_trial(parameters)
        except Exception as e:
            self.config.log(
                "Cannot generate trial parameters. Will try again after a " +
                "running trial has completed. message was: {}".format(e))
        return parameters, trial_id

    def register_trial_result(self, trial_id, parameters, trace_entry):
        if trace_entry is None:
            self.ax_client.log_trial_failure(trial_index=trial_id)
        else:
            self.ax_client.complete_trial(trial_index=trial_id,
                                          raw_data=trace_entry["metric_value"])

    def get_best_parameters(self):
        best_parameters, values = self.ax_client.get_best_parameters()
        return best_parameters, float(values[0]["metric_value"])
Exemplo n.º 4
0
    opts.max_conv_size = parameters['max_conv_size']
    opts.dense_kernel_size = parameters['dense_kernel_size']
    opts.batch_size = 64  # parameters['batch_size']
    opts.learning_rate = parameters['learning_rate']
    opts.epochs = cmd_line_opts.epochs  # max to run, we also use early stopping

    # run
    start_time = time.time()
    # final_loss = train.train_in_subprocess(opts)
    final_loss = train.train(opts)
    log_record.append(time.time() - start_time)
    log_record.append(final_loss)

    # complete trial
    if final_loss is None:
        print("ax trial", trial_index, "failed?")
        ax.log_trial_failure(trial_index=trial_index)
    else:
        ax.complete_trial(trial_index=trial_index,
                          raw_data={'final_loss': (final_loss, 0)})
    print("CURRENT_BEST", ax.get_best_parameters())

    # flush log
    log_msg = "\t".join(map(str, log_record))
    print(log_msg, file=log)
    print(log_msg)
    log.flush()

    # save ax state
    ax.save_to_json_file()