def test_raw_data_format(self): ax = AxClient() ax.create_experiment( parameters=[ { "name": "x1", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "x2", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) for _ in range(6): parameterization, trial_index = ax.get_next_trial() x1, x2 = parameterization.get("x1"), parameterization.get("x2") ax.complete_trial(trial_index, raw_data=(branin(x1, x2), 0.0)) with self.assertRaisesRegex(ValueError, "Raw data has an invalid type"): ax.complete_trial(trial_index, raw_data=[(branin(x1, x2), 0.0), (branin(x1, x2), 0.0)])
def test_raw_data_format_with_fidelities(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 1.0] }, ], minimize=True, ) for _ in range(6): parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial( trial_index, raw_data=[ ({ "y": y / 2.0 }, { "objective": (branin(x, y / 2.0), 0.0) }), ({ "y": y }, { "objective": (branin(x, y), 0.0) }), ], )
def test_fixed_random_seed_reproducibility(self): ax_client = AxClient(random_seed=239) ax_client.create_experiment( parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ] ) for _ in range(5): params, idx = ax_client.get_next_trial() ax_client.complete_trial(idx, branin(params.get("x1"), params.get("x2"))) trial_parameters_1 = [ t.arm.parameters for t in ax_client.experiment.trials.values() ] ax_client = AxClient(random_seed=239) ax_client.create_experiment( parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ] ) for _ in range(5): params, idx = ax_client.get_next_trial() ax_client.complete_trial(idx, branin(params.get("x1"), params.get("x2"))) trial_parameters_2 = [ t.arm.parameters for t in ax_client.experiment.trials.values() ] self.assertEqual(trial_parameters_1, trial_parameters_2)
def _branin_evaluation_function(parameterization, weight=None): if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]): raise ValueError("Parametrization does not contain x1 or x2") x1, x2 = parameterization["x1"], parameterization["x2"] return { "branin": (branin(x1, x2), 0.0), "constrained_metric": (-branin(x1, x2), 0.0), }
def test_update_running_trial_with_intermediate_data(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 1.0]}, ], minimize=True, support_intermediate_data=True, ) parameterization, trial_index = ax_client.get_next_trial() # Launch Trial and update it 3 times with additional data. for t in range(3): x, y = parameterization.get("x"), parameterization.get("y") if t < 2: ax_client.update_running_trial_with_intermediate_data( 0, raw_data=[ ({"t": p_t}, {"objective": (branin(x, y) + t, 0.0)}) for p_t in range(t + 1) ], ) if t == 2: ax_client.complete_trial( 0, raw_data=[ ({"t": p_t}, {"objective": (branin(x, y) + t, 0.0)}) for p_t in range(t + 1) ], ) current_data = ax_client.experiment.fetch_data().df self.assertEqual(len(current_data), 0 if t < 2 else 3) no_intermediate_data_ax_client = AxClient() no_intermediate_data_ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 1.0]}, ], minimize=True, support_intermediate_data=False, ) parameterization, trial_index = no_intermediate_data_ax_client.get_next_trial() with self.assertRaises(ValueError): no_intermediate_data_ax_client.update_running_trial_with_intermediate_data( 0, raw_data=[ ({"t": p_t}, {"objective": (branin(x, y) + t, 0.0)}) for p_t in range(t + 1) ], )
def test_storage_error_handling(self, mock_save_fails): """Check that if `suppress_storage_errors` is True, AxClient won't visibly fail if encountered storage errors. """ init_test_engine_and_session_factory(force_init=True) config = SQAConfig() encoder = Encoder(config=config) decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) ax_client = AxClient(db_settings=db_settings, suppress_storage_errors=True) ax_client.create_experiment( name="test_experiment", parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) for _ in range(3): parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial(trial_index=trial_index, raw_data=branin(*parameters.values()))
def test_default_generation_strategy(self) -> None: """Test that Sobol+GPEI is used if no GenerationStrategy is provided.""" ax = AxClient() ax.create_experiment( name="test_branin", parameters=[ { "name": "x1", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "x2", "type": "range", "bounds": [0.0, 15.0] }, ], objective_name="branin", minimize=True, ) self.assertEqual( [s.model for s in ax.generation_strategy._steps], [Models.SOBOL, Models.GPEI], ) for _ in range(6): parameterization, trial_index = ax.get_next_trial() x1, x2 = parameterization.get("x1"), parameterization.get("x2") ax.complete_trial(trial_index, raw_data={"branin": (branin(x1, x2), 0.0)})
def _branin_evaluation_function_with_unknown_sem(parameterization, weight=None): if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]): raise ValueError("Parametrization does not contain x1 or x2") x1, x2 = parameterization["x1"], parameterization["x2"] return (branin(x1, x2), None)
def test_interruption(self) -> None: ax_client = AxClient() ax_client.create_experiment( name="test", parameters=[ # pyre-fixme[6]: expected union that should include {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], objective_name="branin", minimize=True, ) for i in range(6): parameterization, trial_index = ax_client.get_next_trial() self.assertFalse( # There should be non-complete trials. all(t.status.is_terminal for t in ax_client.experiment.trials.values()) ) x1, x2 = parameterization.get("x1"), parameterization.get("x2") ax_client.complete_trial( trial_index, raw_data=checked_cast( float, branin(checked_cast(float, x1), checked_cast(float, x2)) ), ) old_client = ax_client serialized = ax_client.to_json_snapshot() ax_client = AxClient.from_json_snapshot(serialized) self.assertEqual(len(ax_client.experiment.trials.keys()), i + 1) self.assertIsNot(ax_client, old_client) self.assertTrue( # There should be no non-complete trials. all(t.status.is_terminal for t in ax_client.experiment.trials.values()) )
def test_init_position_saved(self): ax_client = AxClient(random_seed=239) ax_client.create_experiment( parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], name="sobol_init_position_test", ) for _ in range(4): # For each generated trial, snapshot the client before generating it, # then recreate client, regenerate the trial and compare the trial # generated before and after snapshotting. If the state of Sobol is # recorded correctly, the newly generated trial will be the same as # the one generated before the snapshotting. serialized = ax_client.to_json_snapshot() params, idx = ax_client.get_next_trial() ax_client = AxClient.from_json_snapshot(serialized) with self.subTest(ax=ax_client, params=params, idx=idx): new_params, new_idx = ax_client.get_next_trial() self.assertEqual(params, new_params) self.assertEqual(idx, new_idx) self.assertEqual( ax_client.experiment.trials[idx]._generator_run._model_kwargs[ "init_position" ], idx + 1, ) ax_client.complete_trial(idx, branin(params.get("x1"), params.get("x2")))
def test_overwrite(self): init_test_engine_and_session_factory(force_init=True) ax_client = AxClient() ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) # Log a trial parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, raw_data=branin(*parameters.values()) ) with self.assertRaises(ValueError): # Overwriting existing experiment. ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) # Overwriting existing experiment with overwrite flag. ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], overwrite_existing_experiment=True, ) # There should be no trials, as we just put in a fresh experiment. self.assertEqual(len(ax_client.experiment.trials), 0) # Log a trial parameters, trial_index = ax_client.get_next_trial() self.assertIn("x1", parameters.keys()) self.assertIn("x2", parameters.keys()) ax_client.complete_trial( trial_index=trial_index, raw_data=branin(*parameters.values()) )
def test_default_generation_strategy(self) -> None: """Test that Sobol+GPEI is used if no GenerationStrategy is provided.""" ax_client = AxClient() ax_client.create_experiment( parameters=[ # pyre-fixme[6]: expected union that should include {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], objective_name="branin", minimize=True, ) self.assertEqual( [s.model for s in not_none(ax_client.generation_strategy)._steps], [Models.SOBOL, Models.GPEI], ) with self.assertRaisesRegex(ValueError, ".* no trials."): ax_client.get_optimization_trace(objective_optimum=branin.fmin) for i in range(6): parameterization, trial_index = ax_client.get_next_trial() x1, x2 = parameterization.get("x1"), parameterization.get("x2") ax_client.complete_trial( trial_index, raw_data={ "branin": ( checked_cast( float, branin(checked_cast(float, x1), checked_cast(float, x2)), ), 0.0, ) }, sample_size=i, ) if i < 5: with self.assertRaisesRegex(ValueError, "Could not obtain contour"): ax_client.get_contour_plot(param_x="x1", param_y="x2") ax_client.get_optimization_trace(objective_optimum=branin.fmin) ax_client.get_contour_plot() self.assertIn("x1", ax_client.get_trials_data_frame()) self.assertIn("x2", ax_client.get_trials_data_frame()) self.assertIn("branin", ax_client.get_trials_data_frame()) self.assertEqual(len(ax_client.get_trials_data_frame()), 6) # Test that Sobol is chosen when all parameters are choice. ax_client = AxClient() ax_client.create_experiment( parameters=[ # pyre-fixme[6]: expected union that should include {"name": "x1", "type": "choice", "values": [1, 2, 3]}, {"name": "x2", "type": "choice", "values": [1, 2, 3]}, ] ) self.assertEqual( [s.model for s in not_none(ax_client.generation_strategy)._steps], [Models.SOBOL], ) self.assertEqual(ax_client.get_recommended_max_parallelism(), [(-1, -1)]) self.assertTrue(ax_client.get_trials_data_frame().empty)
def test_default_generation_strategy_continuous(self, _a, _b, _c, _d) -> None: """Test that Sobol+GPEI is used if no GenerationStrategy is provided.""" ax_client = AxClient() ax_client.create_experiment( parameters=[ # pyre-fixme[6]: expected union that should include { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], objective_name="a", minimize=True, ) self.assertEqual( [s.model for s in not_none(ax_client.generation_strategy)._steps], [Models.SOBOL, Models.GPEI], ) with self.assertRaisesRegex(ValueError, ".* no trials"): ax_client.get_optimization_trace(objective_optimum=branin.fmin) for i in range(6): parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial( trial_index, raw_data={ "a": ( checked_cast( float, branin(checked_cast(float, x), checked_cast(float, y)), ), 0.0, ) }, sample_size=i, ) self.assertEqual(ax_client.generation_strategy.model._model_key, "GPEI") ax_client.get_optimization_trace(objective_optimum=branin.fmin) ax_client.get_contour_plot() ax_client.get_feature_importances() trials_df = ax_client.get_trials_data_frame() self.assertIn("x", trials_df) self.assertIn("y", trials_df) self.assertIn("a", trials_df) self.assertEqual(len(trials_df), 6)
def test_sqa_storage(self): init_test_engine_and_session_factory(force_init=True) config = SQAConfig() encoder = Encoder(config=config) decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) ax_client = AxClient(db_settings=db_settings) ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) for _ in range(5): parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, raw_data=branin(*parameters.values()) ) gs = ax_client.generation_strategy ax_client = AxClient(db_settings=db_settings) ax_client.load_experiment_from_database("test_experiment") # Trial #4 was completed after the last time the generation strategy # generated candidates, so pre-save generation strategy was not # "aware" of completion of trial #4. Post-restoration generation # strategy is aware of it, however, since it gets restored with most # up-to-date experiment data. Do adding trial #4 to the seen completed # trials of pre-storage GS to check their equality otherwise. gs._seen_trial_indices_by_status[TrialStatus.COMPLETED].add(4) self.assertEqual(gs, ax_client.generation_strategy) with self.assertRaises(ValueError): # Overwriting existing experiment. ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) with self.assertRaises(ValueError): # Overwriting existing experiment with overwrite flag with present # DB settings. This should fail as we no longer allow overwriting # experiments stored in the DB. ax_client.create_experiment( name="test_experiment", parameters=[{"name": "x", "type": "range", "bounds": [-5.0, 10.0]}], overwrite_existing_experiment=True, ) # Original experiment should still be in DB and not have been overwritten. self.assertEqual(len(ax_client.experiment.trials), 5)
def test_raw_data_format(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) for _ in range(6): parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial(trial_index, raw_data=(branin(x, y), 0.0)) with self.assertRaisesRegex(ValueError, "Raw data has an invalid type"): ax_client.update_trial_data(trial_index, raw_data="invalid_data")
def test_sqa_storage(self): init_test_engine_and_session_factory(force_init=True) config = SQAConfig() encoder = Encoder(config=config) decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) ax_client = AxClient(db_settings=db_settings) ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) for _ in range(5): parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, raw_data=branin(*parameters.values()) ) gs = ax_client.generation_strategy ax_client = AxClient(db_settings=db_settings) ax_client.load_experiment_from_database("test_experiment") self.assertEqual(gs, ax_client.generation_strategy) with self.assertRaises(ValueError): # Overwriting existing experiment. ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) with self.assertRaises(ValueError): # Overwriting existing experiment with overwrite flag with present # DB settings. This should fail as we no longer allow overwriting # experiments stored in the DB. ax_client.create_experiment( name="test_experiment", parameters=[{"name": "x", "type": "range", "bounds": [-5.0, 10.0]}], overwrite_existing_experiment=True, ) # Original experiment should still be in DB and not have been overwritten. self.assertEqual(len(ax_client.experiment.trials), 5)
def test_sqa_storage(self): init_test_engine_and_session_factory(force_init=True) config = SQAConfig() encoder = Encoder(config=config) decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) ax_client = AxClient(db_settings=db_settings) ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) for _ in range(5): parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, raw_data=branin(*parameters.values()) ) gs = ax_client.generation_strategy ax_client = AxClient(db_settings=db_settings) ax_client.load_experiment_from_database("test_experiment") self.assertEqual(gs, ax_client.generation_strategy) with self.assertRaises(ValueError): # Overwriting existing experiment. ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) # Overwriting existing experiment with overwrite flag. ax_client.create_experiment( name="test_experiment", parameters=[{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}], overwrite_existing_experiment=True, ) # There should be no trials, as we just put in a fresh experiment. self.assertEqual(len(ax_client.experiment.trials), 0)
def run_trials_using_recommended_parallelism( ax_client: AxClient, recommended_parallelism: List[Tuple[int, int]], total_trials: int, ) -> int: remaining_trials = total_trials for num_trials, parallelism_setting in recommended_parallelism: if num_trials == -1: num_trials = remaining_trials for _ in range(ceil(num_trials / parallelism_setting)): in_flight_trials = [] if parallelism_setting > remaining_trials: parallelism_setting = remaining_trials for _ in range(parallelism_setting): params, idx = ax_client.get_next_trial() in_flight_trials.append((params, idx)) remaining_trials -= 1 for _ in range(parallelism_setting): params, idx = in_flight_trials.pop() ax_client.complete_trial(idx, branin(params["x"], params["y"])) # If all went well and no errors were raised, remaining_trials should be 0. return remaining_trials
def test_sqa_storage(self): init_test_engine_and_session_factory(force_init=True) config = SQAConfig() encoder = Encoder(config=config) decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) ax = AxClient(db_settings=db_settings) ax.create_experiment( name="test_experiment", parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) for _ in range(5): parameters, trial_index = ax.get_next_trial() ax.complete_trial( trial_index=trial_index, raw_data=branin(*parameters.values()) ) gs = ax.generation_strategy ax = AxClient(db_settings=db_settings) ax.load_experiment_from_database("test_experiment") self.assertEqual(gs, ax.generation_strategy)