def test_pytorch_cifar10_parallel() -> None:
    config = conf.load_config(
        conf.official_examples_path("trial/cifar10_cnn_pytorch/const.yaml"))
    config = conf.set_max_steps(config, 2)
    config = conf.set_slots_per_trial(config, 8)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/cifar10_cnn_pytorch"), 1)
    trials = exp.experiment_trials(experiment_id)
    nn = (Determined(conf.make_master_url()).get_trial(
        trials[0]["id"]).select_checkpoint(latest=True).load())
    assert isinstance(nn, torch.nn.Module)
Exemple #2
0
def test_tf_keras_native_parallel(tf2: bool) -> None:
    config = conf.load_config(conf.official_examples_path("cifar10_cnn_tf_keras/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, True)
    config = conf.set_max_steps(config, 2)
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("cifar10_cnn_tf_keras"), 1
    )
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
Exemple #3
0
def test_pytorch_gan_parallel() -> None:
    config = conf.load_config(
        conf.official_examples_path("trial/mnist_gan_pytorch/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_slots_per_trial(config, 8)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_gan_pytorch"), 1)
    trials = exp.experiment_trials(experiment_id)
    (Determined(conf.make_master_url()).get_trial(
        trials[0]["id"]).select_checkpoint(latest=True).load(
            map_location="cpu"))
Exemple #4
0
def test_pytorch_cifar10_const() -> None:
    config = conf.load_config(
        conf.official_examples_path("cifar10_cnn_pytorch/const.yaml"))
    config = conf.set_max_steps(config, 2)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("cifar10_cnn_pytorch"), 1)
    trials = exp.experiment_trials(experiment_id)
    nn = (Determined(conf.make_master_url()).get_trial(
        trials[0].id).select_checkpoint(latest=True).load(
            map_location=torch.device("cpu")))
    assert isinstance(nn, torch.nn.Module)
Exemple #5
0
def test_tensorpack_parallel(aggregation_frequency: int) -> None:
    config = conf.load_config(
        conf.official_examples_path("trial/mnist_tp/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_length(config, {"batches": 32})
    config = conf.set_aggregation_frequency(config, aggregation_frequency)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_tp"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
Exemple #6
0
def test_tf_keras_parallel(aggregation_frequency: int, tf2: bool) -> None:
    config = conf.load_config(
        conf.official_examples_path("trial/cifar10_cnn_tf_keras/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_aggregation_frequency(config, aggregation_frequency)
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/cifar10_cnn_tf_keras"), 1)
    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1
Exemple #7
0
def test_s3_no_creds(secrets: Dict[str, str]) -> None:
    pytest.skip("Temporarily skipping this until we find a more secure way of testing this.")
    config = conf.load_config(conf.official_examples_path("trial/mnist_pytorch/const.yaml"))
    config["checkpoint_storage"] = exp.s3_checkpoint_config_no_creds()
    config.setdefault("environment", {})
    config["environment"].setdefault("environment_variables", [])
    config["environment"]["environment_variables"] += [
        f"AWS_ACCESS_KEY_ID={secrets['INTEGRATIONS_S3_ACCESS_KEY']}",
        f"AWS_SECRET_ACCESS_KEY={secrets['INTEGRATIONS_S3_SECRET_KEY']}",
    ]
    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_pytorch"), 1
    )
Exemple #8
0
def test_pytorch_const_parallel(aggregation_frequency: int,
                                use_amp: bool) -> None:
    config = conf.load_config(
        conf.official_examples_path("mnist_pytorch/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_steps(config, 2)
    config = conf.set_aggregation_frequency(config, aggregation_frequency)
    if use_amp:
        config = conf.set_amp_level(config, "O1")

    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_pytorch"), 1)
Exemple #9
0
def test_pytorch_const_parallel(aggregation_frequency: int, use_amp: bool) -> None:
    if use_amp and aggregation_frequency > 1:
        pytest.skip("Mixed precision is not support with aggregation frequency > 1.")

    config = conf.load_config(conf.official_examples_path("trial/mnist_pytorch/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_aggregation_frequency(config, aggregation_frequency)
    if use_amp:
        config = conf.set_amp_level(config, "O1")

    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_pytorch"), 1
    )
Exemple #10
0
def test_pytorch_parallel() -> None:
    config = conf.load_config(
        conf.official_examples_path("trial/mnist_pytorch/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_tensor_auto_tuning(config, True)
    config = conf.set_perform_initial_validation(config, True)

    exp_id = exp.run_basic_test_with_temp_config(
        config,
        conf.official_examples_path("trial/mnist_pytorch"),
        1,
        has_zeroth_step=True)
    exp.assert_performed_initial_validation(exp_id)
Exemple #11
0
def test_create_test_mode() -> None:
    # test-mode should succeed with a valid experiment.
    command = [
        "det",
        "-m",
        conf.make_master_url(),
        "experiment",
        "create",
        "--test-mode",
        conf.fixtures_path("mnist_pytorch/adaptive_short.yaml"),
        conf.official_examples_path("trial/mnist_pytorch"),
    ]
    output = subprocess.check_output(command, universal_newlines=True)
    assert "Model definition test succeeded" in output

    # test-mode should fail when an error is introduced into the trial
    # implementation.
    command = [
        "det",
        "-m",
        conf.make_master_url(),
        "experiment",
        "create",
        "--test-mode",
        conf.fixtures_path("trial_error/const.yaml"),
        conf.fixtures_path("trial_error"),
    ]
    with pytest.raises(subprocess.CalledProcessError):
        subprocess.check_call(command)
Exemple #12
0
def test_pytorch_11_const(aggregation_frequency: int) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"))
    config = conf.set_aggregation_frequency(config, aggregation_frequency)

    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_pytorch"), 1)
def test_mnist_estimator_const(tf2: bool) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_estimator/single.yaml"))
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_estimator"), 1)

    trials = exp.experiment_trials(experiment_id)
    assert len(trials) == 1

    # Check validation metrics.
    steps = trials[0]["steps"]
    assert len(steps) == 1

    step = steps[0]
    assert "validation" in step

    v_metrics = step["validation"]["metrics"]["validation_metrics"]

    # GPU training is non-deterministic, but on CPU we can validate that we
    # reach a consistent result.
    if not cluster.running_on_gpu():
        assert v_metrics["accuracy"] == 0.9125999808311462

    # Check training metrics.
    full_trial_metrics = exp.trial_metrics(trials[0]["id"])
    for step in full_trial_metrics["steps"]:
        metrics = step["metrics"]

        batch_metrics = metrics["batch_metrics"]
        assert len(batch_metrics) == 100

        for batch_metric in batch_metrics:
            assert batch_metric["loss"] > 0
Exemple #14
0
def test_pytorch_11_const(aggregation_frequency: int, using_k8s: bool) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"))
    config = conf.set_aggregation_frequency(config, aggregation_frequency)

    if using_k8s:
        pod_spec = {
            "metadata": {
                "labels": {
                    "ci": "testing"
                }
            },
            "spec": {
                "containers": [{
                    "volumeMounts": [{
                        "name": "temp1",
                        "mountPath": "/random"
                    }]
                }],
                "volumes": [{
                    "name": "temp1",
                    "emptyDir": {}
                }],
            },
        }
        config = conf.set_pod_spec(config, pod_spec)

    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_pytorch"), 1)
def test_model_registry() -> None:
    exp_id = exp.run_basic_test(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"),
        conf.official_examples_path("trial/mnist_pytorch"),
        None,
    )

    d = Determined(conf.make_master_url())

    mnist = d.create_model("mnist", "simple computer vision model")
    assert mnist.metadata == {}

    mnist.add_metadata({"testing": "metadata"})
    assert mnist.metadata == {"testing": "metadata"}

    mnist.add_metadata({"some_key": "some_value"})
    assert mnist.metadata == {"testing": "metadata", "some_key": "some_value"}

    mnist.add_metadata({"testing": "override"})
    assert mnist.metadata == {"testing": "override", "some_key": "some_value"}

    mnist.remove_metadata(["some_key"])
    assert mnist.metadata == {"testing": "override"}

    checkpoint = d.get_experiment(exp_id).top_checkpoint()
    model_version = mnist.register_version(checkpoint)
    assert model_version == 1
    assert mnist.get_version().uuid == checkpoint.uuid

    d.create_model("transformer", "all you need is attention")
    d.create_model("object-detection", "a bounding box model")

    models = d.get_models(sort_by=ModelSortBy.NAME)
    assert [m.name
            for m in models] == ["mnist", "object-detection", "transformer"]
Exemple #16
0
def test_end_to_end_adaptive() -> None:
    exp_id = exp.run_basic_test(
        conf.fixtures_path("mnist_pytorch/adaptive_short.yaml"),
        conf.official_examples_path("trial/mnist_pytorch"),
        None,
    )

    # Check that validation accuracy look sane (more than 93% on MNIST).
    trials = exp.experiment_trials(exp_id)
    best = None
    for trial in trials:
        assert len(trial["steps"])
        last_step = trial["steps"][-1]
        accuracy = last_step["validation"]["metrics"]["validation_metrics"]["accuracy"]
        if not best or accuracy > best:
            best = accuracy

    assert best is not None
    assert best > 0.93

    # Check that ExperimentReference returns a sorted order of top checkpoints
    # without gaps. The top 2 checkpoints should be the first 2 of the top k
    # checkpoints if sorting is stable.
    d = Determined(conf.make_master_url())
    exp_ref = d.get_experiment(exp_id)

    top_2 = exp_ref.top_n_checkpoints(2)
    top_k = exp_ref.top_n_checkpoints(len(trials))

    top_2_uuids = [c.uuid for c in top_2]
    top_k_uuids = [c.uuid for c in top_k]

    assert top_2_uuids == top_k_uuids[:2]

    # Check that metrics are truly in sorted order.
    metrics = [c.validation["metrics"]["validation_metrics"]["validation_loss"] for c in top_k]

    assert metrics == sorted(metrics)

    # Check that changing smaller is better reverses the checkpoint ordering.
    top_k_reversed = exp_ref.top_n_checkpoints(
        len(trials), sort_by="validation_loss", smaller_is_better=False
    )
    top_k_reversed_uuids = [c.uuid for c in top_k_reversed]

    assert top_k_uuids == top_k_reversed_uuids[::-1]

    checkpoint = top_k[0]
    checkpoint.add_metadata({"testing": "metadata"})
    assert checkpoint.metadata == {"testing": "metadata"}

    checkpoint.add_metadata({"some_key": "some_value"})
    assert checkpoint.metadata == {"testing": "metadata", "some_key": "some_value"}

    checkpoint.add_metadata({"testing": "override"})
    assert checkpoint.metadata == {"testing": "override", "some_key": "some_value"}

    checkpoint.remove_metadata(["some_key"])
    assert checkpoint.metadata == {"testing": "override"}
def test_mnist_estimator_adaptive(tf2: bool) -> None:
    # Only test tf1 here, because a tf2 test would add no extra coverage.
    config = conf.load_config(
        conf.fixtures_path("mnist_estimator/adaptive.yaml"))
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_estimator"), None)
Exemple #18
0
def test_pytorch_load() -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"))

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_pytorch"), 1)

    (Determined(conf.make_master_url()).get_experiment(
        experiment_id).top_checkpoint().load(map_location="cpu"))
Exemple #19
0
def test_mnist_estimator_load() -> None:
    config = conf.load_config(conf.fixtures_path("mnist_estimator/single.yaml"))
    config = conf.set_tf1_image(config)
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_estimator"), 1
    )

    trials = exp.experiment_trials(experiment_id)
    model = Determined(conf.make_master_url()).get_trial(trials[0]["id"]).top_checkpoint().load()
    assert isinstance(model, AutoTrackable)
Exemple #20
0
def test_pytorch_load() -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"))

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_pytorch"), 1)

    nn = (Determined(conf.make_master_url()).get_experiment(
        experiment_id).top_checkpoint().load(map_location=torch.device("cpu")))
    assert isinstance(nn, torch.nn.Module)
def test_mnist_pytorch_accuracy() -> None:
    config = conf.load_config(
        conf.official_examples_path("trial/mnist_pytorch/const.yaml"))
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_pytorch"), 1)

    trials = exp.experiment_trials(experiment_id)
    trial_metrics = exp.trial_metrics(trials[0]["id"])

    validation_errors = [
        step["validation"]["metrics"]["validation_metrics"]["accuracy"]
        for step in trial_metrics["steps"] if step.get("validation")
    ]

    target_accuracy = 0.97
    assert max(validation_errors) > target_accuracy, (
        "mnist_pytorch did not reach minimum target accuracy {} in {} steps."
        " full validation error history: {}".format(
            target_accuracy, len(trial_metrics["steps"]), validation_errors))
def test_fashion_mnist_tf_keras() -> None:
    config = conf.load_config(
        conf.official_examples_path("trial/fashion_mnist_tf_keras/const.yaml"))
    config = conf.set_random_seed(config, 1591110586)
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/fashion_mnist_tf_keras"), 1)

    trials = exp.experiment_trials(experiment_id)
    trial_metrics = exp.trial_metrics(trials[0]["id"])

    validation_errors = [
        step["validation"]["metrics"]["validation_metrics"]["val_accuracy"]
        for step in trial_metrics["steps"] if step.get("validation")
    ]

    accuracy = 0.85
    assert max(validation_errors) > accuracy, (
        "fashion_mnist_tf_keras did not reach minimum target accuracy {} in {} steps."
        " full validation error history: {}".format(
            accuracy, len(trial_metrics["steps"]), validation_errors))
Exemple #23
0
def test_cifar10_pytorch_accuracy() -> None:
    config = conf.load_config(conf.official_examples_path("cifar10_cnn_pytorch/const.yaml"))
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("cifar10_cnn_pytorch"), 1
    )

    trials = exp.experiment_trials(experiment_id)
    trial_metrics = exp.trial_metrics(trials[0].id)

    validation_errors = [
        step.validation.metrics["validation_metrics"]["validation_accuracy"]
        for step in trial_metrics.steps
        if step.validation
    ]

    target_accuracy = 0.745
    assert max(validation_errors) > target_accuracy, (
        "cifar10_cnn_pytorch did not reach minimum target accuracy {} in {} steps."
        " full validation error history: {}".format(
            target_accuracy, len(trial_metrics.steps), validation_errors
        )
    )
def test_object_detection_accuracy() -> None:
    config = conf.load_config(
        conf.official_examples_path(
            "trial/object_detection_pytorch/const.yaml"))
    config = conf.set_random_seed(config, 1590497309)
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/object_detection_pytorch"),
        1)

    trials = exp.experiment_trials(experiment_id)
    trial_metrics = exp.trial_metrics(trials[0]["id"])

    validation_errors = [
        step["validation"]["metrics"]["validation_metrics"]["val_avg_iou"]
        for step in trial_metrics["steps"] if step.get("validation")
    ]

    target_iou = 0.42
    assert max(validation_errors) > target_iou, (
        "object_detection_pytorch did not reach minimum target accuracy {} in {} steps."
        " full validation error history: {}".format(
            target_iou, len(trial_metrics["steps"]), validation_errors))
Exemple #25
0
def test_pytorch_const_warm_start() -> None:
    """
    Test that specifying an earlier trial checkpoint to warm-start from
    correctly populates the later trials' `warm_start_checkpoint_id` fields.
    """
    config = conf.load_config(
        conf.official_examples_path("mnist_pytorch/const.yaml"))
    config = conf.set_max_steps(config, 2)

    experiment_id1 = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_pytorch"), 1)

    trials = exp.experiment_trials(experiment_id1)
    assert len(trials) == 1

    first_trial = trials[0]
    first_trial_id = first_trial["id"]

    assert len(first_trial["steps"]) == 2
    first_checkpoint_id = first_trial["steps"][-1]["checkpoint"]["id"]

    config_obj = conf.load_config(
        conf.official_examples_path("mnist_pytorch/const.yaml"))

    # Change the search method to random, and add a source trial ID to warm
    # start from.
    config_obj["searcher"]["source_trial_id"] = first_trial_id
    config_obj["searcher"]["name"] = "random"
    config_obj["searcher"]["max_steps"] = 1
    config_obj["searcher"]["max_trials"] = 3

    experiment_id2 = exp.run_basic_test_with_temp_config(
        config_obj, conf.official_examples_path("mnist_pytorch"), 3)

    trials = exp.experiment_trials(experiment_id2)
    assert len(trials) == 3
    for trial in trials:
        assert trial["warm_start_checkpoint_id"] == first_checkpoint_id
Exemple #26
0
def test_mnist_estimmator_const_parallel(native_parallel: bool,
                                         tf2: bool) -> None:
    if tf2 and native_parallel:
        pytest.skip("TF2 native parallel training is not currently supported.")

    config = conf.load_config(
        conf.fixtures_path("mnist_estimator/single-multi-slot.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, native_parallel)
    config = conf.set_max_steps(config, 2)
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_estimator"), 1)
def test_mnist_tp_accuracy() -> None:
    config = conf.load_config(
        conf.official_examples_path("trial/mnist_tp/const.yaml"))
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_tp"), 1)

    trials = exp.experiment_trials(experiment_id)
    trial_metrics = exp.trial_metrics(trials[0]["id"])

    validation_errors = []
    # TODO (DET-3082): The validation metric names were modified by our trial reporting
    # from accuracy to val_accuracy.  We should probably remove the added prefix so
    # the metric name is as specified.
    validation_errors = [
        step["validation"]["metrics"]["validation_metrics"]["val_accuracy"]
        for step in trial_metrics["steps"] if step.get("validation")
    ]

    target_accuracy = 0.95
    assert max(validation_errors) > target_accuracy, (
        "mnist_tp did not reach minimum target accuracy {} in {} steps."
        " full validation error history: {}".format(
            target_accuracy, len(trial_metrics["steps"]), validation_errors))
Exemple #28
0
def test_mnist_estimmator_const_parallel(native_parallel: bool, tf2: bool) -> None:
    if tf2 and native_parallel:
        pytest.skip("TF2 native parallel training is not currently supported.")

    config = conf.load_config(conf.fixtures_path("mnist_estimator/single-multi-slot.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, native_parallel)
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)
    config = conf.set_perform_initial_validation(config, True)

    exp_id = exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_estimator"), 1, has_zeroth_step=True
    )
    exp.assert_performed_initial_validation(exp_id)
Exemple #29
0
def test_invalid_experiment() -> None:
    completed_process = exp.maybe_create_experiment(
        conf.fixtures_path("invalid_experiment/const.yaml"), conf.official_examples_path("mnist_tf")
    )
    assert completed_process.returncode != 0
Exemple #30
0
class NativeImplementations:
    PytorchMNISTCNNSingleGeneric = NativeImplementation(
        cwd=conf.official_examples_path("native/native_mnist_pytorch"),
        command=[
            "python",
            conf.official_examples_path("native/native_mnist_pytorch/trial_impl.py"),
        ],
        configuration={
            "checkpoint_storage": experiment.shared_fs_checkpoint_config(),
            "searcher": {"name": "single", "max_steps": 1, "metric": "validation_error"},
            "max_restarts": 0,
        },
        num_expected_steps_per_trial=1,
        num_expected_trials=1,
    )
    TFEstimatorMNISTCNNSingle = NativeImplementation(
        cwd=conf.official_examples_path("native/native_mnist_estimator"),
        command=[
            "python",
            conf.official_examples_path("native/native_mnist_estimator/native_impl.py"),
        ],
        configuration={
            "batches_per_step": 4,
            "checkpoint_storage": experiment.shared_fs_checkpoint_config(),
            "searcher": {"name": "single", "max_steps": 1, "metric": "accuracy"},
            "max_restarts": 0,
        },
        num_expected_steps_per_trial=1,
        num_expected_trials=1,
    )

    TFEstimatorMNISTCNNSingleGeneric = NativeImplementation(
        cwd=conf.official_examples_path("native/native_mnist_estimator"),
        command=[
            "python",
            conf.official_examples_path("native/native_mnist_estimator/trial_impl.py"),
        ],
        configuration={
            "batches_per_step": 4,
            "checkpoint_storage": experiment.shared_fs_checkpoint_config(),
            "searcher": {"name": "single", "max_steps": 1, "metric": "accuracy"},
            "max_restarts": 0,
        },
        num_expected_steps_per_trial=1,
        num_expected_trials=1,
    )

    # Train a single tf.keras model using fit().
    TFKerasMNISTCNNSingleFit = NativeImplementation(
        cwd=conf.official_examples_path("native/native_fashion_mnist_tf_keras"),
        command=[
            "python",
            conf.official_examples_path("native/native_fashion_mnist_tf_keras/native_impl.py"),
            "--use-fit",
        ],
        configuration={
            "batches_per_step": 4,
            "checkpoint_storage": experiment.shared_fs_checkpoint_config(),
            "searcher": {"name": "single", "max_steps": 1, "metric": "val_accuracy"},
            "max_restarts": 2,
        },
        num_expected_steps_per_trial=1,
        num_expected_trials=1,
    )

    # Train a single tf.keras model using fit_generator().
    TFKerasMNISTCNNSingleFitGenerator = NativeImplementation(
        cwd=conf.official_examples_path("native/native_fashion_mnist_tf_keras"),
        command=[
            "python",
            conf.official_examples_path("native/native_fashion_mnist_tf_keras/native_impl.py"),
        ],
        configuration={
            "batches_per_step": 4,
            "checkpoint_storage": experiment.shared_fs_checkpoint_config(),
            "searcher": {"name": "single", "max_steps": 1, "metric": "val_accuracy"},
            "max_restarts": 2,
        },
        num_expected_steps_per_trial=1,
        num_expected_trials=1,
    )

    TFKerasMNISTCNNSingleGeneric = NativeImplementation(
        cwd=conf.official_examples_path("native/native_fashion_mnist_tf_keras"),
        command=[
            "python",
            conf.official_examples_path("native/native_fashion_mnist_tf_keras/trial_impl.py"),
        ],
        configuration={
            "batches_per_step": 4,
            "checkpoint_storage": experiment.shared_fs_checkpoint_config(),
            "searcher": {"name": "single", "max_steps": 1, "metric": "val_accuracy"},
            "max_restarts": 2,
        },
        num_expected_steps_per_trial=1,
        num_expected_trials=1,
    )