Example #1
0
def test_pytorch_const_with_amp() -> None:
    config = conf.load_config(
        conf.official_examples_path("mnist_pytorch/const.yaml"))
    config = conf.set_max_steps(config, 2)
    config = conf.set_amp_level(config, "O1")

    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_pytorch"), 1)
Example #2
0
def test_pytorch_const_with_amp() -> None:
    config = conf.load_config(
        conf.official_examples_path("trial/mnist_pytorch/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_amp_level(config, "O1")

    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_pytorch"), 1)
Example #3
0
def test_pytorch_const_parallel(aggregation_frequency: int, use_amp: bool) -> None:
    if use_amp and aggregation_frequency > 1:
        pytest.skip("Mixed precision is not support with aggregation frequency > 1.")

    config = conf.load_config(conf.tutorials_path("mnist_pytorch/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_aggregation_frequency(config, aggregation_frequency)
    if use_amp:
        config = conf.set_amp_level(config, "O1")

    exp.run_basic_test_with_temp_config(config, conf.tutorials_path("mnist_pytorch"), 1)
Example #4
0
def test_pytorch_const_parallel(aggregation_frequency: int,
                                use_amp: bool) -> None:
    config = conf.load_config(
        conf.official_examples_path("mnist_pytorch/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_steps(config, 2)
    config = conf.set_aggregation_frequency(config, aggregation_frequency)
    if use_amp:
        config = conf.set_amp_level(config, "O1")

    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("mnist_pytorch"), 1)
Example #5
0
def test_pytorch_const_parallel(aggregation_frequency: int,
                                use_amp: bool) -> None:
    if use_amp and aggregation_frequency > 1:
        pytest.skip(
            "Mixed precision is not support with aggregation frequency > 1.")

    config = conf.load_config(
        conf.official_examples_path("trial/mnist_pytorch/const.yaml"))
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_native_parallel(config, False)
    config = conf.set_max_steps(config, 2)
    config = conf.set_aggregation_frequency(config, aggregation_frequency)
    if use_amp:
        config = conf.set_amp_level(config, "O1")

    exp.run_basic_test_with_temp_config(
        config, conf.official_examples_path("trial/mnist_pytorch"), 1)