def test_pytorch_parallel() -> None:
    config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_tensor_auto_tuning(config, True)
    config = conf.set_perform_initial_validation(config, True)

    exp_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("mnist_pytorch"), 1)
    exp.assert_performed_initial_validation(exp_id)

    # Check on record/batch counts we emitted in logs.
    validation_size = 10000
    global_batch_size = config["hyperparameters"]["global_batch_size"]
    num_workers = config.get("resources", {}).get("slots_per_trial", 1)
    global_batch_size = config["hyperparameters"]["global_batch_size"]
    scheduling_unit = config.get("scheduling_unit", 100)
    per_slot_batch_size = global_batch_size // num_workers
    exp_val_batches = (validation_size +
                       (per_slot_batch_size - 1)) // per_slot_batch_size
    patterns = [
        # Expect two copies of matching training reports.
        f"trained: {scheduling_unit * global_batch_size} records.*in {scheduling_unit} batches",
        f"trained: {scheduling_unit * global_batch_size} records.*in {scheduling_unit} batches",
        f"validated: {validation_size} records.*in {exp_val_batches} batches",
    ]
    trial_id = exp.experiment_trials(exp_id)[0].trial.id
    exp.assert_patterns_in_trial_logs(trial_id, patterns)
예제 #2
0
def test_job_queue_adjust_weight() -> None:
    config = conf.tutorials_path("mnist_pytorch/const.yaml")
    model = conf.tutorials_path("mnist_pytorch")
    for _ in range(2):
        exp.create_experiment(config, model)

    jobs = JobInfo()
    ok = jobs.refresh_until_populated()
    assert ok

    ordered_ids = jobs.get_ids()
    subprocess.run(["det", "job", "update", ordered_ids[0], "--weight", "10"])

    sleep(2)
    jobs.refresh()
    new_weight = jobs.get_job_weight(ordered_ids[0])
    assert new_weight == "10"

    subprocess.run(
        ["det", "job", "update-batch", f"{ordered_ids[1]}.weight=10"])

    sleep(2)
    jobs.refresh()
    new_weight = jobs.get_job_weight(ordered_ids[1])
    assert new_weight == "10"
예제 #3
0
def test_fashion_mnist_tf_keras_distributed() -> None:
    config = conf.load_config(
        conf.tutorials_path("fashion_mnist_tf_keras/distributed.yaml"))
    config = conf.set_max_length(config, {"batches": 200})

    exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("fashion_mnist_tf_keras"), 1)
예제 #4
0
def test_streaming_observability_metrics_apis(
    framework_base_experiment: str, framework_timings_enabled: bool
) -> None:
    # TODO: refactor tests to not use cli singleton auth.
    certs.cli_cert = certs.default_load(conf.make_master_url())
    authentication.cli_auth = authentication.Authentication(conf.make_master_url(), try_reauth=True)

    config_path = conf.tutorials_path(f"../{framework_base_experiment}/const.yaml")
    model_def_path = conf.tutorials_path(f"../{framework_base_experiment}")

    config_obj = conf.load_config(config_path)
    config_obj = conf.set_profiling_enabled(config_obj)
    with tempfile.NamedTemporaryFile() as tf:
        with open(tf.name, "w") as f:
            yaml.dump(config_obj, f)
        experiment_id = exp.create_experiment(
            tf.name,
            model_def_path,
        )

    exp.wait_for_experiment_state(experiment_id, "COMPLETED")
    trials = exp.experiment_trials(experiment_id)
    trial_id = trials[0]["id"]

    gpu_enabled = conf.GPU_ENABLED

    request_profiling_metric_labels(trial_id, framework_timings_enabled, gpu_enabled)
    if gpu_enabled:
        request_profiling_system_metrics(trial_id, "gpu_util")
    if framework_timings_enabled:
        request_profiling_pytorch_timing_metrics(trial_id, "train_batch")
예제 #5
0
def test_mnist_pytorch_distributed() -> None:
    config = conf.load_config(
        conf.tutorials_path("mnist_pytorch/distributed.yaml"))
    config = conf.set_max_length(config, {"batches": 200})

    exp.run_basic_test_with_temp_config(config,
                                        conf.tutorials_path("mnist_pytorch"),
                                        1)
예제 #6
0
def test_pytorch_const_with_amp() -> None:
    config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_amp_level(config, "O1")

    exp.run_basic_test_with_temp_config(config,
                                        conf.tutorials_path("mnist_pytorch"),
                                        1)
예제 #7
0
def test_pytorch_const_native_parallel() -> None:
    config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, True)
    config = conf.set_max_length(config, {"batches": 200})

    exp.run_basic_test_with_temp_config(config,
                                        conf.tutorials_path("mnist_pytorch"),
                                        1)
예제 #8
0
def test_core_api_tutorials(stage: str, ntrials: int, expect_workloads: bool,
                            expect_checkpoints: bool) -> None:
    exp.run_basic_test(
        conf.tutorials_path(f"core_api/{stage}.yaml"),
        conf.tutorials_path("core_api"),
        ntrials,
        expect_workloads=expect_workloads,
        expect_checkpoints=expect_checkpoints,
    )
예제 #9
0
def test_tutorial() -> None:
    exp_id1 = create_native_experiment(conf.tutorials_path("native-tf-keras"),
                                       ["python", "tf_keras_native.py"])
    experiment.wait_for_experiment_state(
        exp_id1, "COMPLETED", max_wait_secs=conf.DEFAULT_MAX_WAIT_SECS)
    exp_id2 = create_native_experiment(
        conf.tutorials_path("native-tf-keras"),
        ["python", "tf_keras_native_hparam_search.py"])
    experiment.wait_for_experiment_state(
        exp_id2, "COMPLETED", max_wait_secs=conf.DEFAULT_MAX_WAIT_SECS)
예제 #10
0
def test_tf_keras_mnist_parallel() -> None:
    config = conf.load_config(
        conf.tutorials_path("fashion_mnist_tf_keras/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_max_length(config, {"batches": 200})

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("fashion_mnist_tf_keras"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
예제 #11
0
def test_pytorch_parallel() -> None:
    config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_tensor_auto_tuning(config, True)
    config = conf.set_perform_initial_validation(config, True)

    exp_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("mnist_pytorch"), 1, has_zeroth_step=True)
    exp.assert_performed_initial_validation(exp_id)
예제 #12
0
def test_s3_no_creds(secrets: Dict[str, str]) -> None:
    pytest.skip("Temporarily skipping this until we find a more secure way of testing this.")
    config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml"))
    config["checkpoint_storage"] = exp.s3_checkpoint_config_no_creds()
    config.setdefault("environment", {})
    config["environment"].setdefault("environment_variables", [])
    config["environment"]["environment_variables"] += [
        f"AWS_ACCESS_KEY_ID={secrets['INTEGRATIONS_S3_ACCESS_KEY']}",
        f"AWS_SECRET_ACCESS_KEY={secrets['INTEGRATIONS_S3_SECRET_KEY']}",
    ]
    exp.run_basic_test_with_temp_config(config, conf.tutorials_path("mnist_pytorch"), 1)
예제 #13
0
def test_pytorch_const_parallel(aggregation_frequency: int, use_amp: bool) -> None:
    if use_amp and aggregation_frequency > 1:
        pytest.skip("Mixed precision is not support with aggregation frequency > 1.")

    config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_aggregation_frequency(config, aggregation_frequency)
    if use_amp:
        config = conf.set_amp_level(config, "O1")

    exp.run_basic_test_with_temp_config(config, conf.tutorials_path("mnist_pytorch"), 1)
예제 #14
0
def test_cancel_one_active_experiment_ready() -> None:
    experiment_id = exp.create_experiment(
        conf.tutorials_path("mnist_pytorch/const.yaml"),
        conf.tutorials_path("mnist_pytorch"),
    )

    while 1:
        if exp.experiment_has_completed_workload(experiment_id):
            break
        time.sleep(1)

    exp.cancel_single(experiment_id, should_have_trial=True)
    exp.assert_performed_final_checkpoint(experiment_id)
예제 #15
0
def test_tf_keras_mnist_parallel(
        collect_trial_profiles: Callable[[int], None]) -> None:
    config = conf.load_config(
        conf.tutorials_path("fashion_mnist_tf_keras/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_profiling_enabled(config)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("fashion_mnist_tf_keras"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
    collect_trial_profiles(trials[0].trial.id)
예제 #16
0
def test_pytorch_11_const(aggregation_frequency: int, using_k8s: bool) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"))
    config = conf.set_aggregation_frequency(config, aggregation_frequency)

    if using_k8s:
        pod_spec = {
            "metadata": {
                "labels": {
                    "ci": "testing"
                }
            },
            "spec": {
                "containers": [{
                    "name":
                    "determined-container",
                    "volumeMounts": [{
                        "name": "temp1",
                        "mountPath": "/random"
                    }],
                }],
                "volumes": [{
                    "name": "temp1",
                    "emptyDir": {}
                }],
            },
        }
        config = conf.set_pod_spec(config, pod_spec)

    exp.run_basic_test_with_temp_config(config,
                                        conf.tutorials_path("mnist_pytorch"),
                                        1)
예제 #17
0
def test_create_test_mode() -> None:
    # test-mode should succeed with a valid experiment.
    command = [
        "det",
        "-m",
        conf.make_master_url(),
        "experiment",
        "create",
        "--test-mode",
        conf.fixtures_path("mnist_pytorch/adaptive_short.yaml"),
        conf.tutorials_path("mnist_pytorch"),
    ]
    output = subprocess.check_output(command, universal_newlines=True)
    assert "Model definition test succeeded" in output

    # test-mode should fail when an error is introduced into the trial
    # implementation.
    command = [
        "det",
        "-m",
        conf.make_master_url(),
        "experiment",
        "create",
        "--test-mode",
        conf.fixtures_path("trial_error/const.yaml"),
        conf.fixtures_path("trial_error"),
    ]
    with pytest.raises(subprocess.CalledProcessError):
        subprocess.check_call(command)
예제 #18
0
def test_streaming_metrics_api() -> None:
    auth.initialize_session(conf.make_master_url(), try_reauth=True)

    pool = mp.pool.ThreadPool(processes=7)

    experiment_id = exp.create_experiment(
        conf.fixtures_path("mnist_pytorch/adaptive_short.yaml"),
        conf.tutorials_path("mnist_pytorch"),
    )
    # To fully test the streaming APIs, the requests need to start running immediately after the
    # experiment, and then stay open until the experiment is complete. To accomplish this with all
    # of the API calls on a single experiment, we spawn them all in threads.

    # The HP importance portion of this test is commented out until the feature is enabled by
    # default

    metric_names_thread = pool.apply_async(request_metric_names,
                                           (experiment_id, ))
    train_metric_batches_thread = pool.apply_async(
        request_train_metric_batches, (experiment_id, ))
    valid_metric_batches_thread = pool.apply_async(
        request_valid_metric_batches, (experiment_id, ))
    train_trials_snapshot_thread = pool.apply_async(
        request_train_trials_snapshot, (experiment_id, ))
    valid_trials_snapshot_thread = pool.apply_async(
        request_valid_trials_snapshot, (experiment_id, ))
    train_trials_sample_thread = pool.apply_async(request_train_trials_sample,
                                                  (experiment_id, ))
    valid_trials_sample_thread = pool.apply_async(request_valid_trials_sample,
                                                  (experiment_id, ))

    metric_names_results = metric_names_thread.get()
    train_metric_batches_results = train_metric_batches_thread.get()
    valid_metric_batches_results = valid_metric_batches_thread.get()
    train_trials_snapshot_results = train_trials_snapshot_thread.get()
    valid_trials_snapshot_results = valid_trials_snapshot_thread.get()
    train_trials_sample_results = train_trials_sample_thread.get()
    valid_trials_sample_results = valid_trials_sample_thread.get()

    if metric_names_results is not None:
        pytest.fail("metric-names: %s. Results: %s" % metric_names_results)
    if train_metric_batches_results is not None:
        pytest.fail("metric-batches (training): %s. Results: %s" %
                    train_metric_batches_results)
    if valid_metric_batches_results is not None:
        pytest.fail("metric-batches (validation): %s. Results: %s" %
                    valid_metric_batches_results)
    if train_trials_snapshot_results is not None:
        pytest.fail("trials-snapshot (training): %s. Results: %s" %
                    train_trials_snapshot_results)
    if valid_trials_snapshot_results is not None:
        pytest.fail("trials-snapshot (validation): %s. Results: %s" %
                    valid_trials_snapshot_results)
    if train_trials_sample_results is not None:
        pytest.fail("trials-sample (training): %s. Results: %s" %
                    train_trials_sample_results)
    if valid_trials_sample_results is not None:
        pytest.fail("trials-sample (validation): %s. Results: %s" %
                    valid_trials_sample_results)
예제 #19
0
def test_pytorch_load() -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"))

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("mnist_pytorch"), 1)

    (Determined(conf.make_master_url()).get_experiment(
        experiment_id).top_checkpoint().load(map_location="cpu"))
예제 #20
0
def test_mnist_pytorch_accuracy() -> None:
    config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml"))
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("mnist_pytorch"), 1)

    trials = exp.experiment_trials(experiment_id)
    trial_metrics = exp.trial_metrics(trials[0]["id"])

    validation_errors = [
        step["validation"]["metrics"]["validation_metrics"]["accuracy"]
        for step in trial_metrics["steps"] if step.get("validation")
    ]

    target_accuracy = 0.97
    assert max(validation_errors) > target_accuracy, (
        "mnist_pytorch did not reach minimum target accuracy {} in {} steps."
        " full validation error history: {}".format(
            target_accuracy, len(trial_metrics["steps"]), validation_errors))
def test_imagenet_pytorch() -> None:
    config = conf.load_config(
        conf.tutorials_path("imagenet_pytorch/const_cifar.yaml"))
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("imagenet_pytorch"), 1)

    trials = exp.experiment_trials(experiment_id)
    trial_metrics = exp.trial_metrics(trials[0].trial.id)

    validation_loss = [
        step["validation"]["metrics"]["validation_metrics"]["val_loss"]
        for step in trial_metrics["steps"] if step.get("validation")
    ]

    target_loss = 1.55
    assert max(validation_loss) < target_loss, (
        "imagenet_pytorch did not reach minimum target loss {} in {} steps."
        " full validation accuracy history: {}".format(
            target_loss, len(trial_metrics["steps"]), validation_loss))
예제 #22
0
def test_fashion_mnist_tf_keras() -> None:
    config = conf.load_config(
        conf.tutorials_path("fashion_mnist_tf_keras/const.yaml"))
    config = conf.set_random_seed(config, 1591110586)
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("fashion_mnist_tf_keras"), 1)

    trials = exp.experiment_trials(experiment_id)
    trial_metrics = exp.trial_metrics(trials[0]["id"])

    validation_errors = [
        step["validation"]["metrics"]["validation_metrics"]["val_accuracy"]
        for step in trial_metrics["steps"] if step.get("validation")
    ]

    accuracy = 0.85
    assert max(validation_errors) > accuracy, (
        "fashion_mnist_tf_keras did not reach minimum target accuracy {} in {} steps."
        " full validation error history: {}".format(
            accuracy, len(trial_metrics["steps"]), validation_errors))
예제 #23
0
def test_gang_scheduling() -> None:
    total_slots = os.getenv("TOTAL_SLOTS")
    if total_slots is None:
        pytest.skip("test requires a static cluster and TOTAL_SLOTS set in the environment")

    config = conf.load_config(conf.tutorials_path("mnist_pytorch/distributed.yaml"))
    config = conf.set_slots_per_trial(config, int(total_slots))
    model = conf.tutorials_path("mnist_pytorch")

    def submit_job() -> None:
        ret_value = exp.run_basic_test_with_temp_config(config, model, 1)
        print(ret_value)

    t = []
    for _i in range(2):
        t.append(threading.Thread(target=submit_job))
    for i in range(2):
        t[i].start()
    for i in range(2):
        t[i].join()
def test_pytorch_load(collect_trial_profiles: Callable[[int], None]) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"))
    config = conf.set_profiling_enabled(config)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("mnist_pytorch"), 1)

    (Determined(conf.make_master_url()).get_experiment(
        experiment_id).top_checkpoint().load(map_location="cpu"))
    trial_id = exp.experiment_trials(experiment_id)[0].trial.id
    collect_trial_profiles(trial_id)
def test_pytorch_const_warm_start() -> None:
    """
    Test that specifying an earlier trial checkpoint to warm-start from
    correctly populates the later trials' `warm_start_checkpoint_id` fields.
    """
    config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})

    experiment_id1 = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("mnist_pytorch"), 1)

    trials = exp.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0]
    first_trial_id = first_trial.trial.id

    assert len(first_trial.workloads) == 4
    checkpoints = exp.workloads_with_checkpoint(first_trial.workloads)
    first_checkpoint = checkpoints[-1]
    first_checkpoint_uuid = first_checkpoint.uuid

    config_obj = conf.load_config(
        conf.tutorials_path("mnist_pytorch/const.yaml"))

    # Change the search method to random, and add a source trial ID to warm
    # start from.
    config_obj["searcher"]["source_trial_id"] = first_trial_id
    config_obj["searcher"]["name"] = "random"
    config_obj["searcher"]["max_length"] = {"batches": 100}
    config_obj["searcher"]["max_trials"] = 3

    experiment_id2 = exp.run_basic_test_with_temp_config(
        config_obj, conf.tutorials_path("mnist_pytorch"), 3)

    trials = exp.experiment_trials(experiment_id2)
    assert len(trials) == 3
    for t in trials:
        assert t.trial.warmStartCheckpointUuid == first_checkpoint_uuid
예제 #26
0
def test_model_registry() -> None:
    exp_id = exp.run_basic_test(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"),
        conf.tutorials_path("mnist_pytorch"),
        None,
    )

    d = Determined(conf.make_master_url())

    mnist = d.create_model("mnist", "simple computer vision model")
    assert mnist.metadata == {}

    mnist.add_metadata({"testing": "metadata"})
    db_model = d.get_model("mnist")
    # Make sure the model metadata is correct and correctly saved to the db.
    assert mnist.metadata == db_model.metadata
    assert mnist.metadata == {"testing": "metadata"}

    mnist.add_metadata({"some_key": "some_value"})
    db_model = d.get_model("mnist")
    assert mnist.metadata == db_model.metadata
    assert mnist.metadata == {"testing": "metadata", "some_key": "some_value"}

    mnist.add_metadata({"testing": "override"})
    db_model = d.get_model("mnist")
    assert mnist.metadata == db_model.metadata
    assert mnist.metadata == {"testing": "override", "some_key": "some_value"}

    mnist.remove_metadata(["some_key"])
    db_model = d.get_model("mnist")
    assert mnist.metadata == db_model.metadata
    assert mnist.metadata == {"testing": "override"}

    checkpoint = d.get_experiment(exp_id).top_checkpoint()
    model_version = mnist.register_version(checkpoint.uuid)

    assert model_version.model_version == 1

    latest_version = mnist.get_version()
    assert latest_version is not None
    assert latest_version.uuid == checkpoint.uuid

    d.create_model("transformer", "all you need is attention")
    d.create_model("object-detection", "a bounding box model")

    models = d.get_models(sort_by=ModelSortBy.NAME)
    assert [m.name
            for m in models] == ["mnist", "object-detection", "transformer"]
예제 #27
0
def test_hp_importance_api() -> None:
    auth.initialize_session(conf.make_master_url(), try_reauth=True)

    pool = mp.pool.ThreadPool(processes=1)

    experiment_id = exp.create_experiment(
        conf.fixtures_path("mnist_pytorch/random.yaml"),
        conf.tutorials_path("mnist_pytorch"),
    )

    hp_importance_thread = pool.apply_async(request_hp_importance, (experiment_id,))

    hp_importance_results = hp_importance_thread.get()

    if hp_importance_results is not None:
        pytest.fail("hyperparameter-importance: %s. Results: %s" % hp_importance_results)
def test_pytorch_11_const(
        aggregation_frequency: int, using_k8s: bool,
        collect_trial_profiles: Callable[[int], None]) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"))
    config = conf.set_aggregation_frequency(config, aggregation_frequency)
    config = conf.set_profiling_enabled(config)

    if using_k8s:
        pod_spec = {
            "metadata": {
                "labels": {
                    "ci": "testing"
                }
            },
            "spec": {
                "containers": [{
                    "name":
                    "determined-container",
                    "volumeMounts": [{
                        "name": "temp1",
                        "mountPath": "/random"
                    }],
                }],
                "volumes": [{
                    "name": "temp1",
                    "emptyDir": {}
                }],
            },
        }
        config = conf.set_pod_spec(config, pod_spec)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("mnist_pytorch"), 1)
    trial_id = exp.experiment_trials(experiment_id)[0].trial.id
    collect_trial_profiles(trial_id)
예제 #29
0
def test_end_to_end_adaptive() -> None:
    exp_id = exp.run_basic_test(
        conf.fixtures_path("mnist_pytorch/adaptive_short.yaml"),
        conf.tutorials_path("mnist_pytorch"),
        None,
    )

    # Check that validation accuracy look sane (more than 93% on MNIST).
    trials = exp.experiment_trials(exp_id)
    best = None
    for trial in trials:
        assert len(trial["steps"])
        last_step = trial["steps"][-1]
        accuracy = last_step["validation"]["metrics"]["validation_metrics"][
            "accuracy"]
        if not best or accuracy > best:
            best = accuracy

    assert best is not None
    assert best > 0.93

    # Check that ExperimentReference returns a sorted order of top checkpoints
    # without gaps. The top 2 checkpoints should be the first 2 of the top k
    # checkpoints if sorting is stable.
    d = Determined(conf.make_master_url())
    exp_ref = d.get_experiment(exp_id)

    top_2 = exp_ref.top_n_checkpoints(2)
    top_k = exp_ref.top_n_checkpoints(len(trials))

    top_2_uuids = [c.uuid for c in top_2]
    top_k_uuids = [c.uuid for c in top_k]

    assert top_2_uuids == top_k_uuids[:2]

    # Check that metrics are truly in sorted order.
    metrics = [
        c.validation["metrics"]["validationMetrics"]["validation_loss"]
        for c in top_k
    ]

    assert metrics == sorted(metrics)

    # Check that changing smaller is better reverses the checkpoint ordering.
    top_k_reversed = exp_ref.top_n_checkpoints(len(trials),
                                               sort_by="validation_loss",
                                               smaller_is_better=False)
    top_k_reversed_uuids = [c.uuid for c in top_k_reversed]

    assert top_k_uuids == top_k_reversed_uuids[::-1]

    checkpoint = top_k[0]
    checkpoint.add_metadata({"testing": "metadata"})
    db_check = d.get_checkpoint(checkpoint.uuid)
    # Make sure the checkpoint metadata is correct and correctly saved to the db.
    assert checkpoint.metadata == {"testing": "metadata"}
    assert checkpoint.metadata == db_check.metadata

    checkpoint.add_metadata({"some_key": "some_value"})
    db_check = d.get_checkpoint(checkpoint.uuid)
    assert checkpoint.metadata == {
        "testing": "metadata",
        "some_key": "some_value"
    }
    assert checkpoint.metadata == db_check.metadata

    checkpoint.add_metadata({"testing": "override"})
    db_check = d.get_checkpoint(checkpoint.uuid)
    assert checkpoint.metadata == {
        "testing": "override",
        "some_key": "some_value"
    }
    assert checkpoint.metadata == db_check.metadata

    checkpoint.remove_metadata(["some_key"])
    db_check = d.get_checkpoint(checkpoint.uuid)
    assert checkpoint.metadata == {"testing": "override"}
    assert checkpoint.metadata == db_check.metadata
예제 #30
0
def test_tutorial_dtrain() -> None:
    exp_id = create_native_experiment(conf.tutorials_path("native-tf-keras"),
                                      ["python", "tf_keras_native_dtrain.py"])
    experiment.wait_for_experiment_state(
        exp_id, "COMPLETED", max_wait_secs=conf.DEFAULT_MAX_WAIT_SECS)