def _test_get_best_trials(): hps = kt.HyperParameters() hps.Int("a", 0, 100, default=5) hps.Int("b", 0, 100, default=6) oracle = randomsearch.RandomSearchOracle( objective=kt.Objective("score", direction="max"), max_trials=10, hyperparameters=hps, ) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) trial_scores = {} for score in range(10): trial = client.create_trial(tuner_id) assert trial.status == "RUNNING" assert "a" in trial.hyperparameters.values assert "b" in trial.hyperparameters.values trial_id = trial.trial_id client.update_trial(trial_id, {"score": score}) client.end_trial(trial_id) trial_scores[trial_id] = score return best_trials = client.get_best_trials(3) best_scores = [t.score for t in best_trials] assert best_scores == [9, 8, 7] # Check that trial_ids are correctly mapped to scores. for t in best_trials: assert trial_scores[t.trial_id] == t.score
def __init__( self, oracle, hypermodel=None, directory=None, project_name=None, logger=None, overwrite=False, ): # Ops and metadata self.directory = directory or "." self.project_name = project_name or "untitled_project" if overwrite and tf.io.gfile.exists(self.project_dir): tf.io.gfile.rmtree(self.project_dir) if not isinstance(oracle, oracle_module.Oracle): raise ValueError( "Expected `oracle` argument to be an instance of `Oracle`. " f"Received: oracle={oracle} (of type ({type(oracle)}).") self.oracle = oracle self.oracle._set_project_dir(self.directory, self.project_name, overwrite=overwrite) # Run in distributed mode. if dist_utils.is_chief_oracle(): # Blocks forever. oracle_chief.start_server(self.oracle) elif dist_utils.has_chief_oracle(): # Proxies requests to the chief oracle. self.oracle = oracle_client.OracleClient(self.oracle) # To support tuning distribution. self.tuner_id = os.environ.get("KERASTUNER_TUNER_ID", "tuner0") self.hypermodel = hm_module.get_hypermodel(hypermodel) # Logs etc self.logger = logger self._display = tuner_utils.Display(oracle=self.oracle) self._populate_initial_space() if not overwrite and tf.io.gfile.exists(self._get_tuner_fname()): tf.get_logger().info("Reloading Tuner from {}".format( self._get_tuner_fname())) self.reload()
def _test_get_space(): hps = kt.HyperParameters() hps.Int("a", 0, 10, default=3) oracle = randomsearch.RandomSearchOracle( objective=kt.Objective("score", "max"), max_trials=10, hyperparameters=hps, ) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) retrieved_hps = client.get_space() assert retrieved_hps.values == {"a": 3} assert len(retrieved_hps.space) == 1
def _test_update_space(): oracle = randomsearch.RandomSearchOracle(objective=kt.Objective( "score", "max"), max_trials=10) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) hps = kt.HyperParameters() hps.Int("a", 0, 10, default=5) hps.Choice("b", [1, 2, 3]) client.update_space(hps) retrieved_hps = client.get_space() assert len(retrieved_hps.space) == 2 assert retrieved_hps.values["a"] == 5 assert retrieved_hps.values["b"] == 1
def _test_end_trial(): hps = kt.HyperParameters() hps.Int("a", 0, 10, default=5) oracle = randomsearch.RandomSearchOracle( objective=kt.Objective("score", "max"), max_trials=10, hyperparameters=hps, ) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) trial = client.create_trial(tuner_id) trial_id = trial.trial_id client.update_trial(trial_id, {"score": 1}, step=2) client.end_trial(trial_id, "INVALID") updated_trial = client.get_trial(trial_id) assert updated_trial.status == "INVALID"
def _test_create_trial(): hps = kt.HyperParameters() hps.Int("a", 0, 10, default=5) hps.Choice("b", [1, 2, 3]) oracle = randomsearch.RandomSearchOracle( objective=kt.Objective("score", "max"), max_trials=10, hyperparameters=hps, ) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) trial = client.create_trial(tuner_id) assert trial.status == "RUNNING" a = trial.hyperparameters.get("a") assert a >= 0 and a <= 10 b = trial.hyperparameters.get("b") assert b in {1, 2, 3}
def _test_update_trial(): hps = kt.HyperParameters() hps.Int("a", 0, 10, default=5) oracle = randomsearch.RandomSearchOracle( objective=kt.Objective("score", "max"), max_trials=10, hyperparameters=hps, ) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) trial = client.create_trial(tuner_id) assert "score" not in trial.metrics.metrics trial_id = trial.trial_id client.update_trial(trial_id, {"score": 1}, step=2) updated_trial = client.get_trial(trial_id) assert updated_trial.metrics.get_history("score") == [ metrics_tracking.MetricObservation([1], step=2) ]