def test_get_best_trials(tmp_dir): def _test_get_best_trials(): hps = kt.HyperParameters() hps.Int("a", 0, 100, default=5) hps.Int("b", 0, 100, default=6) oracle = randomsearch.RandomSearchOracle( objective=kt.Objective("score", direction="max"), max_trials=10, hyperparameters=hps, ) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) trial_scores = {} for score in range(10): trial = client.create_trial(tuner_id) assert trial.status == "RUNNING" assert "a" in trial.hyperparameters.values assert "b" in trial.hyperparameters.values trial_id = trial.trial_id client.update_trial(trial_id, {"score": score}) client.end_trial(trial_id) trial_scores[trial_id] = score return best_trials = client.get_best_trials(3) best_scores = [t.score for t in best_trials] assert best_scores == [9, 8, 7] # Check that trial_ids are correctly mapped to scores. for t in best_trials: assert trial_scores[t.trial_id] == t.score mock_distribute.mock_distribute(_test_get_best_trials, num_workers=1)
def test_base_tuner_distribution(tmp_dir): num_workers = 3 barrier = threading.Barrier(num_workers) def _test_base_tuner(): def build_model(hp): return hp.Int("a", 1, 100) tuner = SimpleTuner( oracle=kt.oracles.RandomSearch(objective=kt.Objective( "score", "max"), max_trials=10), hypermodel=build_model, directory=tmp_dir, ) tuner.search() # Only worker makes it to this point, server runs until thread stops. assert dist_utils.has_chief_oracle() assert not dist_utils.is_chief_oracle() assert isinstance(tuner.oracle, kt.distribute.oracle_client.OracleClient) barrier.wait(60) # Model is just a score. scores = tuner.get_best_models(10) assert len(scores) assert scores == sorted(copy.copy(scores), reverse=True) mock_distribute.mock_distribute(_test_base_tuner, num_workers=num_workers)
def test_update_trial(tmp_dir): def _test_update_trial(): hps = kt.HyperParameters() hps.Int("a", 0, 10, default=5) oracle = randomsearch.RandomSearchOracle( objective=kt.Objective("score", "max"), max_trials=10, hyperparameters=hps, ) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) trial = client.create_trial(tuner_id) assert "score" not in trial.metrics.metrics trial_id = trial.trial_id client.update_trial(trial_id, {"score": 1}, step=2) updated_trial = client.get_trial(trial_id) assert updated_trial.metrics.get_history("score") == [ metrics_tracking.MetricObservation([1], step=2) ] mock_distribute.mock_distribute(_test_update_trial)
def test_exception_raising(): def worker_error_fn(): if "worker" in os.environ["KERASTUNER_TUNER_ID"]: raise ValueError("Found a worker error") with pytest.raises(ValueError, match="Found a worker error"): mock_distribute.mock_distribute(worker_error_fn, num_workers=2) def chief_error_fn(): if "chief" in os.environ["KERASTUNER_TUNER_ID"]: raise ValueError("Found a chief error") with pytest.raises(ValueError, match="Found a chief error"): mock_distribute.mock_distribute(chief_error_fn, num_workers=2)
def test_random_search(tmp_dir): # TensorFlow model building and execution is not thread-safe. num_workers = 1 def _test_random_search(): def build_model(hp): model = keras.Sequential() model.add(keras.layers.Dense(3, input_shape=(5, ))) for i in range(hp.Int("num_layers", 1, 3)): model.add( keras.layers.Dense(hp.Int("num_units_%i" % i, 1, 3), activation="relu")) model.add(keras.layers.Dense(1, activation="sigmoid")) model.compile("sgd", "binary_crossentropy") return model x = np.random.uniform(-1, 1, size=(2, 5)) y = np.ones((2, 1)) tuner = kt.tuners.RandomSearch( hypermodel=build_model, objective="val_loss", max_trials=10, directory=tmp_dir, ) # Only worker makes it to this point, server runs until thread stops. assert dist_utils.has_chief_oracle() assert not dist_utils.is_chief_oracle() assert isinstance(tuner.oracle, kt.distribute.oracle_client.OracleClient) tuner.search(x, y, validation_data=(x, y), epochs=1, batch_size=2) # Suppress warnings about optimizer state not being restored by tf.keras. tf.get_logger().setLevel(logging.ERROR) trials = tuner.oracle.get_best_trials(2) assert trials[0].score <= trials[1].score models = tuner.get_best_models(2) assert models[0].evaluate(x, y) <= models[1].evaluate(x, y) mock_distribute.mock_distribute(_test_random_search, num_workers)
def test_get_space(tmp_dir): def _test_get_space(): hps = kt.HyperParameters() hps.Int("a", 0, 10, default=3) oracle = randomsearch.RandomSearchOracle( objective=kt.Objective("score", "max"), max_trials=10, hyperparameters=hps, ) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) retrieved_hps = client.get_space() assert retrieved_hps.values == {"a": 3} assert len(retrieved_hps.space) == 1 mock_distribute.mock_distribute(_test_get_space)
def test_mock_distribute(tmp_dir): def process_fn(): assert "KERASTUNER_ORACLE_IP" in os.environ # Wait, to test that other threads aren't overriding env vars. time.sleep(1) assert isinstance(os.environ, mock_distribute.MockEnvVars) tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "worker" in tuner_id: # Give the chief process time to write its value, # as we do not join on the chief since it will run # a server. time.sleep(2) fname = os.path.join(str(tmp_dir), tuner_id) with tf.io.gfile.GFile(fname, "w") as f: f.write(tuner_id) mock_distribute.mock_distribute(process_fn, num_workers=3) for tuner_id in {"chief", "worker0", "worker1", "worker2"}: fname = os.path.join(str(tmp_dir), tuner_id) with tf.io.gfile.GFile(fname, "r") as f: assert f.read() == tuner_id
def test_update_space(tmp_dir): def _test_update_space(): oracle = randomsearch.RandomSearchOracle(objective=kt.Objective( "score", "max"), max_trials=10) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) hps = kt.HyperParameters() hps.Int("a", 0, 10, default=5) hps.Choice("b", [1, 2, 3]) client.update_space(hps) retrieved_hps = client.get_space() assert len(retrieved_hps.space) == 2 assert retrieved_hps.values["a"] == 5 assert retrieved_hps.values["b"] == 1 mock_distribute.mock_distribute(_test_update_space)
def test_end_trial(tmp_dir): def _test_end_trial(): hps = kt.HyperParameters() hps.Int("a", 0, 10, default=5) oracle = randomsearch.RandomSearchOracle( objective=kt.Objective("score", "max"), max_trials=10, hyperparameters=hps, ) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) trial = client.create_trial(tuner_id) trial_id = trial.trial_id client.update_trial(trial_id, {"score": 1}, step=2) client.end_trial(trial_id, "INVALID") updated_trial = client.get_trial(trial_id) assert updated_trial.status == "INVALID" mock_distribute.mock_distribute(_test_end_trial)
def test_create_trial(tmp_dir): def _test_create_trial(): hps = kt.HyperParameters() hps.Int("a", 0, 10, default=5) hps.Choice("b", [1, 2, 3]) oracle = randomsearch.RandomSearchOracle( objective=kt.Objective("score", "max"), max_trials=10, hyperparameters=hps, ) oracle._set_project_dir(tmp_dir, "untitled") tuner_id = os.environ["KERASTUNER_TUNER_ID"] if "chief" in tuner_id: oracle_chief.start_server(oracle) else: client = oracle_client.OracleClient(oracle) trial = client.create_trial(tuner_id) assert trial.status == "RUNNING" a = trial.hyperparameters.get("a") assert a >= 0 and a <= 10 b = trial.hyperparameters.get("b") assert b in {1, 2, 3} mock_distribute.mock_distribute(_test_create_trial)