コード例 #1
0
def test_get_best_trials(tmp_dir):
    def _test_get_best_trials():
        hps = kt.HyperParameters()
        hps.Int("a", 0, 100, default=5)
        hps.Int("b", 0, 100, default=6)
        oracle = randomsearch.RandomSearchOracle(
            objective=kt.Objective("score", direction="max"),
            max_trials=10,
            hyperparameters=hps,
        )
        oracle._set_project_dir(tmp_dir, "untitled")
        tuner_id = os.environ["KERASTUNER_TUNER_ID"]
        if "chief" in tuner_id:
            oracle_chief.start_server(oracle)
        else:
            client = oracle_client.OracleClient(oracle)
            trial_scores = {}
            for score in range(10):
                trial = client.create_trial(tuner_id)
                assert trial.status == "RUNNING"
                assert "a" in trial.hyperparameters.values
                assert "b" in trial.hyperparameters.values
                trial_id = trial.trial_id
                client.update_trial(trial_id, {"score": score})
                client.end_trial(trial_id)
                trial_scores[trial_id] = score
            return
            best_trials = client.get_best_trials(3)
            best_scores = [t.score for t in best_trials]
            assert best_scores == [9, 8, 7]
            # Check that trial_ids are correctly mapped to scores.
            for t in best_trials:
                assert trial_scores[t.trial_id] == t.score

    mock_distribute.mock_distribute(_test_get_best_trials, num_workers=1)
コード例 #2
0
def test_base_tuner_distribution(tmp_dir):
    num_workers = 3
    barrier = threading.Barrier(num_workers)

    def _test_base_tuner():
        def build_model(hp):
            return hp.Int("a", 1, 100)

        tuner = SimpleTuner(
            oracle=kt.oracles.RandomSearch(objective=kt.Objective(
                "score", "max"),
                                           max_trials=10),
            hypermodel=build_model,
            directory=tmp_dir,
        )
        tuner.search()

        # Only worker makes it to this point, server runs until thread stops.
        assert dist_utils.has_chief_oracle()
        assert not dist_utils.is_chief_oracle()
        assert isinstance(tuner.oracle,
                          kt.distribute.oracle_client.OracleClient)

        barrier.wait(60)

        # Model is just a score.
        scores = tuner.get_best_models(10)
        assert len(scores)
        assert scores == sorted(copy.copy(scores), reverse=True)

    mock_distribute.mock_distribute(_test_base_tuner, num_workers=num_workers)
コード例 #3
0
def test_update_trial(tmp_dir):
    def _test_update_trial():
        hps = kt.HyperParameters()
        hps.Int("a", 0, 10, default=5)
        oracle = randomsearch.RandomSearchOracle(
            objective=kt.Objective("score", "max"),
            max_trials=10,
            hyperparameters=hps,
        )
        oracle._set_project_dir(tmp_dir, "untitled")
        tuner_id = os.environ["KERASTUNER_TUNER_ID"]
        if "chief" in tuner_id:
            oracle_chief.start_server(oracle)
        else:
            client = oracle_client.OracleClient(oracle)
            trial = client.create_trial(tuner_id)
            assert "score" not in trial.metrics.metrics
            trial_id = trial.trial_id
            client.update_trial(trial_id, {"score": 1}, step=2)
            updated_trial = client.get_trial(trial_id)
            assert updated_trial.metrics.get_history("score") == [
                metrics_tracking.MetricObservation([1], step=2)
            ]

    mock_distribute.mock_distribute(_test_update_trial)
コード例 #4
0
def test_exception_raising():
    def worker_error_fn():
        if "worker" in os.environ["KERASTUNER_TUNER_ID"]:
            raise ValueError("Found a worker error")

    with pytest.raises(ValueError, match="Found a worker error"):
        mock_distribute.mock_distribute(worker_error_fn, num_workers=2)

    def chief_error_fn():
        if "chief" in os.environ["KERASTUNER_TUNER_ID"]:
            raise ValueError("Found a chief error")

    with pytest.raises(ValueError, match="Found a chief error"):
        mock_distribute.mock_distribute(chief_error_fn, num_workers=2)
コード例 #5
0
def test_random_search(tmp_dir):
    # TensorFlow model building and execution is not thread-safe.
    num_workers = 1

    def _test_random_search():
        def build_model(hp):
            model = keras.Sequential()
            model.add(keras.layers.Dense(3, input_shape=(5, )))
            for i in range(hp.Int("num_layers", 1, 3)):
                model.add(
                    keras.layers.Dense(hp.Int("num_units_%i" % i, 1, 3),
                                       activation="relu"))
            model.add(keras.layers.Dense(1, activation="sigmoid"))
            model.compile("sgd", "binary_crossentropy")
            return model

        x = np.random.uniform(-1, 1, size=(2, 5))
        y = np.ones((2, 1))

        tuner = kt.tuners.RandomSearch(
            hypermodel=build_model,
            objective="val_loss",
            max_trials=10,
            directory=tmp_dir,
        )

        # Only worker makes it to this point, server runs until thread stops.
        assert dist_utils.has_chief_oracle()
        assert not dist_utils.is_chief_oracle()
        assert isinstance(tuner.oracle,
                          kt.distribute.oracle_client.OracleClient)

        tuner.search(x, y, validation_data=(x, y), epochs=1, batch_size=2)

        # Suppress warnings about optimizer state not being restored by tf.keras.
        tf.get_logger().setLevel(logging.ERROR)

        trials = tuner.oracle.get_best_trials(2)
        assert trials[0].score <= trials[1].score

        models = tuner.get_best_models(2)
        assert models[0].evaluate(x, y) <= models[1].evaluate(x, y)

    mock_distribute.mock_distribute(_test_random_search, num_workers)
コード例 #6
0
def test_get_space(tmp_dir):
    def _test_get_space():
        hps = kt.HyperParameters()
        hps.Int("a", 0, 10, default=3)
        oracle = randomsearch.RandomSearchOracle(
            objective=kt.Objective("score", "max"),
            max_trials=10,
            hyperparameters=hps,
        )
        oracle._set_project_dir(tmp_dir, "untitled")
        tuner_id = os.environ["KERASTUNER_TUNER_ID"]
        if "chief" in tuner_id:
            oracle_chief.start_server(oracle)
        else:
            client = oracle_client.OracleClient(oracle)
            retrieved_hps = client.get_space()
            assert retrieved_hps.values == {"a": 3}
            assert len(retrieved_hps.space) == 1

    mock_distribute.mock_distribute(_test_get_space)
コード例 #7
0
def test_mock_distribute(tmp_dir):
    def process_fn():
        assert "KERASTUNER_ORACLE_IP" in os.environ
        # Wait, to test that other threads aren't overriding env vars.
        time.sleep(1)
        assert isinstance(os.environ, mock_distribute.MockEnvVars)
        tuner_id = os.environ["KERASTUNER_TUNER_ID"]
        if "worker" in tuner_id:
            # Give the chief process time to write its value,
            # as we do not join on the chief since it will run
            # a server.
            time.sleep(2)
        fname = os.path.join(str(tmp_dir), tuner_id)
        with tf.io.gfile.GFile(fname, "w") as f:
            f.write(tuner_id)

    mock_distribute.mock_distribute(process_fn, num_workers=3)

    for tuner_id in {"chief", "worker0", "worker1", "worker2"}:
        fname = os.path.join(str(tmp_dir), tuner_id)
        with tf.io.gfile.GFile(fname, "r") as f:
            assert f.read() == tuner_id
コード例 #8
0
def test_update_space(tmp_dir):
    def _test_update_space():
        oracle = randomsearch.RandomSearchOracle(objective=kt.Objective(
            "score", "max"),
                                                 max_trials=10)
        oracle._set_project_dir(tmp_dir, "untitled")
        tuner_id = os.environ["KERASTUNER_TUNER_ID"]
        if "chief" in tuner_id:
            oracle_chief.start_server(oracle)
        else:
            client = oracle_client.OracleClient(oracle)

            hps = kt.HyperParameters()
            hps.Int("a", 0, 10, default=5)
            hps.Choice("b", [1, 2, 3])
            client.update_space(hps)

            retrieved_hps = client.get_space()
            assert len(retrieved_hps.space) == 2
            assert retrieved_hps.values["a"] == 5
            assert retrieved_hps.values["b"] == 1

    mock_distribute.mock_distribute(_test_update_space)
コード例 #9
0
def test_end_trial(tmp_dir):
    def _test_end_trial():
        hps = kt.HyperParameters()
        hps.Int("a", 0, 10, default=5)
        oracle = randomsearch.RandomSearchOracle(
            objective=kt.Objective("score", "max"),
            max_trials=10,
            hyperparameters=hps,
        )
        oracle._set_project_dir(tmp_dir, "untitled")
        tuner_id = os.environ["KERASTUNER_TUNER_ID"]
        if "chief" in tuner_id:
            oracle_chief.start_server(oracle)
        else:
            client = oracle_client.OracleClient(oracle)
            trial = client.create_trial(tuner_id)
            trial_id = trial.trial_id
            client.update_trial(trial_id, {"score": 1}, step=2)
            client.end_trial(trial_id, "INVALID")
            updated_trial = client.get_trial(trial_id)
            assert updated_trial.status == "INVALID"

    mock_distribute.mock_distribute(_test_end_trial)
コード例 #10
0
def test_create_trial(tmp_dir):
    def _test_create_trial():
        hps = kt.HyperParameters()
        hps.Int("a", 0, 10, default=5)
        hps.Choice("b", [1, 2, 3])
        oracle = randomsearch.RandomSearchOracle(
            objective=kt.Objective("score", "max"),
            max_trials=10,
            hyperparameters=hps,
        )
        oracle._set_project_dir(tmp_dir, "untitled")
        tuner_id = os.environ["KERASTUNER_TUNER_ID"]
        if "chief" in tuner_id:
            oracle_chief.start_server(oracle)
        else:
            client = oracle_client.OracleClient(oracle)
            trial = client.create_trial(tuner_id)
            assert trial.status == "RUNNING"
            a = trial.hyperparameters.get("a")
            assert a >= 0 and a <= 10
            b = trial.hyperparameters.get("b")
            assert b in {1, 2, 3}

    mock_distribute.mock_distribute(_test_create_trial)