def test_main(self, mock_provide_celeba_test_set, mock_provide_data): hparams = train_lib.HParams(batch_size=1, patch_size=8, output_dir='/tmp/tfgan_logdir/stargan/', generator_lr=1e-4, discriminator_lr=1e-4, max_number_of_steps=0, steps_per_eval=1, adam_beta1=0.5, adam_beta2=0.999, gen_disc_step_ratio=0.2, master='', ps_tasks=0, task=0) num_domains = 3 # Construct mock inputs. images_shape = [ hparams.batch_size, hparams.patch_size, hparams.patch_size, 3 ] img_list = [np.zeros(images_shape, dtype=np.float32)] * num_domains # Create a list of num_domains arrays of shape [batch_size, num_domains]. # Note: assumes hparams.batch_size <= num_domains. lbl_list = [np.eye(num_domains)[:hparams.batch_size, :]] * num_domains mock_provide_data.return_value = (img_list, lbl_list) mock_provide_celeba_test_set.return_value = np.zeros( [3, hparams.patch_size, hparams.patch_size, 3]) train_lib.train(hparams, _test_generator, _test_discriminator)
def main(_): hparams = train_lib.HParams( FLAGS.batch_size, FLAGS.patch_size, FLAGS.output_dir, FLAGS.generator_lr, FLAGS.discriminator_lr, FLAGS.max_number_of_steps, FLAGS.steps_per_eval, FLAGS.adam_beta1, FLAGS.adam_beta2, FLAGS.gen_disc_step_ratio, FLAGS.master, FLAGS.ps_tasks, FLAGS.task) train_lib.train(hparams)
def main(_): hparams = train_lib.HParams( FLAGS.batch_size, FLAGS.patch_size, FLAGS.output_dir, FLAGS.generator_lr, FLAGS.discriminator_lr, FLAGS.max_number_of_steps, FLAGS.steps_per_eval, FLAGS.adam_beta1, FLAGS.adam_beta2, FLAGS.gen_disc_step_ratio, FLAGS.master, FLAGS.ps_tasks, FLAGS.task, FLAGS.tfdata_source, FLAGS.tfdata_source_domains, FLAGS.download, FLAGS.data_dir, FLAGS.cls_model, FLAGS.cls_checkpoint, FLAGS.save_checkpoints_steps, FLAGS.keep_checkpoint_max, FLAGS.reconstruction_loss_weight, FLAGS.self_consistency_loss_weight, FLAGS.classification_loss_weight, FLAGS.use_color_labels) override_generator_fn = None # override_generator_fn = network.generator_hack # override_generator_fn = network.generator_smooth train_lib.train(hparams, override_generator_fn=override_generator_fn)