コード例 #1
0
def test_tf_keras_const_warm_start(tf2: bool) -> None:
    config = conf.load_config(
        conf.official_examples_path("cifar10_cnn_tf_keras/const.yaml"))
    config = conf.set_max_steps(config, 2)
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

    experiment_id1 = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("cifar10_cnn_tf_keras"), 1)
    trials = exp.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0]
    first_trial_id = first_trial["id"]

    assert len(first_trial["steps"]) == 2
    first_checkpoint_id = first_trial["steps"][1]["checkpoint"]["id"]

    # Add a source trial ID to warm start from.
    config["searcher"]["source_trial_id"] = first_trial_id

    experiment_id2 = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("cifar10_cnn_tf_keras"), 1)

    # The new  trials should have a warm start checkpoint ID.
    trials = exp.experiment_trials(experiment_id2)
    assert len(trials) == 1
    for trial in trials:
        assert trial["warm_start_checkpoint_id"] == first_checkpoint_id
コード例 #2
0
def test_mnist_estimator_warm_start(tf2: bool) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_estimator/single.yaml"))
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)
    experiment_id1 = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_estimator"), 1)

    trials = exp.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0]
    first_trial_id = first_trial.id

    assert len(first_trial.steps) == 1
    first_checkpoint_id = first_trial.steps[0].checkpoint.id

    config_obj = conf.load_config(
        conf.fixtures_path("mnist_estimator/single.yaml"))

    config_obj["searcher"]["source_trial_id"] = first_trial_id
    config_obj = conf.set_tf2_image(config_obj) if tf2 else conf.set_tf1_image(
        config_obj)

    experiment_id2 = exp.run_basic_test_with_temp_config(
        config_obj, conf.official_examples_path("mnist_estimator"), 1)

    trials = exp.experiment_trials(experiment_id2)
    assert len(trials) == 1
    assert trials[0].warm_start_checkpoint_id == first_checkpoint_id
コード例 #3
0
def test_mnist_estimator_const(tf2: bool) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_estimator/single.yaml"))
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_estimator"), 1)

    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1

    # Check validation metrics.
    steps = trials[0].steps
    assert len(steps) == 1

    step = steps[0]
    assert "validation" in step

    v_metrics = step.validation.metrics["validation_metrics"]

    # GPU training is non-deterministic, but on CPU we can validate that we
    # reach a consistent result.
    if not cluster.running_on_gpu():
        assert v_metrics["accuracy"] == 0.9125999808311462

    # Check training metrics.
    full_trial_metrics = exp.trial_metrics(trials[0].id)
    for step in full_trial_metrics.steps:
        metrics = step.metrics

        batch_metrics = metrics["batch_metrics"]
        assert len(batch_metrics) == 100

        for batch_metric in batch_metrics:
            assert batch_metric["loss"] > 0
コード例 #4
0
def run_dataset_experiment(
    searcher_max_steps: int,
    batches_per_step: int,
    secrets: Dict[str, str],
    tf2: bool,
    slots_per_trial: int = 1,
    source_trial_id: Optional[str] = None,
) -> List[gql.trials]:
    config = conf.load_config(
        conf.fixtures_path("estimator_dataset/const.yaml"))
    config.setdefault("searcher", {})
    config["searcher"]["max_steps"] = searcher_max_steps
    config["batches_per_step"] = batches_per_step
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

    if source_trial_id is not None:
        config["searcher"]["source_trial_id"] = source_trial_id

    config.setdefault("resources", {})
    config["resources"]["slots_per_trial"] = slots_per_trial

    if cluster.num_agents() > 1:
        config["checkpoint_storage"] = exp.s3_checkpoint_config(secrets)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.fixtures_path("estimator_dataset"), 1)
    return exp.experiment_trials(experiment_id)
コード例 #5
0
ファイル: test_noop.py プロジェクト: atalwalkar/determined-1
def test_noop_pause() -> None:
    """
    Walk through starting, pausing, and resuming a single no-op experiment.
    """
    experiment_id = exp.create_experiment(
        conf.fixtures_path("no_op/single-medium-train-step.yaml"),
        conf.fixtures_path("no_op"),
        None,
    )
    exp.wait_for_experiment_state(experiment_id, "ACTIVE")

    # Wait for the only trial to get scheduled.
    workload_active = False
    for _ in range(conf.MAX_TASK_SCHEDULED_SECS):
        workload_active = exp.experiment_has_active_workload(experiment_id)
        if workload_active:
            break
        else:
            time.sleep(1)
    check.true(
        workload_active,
        f"The only trial cannot be scheduled within {conf.MAX_TASK_SCHEDULED_SECS} seconds.",
    )

    # Wait for the only trial to show progress, indicating the image is built and running.
    num_steps = 0
    for _ in range(conf.MAX_TRIAL_BUILD_SECS):
        trials = exp.experiment_trials(experiment_id)
        if len(trials) > 0:
            only_trial = trials[0]
            num_steps = len(only_trial["steps"])
            if num_steps > 1:
                break
        time.sleep(1)
    check.true(
        num_steps > 1,
        f"The only trial cannot start training within {conf.MAX_TRIAL_BUILD_SECS} seconds.",
    )

    # Pause the experiment. Note that Determined does not currently differentiate
    # between a "stopping paused" and a "paused" state, so we follow this check
    # up by ensuring the experiment cleared all scheduled workloads.
    exp.pause_experiment(experiment_id)
    exp.wait_for_experiment_state(experiment_id, "PAUSED")

    # Wait at most 20 seconds for the experiment to clear all workloads (each
    # train step should take 5 seconds).
    for _ in range(20):
        workload_active = exp.experiment_has_active_workload(experiment_id)
        if not workload_active:
            break
        else:
            time.sleep(1)
    check.true(
        not workload_active, f"The experiment cannot be paused within 20 seconds.",
    )

    # Resume the experiment and wait for completion.
    exp.activate_experiment(experiment_id)
    exp.wait_for_experiment_state(experiment_id, "COMPLETED")
コード例 #6
0
def test_tensorpack_const() -> None:
    config = conf.load_config(
        conf.official_examples_path("mnist_tp/const.yaml"))

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_tp"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
コード例 #7
0
def test_tensorpack_const() -> None:
    config = conf.load_config(
        conf.official_examples_path("mnist_tp/const.yaml"))
    config["checkpoint_storage"] = exp.shared_fs_checkpoint_config()
    config.get("bind_mounts", []).append(exp.root_user_home_bind_mount())

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_tp"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
コード例 #8
0
def test_mnist_estimator_load() -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_estimator/single.yaml"))
    config = conf.set_tf1_image(config)
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_estimator"), 1)

    trials = exp.experiment_trials(experiment_id)
    model = Determined().get_trial(trials[0].id).top_checkpoint().load()
    assert isinstance(model, AutoTrackable)
コード例 #9
0
ファイル: test_system.py プロジェクト: yilunchen27/determined
def test_log_null_bytes() -> None:
    config_obj = conf.load_config(conf.fixtures_path("no_op/single.yaml"))
    config_obj["hyperparameters"]["write_null"] = True
    config_obj["max_restarts"] = 0
    config_obj["searcher"]["max_steps"] = 1
    experiment_id = exp.run_basic_test_with_temp_config(config_obj, conf.fixtures_path("no_op"), 1)

    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
    logs = exp.trial_logs(trials[0].id)
    assert len(logs) > 0
コード例 #10
0
def test_pytorch_cifar10_const() -> None:
    config = conf.load_config(
        conf.official_examples_path("cifar10_cnn_pytorch/const.yaml"))
    config = conf.set_max_steps(config, 2)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("cifar10_cnn_pytorch"), 1)
    trials = exp.experiment_trials(experiment_id)
    nn = Determined().get_trial(
        trials[0].id).select_checkpoint(latest=True).load()
    assert isinstance(nn, torch.nn.Module)
コード例 #11
0
def test_tensorpack_native_parallel() -> None:
    config = conf.load_config(
        conf.official_examples_path("mnist_tp/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, True)
    config = conf.set_max_steps(config, 2)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_tp"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
コード例 #12
0
def test_tf_keras_mnist_parallel() -> None:
    config = conf.load_config(
        conf.official_examples_path("fashion_mnist_tf_keras/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_steps(config, 2)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("fashion_mnist_tf_keras"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
コード例 #13
0
def test_tf_keras_single_gpu(tf2: bool) -> None:
    config = conf.load_config(
        conf.official_examples_path("cifar10_cnn_tf_keras/const.yaml"))
    config = conf.set_slots_per_trial(config, 1)
    config = conf.set_max_steps(config, 2)
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("cifar10_cnn_tf_keras"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
コード例 #14
0
def test_pytorch_cifar10_parallel() -> None:
    config = conf.load_config(
        conf.official_examples_path("cifar10_cnn_pytorch/const.yaml"))
    config = conf.set_max_steps(config, 2)
    config = conf.set_slots_per_trial(config, 8)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("cifar10_cnn_pytorch"), 1)
    trials = exp.experiment_trials(experiment_id)
    nn = (Determined(conf.make_master_url()).get_trial(
        trials[0].id).select_checkpoint(latest=True).load(
            map_location=torch.device("cpu")))
    assert isinstance(nn, torch.nn.Module)
コード例 #15
0
def test_tf_keras_parallel(aggregation_frequency: int, tf2: bool) -> None:
    config = conf.load_config(
        conf.official_examples_path("cifar10_cnn_tf_keras/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_steps(config, 2)
    config = conf.set_aggregation_frequency(config, aggregation_frequency)
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("cifar10_cnn_tf_keras"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
コード例 #16
0
def test_tf_keras_mnist_parallel() -> None:
    config = conf.load_config(
        conf.official_examples_path("fashion_mnist_tf_keras/const.yaml"))
    config["checkpoint_storage"] = exp.shared_fs_checkpoint_config()
    config.get("bind_mounts", []).append(exp.root_user_home_bind_mount())
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_steps(config, 2)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("fashion_mnist_tf_keras"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
コード例 #17
0
def test_tf_keras_single_gpu(tf2: bool) -> None:
    config = conf.load_config(
        conf.official_examples_path("cifar10_cnn_tf_keras/const.yaml"))
    config["checkpoint_storage"] = exp.shared_fs_checkpoint_config()
    config.get("bind_mounts", []).append(exp.root_user_home_bind_mount())
    config = conf.set_slots_per_trial(config, 1)
    config = conf.set_max_steps(config, 2)
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("cifar10_cnn_tf_keras"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
コード例 #18
0
def test_pytorch_const_warm_start() -> None:
    """
    Test that specifying an earlier trial checkpoint to warm-start from
    correctly populates the later trials' `warm_start_checkpoint_id` fields.
    """
    config = conf.load_config(
        conf.official_examples_path("mnist_pytorch/const.yaml"))
    config = conf.set_max_steps(config, 2)

    experiment_id1 = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_pytorch"), 1)

    trials = exp.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0]
    first_trial_id = first_trial["id"]

    assert len(first_trial["steps"]) == 2
    first_checkpoint_id = first_trial["steps"][-1]["checkpoint"]["id"]

    config_obj = conf.load_config(
        conf.official_examples_path("mnist_pytorch/const.yaml"))

    # Change the search method to random, and add a source trial ID to warm
    # start from.
    config_obj["searcher"]["source_trial_id"] = first_trial_id
    config_obj["searcher"]["name"] = "random"
    config_obj["searcher"]["max_steps"] = 1
    config_obj["searcher"]["max_trials"] = 3

    experiment_id2 = exp.run_basic_test_with_temp_config(
        config_obj, conf.official_examples_path("mnist_pytorch"), 3)

    trials = exp.experiment_trials(experiment_id2)
    assert len(trials) == 3
    for trial in trials:
        assert trial["warm_start_checkpoint_id"] == first_checkpoint_id
コード例 #19
0
def test_tensorpack_parallel(aggregation_frequency: int) -> None:
    config = conf.load_config(
        conf.official_examples_path("mnist_tp/const.yaml"))
    config["checkpoint_storage"] = exp.shared_fs_checkpoint_config()
    config.get("bind_mounts", []).append(exp.root_user_home_bind_mount())
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_steps(config, 2)
    config = conf.set_aggregation_frequency(config, aggregation_frequency)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_tp"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
コード例 #20
0
def test_start_tensorboard_for_multi_experiment(
        tmp_path: Path, secrets: Dict[str, str]) -> None:
    """
    Start 3 random experiments configured with the s3 and shared_fs backends,
    start a TensorBoard instance pointed to the experiments and some select
    trials, and kill the TensorBoard instance.
    """

    with FileTree(
            tmp_path,
        {
            "shared_fs_config.yaml": shared_fs_config(1),
            "s3_config.yaml": s3_config(1, secrets),
            "multi_trial_config.yaml": shared_fs_config(3),
        },
    ) as tree:
        shared_conf_path = tree.joinpath("shared_fs_config.yaml")
        shared_fs_exp_id = exp.run_basic_test(str(shared_conf_path),
                                              conf.fixtures_path("no_op"),
                                              num_trials)

        s3_conf_path = tree.joinpath("s3_config.yaml")
        s3_exp_id = exp.run_basic_test(str(s3_conf_path),
                                       conf.fixtures_path("no_op"), num_trials)

        multi_trial_config = tree.joinpath("multi_trial_config.yaml")
        multi_trial_exp_id = exp.run_basic_test(str(multi_trial_config),
                                                conf.fixtures_path("no_op"), 3)

        trial_ids = [
            str(t["id"]) for t in exp.experiment_trials(multi_trial_exp_id)
        ]

    command = [
        "tensorboard",
        "start",
        str(shared_fs_exp_id),
        str(s3_exp_id),
        "-t",
        *trial_ids,
        "--no-browser",
    ]

    with cmd.interactive_command(*command) as tensorboard:
        for line in tensorboard.stdout:
            if SERVICE_READY in line:
                break
        else:
            raise AssertionError(f"Did not find {SERVICE_READY} in output")
コード例 #21
0
def test_metric_gathering() -> None:
    """
    Confirm that metrics are gathered from the trial the way that we expect.
    """
    experiment_id = exp.run_basic_test(
        conf.fixtures_path("metric_maker/const.yaml"),
        conf.fixtures_path("metric_maker"), 1)

    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1

    # Read the structure of the metrics directly from the config file
    config = conf.load_config(conf.fixtures_path("metric_maker/const.yaml"))

    base_value = config["hyperparameters"]["starting_base_value"]
    gain_per_batch = config["hyperparameters"]["gain_per_batch"]
    training_structure = config["hyperparameters"]["training_structure"]["val"]
    validation_structure = config["hyperparameters"]["validation_structure"][
        "val"]

    batches_per_step = 100

    # Check training metrics.
    full_trial_metrics = exp.trial_metrics(trials[0].id)
    for step in full_trial_metrics.steps:
        metrics = step.metrics
        assert metrics["num_inputs"] == batches_per_step

        actual = metrics["batch_metrics"]
        assert len(actual) == batches_per_step

        first_base_value = base_value + (step.id - 1) * batches_per_step
        batch_values = first_base_value + gain_per_batch * np.arange(
            batches_per_step)
        expected = [
            structure_to_metrics(value, training_structure)
            for value in batch_values
        ]
        assert structure_equal(expected, actual)

    # Check validation metrics.
    for step in trials[0].steps:
        validation = step.validation
        metrics = validation.metrics
        actual = metrics["validation_metrics"]

        value = base_value + step.id * batches_per_step
        expected = structure_to_metrics(value, validation_structure)
        assert structure_equal(expected, actual)
コード例 #22
0
def test_end_to_end_adaptive() -> None:
    exp_id = exp.run_basic_test(
        conf.fixtures_path("mnist_pytorch/adaptive_short.yaml"),
        conf.official_examples_path("mnist_pytorch"),
        None,
    )

    # Check that validation accuracy look sane (more than 93% on MNIST).
    trials = exp.experiment_trials(exp_id)
    best = None
    for trial in trials:
        assert len(trial.steps)
        last_step = trial.steps[-1]
        accuracy = last_step.validation.metrics["validation_metrics"][
            "accuracy"]
        if not best or accuracy > best:
            best = accuracy

    assert best is not None
    assert best > 0.93

    # Check that ExperimentReference returns a sorted order of top checkpoints
    # without gaps. The top 2 checkpoints should be the first 2 of the top k
    # checkpoints if sorting is stable.
    exp_ref = Determined().get_experiment(exp_id)

    top_2 = exp_ref.top_n_checkpoints(2)
    top_k = exp_ref.top_n_checkpoints(len(trials))

    top_2_uuids = [c.uuid for c in top_2]
    top_k_uuids = [c.uuid for c in top_k]

    assert top_2_uuids == top_k_uuids[:2]

    # Check that metrics are truly in sorted order.
    metrics = [
        c.validation.metrics["validation_metrics"]["validation_loss"]
        for c in top_k
    ]

    assert metrics == sorted(metrics)

    # Check that changing smaller is better reverses the checkpoint ordering.
    top_k_reversed = exp_ref.top_n_checkpoints(len(trials),
                                               sort_by="validation_loss",
                                               smaller_is_better=False)
    top_k_reversed_uuids = [c.uuid for c in top_k_reversed]

    assert top_k_uuids == top_k_reversed_uuids[::-1]
コード例 #23
0
ファイル: test_noop.py プロジェクト: atalwalkar/determined-1
def test_noop_single_warm_start() -> None:
    experiment_id1 = exp.run_basic_test(
        conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1
    )

    trials = exp.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0]
    first_trial_id = first_trial["id"]

    assert len(first_trial["steps"]) == 30
    first_step = first_trial["steps"][0]
    first_checkpoint_id = first_step["checkpoint"]["id"]
    last_step = first_trial["steps"][29]
    last_checkpoint_id = last_step["checkpoint"]["id"]
    assert last_step["validation"]["metrics"]["validation_metrics"][
        "validation_error"
    ] == pytest.approx(0.9 ** 30)

    config_base = conf.load_config(conf.fixtures_path("no_op/single.yaml"))

    # Test source_trial_id.
    config_obj = copy.deepcopy(config_base)
    # Add a source trial ID to warm start from.
    config_obj["searcher"]["source_trial_id"] = first_trial_id

    experiment_id2 = exp.run_basic_test_with_temp_config(config_obj, conf.fixtures_path("no_op"), 1)

    trials = exp.experiment_trials(experiment_id2)
    assert len(trials) == 1

    second_trial = trials[0]
    assert len(second_trial["steps"]) == 30

    # Second trial should have a warm start checkpoint id.
    assert second_trial["warm_start_checkpoint_id"] == last_checkpoint_id

    assert second_trial["steps"][29]["validation"]["metrics"]["validation_metrics"][
        "validation_error"
    ] == pytest.approx(0.9 ** 60)

    # Now test source_checkpoint_uuid.
    config_obj = copy.deepcopy(config_base)
    # Add a source trial ID to warm start from.
    config_obj["searcher"]["source_checkpoint_uuid"] = first_step["checkpoint"]["uuid"]

    with tempfile.NamedTemporaryFile() as tf:
        with open(tf.name, "w") as f:
            yaml.dump(config_obj, f)

        experiment_id3 = exp.run_basic_test(tf.name, conf.fixtures_path("no_op"), 1)

    trials = exp.experiment_trials(experiment_id3)
    assert len(trials) == 1

    third_trial = trials[0]
    assert len(third_trial["steps"]) == 30

    assert third_trial["warm_start_checkpoint_id"] == first_checkpoint_id

    assert third_trial["steps"][1]["validation"]["metrics"]["validation_metrics"][
        "validation_error"
    ] == pytest.approx(0.9 ** 3)
コード例 #24
0
def run_gc_checkpoints_test(checkpoint_storage: Dict[str, str]) -> None:
    fixtures = [
        (
            conf.fixtures_path("no_op/gc_checkpoints_decreasing.yaml"),
            {
                "COMPLETED": {8, 9, 10},
                "DELETED": {1, 2, 3, 4, 5, 6, 7}
            },
        ),
        (
            conf.fixtures_path("no_op/gc_checkpoints_increasing.yaml"),
            {
                "COMPLETED": {1, 2, 3, 9, 10},
                "DELETED": {4, 5, 6, 7, 8}
            },
        ),
    ]

    all_checkpoints = []
    for base_conf_path, result in fixtures:
        config = conf.load_config(str(base_conf_path))
        config["checkpoint_storage"].update(checkpoint_storage)

        with tempfile.NamedTemporaryFile() as tf:
            with open(tf.name, "w") as f:
                yaml.dump(config, f)

            experiment_id = exp.create_experiment(tf.name,
                                                  conf.fixtures_path("no_op"))

        exp.wait_for_experiment_state(experiment_id, "COMPLETED")

        # Checkpoints are not marked as deleted until gc_checkpoint task starts.
        retries = 5
        for retry in range(retries):
            trials = exp.experiment_trials(experiment_id)
            assert len(trials) == 1

            checkpoints = sorted(
                (step.checkpoint for step in trials[0].steps),
                key=operator.itemgetter("step_id"),
            )
            assert len(checkpoints) == 10
            by_state = {}  # type: Dict[str, Set[int]]
            for checkpoint in checkpoints:
                by_state.setdefault(checkpoint.state,
                                    set()).add(checkpoint.step_id)

            if by_state == result:
                all_checkpoints.append((config, checkpoints))
                break

            if retry + 1 == retries:
                assert by_state == result

            time.sleep(1)

    # Check that the actual checkpoint storage (for shared_fs) reflects the
    # deletions. We want to wait for the GC containers to exit, so check
    # repeatedly with a timeout.
    max_checks = 30
    for check in range(max_checks):
        time.sleep(1)
        try:
            for config, checkpoints in all_checkpoints:
                checkpoint_config = config["checkpoint_storage"]

                if checkpoint_config["type"] == "shared_fs" and (
                        "storage_path" not in checkpoint_config):
                    if "tensorboard_path" in checkpoint_config:
                        checkpoint_config[
                            "storage_path"] = checkpoint_config.get(
                                "tensorboard_path", None)
                    else:
                        checkpoint_config[
                            "storage_path"] = checkpoint_config.get(
                                "checkpoint_path", None)

                    root = os.path.join(checkpoint_config["host_path"],
                                        checkpoint_config["storage_path"])

                    for checkpoint in checkpoints:
                        dirname = os.path.join(root, checkpoint.uuid)
                        if checkpoint.state == "COMPLETED":
                            assert os.path.isdir(dirname)
                        elif checkpoint.state == "DELETED":
                            assert not os.path.exists(dirname)
        except AssertionError:
            if check == max_checks - 1:
                raise
        else:
            break