if hparams["adversarial_loss"] == "nsgan": adversarial_loss_fn = non_saturating_gan_losses hparams["n_discriminator_iters"] = 1 elif hparams["adversarial_loss"] == "lsgan": adversarial_loss_fn = lsgan_losses hparams["n_discriminator_iters"] = 1 elif hparams["adversarial_loss"] == "wgan-gp": adversarial_loss_fn = wgan_gp_losses hparams["n_discriminator_iters"] = 5 hparams["wgan_gp_lambda"] = 10.0 else: raise ValueError("Invalid adversarial loss fn.") discogan = DiscoGAN(a_train=input_fn(args.x_train, hparams["batch_size"]), a_test=input_fn(args.x_test, 3), a_test_static=first_n(args.x_test, 10), b_train=input_fn(args.y_train, hparams["batch_size"]), b_test=input_fn(args.y_test, 3), b_test_static=first_n(args.y_test, 10), generator_fn=model_architecture.generator, discriminator_fn=model_architecture.discriminator, adversarial_loss_fn=adversarial_loss_fn, **hparams) experiments.run_experiment(model_dir=args.model_dir or experiments.ROOT_RUNS_DIR / experiments.experiment_name("discogan", hparams), model=discogan, n_training_step=args.steps)
elif hparams["adversarial_loss"] == "wgan-gp": adversarial_loss_fn = wgan_gp_losses hparams["n_discriminator_iters"] = 5 hparams["wgan_gp_lambda"] = 10.0 else: raise ValueError("Invalid adversarial loss fn.") attgan = AttGAN( attribute_names=args.considered_attributes, img=img_train, attributes=attributes_train, img_test=img_test, attributes_test=attributes_test, img_test_static=img_test_static, attributes_test_static=attributes_test_static, encoder_fn=model_architecture.encoder, decoder_fn=model_architecture.decoder, classifier_discriminator_shared_fn=model_architecture. classifier_discriminator_shared, classifier_private_fn=model_architecture.classifier_private, discriminator_private_fn=model_architecture.discriminator_private, adversarial_loss_fn=adversarial_loss_fn, **hparams) experiments.run_experiment(model_dir=args.model_dir or experiments.ROOT_RUNS_DIR / experiments.experiment_name("attgan", hparams), model=attgan, n_training_step=args.steps, custom_init_op=init_op)