def test_gan_mnist_pytorch_const() -> None: config = conf.load_config( conf.gan_examples_path("gan_mnist_pytorch/const.yaml")) config = conf.set_max_length(config, {"batches": 200}) exp.run_basic_test_with_temp_config( config, conf.gan_examples_path("gan_mnist_pytorch"), 1)
def test_pl_mnist_gan() -> None: exp_dir = "gan_mnist_pl" config = conf.load_config(conf.gan_examples_path(exp_dir + "/const.yaml")) config = conf.set_max_length(config, {"batches": 200}) config = conf.set_tf2_image(config) exp.run_basic_test_with_temp_config(config, conf.gan_examples_path(exp_dir), 1)
def test_pix2pix_facades_const() -> None: config = conf.load_config( conf.gan_examples_path("pix2pix_tf_keras/const.yaml")) config = conf.set_max_length(config, {"batches": 200}) exp.run_basic_test_with_temp_config( config, conf.gan_examples_path("pix2pix_tf_keras"), 1)
def run_tf_keras_dcgan_example() -> None: config = conf.load_config( conf.gan_examples_path("dcgan_tf_keras/const.yaml")) config = conf.set_max_length(config, {"batches": 200}) config = conf.set_min_validation_period(config, {"batches": 200}) config = conf.set_slots_per_trial(config, 8) config = conf.set_tf2_image(config) exp.run_basic_test_with_temp_config( config, conf.gan_examples_path("dcgan_tf_keras"), 1)
def test_pytorch_gan_parallel() -> None: config = conf.load_config( conf.gan_examples_path("gan_mnist_pytorch/const.yaml")) config = conf.set_max_length(config, {"batches": 200}) config = conf.set_slots_per_trial(config, 8) experiment_id = exp.run_basic_test_with_temp_config( config, conf.gan_examples_path("gan_mnist_pytorch"), 1) trials = exp.experiment_trials(experiment_id) (Determined(conf.make_master_url()).get_trial( trials[0]["id"]).select_checkpoint(latest=True).load( map_location="cpu"))
def run_tf_keras_dcgan_example( collect_trial_profiles: Callable[[int], None]) -> None: config = conf.load_config( conf.gan_examples_path("dcgan_tf_keras/const.yaml")) config = conf.set_max_length(config, {"batches": 200}) config = conf.set_min_validation_period(config, {"batches": 200}) config = conf.set_slots_per_trial(config, 8) config = conf.set_tf2_image(config) config = conf.set_profiling_enabled(config) exp_id = exp.run_basic_test_with_temp_config( config, conf.gan_examples_path("dcgan_tf_keras"), 1) trial_id = exp.experiment_trials(exp_id)[0].trial.id collect_trial_profiles(trial_id)