def test_fail_on_batch(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) batch_trial = ax_client.experiment.new_batch_trial( generator_run=GeneratorRun( arms=[ Arm(parameters={"x": 0, "y": 1}), Arm(parameters={"x": 0, "y": 1}), ] ) ) with self.assertRaises(NotImplementedError): ax_client.complete_trial(batch_trial.index, 0)
def test_raw_data_format_with_fidelities(self): ax = AxClient() ax.create_experiment( parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 1.0]}, ], minimize=True, ) for _ in range(6): parameterization, trial_index = ax.get_next_trial() x1, x2 = parameterization.get("x1"), parameterization.get("x2") ax.complete_trial( trial_index, raw_data=[ ({"x2": x2 / 2.0}, {"objective": (branin(x1, x2 / 2.0), 0.0)}), ({"x2": x2}, {"objective": (branin(x1, x2), 0.0)}), ], )
def sweep_over_batches( self, ax_client: AxClient, batch_of_trials: BatchOfTrialType, ) -> None: assert self.launcher is not None assert self.job_idx is not None chunked_batches = self.chunks(batch_of_trials, self.max_batch_size) for batch in chunked_batches: overrides = [x.overrides for x in batch] self.validate_batch_is_legal(overrides) rets = self.launcher.launch(job_overrides=overrides, initial_job_idx=self.job_idx) self.job_idx += len(rets) for idx in range(len(batch)): val = rets[idx].return_value ax_client.complete_trial(trial_index=batch[idx].trial_index, raw_data=val)
def test_sqa_storage(self): init_test_engine_and_session_factory(force_init=True) config = SQAConfig() encoder = Encoder(config=config) decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) ax_client = AxClient(db_settings=db_settings) ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) for _ in range(5): parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, raw_data=branin(*parameters.values()) ) gs = ax_client.generation_strategy ax_client = AxClient(db_settings=db_settings) ax_client.load_experiment_from_database("test_experiment") self.assertEqual(gs, ax_client.generation_strategy) with self.assertRaises(ValueError): # Overwriting existing experiment. ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) # Overwriting existing experiment with overwrite flag. ax_client.create_experiment( name="test_experiment", parameters=[{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}], overwrite_existing_experiment=True, ) # There should be no trials, as we just put in a fresh experiment. self.assertEqual(len(ax_client.experiment.trials), 0)
def test_fixed_random_seed_reproducibility(self): ax_client = AxClient(random_seed=239) ax_client.create_experiment(parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ]) for _ in range(5): params, idx = ax_client.get_next_trial() ax_client.complete_trial(idx, branin(params.get("x"), params.get("y"))) trial_parameters_1 = [ t.arm.parameters for t in ax_client.experiment.trials.values() ] ax_client = AxClient(random_seed=239) ax_client.create_experiment(parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ]) for _ in range(5): params, idx = ax_client.get_next_trial() ax_client.complete_trial(idx, branin(params.get("x"), params.get("y"))) trial_parameters_2 = [ t.arm.parameters for t in ax_client.experiment.trials.values() ] self.assertEqual(trial_parameters_1, trial_parameters_2)
def test_attach_trial_ttl_seconds(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) params, idx = ax_client.attach_trial(parameters={ "x": 0.0, "y": 1.0 }, ttl_seconds=1) self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running) time.sleep(1) # Wait for TTL to elapse. self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed) # Also make sure we can no longer complete the trial as it is failed. with self.assertRaisesRegex( ValueError, ".* has been marked FAILED, so it no longer expects data."): ax_client.complete_trial(trial_index=idx, raw_data=5) params2, idx2 = ax_client.attach_trial(parameters={ "x": 0.0, "y": 1.0 }, ttl_seconds=1) ax_client.complete_trial(trial_index=idx2, raw_data=5) self.assertEqual(ax_client.get_best_parameters()[0], params2) self.assertEqual(ax_client.get_trial_parameters(trial_index=idx2), { "x": 0, "y": 1 })
def test_default_generation_strategy_continuous(self, _a, _b, _c, _d) -> None: """Test that Sobol+GPEI is used if no GenerationStrategy is provided.""" ax_client = AxClient() ax_client.create_experiment( parameters=[ # pyre-fixme[6]: expected union that should include {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], objective_name="a", minimize=True, ) self.assertEqual( [s.model for s in not_none(ax_client.generation_strategy)._steps], [Models.SOBOL, Models.GPEI], ) with self.assertRaisesRegex(ValueError, ".* no trials"): ax_client.get_optimization_trace(objective_optimum=branin.fmin) for i in range(6): parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial( trial_index, raw_data={ "a": ( checked_cast( float, branin(checked_cast(float, x), checked_cast(float, y)), ), 0.0, ) }, sample_size=i, ) self.assertEqual(ax_client.generation_strategy.model._model_key, "GPEI") ax_client.get_optimization_trace(objective_optimum=branin.fmin) ax_client.get_contour_plot() ax_client.get_feature_importances() trials_df = ax_client.get_trials_data_frame() self.assertIn("x", trials_df) self.assertIn("y", trials_df) self.assertIn("a", trials_df) self.assertEqual(len(trials_df), 6)
def test_attach_trial(self): ax = AxClient() ax.create_experiment( parameters=[ { "name": "x1", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "x2", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) params, idx = ax.attach_trial(parameters={"x1": 0, "x2": 1}) ax.complete_trial(trial_index=idx, raw_data=5) self.assertEqual(ax.get_best_parameters()[0], params)
def test_attach_trial_numpy(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) params, idx = ax_client.attach_trial(parameters={"x": 0.0, "y": 1.0}) ax_client.complete_trial(trial_index=idx, raw_data=np.int32(5)) self.assertEqual(ax_client.get_best_parameters()[0], params)
def test_raw_data_format_with_map_results(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 1.0]}, ], minimize=True, support_intermediate_data=True, ) for _ in range(6): parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial( trial_index, raw_data=[ ({"y": y / 2.0}, {"objective": (branin(x, y / 2.0), 0.0)}), ({"y": y}, {"objective": (branin(x, y), 0.0)}), ], )
def test_start_and_end_time_in_trial_completion(self): start_time = current_timestamp_in_millis() ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) params, idx = ax_client.get_next_trial() ax_client.complete_trial( trial_index=idx, raw_data=1.0, metadata={ "start_time": start_time, "end_time": current_timestamp_in_millis(), }, ) dat = ax_client.experiment.fetch_data().df self.assertGreater(dat["end_time"][0], dat["start_time"][0])
def test_attach_trial_and_get_trial_parameters(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) params, idx = ax_client.attach_trial(parameters={"x": 0.0, "y": 1.0}) ax_client.complete_trial(trial_index=idx, raw_data=5) self.assertEqual(ax_client.get_best_parameters()[0], params) self.assertEqual( ax_client.get_trial_parameters(trial_index=idx), {"x": 0, "y": 1} ) with self.assertRaises(ValueError): ax_client.get_trial_parameters( trial_index=10 ) # No trial #10 in experiment. with self.assertRaisesRegex(ValueError, ".* is of type"): ax_client.attach_trial({"x": 1, "y": 2})
def test_abandon_trial(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) # An abandoned trial adds no data. params, idx = ax_client.get_next_trial() ax_client.abandon_trial(trial_index=idx) data = ax_client.experiment.fetch_data() self.assertEqual(len(data.df.index), 0) # Can't update a completed trial. params2, idx2 = ax_client.get_next_trial() ax_client.complete_trial(trial_index=idx2, raw_data={"objective": (0, 0.0)}) with self.assertRaisesRegex(ValueError, ".* in a terminal state."): ax_client.abandon_trial(trial_index=idx2)
def test_trial_completion(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) params, idx = ax_client.get_next_trial() ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)}) self.assertEqual(ax_client.get_best_parameters()[0], params) params2, idx2 = ax_client.get_next_trial() ax_client.complete_trial(trial_index=idx2, raw_data=(-1, 0.0)) self.assertEqual(ax_client.get_best_parameters()[0], params2) params3, idx3 = ax_client.get_next_trial() ax_client.complete_trial( trial_index=idx3, raw_data=-2, metadata={"dummy": "test"} ) self.assertEqual(ax_client.get_best_parameters()[0], params3) self.assertEqual( ax_client.experiment.trials.get(2).run_metadata.get("dummy"), "test" ) best_trial_values = ax_client.get_best_parameters()[1] self.assertEqual(best_trial_values[0], {"objective": -2.0}) self.assertTrue(math.isnan(best_trial_values[1]["objective"]["objective"]))
def test_interruption(self) -> None: ax_client = AxClient() ax_client.create_experiment( name="test", parameters=[ # pyre-fixme[6]: expected union that should include { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], objective_name="branin", minimize=True, ) for i in range(6): parameterization, trial_index = ax_client.get_next_trial() self.assertFalse( # There should be non-complete trials. all(t.status.is_terminal for t in ax_client.experiment.trials.values())) x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial( trial_index, raw_data=checked_cast( float, branin(checked_cast(float, x), checked_cast(float, y))), ) old_client = ax_client serialized = ax_client.to_json_snapshot() ax_client = AxClient.from_json_snapshot(serialized) self.assertEqual(len(ax_client.experiment.trials.keys()), i + 1) self.assertIsNot(ax_client, old_client) self.assertTrue( # There should be no non-complete trials. all(t.status.is_terminal for t in ax_client.experiment.trials.values()))
def run_trials_using_recommended_parallelism( ax_client: AxClient, recommended_parallelism: List[Tuple[int, int]], total_trials: int, ) -> int: remaining_trials = total_trials for num_trials, parallelism_setting in recommended_parallelism: if num_trials == -1: num_trials = remaining_trials for _ in range(ceil(num_trials / parallelism_setting)): in_flight_trials = [] if parallelism_setting > remaining_trials: parallelism_setting = remaining_trials for _ in range(parallelism_setting): params, idx = ax_client.get_next_trial() in_flight_trials.append((params, idx)) remaining_trials -= 1 for _ in range(parallelism_setting): params, idx = in_flight_trials.pop() ax_client.complete_trial(idx, branin(params["x"], params["y"])) # If all went well and no errors were raised, remaining_trials should be 0. return remaining_trials
def test_storage_error_handling(self, mock_save_fails): """Check that if `suppress_storage_errors` is True, AxClient won't visibly fail if encountered storage errors. """ init_test_engine_and_session_factory(force_init=True) config = SQAConfig() encoder = Encoder(config=config) decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) ax_client = AxClient(db_settings=db_settings, suppress_storage_errors=True) ax_client.create_experiment( name="test_experiment", parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) for _ in range(3): parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, raw_data=branin(*parameters.values()) )
def test_init_position_saved(self): ax_client = AxClient(random_seed=239) ax_client.create_experiment( parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], name="sobol_init_position_test", ) for _ in range(4): # For each generated trial, snapshot the client before generating it, # then recreate client, regenerate the trial and compare the trial # generated before and after snapshotting. If the state of Sobol is # recorded correctly, the newly generated trial will be the same as # the one generated before the snapshotting. serialized = ax_client.to_json_snapshot() params, idx = ax_client.get_next_trial() ax_client = AxClient.from_json_snapshot(serialized) with self.subTest(ax=ax_client, params=params, idx=idx): new_params, new_idx = ax_client.get_next_trial() self.assertEqual(params, new_params) self.assertEqual(idx, new_idx) self.assertEqual( ax_client.experiment.trials[idx]._generator_run. _model_state_after_gen["init_position"], idx + 1, ) ax_client.complete_trial(idx, branin(params.get("x"), params.get("y")))
def test_ttl_trial(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) # A ttl trial that ends adds no data. params, idx = ax_client.get_next_trial(ttl_seconds=1) self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running) time.sleep(1) # Wait for TTL to elapse. self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed) # Also make sure we can no longer complete the trial as it is failed. with self.assertRaisesRegex( ValueError, ".* has been marked FAILED, so it no longer expects data." ): ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)}) params2, idy = ax_client.get_next_trial(ttl_seconds=1) ax_client.complete_trial(trial_index=idy, raw_data=(-1, 0.0)) self.assertEqual(ax_client.get_best_parameters()[0], params2)
def test_sqa_storage(self): init_test_engine_and_session_factory(force_init=True) config = SQAConfig() encoder = Encoder(config=config) decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) ax = AxClient(db_settings=db_settings) ax.create_experiment( name="test_experiment", parameters=[ {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) for _ in range(5): parameters, trial_index = ax.get_next_trial() ax.complete_trial( trial_index=trial_index, raw_data=branin(*parameters.values()) ) gs = ax.generation_strategy ax = AxClient(db_settings=db_settings) ax.load_experiment_from_database("test_experiment") self.assertEqual(gs, ax.generation_strategy)
def test_raw_data_format(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) for _ in range(6): parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") ax_client.complete_trial(trial_index, raw_data=(branin(x, y), 0.0)) with self.assertRaisesRegex(ValueError, "Raw data has an invalid type"): ax_client.update_trial_data(trial_index, raw_data="invalid_data")
def test_trial_completion(self): ax_client = AxClient() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], minimize=True, ) params, idx = ax_client.get_next_trial() # Can't update before completing. with self.assertRaisesRegex(ValueError, ".* not yet"): ax_client.update_trial_data( trial_index=idx, raw_data={"objective": (0, 0.0)} ) ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)}) # Cannot complete a trial twice, should use `update_trial_data`. with self.assertRaisesRegex(ValueError, ".* already been completed"): ax_client.complete_trial(trial_index=idx, raw_data={"objective": (0, 0.0)}) # Cannot update trial data with observation for a metric it already has. with self.assertRaisesRegex(ValueError, ".* contained an observation"): ax_client.update_trial_data( trial_index=idx, raw_data={"objective": (0, 0.0)} ) # Same as above, except objective name should be getting inferred. with self.assertRaisesRegex(ValueError, ".* contained an observation"): ax_client.update_trial_data(trial_index=idx, raw_data=1.0) ax_client.update_trial_data(trial_index=idx, raw_data={"m1": (1, 0.0)}) metrics_in_data = ax_client.experiment.fetch_data().df["metric_name"].values self.assertIn("m1", metrics_in_data) self.assertIn("objective", metrics_in_data) self.assertEqual(ax_client.get_best_parameters()[0], params) params2, idy = ax_client.get_next_trial() ax_client.complete_trial(trial_index=idy, raw_data=(-1, 0.0)) self.assertEqual(ax_client.get_best_parameters()[0], params2) params3, idx3 = ax_client.get_next_trial() ax_client.complete_trial( trial_index=idx3, raw_data=-2, metadata={"dummy": "test"} ) self.assertEqual(ax_client.get_best_parameters()[0], params3) self.assertEqual( ax_client.experiment.trials.get(2).run_metadata.get("dummy"), "test" ) best_trial_values = ax_client.get_best_parameters()[1] self.assertEqual(best_trial_values[0], {"objective": -2.0}) self.assertTrue(math.isnan(best_trial_values[1]["objective"]["objective"]))
def _benchmark_replication_Service_API( problem: SimpleBenchmarkProblem, method: GenerationStrategy, num_trials: int, experiment_name: str, batch_size: int = 1, raise_all_exceptions: bool = False, benchmark_trial: FunctionType = benchmark_trial, verbose_logging: bool = True, # Number of trials that need to fail for a replication to be considered failed. failed_trials_tolerated: int = 5, async_benchmark_options: Optional[AsyncBenchmarkOptions] = None, ) -> Tuple[Experiment, List[Exception]]: """Run a benchmark replication via the Service API because the problem was set up in a simplified way, without the use of Ax classes like `OptimizationConfig` or `SearchSpace`. """ if async_benchmark_options is not None: raise NonRetryableBenchmarkingError( "`async_benchmark_options` not supported when using the Service API." ) exceptions = [] if batch_size == 1: ax_client = AxClient(generation_strategy=method, verbose_logging=verbose_logging) else: # pragma: no cover, TODO[T53975770] assert batch_size > 1, "Batch size of 1 or greater is expected." raise NotImplementedError( "Batched benchmarking on `SimpleBenchmarkProblem`-s not yet implemented." ) ax_client.create_experiment( name=experiment_name, parameters=problem.domain_as_ax_client_parameters(), minimize=problem.minimize, objective_name=problem.name, ) parameter_names = list(ax_client.experiment.search_space.parameters.keys()) assert num_trials > 0 for _ in range(num_trials): parameterization, idx = ax_client.get_next_trial() param_values = np.array( [parameterization.get(x) for x in parameter_names]) try: mean, sem = benchmark_trial(parameterization=param_values, evaluation_function=problem.f) # If problem indicates a noise level and is using a synthetic callable, # add normal noise to the measurement of the mean. if problem.uses_synthetic_function and problem.noise_sd != 0.0: noise = np.random.randn() * problem.noise_sd sem = (sem or 0.0) + problem.noise_sd logger.info( f"Adding noise of {noise} to the measurement mean ({mean})." f"Problem noise SD setting: {problem.noise_sd}.") mean = mean + noise ax_client.complete_trial(trial_index=idx, raw_data=(mean, sem)) except Exception as err: # TODO[T53975770]: test if raise_all_exceptions: raise exceptions.append(err) if len(exceptions) > failed_trials_tolerated: raise RuntimeError( # TODO[T53975770]: test f"More than {failed_trials_tolerated} failed for {experiment_name}." ) return ax_client.experiment, exceptions
def ml_run(self, run_id=None): seed_randomness(self.random_seed) mlflow.log_params(flatten(get_params_of_task(self))) total_training_time = 0 # should land to 'optimizer_props' params_space = [ { 'name': 'lr', 'type': 'range', 'bounds': [1e-6, 0.008], # 'value_type': 'float', 'log_scale': True, }, { 'name': 'beta_1', 'type': 'range', 'bounds': [.0, 0.9999], 'value_type': 'float', # 'log_scale': True, }, { 'name': 'beta_2', 'type': 'range', 'bounds': [.0, 0.9999], 'value_type': 'float', # 'log_scale': True, } ] # TODO: make reproducibility of search # without it we will get each time new params # for example we can use: # ax.storage.sqa_store.structs.DBSettings # DBSettings(url="sqlite://<path-to-file>") # to store experiments ax = AxClient( # can't use that feature yet. # got error # NotImplementedError: # Saving and loading experiment in `AxClient` functionality currently under development. # db_settings=DBSettings(url=self.output()['ax_settings'].path) ) # FIXME: temporal solution while ax doesn't have api to (re-)store state class_name = get_class_name_as_snake(self) ax.create_experiment( name=f'{class_name}_experiment', parameters=params_space, objective_name='score', minimize=should_minimize(self.metric), # parameter_constraints=['x1 + x2 <= 2.0'], # Optional. # outcome_constraints=['l2norm <= 1.25'], # Optional. ) trial_index = 0 experiment = self._get_ax_experiment() if experiment: print('AX: restore experiment') print('AX: num_trials:', experiment.num_trials) ax._experiment = experiment trial_index = experiment.num_trials - 1 model_task = get_model_task_by_name(self.model_name) while trial_index < self.max_runs: print(f'AX: Running trial {trial_index + 1}/{self.max_runs}...') # get last unfinished trial parameters = get_last_unfinished_params(ax) if parameters is None: print('AX: generate new Trial') parameters, trial_index = ax.get_next_trial() # good time to store experiment (with new Trial) with self.output()['ax_experiment'].open('w') as f: print('AX: store experiment: ', ax.experiment) pickle.dump(ax.experiment, f) print('AX: parameters', parameters) # now is time to evaluate model model_result = yield model_task( parent_run_id=run_id, random_seed=self.random_seed, # TODO: actually we should be able to pass even nested params # **parameters, optimizer_props=parameters) # TODO: store run_id in Trial model_run_id = self.get_run_id_from_result(model_result) with model_result['metrics'].open('r') as f: model_metrics = yaml.load(f) model_score_mean = model_metrics[self.metric]['val'] # TODO: we might know it :/ model_score_error = 0.0 total_training_time += model_metrics['train_time']['total'] with model_result['params'].open('r') as f: model_params = yaml.load(f) print('AX: complete trial:', trial_index) ax.complete_trial( trial_index=trial_index, raw_data={'score': (model_score_mean, model_score_error)}, metadata={ 'metrics': model_metrics, 'params': model_params, 'run_id': model_run_id, }) best_parameters, _ = ax.get_best_parameters() mlflow.log_metric('train_time.total', total_training_time) print('best params', best_parameters) best_trial = get_best_trial(experiment, self.metric) mlflow.log_metrics(flatten(best_trial.run_metadata['metrics'])) mlflow.log_params(flatten(best_trial.run_metadata['params']))
class IterativePrune: def __init__(self): self.parser_args = None self.ax_client = None self.base_model_path = "base_model" self.pruning_amount = None def run_mnist_model(self, base=False): parser_dict = vars(self.parser_args) if base: mlflow.start_run(run_name="BaseModel") mlflow.pytorch.autolog() dm = MNISTDataModule(**parser_dict) dm.setup(stage="fit") model = LightningMNISTClassifier(**parser_dict) trainer = pl.Trainer.from_argparse_args(self.parser_args) trainer.fit(model, dm) trainer.test() if os.path.exists(self.base_model_path): shutil.rmtree(self.base_model_path) mlflow.pytorch.save_model(trainer.get_model(), self.base_model_path) return trainer def load_base_model(self): path = Path(_download_artifact_from_uri(self.base_model_path)) model_file_path = os.path.join(path, "data/model.pth") return torch.load(model_file_path) def initialize_ax_client(self): self.ax_client = AxClient() self.ax_client.create_experiment( parameters=[{ "name": "amount", "type": "range", "bounds": [0.05, 0.15], "value_type": "float" }], objective_name="test_accuracy", ) @staticmethod def prune_and_save_model(model, amount): for _, module in model.named_modules(): if isinstance(module, torch.nn.Conv2d) or isinstance( module, torch.nn.Linear): prune.l1_unstructured(module, name="weight", amount=amount) prune.remove(module, "weight") mlflow.pytorch.save_state_dict(model.state_dict(), ".") model = torch.load("state_dict.pth") os.remove("state_dict.pth") return model @staticmethod def count_model_parameters(model): table = PrettyTable(["Modules", "Parameters"]) total_params = 0 for name, parameter in model.named_parameters(): if not parameter.requires_grad: continue param = parameter.nonzero(as_tuple=False).size(0) table.add_row([name, param]) total_params += param return table, total_params @staticmethod def write_prune_summary(summary, params): tempdir = tempfile.mkdtemp() try: summary_file = os.path.join(tempdir, "pruned_model_summary.txt") params = "Total Trainable Parameters :" + str(params) with open(summary_file, "w") as f: f.write(str(summary)) f.write("\n") f.write(str(params)) try_mlflow_log(mlflow.log_artifact, local_path=summary_file) finally: shutil.rmtree(tempdir) def iterative_prune(self, model, parametrization): if not self.pruning_amount: self.pruning_amount = parametrization.get("amount") else: self.pruning_amount += 0.15 mlflow.log_metric("PRUNING PERCENTAGE", self.pruning_amount) pruned_model = self.prune_and_save_model(model, self.pruning_amount) model.load_state_dict(copy.deepcopy(pruned_model)) summary, params = self.count_model_parameters(model) self.write_prune_summary(summary, params) trainer = self.run_mnist_model() metrics = trainer.callback_metrics test_accuracy = metrics.get("avg_test_acc") return test_accuracy def initiate_pruning_process(self, model): total_trials = int(vars(self.parser_args)["total_trials"]) trial_index = None for i in range(total_trials): parameters, trial_index = self.ax_client.get_next_trial() print( "***************************************************************************" ) print("Running Trial {}".format(i + 1)) print( "***************************************************************************" ) with mlflow.start_run(nested=True, run_name="Iteration" + str(i)): mlflow.set_tags({"AX_TRIAL": i}) # calling the model test_accuracy = self.iterative_prune(model, parameters) # completion of trial self.ax_client.complete_trial(trial_index=trial_index, raw_data=test_accuracy.item()) # Ending the Base run mlflow.end_run() def get_parser_args(self): parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parent_parser=parser) parser = LightningMNISTClassifier.add_model_specific_args( parent_parser=parser) parser.add_argument( "--total_trials", default=3, help= "Number of AX trials to be run for the optimization experiment", ) self.parser_args = parser.parse_args()
class AxSearch(Searcher): """Uses `Ax <https://ax.dev/>`_ to optimize hyperparameters. Ax is a platform for understanding, managing, deploying, and automating adaptive experiments. Ax provides an easy to use interface with BoTorch, a flexible, modern library for Bayesian optimization in PyTorch. More information can be found in https://ax.dev/. To use this search algorithm, you must install Ax and sqlalchemy: .. code-block:: bash $ pip install ax-platform sqlalchemy Parameters: space (list[dict]): Parameters in the experiment search space. Required elements in the dictionaries are: "name" (name of this parameter, string), "type" (type of the parameter: "range", "fixed", or "choice", string), "bounds" for range parameters (list of two values, lower bound first), "values" for choice parameters (list of values), and "value" for fixed parameters (single value). metric (str): Name of the metric used as objective in this experiment. This metric must be present in `raw_data` argument to `log_data`. This metric must also be present in the dict reported/returned by the Trainable. If None but a mode was passed, the `ray.tune.result.DEFAULT_METRIC` will be used per default. mode (str): One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute. Defaults to "max". points_to_evaluate (list): Initial parameter suggestions to be run first. This is for when you already have some good parameters you want to run first to help the algorithm make better suggestions for future parameters. Needs to be a list of dicts containing the configurations. parameter_constraints (list[str]): Parameter constraints, such as "x3 >= x4" or "x3 + x4 >= 2". outcome_constraints (list[str]): Outcome constraints of form "metric_name >= bound", like "m1 <= 3." ax_client (AxClient): Optional AxClient instance. If this is set, do not pass any values to these parameters: `space`, `metric`, `parameter_constraints`, `outcome_constraints`. use_early_stopped_trials: Deprecated. max_concurrent (int): Deprecated. Tune automatically converts search spaces to Ax's format: .. code-block:: python from ray import tune from ray.tune.suggest.ax import AxSearch config = { "x1": tune.uniform(0.0, 1.0), "x2": tune.uniform(0.0, 1.0) } def easy_objective(config): for i in range(100): intermediate_result = config["x1"] + config["x2"] * i tune.report(score=intermediate_result) ax_search = AxSearch(metric="score") tune.run( config=config, easy_objective, search_alg=ax_search) If you would like to pass the search space manually, the code would look like this: .. code-block:: python from ray import tune from ray.tune.suggest.ax import AxSearch parameters = [ {"name": "x1", "type": "range", "bounds": [0.0, 1.0]}, {"name": "x2", "type": "range", "bounds": [0.0, 1.0]}, ] def easy_objective(config): for i in range(100): intermediate_result = config["x1"] + config["x2"] * i tune.report(score=intermediate_result) ax_search = AxSearch(space=parameters, metric="score") tune.run(easy_objective, search_alg=ax_search) """ def __init__(self, space: Optional[Union[Dict, List[Dict]]] = None, metric: Optional[str] = None, mode: Optional[str] = None, points_to_evaluate: Optional[List[Dict]] = None, parameter_constraints: Optional[List] = None, outcome_constraints: Optional[List] = None, ax_client: Optional[AxClient] = None, use_early_stopped_trials: Optional[bool] = None, max_concurrent: Optional[int] = None): assert ax is not None, """Ax must be installed! You can install AxSearch with the command: `pip install ax-platform sqlalchemy`.""" if mode: assert mode in ["min", "max"], "`mode` must be 'min' or 'max'." super(AxSearch, self).__init__(metric=metric, mode=mode, max_concurrent=max_concurrent, use_early_stopped_trials=use_early_stopped_trials) self._ax = ax_client if isinstance(space, dict) and space: resolved_vars, domain_vars, grid_vars = parse_spec_vars(space) if domain_vars or grid_vars: logger.warning( UNRESOLVED_SEARCH_SPACE.format(par="space", cls=type(self))) space = self.convert_search_space(space) self._space = space self._parameter_constraints = parameter_constraints self._outcome_constraints = outcome_constraints self._points_to_evaluate = copy.deepcopy(points_to_evaluate) self.max_concurrent = max_concurrent self._objective_name = metric self._parameters = [] self._live_trial_mapping = {} if self._ax or self._space: self._setup_experiment() def _setup_experiment(self): if self._metric is None and self._mode: # If only a mode was passed, use anonymous metric self._metric = DEFAULT_METRIC if not self._ax: self._ax = AxClient() try: exp = self._ax.experiment has_experiment = True except ValueError: has_experiment = False if not has_experiment: if not self._space: raise ValueError( "You have to create an Ax experiment by calling " "`AxClient.create_experiment()`, or you should pass an " "Ax search space as the `space` parameter to `AxSearch`, " "or pass a `config` dict to `tune.run()`.") self._ax.create_experiment( parameters=self._space, objective_name=self._metric, parameter_constraints=self._parameter_constraints, outcome_constraints=self._outcome_constraints, minimize=self._mode != "max") else: if any([ self._space, self._parameter_constraints, self._outcome_constraints ]): raise ValueError( "If you create the Ax experiment yourself, do not pass " "values for these parameters to `AxSearch`: {}.".format([ "space", "parameter_constraints", "outcome_constraints" ])) exp = self._ax.experiment self._objective_name = exp.optimization_config.objective.metric.name self._parameters = list(exp.parameters) if self._ax._enforce_sequential_optimization: logger.warning("Detected sequential enforcement. Be sure to use " "a ConcurrencyLimiter.") def set_search_properties(self, metric: Optional[str], mode: Optional[str], config: Dict): if self._ax: return False space = self.convert_search_space(config) self._space = space if metric: self._metric = metric if mode: self._mode = mode self._setup_experiment() return True def suggest(self, trial_id: str) -> Optional[Dict]: if not self._ax: raise RuntimeError( UNDEFINED_SEARCH_SPACE.format(cls=self.__class__.__name__, space="space")) if not self._metric or not self._mode: raise RuntimeError( UNDEFINED_METRIC_MODE.format(cls=self.__class__.__name__, metric=self._metric, mode=self._mode)) if self.max_concurrent: if len(self._live_trial_mapping) >= self.max_concurrent: return None if self._points_to_evaluate: config = self._points_to_evaluate.pop(0) parameters, trial_index = self._ax.attach_trial(config) else: parameters, trial_index = self._ax.get_next_trial() self._live_trial_mapping[trial_id] = trial_index return unflatten_dict(parameters) def on_trial_complete(self, trial_id, result=None, error=False): """Notification for the completion of trial. Data of form key value dictionary of metric names and values. """ if result: self._process_result(trial_id, result) self._live_trial_mapping.pop(trial_id) def _process_result(self, trial_id, result): ax_trial_index = self._live_trial_mapping[trial_id] metric_dict = { self._objective_name: (result[self._objective_name], 0.0) } outcome_names = [ oc.metric.name for oc in self._ax.experiment.optimization_config.outcome_constraints ] metric_dict.update({on: (result[on], 0.0) for on in outcome_names}) self._ax.complete_trial(trial_index=ax_trial_index, raw_data=metric_dict) @staticmethod def convert_search_space(spec: Dict): spec = flatten_dict(spec, prevent_delimiter=True) resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) if grid_vars: raise ValueError( "Grid search parameters cannot be automatically converted " "to an Ax search space.") def resolve_value(par, domain): sampler = domain.get_sampler() if isinstance(sampler, Quantized): logger.warning("AxSearch does not support quantization. " "Dropped quantization.") sampler = sampler.sampler if isinstance(domain, Float): if isinstance(sampler, LogUniform): return { "name": par, "type": "range", "bounds": [domain.lower, domain.upper], "value_type": "float", "log_scale": True } elif isinstance(sampler, Uniform): return { "name": par, "type": "range", "bounds": [domain.lower, domain.upper], "value_type": "float", "log_scale": False } elif isinstance(domain, Integer): if isinstance(sampler, LogUniform): return { "name": par, "type": "range", "bounds": [domain.lower, domain.upper], "value_type": "int", "log_scale": True } elif isinstance(sampler, Uniform): return { "name": par, "type": "range", "bounds": [domain.lower, domain.upper], "value_type": "int", "log_scale": False } elif isinstance(domain, Categorical): if isinstance(sampler, Uniform): return { "name": par, "type": "choice", "values": domain.categories } raise ValueError("AxSearch does not support parameters of type " "`{}` with samplers of type `{}`".format( type(domain).__name__, type(domain.sampler).__name__)) # Fixed vars fixed_values = [{ "name": "/".join(path), "type": "fixed", "value": val } for path, val in resolved_vars] # Parameter name is e.g. "a/b/c" for nested dicts resolved_values = [ resolve_value("/".join(path), domain) for path, domain in domain_vars ] return fixed_values + resolved_values
ax = AxClient() N_TRIALS = int(os.environ.get("N_TRIALS", 5)) ax.create_experiment( name="GaussianProcessRegression-%s" % isonow(), parameters=PARAMETERS, objective_name="mean_square_error", minimize=True) for _ in range(N_TRIALS): print(f"[ax-service-loop] Trial {_+1} of {N_TRIALS}") parameters, trial_index = ax.get_next_trial() ax.complete_trial( trial_index=trial_index, raw_data= evaluate(parameters) ) print(parameters) print("") print("[ax-service-loop] Training complete!") best_parameters, metrics = ax.get_best_parameters() print(f"[ax-service-loop] Sending data to db.") print(f"[ax-service-loop] Best parameters found: {best_parameters}") DB_URL = os.environ.get("DB_URL", "mysql://*****:*****@localhost/axdb") from sqlalchemy import create_engine engine = axst.sqa_store.db.create_mysql_engine_from_url(url=DB_URL) conn = engine.connect() axst.sqa_store.db.init_engine_and_session_factory(url=DB_URL) table_names = engine.table_names()
def test_sqa_storage(self): init_test_engine_and_session_factory(force_init=True) config = SQAConfig() encoder = Encoder(config=config) decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) ax_client = AxClient(db_settings=db_settings) ax_client.create_experiment( name="test_experiment", parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) for _ in range(5): parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial(trial_index=trial_index, raw_data=branin(*parameters.values())) gs = ax_client.generation_strategy ax_client = AxClient(db_settings=db_settings) ax_client.load_experiment_from_database("test_experiment") # Trial #4 was completed after the last time the generation strategy # generated candidates, so pre-save generation strategy was not # "aware" of completion of trial #4. Post-restoration generation # strategy is aware of it, however, since it gets restored with most # up-to-date experiment data. Do adding trial #4 to the seen completed # trials of pre-storage GS to check their equality otherwise. gs._seen_trial_indices_by_status[TrialStatus.COMPLETED].add(4) self.assertEqual(gs, ax_client.generation_strategy) with self.assertRaises(ValueError): # Overwriting existing experiment. ax_client.create_experiment( name="test_experiment", parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) with self.assertRaises(ValueError): # Overwriting existing experiment with overwrite flag with present # DB settings. This should fail as we no longer allow overwriting # experiments stored in the DB. ax_client.create_experiment( name="test_experiment", parameters=[{ "name": "x", "type": "range", "bounds": [-5.0, 10.0] }], overwrite_existing_experiment=True, ) # Original experiment should still be in DB and not have been overwritten. self.assertEqual(len(ax_client.experiment.trials), 5)
def test_overwrite(self): init_test_engine_and_session_factory(force_init=True) ax_client = AxClient() ax_client.create_experiment( name="test_experiment", parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) # Log a trial parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial(trial_index=trial_index, raw_data=branin(*parameters.values())) with self.assertRaises(ValueError): # Overwriting existing experiment. ax_client.create_experiment( name="test_experiment", parameters=[ { "name": "x", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "y", "type": "range", "bounds": [0.0, 15.0] }, ], minimize=True, ) # Overwriting existing experiment with overwrite flag. ax_client.create_experiment( name="test_experiment", parameters=[ { "name": "x1", "type": "range", "bounds": [-5.0, 10.0] }, { "name": "x2", "type": "range", "bounds": [0.0, 15.0] }, ], overwrite_existing_experiment=True, ) # There should be no trials, as we just put in a fresh experiment. self.assertEqual(len(ax_client.experiment.trials), 0) # Log a trial parameters, trial_index = ax_client.get_next_trial() self.assertIn("x1", parameters.keys()) self.assertIn("x2", parameters.keys()) ax_client.complete_trial(trial_index=trial_index, raw_data=branin(*parameters.values()))
class AxSearchJob(AutoSearchJob): """Job for hyperparameter search using [ax](https://ax.dev/).""" def __init__(self, config: Config, dataset, parent_job=None): super().__init__(config, dataset, parent_job) self.num_trials = self.config.get("ax_search.num_trials") self.num_sobol_trials = self.config.get("ax_search.num_sobol_trials") self.ax_client: AxClient = None if self.__class__ == AxSearchJob: for f in Job.job_created_hooks: f(self) # Overridden such that instances of search job can be pickled to workers def __getstate__(self): state = super(AxSearchJob, self).__getstate__() del state["ax_client"] return state def _prepare(self): super()._prepare() if self.num_sobol_trials > 0: # BEGIN: from /ax/service/utils/dispatch.py generation_strategy = GenerationStrategy( name="Sobol+GPEI", steps=[ GenerationStep( model=Models.SOBOL, num_trials=self.num_sobol_trials, min_trials_observed=ceil(self.num_sobol_trials / 2), enforce_num_trials=True, model_kwargs={ "seed": self.config.get("ax_search.sobol_seed") }, ), GenerationStep(model=Models.GPEI, num_trials=-1, max_parallelism=3), ], ) # END: from /ax/service/utils/dispatch.py self.ax_client = AxClient(generation_strategy=generation_strategy) choose_generation_strategy_kwargs = dict() else: self.ax_client = AxClient() # set random_seed that will be used by auto created sobol search from ax # note that here the argument is called "random_seed" not "seed" choose_generation_strategy_kwargs = { "random_seed": self.config.get("ax_search.sobol_seed") } self.ax_client.create_experiment( name=self.job_id, parameters=self.config.get("ax_search.parameters"), objective_name="metric_value", minimize=not self.config.get("valid.metric_max"), parameter_constraints=self.config.get( "ax_search.parameter_constraints"), choose_generation_strategy_kwargs=choose_generation_strategy_kwargs, ) self.config.log("ax search initialized with {}".format( self.ax_client.generation_strategy)) # Make sure sobol models are resumed correctly if self.ax_client.generation_strategy._curr.model == Models.SOBOL: self.ax_client.generation_strategy._set_current_model( experiment=self.ax_client.experiment, data=None) # Regenerate and drop SOBOL arms already generated. Since we fixed the seed, # we will skip exactly the arms already generated in the job being resumed. num_generated = len(self.parameters) if num_generated > 0: num_sobol_generated = min( self.ax_client.generation_strategy._curr.num_trials, num_generated) for i in range(num_sobol_generated): generator_run = self.ax_client.generation_strategy.gen( experiment=self.ax_client.experiment) # self.config.log("Skipped parameters: {}".format(generator_run.arms)) self.config.log( "Skipped {} of {} Sobol trials due to prior data.".format( num_sobol_generated, self.ax_client.generation_strategy._curr.num_trials, )) def register_trial(self, parameters=None): trial_id = None try: if parameters is None: parameters, trial_id = self.ax_client.get_next_trial() else: _, trial_id = self.ax_client.attach_trial(parameters) except Exception as e: self.config.log( "Cannot generate trial parameters. Will try again after a " + "running trial has completed. message was: {}".format(e)) return parameters, trial_id def register_trial_result(self, trial_id, parameters, trace_entry): if trace_entry is None: self.ax_client.log_trial_failure(trial_index=trial_id) else: self.ax_client.complete_trial(trial_index=trial_id, raw_data=trace_entry["metric_value"]) def get_best_parameters(self): best_parameters, values = self.ax_client.get_best_parameters() return best_parameters, float(values[0]["metric_value"])