def test_mnist_estimator_warm_start(tf2: bool) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_estimator/single.yaml"))
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)
    experiment_id1 = exp.run_basic_test_with_temp_config(
        config, conf.cv_examples_path("mnist_estimator"), 1)

    trials = exp.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0]
    first_trial_id = first_trial.trial.id

    assert len(first_trial.workloads) == 3
    checkpoint_workloads = exp.workloads_with_checkpoint(first_trial.workloads)
    first_checkpoint_uuid = checkpoint_workloads[0].uuid

    config_obj = conf.load_config(
        conf.fixtures_path("mnist_estimator/single.yaml"))

    config_obj["searcher"]["source_trial_id"] = first_trial_id
    config_obj = conf.set_tf2_image(config_obj) if tf2 else conf.set_tf1_image(
        config_obj)

    experiment_id2 = exp.run_basic_test_with_temp_config(
        config_obj, conf.cv_examples_path("mnist_estimator"), 1)

    trials = exp.experiment_trials(experiment_id2)
    assert len(trials) == 1
    assert trials[0].trial.warmStartCheckpointUuid == first_checkpoint_uuid
Example #2
0
def test_tf_keras_const_warm_start(
        tf2: bool, collect_trial_profiles: Callable[[int], None]) -> None:
    config = conf.load_config(
        conf.cv_examples_path("cifar10_tf_keras/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_min_validation_period(config, {"batches": 1000})
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)
    config = conf.set_profiling_enabled(config)

    experiment_id1 = exp.run_basic_test_with_temp_config(
        config, conf.cv_examples_path("cifar10_tf_keras"), 1)
    trials = exp.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0]
    first_trial_id = first_trial.trial.id

    assert len(first_trial.workloads) == 4
    checkpoints = exp.workloads_with_checkpoint(first_trial.workloads)
    first_checkpoint_uuid = checkpoints[0].uuid

    # Add a source trial ID to warm start from.
    config["searcher"]["source_trial_id"] = first_trial_id

    experiment_id2 = exp.run_basic_test_with_temp_config(
        config, conf.cv_examples_path("cifar10_tf_keras"), 1)

    # The new  trials should have a warm start checkpoint ID.
    trials = exp.experiment_trials(experiment_id2)
    assert len(trials) == 1
    for t in trials:
        assert t.trial.warmStartCheckpointUuid != ""
        assert t.trial.warmStartCheckpointUuid == first_checkpoint_uuid
    trial_id = trials[0].trial.id
    collect_trial_profiles(trial_id)
Example #3
0
def _test_rng_restore(fixture: str, metrics: list, tf2: Union[None, bool] = None) -> None:
    """
    This test confirms that an experiment can be restarted from a checkpoint
    with the same RNG state. It requires a test fixture that will emit
    random numbers from all of the RNGs used in the relevant framework as
    metrics. The experiment must have a const.yaml, run for at least 3 steps,
    checkpoint every step, and keep the first checkpoint (either by having
    metrics get worse over time, or by configuring the experiment to keep all
    checkpoints).
    """
    config_base = conf.load_config(conf.fixtures_path(fixture + "/const.yaml"))
    config = copy.deepcopy(config_base)
    if tf2 is not None:
        config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

    experiment = exp.run_basic_test_with_temp_config(
        config,
        conf.fixtures_path(fixture),
        1,
    )

    first_trial = exp.experiment_trials(experiment)[0]

    assert len(first_trial.workloads) >= 4

    first_checkpoint = exp.workloads_with_checkpoint(first_trial.workloads)[0]
    first_checkpoint_uuid = first_checkpoint.uuid

    config = copy.deepcopy(config_base)
    if tf2 is not None:
        config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)
    config["searcher"]["source_checkpoint_uuid"] = first_checkpoint.uuid

    experiment2 = exp.run_basic_test_with_temp_config(config, conf.fixtures_path(fixture), 1)

    second_trial = exp.experiment_trials(experiment2)[0]

    assert len(second_trial.workloads) >= 4
    assert second_trial.trial.warmStartCheckpointUuid == first_checkpoint_uuid
    first_trial_validations = exp.workloads_with_validation(first_trial.workloads)
    second_trial_validations = exp.workloads_with_validation(second_trial.workloads)

    for wl in range(0, 2):
        for metric in metrics:
            first_trial_val = first_trial_validations[wl + 1]
            first_metric = first_trial_val.metrics[metric]
            second_trial_val = second_trial_validations[wl]
            second_metric = second_trial_val.metrics[metric]
            assert (
                first_metric == second_metric
            ), f"failures on iteration: {wl} with metric: {metric}"
def test_pytorch_const_warm_start() -> None:
    """
    Test that specifying an earlier trial checkpoint to warm-start from
    correctly populates the later trials' `warm_start_checkpoint_id` fields.
    """
    config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})

    experiment_id1 = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("mnist_pytorch"), 1)

    trials = exp.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0]
    first_trial_id = first_trial.trial.id

    assert len(first_trial.workloads) == 4
    checkpoints = exp.workloads_with_checkpoint(first_trial.workloads)
    first_checkpoint = checkpoints[-1]
    first_checkpoint_uuid = first_checkpoint.uuid

    config_obj = conf.load_config(
        conf.tutorials_path("mnist_pytorch/const.yaml"))

    # Change the search method to random, and add a source trial ID to warm
    # start from.
    config_obj["searcher"]["source_trial_id"] = first_trial_id
    config_obj["searcher"]["name"] = "random"
    config_obj["searcher"]["max_length"] = {"batches": 100}
    config_obj["searcher"]["max_trials"] = 3

    experiment_id2 = exp.run_basic_test_with_temp_config(
        config_obj, conf.tutorials_path("mnist_pytorch"), 3)

    trials = exp.experiment_trials(experiment_id2)
    assert len(trials) == 3
    for t in trials:
        assert t.trial.warmStartCheckpointUuid == first_checkpoint_uuid
Example #5
0
def test_noop_single_warm_start() -> None:
    experiment_id1 = exp.run_basic_test(
        conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op"), 1
    )

    trials = exp.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0].trial
    first_trial_id = first_trial.id

    first_workloads = trials[0].workloads
    assert len(first_workloads) == 90
    checkpoints = exp.workloads_with_checkpoint(first_workloads)
    assert len(checkpoints) == 30
    first_checkpoint_uuid = checkpoints[0].uuid
    last_checkpoint_uuid = checkpoints[-1].uuid
    last_validation = exp.workloads_with_validation(first_workloads)[-1]
    assert last_validation.metrics["validation_error"] == pytest.approx(0.9 ** 30)

    config_base = conf.load_config(conf.fixtures_path("no_op/single.yaml"))

    # Test source_trial_id.
    config_obj = copy.deepcopy(config_base)
    # Add a source trial ID to warm start from.
    config_obj["searcher"]["source_trial_id"] = first_trial_id

    experiment_id2 = exp.run_basic_test_with_temp_config(config_obj, conf.fixtures_path("no_op"), 1)

    trials = exp.experiment_trials(experiment_id2)
    assert len(trials) == 1

    second_trial = trials[0]
    assert len(second_trial.workloads) == 90

    # Second trial should have a warm start checkpoint id.
    assert second_trial.trial.warmStartCheckpointUuid == last_checkpoint_uuid

    val_workloads = exp.workloads_with_validation(second_trial.workloads)
    assert val_workloads[-1].metrics["validation_error"] == pytest.approx(0.9 ** 60)

    # Now test source_checkpoint_uuid.
    config_obj = copy.deepcopy(config_base)
    # Add a source trial ID to warm start from.
    config_obj["searcher"]["source_checkpoint_uuid"] = checkpoints[0].uuid

    with tempfile.NamedTemporaryFile() as tf:
        with open(tf.name, "w") as f:
            yaml.dump(config_obj, f)

        experiment_id3 = exp.run_basic_test(tf.name, conf.fixtures_path("no_op"), 1)

    trials = exp.experiment_trials(experiment_id3)
    assert len(trials) == 1

    third_trial = trials[0]
    assert len(third_trial.workloads) == 90

    assert third_trial.trial.warmStartCheckpointUuid == first_checkpoint_uuid
    validations = exp.workloads_with_validation(third_trial.workloads)
    assert validations[1].metrics["validation_error"] == pytest.approx(0.9 ** 3)
def run_gc_checkpoints_test(checkpoint_storage: Dict[str, str]) -> None:
    fixtures = [
        (
            conf.fixtures_path("no_op/gc_checkpoints_decreasing.yaml"),
            {
                (bindings.determinedexperimentv1State.STATE_COMPLETED.value):
                {800, 900, 1000},
                (bindings.determinedexperimentv1State.STATE_DELETED.value): {
                    100,
                    200,
                    300,
                    400,
                    500,
                    600,
                    700,
                },
            },
        ),
        (
            conf.fixtures_path("no_op/gc_checkpoints_increasing.yaml"),
            {
                (bindings.determinedexperimentv1State.STATE_COMPLETED.value): {
                    100,
                    200,
                    300,
                    900,
                    1000,
                },
                (bindings.determinedexperimentv1State.STATE_DELETED.value): {
                    400,
                    500,
                    600,
                    700,
                    800,
                },
            },
        ),
    ]

    all_checkpoints: List[Tuple[Any, List[bindings.v1CheckpointWorkload]]] = []
    for base_conf_path, result in fixtures:
        config = conf.load_config(str(base_conf_path))
        config["checkpoint_storage"].update(checkpoint_storage)

        with tempfile.NamedTemporaryFile() as tf:
            with open(tf.name, "w") as f:
                yaml.dump(config, f)

            experiment_id = exp.create_experiment(tf.name,
                                                  conf.fixtures_path("no_op"))

        exp.wait_for_experiment_state(
            experiment_id,
            bindings.determinedexperimentv1State.STATE_COMPLETED)

        # In some configurations, checkpoint GC will run on an auxillary machine, which may have to
        # be spun up still.  So we'll wait for it to run.
        wait_for_gc_to_finish(experiment_id)

        # Checkpoints are not marked as deleted until gc_checkpoint task starts.
        retries = 5
        for retry in range(retries):
            trials = exp.experiment_trials(experiment_id)
            assert len(trials) == 1

            cpoints = exp.workloads_with_checkpoint(trials[0].workloads)
            sorted_checkpoints = sorted(
                cpoints,
                key=lambda ckp: int(ckp.totalBatches),
            )
            assert len(sorted_checkpoints) == 10
            by_state = {}  # type: Dict[str, Set[int]]
            for ckpt in sorted_checkpoints:
                by_state.setdefault(ckpt.state.value,
                                    set()).add(ckpt.totalBatches)

            if by_state == result:
                all_checkpoints.append((config, sorted_checkpoints))
                break

            if retry + 1 == retries:
                assert by_state == result

            time.sleep(1)

    # Check that the actual checkpoint storage (for shared_fs) reflects the
    # deletions. We want to wait for the GC containers to exit, so check
    # repeatedly with a timeout.
    max_checks = 30
    for i in range(max_checks):
        time.sleep(1)
        try:
            storage_states = []
            for config, checkpoints in all_checkpoints:
                checkpoint_config = config["checkpoint_storage"]

                storage_manager = storage.build(checkpoint_config,
                                                container_path=None)
                storage_state = {}  # type: Dict[str, Any]
                for checkpoint in checkpoints:
                    assert checkpoint.uuid is not None
                    storage_id = checkpoint.uuid
                    storage_state[storage_id] = {}
                    if checkpoint.state == bindings.determinedcheckpointv1State.STATE_COMPLETED:
                        storage_state[storage_id]["found"] = False
                        try:
                            with storage_manager.restore_path(storage_id):
                                storage_state[storage_id]["found"] = True
                        except errors.CheckpointNotFound:
                            pass
                    elif checkpoint.state == bindings.determinedcheckpointv1State.STATE_DELETED:
                        storage_state[storage_id] = {
                            "deleted": False,
                            "checkpoint": checkpoint
                        }
                        try:
                            with storage_manager.restore_path(storage_id):
                                pass
                        except errors.CheckpointNotFound:
                            storage_state[storage_id]["deleted"] = True
                        storage_states.append(storage_state)

            for storage_state in storage_states:
                for state in storage_state.values():
                    if state.get("deleted", None) is False:
                        json_states = json.dumps(storage_states)
                        raise AssertionError(
                            f"Some checkpoints were not deleted: JSON:{json_states}"
                        )
                    if state.get("found", None) is False:
                        json_states = json.dumps(storage_states)
                        raise AssertionError(
                            f"Some checkpoints were not found: JSON:{json_states}"
                        )
        except AssertionError:
            if i == max_checks - 1:
                raise
        else:
            break