Ejemplo n.º 1
0
 def test_fail_on_batch(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     batch_trial = ax_client.experiment.new_batch_trial(
         generator_run=GeneratorRun(
             arms=[
                 Arm(parameters={"x": 0, "y": 1}),
                 Arm(parameters={"x": 0, "y": 1}),
             ]
         )
     )
     with self.assertRaises(NotImplementedError):
         ax_client.complete_trial(batch_trial.index, 0)
Ejemplo n.º 2
0
 def test_raw_data_format_with_fidelities(self):
     ax = AxClient()
     ax.create_experiment(
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
         ],
         minimize=True,
     )
     for _ in range(6):
         parameterization, trial_index = ax.get_next_trial()
         x1, x2 = parameterization.get("x1"), parameterization.get("x2")
         ax.complete_trial(
             trial_index,
             raw_data=[
                 ({"x2": x2 / 2.0}, {"objective": (branin(x1, x2 / 2.0), 0.0)}),
                 ({"x2": x2}, {"objective": (branin(x1, x2), 0.0)}),
             ],
         )
Ejemplo n.º 3
0
    def sweep_over_batches(
        self,
        ax_client: AxClient,
        batch_of_trials: BatchOfTrialType,
    ) -> None:
        assert self.launcher is not None
        assert self.job_idx is not None

        chunked_batches = self.chunks(batch_of_trials, self.max_batch_size)
        for batch in chunked_batches:
            overrides = [x.overrides for x in batch]
            self.validate_batch_is_legal(overrides)
            rets = self.launcher.launch(job_overrides=overrides,
                                        initial_job_idx=self.job_idx)
            self.job_idx += len(rets)
            for idx in range(len(batch)):
                val = rets[idx].return_value
                ax_client.complete_trial(trial_index=batch[idx].trial_index,
                                         raw_data=val)
Ejemplo n.º 4
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
     gs = ax_client.generation_strategy
     ax_client = AxClient(db_settings=db_settings)
     ax_client.load_experiment_from_database("test_experiment")
     self.assertEqual(gs, ax_client.generation_strategy)
     with self.assertRaises(ValueError):
         # Overwriting existing experiment.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[
                 {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
                 {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
             ],
             minimize=True,
         )
     # Overwriting existing experiment with overwrite flag.
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}],
         overwrite_existing_experiment=True,
     )
     # There should be no trials, as we just put in a fresh experiment.
     self.assertEqual(len(ax_client.experiment.trials), 0)
Ejemplo n.º 5
0
 def test_fixed_random_seed_reproducibility(self):
     ax_client = AxClient(random_seed=239)
     ax_client.create_experiment(parameters=[
         {
             "name": "x",
             "type": "range",
             "bounds": [-5.0, 10.0]
         },
         {
             "name": "y",
             "type": "range",
             "bounds": [0.0, 15.0]
         },
     ])
     for _ in range(5):
         params, idx = ax_client.get_next_trial()
         ax_client.complete_trial(idx,
                                  branin(params.get("x"), params.get("y")))
     trial_parameters_1 = [
         t.arm.parameters for t in ax_client.experiment.trials.values()
     ]
     ax_client = AxClient(random_seed=239)
     ax_client.create_experiment(parameters=[
         {
             "name": "x",
             "type": "range",
             "bounds": [-5.0, 10.0]
         },
         {
             "name": "y",
             "type": "range",
             "bounds": [0.0, 15.0]
         },
     ])
     for _ in range(5):
         params, idx = ax_client.get_next_trial()
         ax_client.complete_trial(idx,
                                  branin(params.get("x"), params.get("y")))
     trial_parameters_2 = [
         t.arm.parameters for t in ax_client.experiment.trials.values()
     ]
     self.assertEqual(trial_parameters_1, trial_parameters_2)
Ejemplo n.º 6
0
    def test_attach_trial_ttl_seconds(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )
        params, idx = ax_client.attach_trial(parameters={
            "x": 0.0,
            "y": 1.0
        },
                                             ttl_seconds=1)
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running)
        time.sleep(1)  # Wait for TTL to elapse.
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed)
        # Also make sure we can no longer complete the trial as it is failed.
        with self.assertRaisesRegex(
                ValueError,
                ".* has been marked FAILED, so it no longer expects data."):
            ax_client.complete_trial(trial_index=idx, raw_data=5)

        params2, idx2 = ax_client.attach_trial(parameters={
            "x": 0.0,
            "y": 1.0
        },
                                               ttl_seconds=1)
        ax_client.complete_trial(trial_index=idx2, raw_data=5)
        self.assertEqual(ax_client.get_best_parameters()[0], params2)
        self.assertEqual(ax_client.get_trial_parameters(trial_index=idx2), {
            "x": 0,
            "y": 1
        })
Ejemplo n.º 7
0
 def test_default_generation_strategy_continuous(self, _a, _b, _c, _d) -> None:
     """Test that Sobol+GPEI is used if no GenerationStrategy is provided."""
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[  # pyre-fixme[6]: expected union that should include
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
         ],
         objective_name="a",
         minimize=True,
     )
     self.assertEqual(
         [s.model for s in not_none(ax_client.generation_strategy)._steps],
         [Models.SOBOL, Models.GPEI],
     )
     with self.assertRaisesRegex(ValueError, ".* no trials"):
         ax_client.get_optimization_trace(objective_optimum=branin.fmin)
     for i in range(6):
         parameterization, trial_index = ax_client.get_next_trial()
         x, y = parameterization.get("x"), parameterization.get("y")
         ax_client.complete_trial(
             trial_index,
             raw_data={
                 "a": (
                     checked_cast(
                         float,
                         branin(checked_cast(float, x), checked_cast(float, y)),
                     ),
                     0.0,
                 )
             },
             sample_size=i,
         )
     self.assertEqual(ax_client.generation_strategy.model._model_key, "GPEI")
     ax_client.get_optimization_trace(objective_optimum=branin.fmin)
     ax_client.get_contour_plot()
     ax_client.get_feature_importances()
     trials_df = ax_client.get_trials_data_frame()
     self.assertIn("x", trials_df)
     self.assertIn("y", trials_df)
     self.assertIn("a", trials_df)
     self.assertEqual(len(trials_df), 6)
Ejemplo n.º 8
0
 def test_attach_trial(self):
     ax = AxClient()
     ax.create_experiment(
         parameters=[
             {
                 "name": "x1",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "x2",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
     )
     params, idx = ax.attach_trial(parameters={"x1": 0, "x2": 1})
     ax.complete_trial(trial_index=idx, raw_data=5)
     self.assertEqual(ax.get_best_parameters()[0], params)
Ejemplo n.º 9
0
 def test_attach_trial_numpy(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
     )
     params, idx = ax_client.attach_trial(parameters={"x": 0.0, "y": 1.0})
     ax_client.complete_trial(trial_index=idx, raw_data=np.int32(5))
     self.assertEqual(ax_client.get_best_parameters()[0], params)
Ejemplo n.º 10
0
 def test_raw_data_format_with_map_results(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 1.0]},
         ],
         minimize=True,
         support_intermediate_data=True,
     )
     for _ in range(6):
         parameterization, trial_index = ax_client.get_next_trial()
         x, y = parameterization.get("x"), parameterization.get("y")
         ax_client.complete_trial(
             trial_index,
             raw_data=[
                 ({"y": y / 2.0}, {"objective": (branin(x, y / 2.0), 0.0)}),
                 ({"y": y}, {"objective": (branin(x, y), 0.0)}),
             ],
         )
Ejemplo n.º 11
0
 def test_start_and_end_time_in_trial_completion(self):
     start_time = current_timestamp_in_millis()
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     params, idx = ax_client.get_next_trial()
     ax_client.complete_trial(
         trial_index=idx,
         raw_data=1.0,
         metadata={
             "start_time": start_time,
             "end_time": current_timestamp_in_millis(),
         },
     )
     dat = ax_client.experiment.fetch_data().df
     self.assertGreater(dat["end_time"][0], dat["start_time"][0])
Ejemplo n.º 12
0
 def test_attach_trial_and_get_trial_parameters(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     params, idx = ax_client.attach_trial(parameters={"x": 0.0, "y": 1.0})
     ax_client.complete_trial(trial_index=idx, raw_data=5)
     self.assertEqual(ax_client.get_best_parameters()[0], params)
     self.assertEqual(
         ax_client.get_trial_parameters(trial_index=idx), {"x": 0, "y": 1}
     )
     with self.assertRaises(ValueError):
         ax_client.get_trial_parameters(
             trial_index=10
         )  # No trial #10 in experiment.
     with self.assertRaisesRegex(ValueError, ".* is of type"):
         ax_client.attach_trial({"x": 1, "y": 2})
Ejemplo n.º 13
0
    def test_abandon_trial(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )

        # An abandoned trial adds no data.
        params, idx = ax_client.get_next_trial()
        ax_client.abandon_trial(trial_index=idx)
        data = ax_client.experiment.fetch_data()
        self.assertEqual(len(data.df.index), 0)

        # Can't update a completed trial.
        params2, idx2 = ax_client.get_next_trial()
        ax_client.complete_trial(trial_index=idx2, raw_data={"objective": (0, 0.0)})
        with self.assertRaisesRegex(ValueError, ".* in a terminal state."):
            ax_client.abandon_trial(trial_index=idx2)
Ejemplo n.º 14
0
 def test_trial_completion(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     params, idx = ax_client.get_next_trial()
     ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)})
     self.assertEqual(ax_client.get_best_parameters()[0], params)
     params2, idx2 = ax_client.get_next_trial()
     ax_client.complete_trial(trial_index=idx2, raw_data=(-1, 0.0))
     self.assertEqual(ax_client.get_best_parameters()[0], params2)
     params3, idx3 = ax_client.get_next_trial()
     ax_client.complete_trial(
         trial_index=idx3, raw_data=-2, metadata={"dummy": "test"}
     )
     self.assertEqual(ax_client.get_best_parameters()[0], params3)
     self.assertEqual(
         ax_client.experiment.trials.get(2).run_metadata.get("dummy"), "test"
     )
     best_trial_values = ax_client.get_best_parameters()[1]
     self.assertEqual(best_trial_values[0], {"objective": -2.0})
     self.assertTrue(math.isnan(best_trial_values[1]["objective"]["objective"]))
Ejemplo n.º 15
0
 def test_interruption(self) -> None:
     ax_client = AxClient()
     ax_client.create_experiment(
         name="test",
         parameters=[  # pyre-fixme[6]: expected union that should include
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         objective_name="branin",
         minimize=True,
     )
     for i in range(6):
         parameterization, trial_index = ax_client.get_next_trial()
         self.assertFalse(  # There should be non-complete trials.
             all(t.status.is_terminal
                 for t in ax_client.experiment.trials.values()))
         x, y = parameterization.get("x"), parameterization.get("y")
         ax_client.complete_trial(
             trial_index,
             raw_data=checked_cast(
                 float,
                 branin(checked_cast(float, x), checked_cast(float, y))),
         )
         old_client = ax_client
         serialized = ax_client.to_json_snapshot()
         ax_client = AxClient.from_json_snapshot(serialized)
         self.assertEqual(len(ax_client.experiment.trials.keys()), i + 1)
         self.assertIsNot(ax_client, old_client)
         self.assertTrue(  # There should be no non-complete trials.
             all(t.status.is_terminal
                 for t in ax_client.experiment.trials.values()))
Ejemplo n.º 16
0
def run_trials_using_recommended_parallelism(
    ax_client: AxClient,
    recommended_parallelism: List[Tuple[int, int]],
    total_trials: int,
) -> int:
    remaining_trials = total_trials
    for num_trials, parallelism_setting in recommended_parallelism:
        if num_trials == -1:
            num_trials = remaining_trials
        for _ in range(ceil(num_trials / parallelism_setting)):
            in_flight_trials = []
            if parallelism_setting > remaining_trials:
                parallelism_setting = remaining_trials
            for _ in range(parallelism_setting):
                params, idx = ax_client.get_next_trial()
                in_flight_trials.append((params, idx))
                remaining_trials -= 1
            for _ in range(parallelism_setting):
                params, idx = in_flight_trials.pop()
                ax_client.complete_trial(idx, branin(params["x"], params["y"]))
    # If all went well and no errors were raised, remaining_trials should be 0.
    return remaining_trials
Ejemplo n.º 17
0
 def test_storage_error_handling(self, mock_save_fails):
     """Check that if `suppress_storage_errors` is True, AxClient won't
     visibly fail if encountered storage errors.
     """
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings, suppress_storage_errors=True)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(3):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
Ejemplo n.º 18
0
 def test_init_position_saved(self):
     ax_client = AxClient(random_seed=239)
     ax_client.create_experiment(
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         name="sobol_init_position_test",
     )
     for _ in range(4):
         # For each generated trial, snapshot the client before generating it,
         # then recreate client, regenerate the trial and compare the trial
         # generated before and after snapshotting. If the state of Sobol is
         # recorded correctly, the newly generated trial will be the same as
         # the one generated before the snapshotting.
         serialized = ax_client.to_json_snapshot()
         params, idx = ax_client.get_next_trial()
         ax_client = AxClient.from_json_snapshot(serialized)
         with self.subTest(ax=ax_client, params=params, idx=idx):
             new_params, new_idx = ax_client.get_next_trial()
             self.assertEqual(params, new_params)
             self.assertEqual(idx, new_idx)
             self.assertEqual(
                 ax_client.experiment.trials[idx]._generator_run.
                 _model_state_after_gen["init_position"],
                 idx + 1,
             )
         ax_client.complete_trial(idx,
                                  branin(params.get("x"), params.get("y")))
Ejemplo n.º 19
0
    def test_ttl_trial(self):
        ax_client = AxClient()
        ax_client.create_experiment(
            parameters=[
                {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
                {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
            ],
            minimize=True,
        )

        # A ttl trial that ends adds no data.
        params, idx = ax_client.get_next_trial(ttl_seconds=1)
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running)
        time.sleep(1)  # Wait for TTL to elapse.
        self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed)
        # Also make sure we can no longer complete the trial as it is failed.
        with self.assertRaisesRegex(
            ValueError, ".* has been marked FAILED, so it no longer expects data."
        ):
            ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)})

        params2, idy = ax_client.get_next_trial(ttl_seconds=1)
        ax_client.complete_trial(trial_index=idy, raw_data=(-1, 0.0))
        self.assertEqual(ax_client.get_best_parameters()[0], params2)
Ejemplo n.º 20
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax = AxClient(db_settings=db_settings)
     ax.create_experiment(
         name="test_experiment",
         parameters=[
             {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax.get_next_trial()
         ax.complete_trial(
             trial_index=trial_index, raw_data=branin(*parameters.values())
         )
     gs = ax.generation_strategy
     ax = AxClient(db_settings=db_settings)
     ax.load_experiment_from_database("test_experiment")
     self.assertEqual(gs, ax.generation_strategy)
Ejemplo n.º 21
0
 def test_raw_data_format(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
     )
     for _ in range(6):
         parameterization, trial_index = ax_client.get_next_trial()
         x, y = parameterization.get("x"), parameterization.get("y")
         ax_client.complete_trial(trial_index, raw_data=(branin(x, y), 0.0))
     with self.assertRaisesRegex(ValueError,
                                 "Raw data has an invalid type"):
         ax_client.update_trial_data(trial_index, raw_data="invalid_data")
Ejemplo n.º 22
0
 def test_trial_completion(self):
     ax_client = AxClient()
     ax_client.create_experiment(
         parameters=[
             {"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
             {"name": "y", "type": "range", "bounds": [0.0, 15.0]},
         ],
         minimize=True,
     )
     params, idx = ax_client.get_next_trial()
     # Can't update before completing.
     with self.assertRaisesRegex(ValueError, ".* not yet"):
         ax_client.update_trial_data(
             trial_index=idx, raw_data={"objective": (0, 0.0)}
         )
     ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)})
     # Cannot complete a trial twice, should use `update_trial_data`.
     with self.assertRaisesRegex(ValueError, ".* already been completed"):
         ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)})
     # Cannot update trial data with observation for a metric it already has.
     with self.assertRaisesRegex(ValueError, ".* contained an observation"):
         ax_client.update_trial_data(
             trial_index=idx, raw_data={"objective": (0, 0.0)}
         )
     # Same as above, except objective name should be getting inferred.
     with self.assertRaisesRegex(ValueError, ".* contained an observation"):
         ax_client.update_trial_data(trial_index=idx, raw_data=1.0)
     ax_client.update_trial_data(trial_index=idx, raw_data={"m1": (1, 0.0)})
     metrics_in_data = ax_client.experiment.fetch_data().df["metric_name"].values
     self.assertIn("m1", metrics_in_data)
     self.assertIn("objective", metrics_in_data)
     self.assertEqual(ax_client.get_best_parameters()[0], params)
     params2, idy = ax_client.get_next_trial()
     ax_client.complete_trial(trial_index=idy, raw_data=(-1, 0.0))
     self.assertEqual(ax_client.get_best_parameters()[0], params2)
     params3, idx3 = ax_client.get_next_trial()
     ax_client.complete_trial(
         trial_index=idx3, raw_data=-2, metadata={"dummy": "test"}
     )
     self.assertEqual(ax_client.get_best_parameters()[0], params3)
     self.assertEqual(
         ax_client.experiment.trials.get(2).run_metadata.get("dummy"), "test"
     )
     best_trial_values = ax_client.get_best_parameters()[1]
     self.assertEqual(best_trial_values[0], {"objective": -2.0})
     self.assertTrue(math.isnan(best_trial_values[1]["objective"]["objective"]))
Ejemplo n.º 23
0
def _benchmark_replication_Service_API(
    problem: SimpleBenchmarkProblem,
    method: GenerationStrategy,
    num_trials: int,
    experiment_name: str,
    batch_size: int = 1,
    raise_all_exceptions: bool = False,
    benchmark_trial: FunctionType = benchmark_trial,
    verbose_logging: bool = True,
    # Number of trials that need to fail for a replication to be considered failed.
    failed_trials_tolerated: int = 5,
    async_benchmark_options: Optional[AsyncBenchmarkOptions] = None,
) -> Tuple[Experiment, List[Exception]]:
    """Run a benchmark replication via the Service API because the problem was
    set up in a simplified way, without the use of Ax classes like `OptimizationConfig`
    or `SearchSpace`.
    """
    if async_benchmark_options is not None:
        raise NonRetryableBenchmarkingError(
            "`async_benchmark_options` not supported when using the Service API."
        )

    exceptions = []
    if batch_size == 1:
        ax_client = AxClient(generation_strategy=method,
                             verbose_logging=verbose_logging)
    else:  # pragma: no cover, TODO[T53975770]
        assert batch_size > 1, "Batch size of 1 or greater is expected."
        raise NotImplementedError(
            "Batched benchmarking on `SimpleBenchmarkProblem`-s not yet implemented."
        )
    ax_client.create_experiment(
        name=experiment_name,
        parameters=problem.domain_as_ax_client_parameters(),
        minimize=problem.minimize,
        objective_name=problem.name,
    )
    parameter_names = list(ax_client.experiment.search_space.parameters.keys())
    assert num_trials > 0
    for _ in range(num_trials):
        parameterization, idx = ax_client.get_next_trial()
        param_values = np.array(
            [parameterization.get(x) for x in parameter_names])
        try:
            mean, sem = benchmark_trial(parameterization=param_values,
                                        evaluation_function=problem.f)
            # If problem indicates a noise level and is using a synthetic callable,
            # add normal noise to the measurement of the mean.
            if problem.uses_synthetic_function and problem.noise_sd != 0.0:
                noise = np.random.randn() * problem.noise_sd
                sem = (sem or 0.0) + problem.noise_sd
                logger.info(
                    f"Adding noise of {noise} to the measurement mean ({mean})."
                    f"Problem noise SD setting: {problem.noise_sd}.")
                mean = mean + noise
            ax_client.complete_trial(trial_index=idx, raw_data=(mean, sem))
        except Exception as err:  # TODO[T53975770]: test
            if raise_all_exceptions:
                raise
            exceptions.append(err)
        if len(exceptions) > failed_trials_tolerated:
            raise RuntimeError(  # TODO[T53975770]: test
                f"More than {failed_trials_tolerated} failed for {experiment_name}."
            )
    return ax_client.experiment, exceptions
Ejemplo n.º 24
0
    def ml_run(self, run_id=None):
        seed_randomness(self.random_seed)

        mlflow.log_params(flatten(get_params_of_task(self)))

        total_training_time = 0

        # should land to 'optimizer_props'
        params_space = [
            {
                'name': 'lr',
                'type': 'range',
                'bounds': [1e-6, 0.008],
                # 'value_type': 'float',
                'log_scale': True,
            },
            {
                'name': 'beta_1',
                'type': 'range',
                'bounds': [.0, 0.9999],
                'value_type': 'float',
                # 'log_scale': True,
            },
            {
                'name': 'beta_2',
                'type': 'range',
                'bounds': [.0, 0.9999],
                'value_type': 'float',
                # 'log_scale': True,
            }
        ]

        # TODO: make reproducibility of search
        # without it we will get each time new params

        # for example we can use:
        # ax.storage.sqa_store.structs.DBSettings
        # DBSettings(url="sqlite://<path-to-file>")
        # to store experiments
        ax = AxClient(
            # can't use that feature yet.
            # got error
            # NotImplementedError:
            # Saving and loading experiment in `AxClient` functionality currently under development.
            # db_settings=DBSettings(url=self.output()['ax_settings'].path)
        )

        # FIXME: temporal solution while ax doesn't have api to (re-)store state

        class_name = get_class_name_as_snake(self)
        ax.create_experiment(
            name=f'{class_name}_experiment',
            parameters=params_space,
            objective_name='score',
            minimize=should_minimize(self.metric),
            # parameter_constraints=['x1 + x2 <= 2.0'],  # Optional.
            # outcome_constraints=['l2norm <= 1.25'],  # Optional.
        )

        trial_index = 0
        experiment = self._get_ax_experiment()
        if experiment:
            print('AX: restore experiment')
            print('AX: num_trials:', experiment.num_trials)
            ax._experiment = experiment
            trial_index = experiment.num_trials - 1

        model_task = get_model_task_by_name(self.model_name)

        while trial_index < self.max_runs:
            print(f'AX: Running trial {trial_index + 1}/{self.max_runs}...')

            # get last unfinished trial
            parameters = get_last_unfinished_params(ax)

            if parameters is None:
                print('AX: generate new Trial')
                parameters, trial_index = ax.get_next_trial()

                # good time to store experiment (with new Trial)
                with self.output()['ax_experiment'].open('w') as f:
                    print('AX: store experiment: ', ax.experiment)
                    pickle.dump(ax.experiment, f)

            print('AX: parameters', parameters)

            # now is time to evaluate model
            model_result = yield model_task(
                parent_run_id=run_id,
                random_seed=self.random_seed,
                # TODO: actually we should be able to pass even nested params
                # **parameters,
                optimizer_props=parameters)

            # TODO: store run_id in Trial
            model_run_id = self.get_run_id_from_result(model_result)

            with model_result['metrics'].open('r') as f:
                model_metrics = yaml.load(f)
                model_score_mean = model_metrics[self.metric]['val']
                # TODO: we might know it :/
                model_score_error = 0.0
                total_training_time += model_metrics['train_time']['total']

            with model_result['params'].open('r') as f:
                model_params = yaml.load(f)

            print('AX: complete trial:', trial_index)

            ax.complete_trial(
                trial_index=trial_index,
                raw_data={'score': (model_score_mean, model_score_error)},
                metadata={
                    'metrics': model_metrics,
                    'params': model_params,
                    'run_id': model_run_id,
                })

        best_parameters, _ = ax.get_best_parameters()

        mlflow.log_metric('train_time.total', total_training_time)

        print('best params', best_parameters)

        best_trial = get_best_trial(experiment, self.metric)

        mlflow.log_metrics(flatten(best_trial.run_metadata['metrics']))
        mlflow.log_params(flatten(best_trial.run_metadata['params']))
Ejemplo n.º 25
0
class IterativePrune:
    def __init__(self):
        self.parser_args = None
        self.ax_client = None
        self.base_model_path = "base_model"
        self.pruning_amount = None

    def run_mnist_model(self, base=False):
        parser_dict = vars(self.parser_args)
        if base:
            mlflow.start_run(run_name="BaseModel")
        mlflow.pytorch.autolog()
        dm = MNISTDataModule(**parser_dict)
        dm.setup(stage="fit")

        model = LightningMNISTClassifier(**parser_dict)
        trainer = pl.Trainer.from_argparse_args(self.parser_args)
        trainer.fit(model, dm)
        trainer.test()
        if os.path.exists(self.base_model_path):
            shutil.rmtree(self.base_model_path)
        mlflow.pytorch.save_model(trainer.get_model(), self.base_model_path)
        return trainer

    def load_base_model(self):
        path = Path(_download_artifact_from_uri(self.base_model_path))
        model_file_path = os.path.join(path, "data/model.pth")
        return torch.load(model_file_path)

    def initialize_ax_client(self):
        self.ax_client = AxClient()
        self.ax_client.create_experiment(
            parameters=[{
                "name": "amount",
                "type": "range",
                "bounds": [0.05, 0.15],
                "value_type": "float"
            }],
            objective_name="test_accuracy",
        )

    @staticmethod
    def prune_and_save_model(model, amount):

        for _, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d) or isinstance(
                    module, torch.nn.Linear):
                prune.l1_unstructured(module, name="weight", amount=amount)
                prune.remove(module, "weight")

        mlflow.pytorch.save_state_dict(model.state_dict(), ".")
        model = torch.load("state_dict.pth")
        os.remove("state_dict.pth")
        return model

    @staticmethod
    def count_model_parameters(model):
        table = PrettyTable(["Modules", "Parameters"])
        total_params = 0

        for name, parameter in model.named_parameters():

            if not parameter.requires_grad:
                continue

            param = parameter.nonzero(as_tuple=False).size(0)
            table.add_row([name, param])
            total_params += param
        return table, total_params

    @staticmethod
    def write_prune_summary(summary, params):
        tempdir = tempfile.mkdtemp()
        try:
            summary_file = os.path.join(tempdir, "pruned_model_summary.txt")
            params = "Total Trainable Parameters :" + str(params)
            with open(summary_file, "w") as f:
                f.write(str(summary))
                f.write("\n")
                f.write(str(params))

            try_mlflow_log(mlflow.log_artifact, local_path=summary_file)
        finally:
            shutil.rmtree(tempdir)

    def iterative_prune(self, model, parametrization):
        if not self.pruning_amount:
            self.pruning_amount = parametrization.get("amount")
        else:
            self.pruning_amount += 0.15

        mlflow.log_metric("PRUNING PERCENTAGE", self.pruning_amount)
        pruned_model = self.prune_and_save_model(model, self.pruning_amount)
        model.load_state_dict(copy.deepcopy(pruned_model))
        summary, params = self.count_model_parameters(model)
        self.write_prune_summary(summary, params)
        trainer = self.run_mnist_model()
        metrics = trainer.callback_metrics
        test_accuracy = metrics.get("avg_test_acc")
        return test_accuracy

    def initiate_pruning_process(self, model):
        total_trials = int(vars(self.parser_args)["total_trials"])

        trial_index = None
        for i in range(total_trials):
            parameters, trial_index = self.ax_client.get_next_trial()
            print(
                "***************************************************************************"
            )
            print("Running Trial {}".format(i + 1))
            print(
                "***************************************************************************"
            )
            with mlflow.start_run(nested=True, run_name="Iteration" + str(i)):
                mlflow.set_tags({"AX_TRIAL": i})

                # calling the model
                test_accuracy = self.iterative_prune(model, parameters)

                # completion of trial
        self.ax_client.complete_trial(trial_index=trial_index,
                                      raw_data=test_accuracy.item())

        # Ending the Base run
        mlflow.end_run()

    def get_parser_args(self):
        parser = argparse.ArgumentParser()
        parser = pl.Trainer.add_argparse_args(parent_parser=parser)
        parser = LightningMNISTClassifier.add_model_specific_args(
            parent_parser=parser)

        parser.add_argument(
            "--total_trials",
            default=3,
            help=
            "Number of AX trials to be run for the optimization experiment",
        )

        self.parser_args = parser.parse_args()
Ejemplo n.º 26
0
class AxSearch(Searcher):
    """Uses `Ax <https://ax.dev/>`_ to optimize hyperparameters.

    Ax is a platform for understanding, managing, deploying, and
    automating adaptive experiments. Ax provides an easy to use
    interface with BoTorch, a flexible, modern library for Bayesian
    optimization in PyTorch. More information can be found in https://ax.dev/.

    To use this search algorithm, you must install Ax and sqlalchemy:

    .. code-block:: bash

        $ pip install ax-platform sqlalchemy

    Parameters:
        space (list[dict]): Parameters in the experiment search space.
            Required elements in the dictionaries are: "name" (name of
            this parameter, string), "type" (type of the parameter: "range",
            "fixed", or "choice", string), "bounds" for range parameters
            (list of two values, lower bound first), "values" for choice
            parameters (list of values), and "value" for fixed parameters
            (single value).
        metric (str): Name of the metric used as objective in this
            experiment. This metric must be present in `raw_data` argument
            to `log_data`. This metric must also be present in the dict
            reported/returned by the Trainable. If None but a mode was passed,
            the `ray.tune.result.DEFAULT_METRIC` will be used per default.
        mode (str): One of {min, max}. Determines whether objective is
            minimizing or maximizing the metric attribute. Defaults to "max".
        points_to_evaluate (list): Initial parameter suggestions to be run
            first. This is for when you already have some good parameters
            you want to run first to help the algorithm make better suggestions
            for future parameters. Needs to be a list of dicts containing the
            configurations.
        parameter_constraints (list[str]): Parameter constraints, such as
            "x3 >= x4" or "x3 + x4 >= 2".
        outcome_constraints (list[str]): Outcome constraints of form
            "metric_name >= bound", like "m1 <= 3."
        ax_client (AxClient): Optional AxClient instance. If this is set, do
            not pass any values to these parameters: `space`, `metric`,
            `parameter_constraints`, `outcome_constraints`.
        use_early_stopped_trials: Deprecated.
        max_concurrent (int): Deprecated.

    Tune automatically converts search spaces to Ax's format:

    .. code-block:: python

        from ray import tune
        from ray.tune.suggest.ax import AxSearch

        config = {
            "x1": tune.uniform(0.0, 1.0),
            "x2": tune.uniform(0.0, 1.0)
        }

        def easy_objective(config):
            for i in range(100):
                intermediate_result = config["x1"] + config["x2"] * i
                tune.report(score=intermediate_result)

        ax_search = AxSearch(metric="score")
        tune.run(
            config=config,
            easy_objective,
            search_alg=ax_search)

    If you would like to pass the search space manually, the code would
    look like this:

    .. code-block:: python

        from ray import tune
        from ray.tune.suggest.ax import AxSearch

        parameters = [
            {"name": "x1", "type": "range", "bounds": [0.0, 1.0]},
            {"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
        ]

        def easy_objective(config):
            for i in range(100):
                intermediate_result = config["x1"] + config["x2"] * i
                tune.report(score=intermediate_result)

        ax_search = AxSearch(space=parameters, metric="score")
        tune.run(easy_objective, search_alg=ax_search)

    """
    def __init__(self,
                 space: Optional[Union[Dict, List[Dict]]] = None,
                 metric: Optional[str] = None,
                 mode: Optional[str] = None,
                 points_to_evaluate: Optional[List[Dict]] = None,
                 parameter_constraints: Optional[List] = None,
                 outcome_constraints: Optional[List] = None,
                 ax_client: Optional[AxClient] = None,
                 use_early_stopped_trials: Optional[bool] = None,
                 max_concurrent: Optional[int] = None):
        assert ax is not None, """Ax must be installed!
            You can install AxSearch with the command:
            `pip install ax-platform sqlalchemy`."""
        if mode:
            assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."

        super(AxSearch,
              self).__init__(metric=metric,
                             mode=mode,
                             max_concurrent=max_concurrent,
                             use_early_stopped_trials=use_early_stopped_trials)

        self._ax = ax_client

        if isinstance(space, dict) and space:
            resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
            if domain_vars or grid_vars:
                logger.warning(
                    UNRESOLVED_SEARCH_SPACE.format(par="space",
                                                   cls=type(self)))
                space = self.convert_search_space(space)

        self._space = space
        self._parameter_constraints = parameter_constraints
        self._outcome_constraints = outcome_constraints

        self._points_to_evaluate = copy.deepcopy(points_to_evaluate)

        self.max_concurrent = max_concurrent

        self._objective_name = metric
        self._parameters = []
        self._live_trial_mapping = {}

        if self._ax or self._space:
            self._setup_experiment()

    def _setup_experiment(self):
        if self._metric is None and self._mode:
            # If only a mode was passed, use anonymous metric
            self._metric = DEFAULT_METRIC

        if not self._ax:
            self._ax = AxClient()

        try:
            exp = self._ax.experiment
            has_experiment = True
        except ValueError:
            has_experiment = False

        if not has_experiment:
            if not self._space:
                raise ValueError(
                    "You have to create an Ax experiment by calling "
                    "`AxClient.create_experiment()`, or you should pass an "
                    "Ax search space as the `space` parameter to `AxSearch`, "
                    "or pass a `config` dict to `tune.run()`.")
            self._ax.create_experiment(
                parameters=self._space,
                objective_name=self._metric,
                parameter_constraints=self._parameter_constraints,
                outcome_constraints=self._outcome_constraints,
                minimize=self._mode != "max")
        else:
            if any([
                    self._space, self._parameter_constraints,
                    self._outcome_constraints
            ]):
                raise ValueError(
                    "If you create the Ax experiment yourself, do not pass "
                    "values for these parameters to `AxSearch`: {}.".format([
                        "space", "parameter_constraints", "outcome_constraints"
                    ]))

        exp = self._ax.experiment
        self._objective_name = exp.optimization_config.objective.metric.name
        self._parameters = list(exp.parameters)

        if self._ax._enforce_sequential_optimization:
            logger.warning("Detected sequential enforcement. Be sure to use "
                           "a ConcurrencyLimiter.")

    def set_search_properties(self, metric: Optional[str], mode: Optional[str],
                              config: Dict):
        if self._ax:
            return False
        space = self.convert_search_space(config)
        self._space = space
        if metric:
            self._metric = metric
        if mode:
            self._mode = mode

        self._setup_experiment()
        return True

    def suggest(self, trial_id: str) -> Optional[Dict]:
        if not self._ax:
            raise RuntimeError(
                UNDEFINED_SEARCH_SPACE.format(cls=self.__class__.__name__,
                                              space="space"))

        if not self._metric or not self._mode:
            raise RuntimeError(
                UNDEFINED_METRIC_MODE.format(cls=self.__class__.__name__,
                                             metric=self._metric,
                                             mode=self._mode))

        if self.max_concurrent:
            if len(self._live_trial_mapping) >= self.max_concurrent:
                return None

        if self._points_to_evaluate:
            config = self._points_to_evaluate.pop(0)
            parameters, trial_index = self._ax.attach_trial(config)
        else:
            parameters, trial_index = self._ax.get_next_trial()

        self._live_trial_mapping[trial_id] = trial_index
        return unflatten_dict(parameters)

    def on_trial_complete(self, trial_id, result=None, error=False):
        """Notification for the completion of trial.

        Data of form key value dictionary of metric names and values.
        """
        if result:
            self._process_result(trial_id, result)
        self._live_trial_mapping.pop(trial_id)

    def _process_result(self, trial_id, result):
        ax_trial_index = self._live_trial_mapping[trial_id]
        metric_dict = {
            self._objective_name: (result[self._objective_name], 0.0)
        }
        outcome_names = [
            oc.metric.name for oc in
            self._ax.experiment.optimization_config.outcome_constraints
        ]
        metric_dict.update({on: (result[on], 0.0) for on in outcome_names})
        self._ax.complete_trial(trial_index=ax_trial_index,
                                raw_data=metric_dict)

    @staticmethod
    def convert_search_space(spec: Dict):
        spec = flatten_dict(spec, prevent_delimiter=True)
        resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)

        if grid_vars:
            raise ValueError(
                "Grid search parameters cannot be automatically converted "
                "to an Ax search space.")

        def resolve_value(par, domain):
            sampler = domain.get_sampler()
            if isinstance(sampler, Quantized):
                logger.warning("AxSearch does not support quantization. "
                               "Dropped quantization.")
                sampler = sampler.sampler

            if isinstance(domain, Float):
                if isinstance(sampler, LogUniform):
                    return {
                        "name": par,
                        "type": "range",
                        "bounds": [domain.lower, domain.upper],
                        "value_type": "float",
                        "log_scale": True
                    }
                elif isinstance(sampler, Uniform):
                    return {
                        "name": par,
                        "type": "range",
                        "bounds": [domain.lower, domain.upper],
                        "value_type": "float",
                        "log_scale": False
                    }
            elif isinstance(domain, Integer):
                if isinstance(sampler, LogUniform):
                    return {
                        "name": par,
                        "type": "range",
                        "bounds": [domain.lower, domain.upper],
                        "value_type": "int",
                        "log_scale": True
                    }
                elif isinstance(sampler, Uniform):
                    return {
                        "name": par,
                        "type": "range",
                        "bounds": [domain.lower, domain.upper],
                        "value_type": "int",
                        "log_scale": False
                    }
            elif isinstance(domain, Categorical):
                if isinstance(sampler, Uniform):
                    return {
                        "name": par,
                        "type": "choice",
                        "values": domain.categories
                    }

            raise ValueError("AxSearch does not support parameters of type "
                             "`{}` with samplers of type `{}`".format(
                                 type(domain).__name__,
                                 type(domain.sampler).__name__))

        # Fixed vars
        fixed_values = [{
            "name": "/".join(path),
            "type": "fixed",
            "value": val
        } for path, val in resolved_vars]

        # Parameter name is e.g. "a/b/c" for nested dicts
        resolved_values = [
            resolve_value("/".join(path), domain)
            for path, domain in domain_vars
        ]

        return fixed_values + resolved_values
ax = AxClient()

N_TRIALS = int(os.environ.get("N_TRIALS", 5))

ax.create_experiment(
    name="GaussianProcessRegression-%s" % isonow(),
    parameters=PARAMETERS,
    objective_name="mean_square_error",
    minimize=True)

for _ in range(N_TRIALS):
    print(f"[ax-service-loop] Trial {_+1} of {N_TRIALS}")
    parameters, trial_index = ax.get_next_trial()
    ax.complete_trial(
        trial_index=trial_index,
        raw_data= evaluate(parameters)
    )
    print(parameters)
    print("")


print("[ax-service-loop] Training complete!")
best_parameters, metrics = ax.get_best_parameters()
print(f"[ax-service-loop] Sending data to db.")
print(f"[ax-service-loop] Best parameters found: {best_parameters}")
DB_URL = os.environ.get("DB_URL", "mysql://*****:*****@localhost/axdb")
from sqlalchemy import create_engine
engine = axst.sqa_store.db.create_mysql_engine_from_url(url=DB_URL)
conn = engine.connect()
axst.sqa_store.db.init_engine_and_session_factory(url=DB_URL)
table_names = engine.table_names()
Ejemplo n.º 28
0
 def test_sqa_storage(self):
     init_test_engine_and_session_factory(force_init=True)
     config = SQAConfig()
     encoder = Encoder(config=config)
     decoder = Decoder(config=config)
     db_settings = DBSettings(encoder=encoder, decoder=decoder)
     ax_client = AxClient(db_settings=db_settings)
     ax_client.create_experiment(
         name="test_experiment",
         parameters=[
             {
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             },
             {
                 "name": "y",
                 "type": "range",
                 "bounds": [0.0, 15.0]
             },
         ],
         minimize=True,
     )
     for _ in range(5):
         parameters, trial_index = ax_client.get_next_trial()
         ax_client.complete_trial(trial_index=trial_index,
                                  raw_data=branin(*parameters.values()))
     gs = ax_client.generation_strategy
     ax_client = AxClient(db_settings=db_settings)
     ax_client.load_experiment_from_database("test_experiment")
     # Trial #4 was completed after the last time the generation strategy
     # generated candidates, so pre-save generation strategy was not
     # "aware" of completion of trial #4. Post-restoration generation
     # strategy is aware of it, however, since it gets restored with most
     # up-to-date experiment data. Do adding trial #4 to the seen completed
     # trials of pre-storage GS to check their equality otherwise.
     gs._seen_trial_indices_by_status[TrialStatus.COMPLETED].add(4)
     self.assertEqual(gs, ax_client.generation_strategy)
     with self.assertRaises(ValueError):
         # Overwriting existing experiment.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[
                 {
                     "name": "x",
                     "type": "range",
                     "bounds": [-5.0, 10.0]
                 },
                 {
                     "name": "y",
                     "type": "range",
                     "bounds": [0.0, 15.0]
                 },
             ],
             minimize=True,
         )
     with self.assertRaises(ValueError):
         # Overwriting existing experiment with overwrite flag with present
         # DB settings. This should fail as we no longer allow overwriting
         # experiments stored in the DB.
         ax_client.create_experiment(
             name="test_experiment",
             parameters=[{
                 "name": "x",
                 "type": "range",
                 "bounds": [-5.0, 10.0]
             }],
             overwrite_existing_experiment=True,
         )
     # Original experiment should still be in DB and not have been overwritten.
     self.assertEqual(len(ax_client.experiment.trials), 5)
Ejemplo n.º 29
0
    def test_overwrite(self):
        init_test_engine_and_session_factory(force_init=True)
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "y",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            minimize=True,
        )

        # Log a trial
        parameters, trial_index = ax_client.get_next_trial()
        ax_client.complete_trial(trial_index=trial_index,
                                 raw_data=branin(*parameters.values()))

        with self.assertRaises(ValueError):
            # Overwriting existing experiment.
            ax_client.create_experiment(
                name="test_experiment",
                parameters=[
                    {
                        "name": "x",
                        "type": "range",
                        "bounds": [-5.0, 10.0]
                    },
                    {
                        "name": "y",
                        "type": "range",
                        "bounds": [0.0, 15.0]
                    },
                ],
                minimize=True,
            )
        # Overwriting existing experiment with overwrite flag.
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x1",
                    "type": "range",
                    "bounds": [-5.0, 10.0]
                },
                {
                    "name": "x2",
                    "type": "range",
                    "bounds": [0.0, 15.0]
                },
            ],
            overwrite_existing_experiment=True,
        )
        # There should be no trials, as we just put in a fresh experiment.
        self.assertEqual(len(ax_client.experiment.trials), 0)

        # Log a trial
        parameters, trial_index = ax_client.get_next_trial()
        self.assertIn("x1", parameters.keys())
        self.assertIn("x2", parameters.keys())
        ax_client.complete_trial(trial_index=trial_index,
                                 raw_data=branin(*parameters.values()))
Ejemplo n.º 30
0
class AxSearchJob(AutoSearchJob):
    """Job for hyperparameter search using [ax](https://ax.dev/)."""
    def __init__(self, config: Config, dataset, parent_job=None):
        super().__init__(config, dataset, parent_job)
        self.num_trials = self.config.get("ax_search.num_trials")
        self.num_sobol_trials = self.config.get("ax_search.num_sobol_trials")
        self.ax_client: AxClient = None

        if self.__class__ == AxSearchJob:
            for f in Job.job_created_hooks:
                f(self)

    # Overridden such that instances of search job can be pickled to workers
    def __getstate__(self):
        state = super(AxSearchJob, self).__getstate__()
        del state["ax_client"]
        return state

    def _prepare(self):
        super()._prepare()
        if self.num_sobol_trials > 0:
            # BEGIN: from /ax/service/utils/dispatch.py
            generation_strategy = GenerationStrategy(
                name="Sobol+GPEI",
                steps=[
                    GenerationStep(
                        model=Models.SOBOL,
                        num_trials=self.num_sobol_trials,
                        min_trials_observed=ceil(self.num_sobol_trials / 2),
                        enforce_num_trials=True,
                        model_kwargs={
                            "seed": self.config.get("ax_search.sobol_seed")
                        },
                    ),
                    GenerationStep(model=Models.GPEI,
                                   num_trials=-1,
                                   max_parallelism=3),
                ],
            )
            # END: from /ax/service/utils/dispatch.py

            self.ax_client = AxClient(generation_strategy=generation_strategy)
            choose_generation_strategy_kwargs = dict()
        else:
            self.ax_client = AxClient()
            # set random_seed that will be used by auto created sobol search from ax
            # note that here the argument is called "random_seed" not "seed"
            choose_generation_strategy_kwargs = {
                "random_seed": self.config.get("ax_search.sobol_seed")
            }
        self.ax_client.create_experiment(
            name=self.job_id,
            parameters=self.config.get("ax_search.parameters"),
            objective_name="metric_value",
            minimize=not self.config.get("valid.metric_max"),
            parameter_constraints=self.config.get(
                "ax_search.parameter_constraints"),
            choose_generation_strategy_kwargs=choose_generation_strategy_kwargs,
        )
        self.config.log("ax search initialized with {}".format(
            self.ax_client.generation_strategy))

        # Make sure sobol models are resumed correctly
        if self.ax_client.generation_strategy._curr.model == Models.SOBOL:

            self.ax_client.generation_strategy._set_current_model(
                experiment=self.ax_client.experiment, data=None)

            # Regenerate and drop SOBOL arms already generated. Since we fixed the seed,
            # we will skip exactly the arms already generated in the job being resumed.
            num_generated = len(self.parameters)
            if num_generated > 0:
                num_sobol_generated = min(
                    self.ax_client.generation_strategy._curr.num_trials,
                    num_generated)
                for i in range(num_sobol_generated):
                    generator_run = self.ax_client.generation_strategy.gen(
                        experiment=self.ax_client.experiment)
                    # self.config.log("Skipped parameters: {}".format(generator_run.arms))
                self.config.log(
                    "Skipped {} of {} Sobol trials due to prior data.".format(
                        num_sobol_generated,
                        self.ax_client.generation_strategy._curr.num_trials,
                    ))

    def register_trial(self, parameters=None):
        trial_id = None
        try:
            if parameters is None:
                parameters, trial_id = self.ax_client.get_next_trial()
            else:
                _, trial_id = self.ax_client.attach_trial(parameters)
        except Exception as e:
            self.config.log(
                "Cannot generate trial parameters. Will try again after a " +
                "running trial has completed. message was: {}".format(e))
        return parameters, trial_id

    def register_trial_result(self, trial_id, parameters, trace_entry):
        if trace_entry is None:
            self.ax_client.log_trial_failure(trial_index=trial_id)
        else:
            self.ax_client.complete_trial(trial_index=trial_id,
                                          raw_data=trace_entry["metric_value"])

    def get_best_parameters(self):
        best_parameters, values = self.ax_client.get_best_parameters()
        return best_parameters, float(values[0]["metric_value"])