Esempio n. 1
0
def run_warm_start_test(implementation: NativeImplementation) -> None:
    experiment_id1 = create_experiment(implementation)
    experiment.wait_for_experiment_state(
        experiment_id1, "COMPLETED", max_wait_secs=conf.DEFAULT_MAX_WAIT_SECS
    )
    assert experiment.num_active_trials(experiment_id1) == 0

    trials = experiment.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0]
    first_trial_id = first_trial["id"]
    assert len(first_trial["steps"]) == implementation.num_expected_steps_per_trial
    first_checkpoint_id = first_trial["steps"][0]["checkpoint"]["id"]

    # Add a source trial ID to warm start from.
    second_exp = NativeImplementation(**implementation._asdict())
    second_exp.configuration["searcher"]["source_trial_id"] = first_trial_id

    experiment_id2 = create_experiment(second_exp)
    experiment.wait_for_experiment_state(
        experiment_id2, "COMPLETED", max_wait_secs=conf.DEFAULT_MAX_WAIT_SECS
    )
    assert experiment.num_active_trials(experiment_id2) == 0

    # The new trials should have a warm start checkpoint ID.
    trials = experiment.experiment_trials(experiment_id2)
    assert len(trials) == 1
    for trial in trials:
        assert trial["warm_start_checkpoint_id"] == first_checkpoint_id
Esempio n. 2
0
def test_tutorial() -> None:
    exp_id1 = create_native_experiment(conf.tutorials_path("native-tf-keras"),
                                       ["python", "tf_keras_native.py"])
    experiment.wait_for_experiment_state(
        exp_id1, "COMPLETED", max_wait_secs=conf.DEFAULT_MAX_WAIT_SECS)
    exp_id2 = create_native_experiment(
        conf.tutorials_path("native-tf-keras"),
        ["python", "tf_keras_native_hparam_search.py"])
    experiment.wait_for_experiment_state(
        exp_id2, "COMPLETED", max_wait_secs=conf.DEFAULT_MAX_WAIT_SECS)
Esempio n. 3
0
def test_tutorial_dtrain() -> None:
    exp_id = create_native_experiment(conf.tutorials_path("native-tf-keras"),
                                      ["python", "tf_keras_native_dtrain.py"])
    experiment.wait_for_experiment_state(
        exp_id, "COMPLETED", max_wait_secs=conf.DEFAULT_MAX_WAIT_SECS)