示例#1
0
def test_gan_mnist_pytorch_const() -> None:
    config = conf.load_config(
        conf.gan_examples_path("gan_mnist_pytorch/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})

    exp.run_basic_test_with_temp_config(
        config, conf.gan_examples_path("gan_mnist_pytorch"), 1)
def test_pl_mnist_gan() -> None:
    exp_dir = "gan_mnist_pl"
    config = conf.load_config(conf.gan_examples_path(exp_dir + "/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_tf2_image(config)

    exp.run_basic_test_with_temp_config(config, conf.gan_examples_path(exp_dir), 1)
def test_pix2pix_facades_const() -> None:
    config = conf.load_config(
        conf.gan_examples_path("pix2pix_tf_keras/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})

    exp.run_basic_test_with_temp_config(
        config, conf.gan_examples_path("pix2pix_tf_keras"), 1)
示例#4
0
def run_tf_keras_dcgan_example() -> None:
    config = conf.load_config(
        conf.gan_examples_path("dcgan_tf_keras/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_min_validation_period(config, {"batches": 200})
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_tf2_image(config)

    exp.run_basic_test_with_temp_config(
        config, conf.gan_examples_path("dcgan_tf_keras"), 1)
示例#5
0
def test_pytorch_gan_parallel() -> None:
    config = conf.load_config(
        conf.gan_examples_path("gan_mnist_pytorch/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_slots_per_trial(config, 8)

    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.gan_examples_path("gan_mnist_pytorch"), 1)
    trials = exp.experiment_trials(experiment_id)
    (Determined(conf.make_master_url()).get_trial(
        trials[0]["id"]).select_checkpoint(latest=True).load(
            map_location="cpu"))
示例#6
0
def run_tf_keras_dcgan_example(
        collect_trial_profiles: Callable[[int], None]) -> None:
    config = conf.load_config(
        conf.gan_examples_path("dcgan_tf_keras/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})
    config = conf.set_min_validation_period(config, {"batches": 200})
    config = conf.set_slots_per_trial(config, 8)
    config = conf.set_tf2_image(config)
    config = conf.set_profiling_enabled(config)

    exp_id = exp.run_basic_test_with_temp_config(
        config, conf.gan_examples_path("dcgan_tf_keras"), 1)
    trial_id = exp.experiment_trials(exp_id)[0].trial.id
    collect_trial_profiles(trial_id)