Exemplo n.º 1
0
def test_pytorch_11_const(aggregation_frequency: int, using_k8s: bool) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"))
    config = conf.set_aggregation_frequency(config, aggregation_frequency)

    if using_k8s:
        pod_spec = {
            "metadata": {
                "labels": {
                    "ci": "testing"
                }
            },
            "spec": {
                "containers": [{
                    "name":
                    "determined-container",
                    "volumeMounts": [{
                        "name": "temp1",
                        "mountPath": "/random"
                    }],
                }],
                "volumes": [{
                    "name": "temp1",
                    "emptyDir": {}
                }],
            },
        }
        config = conf.set_pod_spec(config, pod_spec)

    exp.run_basic_test_with_temp_config(config,
                                        conf.tutorials_path("mnist_pytorch"),
                                        1)
def test_pytorch_11_const(
        aggregation_frequency: int, using_k8s: bool,
        collect_trial_profiles: Callable[[int], None]) -> None:
    config = conf.load_config(
        conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"))
    config = conf.set_aggregation_frequency(config, aggregation_frequency)
    config = conf.set_profiling_enabled(config)

    if using_k8s:
        pod_spec = {
            "metadata": {
                "labels": {
                    "ci": "testing"
                }
            },
            "spec": {
                "containers": [{
                    "name":
                    "determined-container",
                    "volumeMounts": [{
                        "name": "temp1",
                        "mountPath": "/random"
                    }],
                }],
                "volumes": [{
                    "name": "temp1",
                    "emptyDir": {}
                }],
            },
        }
        config = conf.set_pod_spec(config, pod_spec)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.tutorials_path("mnist_pytorch"), 1)
    trial_id = exp.experiment_trials(experiment_id)[0].trial.id
    collect_trial_profiles(trial_id)