Beispiel #1
0
    def test_main(self, mock_provide_celeba_test_set, mock_provide_data):
        hparams = train_lib.HParams(batch_size=1,
                                    patch_size=8,
                                    output_dir='/tmp/tfgan_logdir/stargan/',
                                    generator_lr=1e-4,
                                    discriminator_lr=1e-4,
                                    max_number_of_steps=0,
                                    steps_per_eval=1,
                                    adam_beta1=0.5,
                                    adam_beta2=0.999,
                                    gen_disc_step_ratio=0.2,
                                    master='',
                                    ps_tasks=0,
                                    task=0)
        num_domains = 3

        # Construct mock inputs.
        images_shape = [
            hparams.batch_size, hparams.patch_size, hparams.patch_size, 3
        ]
        img_list = [np.zeros(images_shape, dtype=np.float32)] * num_domains
        # Create a list of num_domains arrays of shape [batch_size, num_domains].
        # Note: assumes hparams.batch_size <= num_domains.
        lbl_list = [np.eye(num_domains)[:hparams.batch_size, :]] * num_domains
        mock_provide_data.return_value = (img_list, lbl_list)
        mock_provide_celeba_test_set.return_value = np.zeros(
            [3, hparams.patch_size, hparams.patch_size, 3])

        train_lib.train(hparams, _test_generator, _test_discriminator)
Beispiel #2
0
def main(_):
    hparams = train_lib.HParams(
        FLAGS.batch_size, FLAGS.patch_size, FLAGS.output_dir,
        FLAGS.generator_lr, FLAGS.discriminator_lr, FLAGS.max_number_of_steps,
        FLAGS.steps_per_eval, FLAGS.adam_beta1, FLAGS.adam_beta2,
        FLAGS.gen_disc_step_ratio, FLAGS.master, FLAGS.ps_tasks, FLAGS.task)
    train_lib.train(hparams)
Beispiel #3
0
def main(_):
    hparams = train_lib.HParams(
        FLAGS.batch_size, FLAGS.patch_size, FLAGS.output_dir,
        FLAGS.generator_lr, FLAGS.discriminator_lr, FLAGS.max_number_of_steps,
        FLAGS.steps_per_eval, FLAGS.adam_beta1, FLAGS.adam_beta2,
        FLAGS.gen_disc_step_ratio, FLAGS.master, FLAGS.ps_tasks, FLAGS.task,
        FLAGS.tfdata_source, FLAGS.tfdata_source_domains, FLAGS.download,
        FLAGS.data_dir, FLAGS.cls_model, FLAGS.cls_checkpoint,
        FLAGS.save_checkpoints_steps, FLAGS.keep_checkpoint_max,
        FLAGS.reconstruction_loss_weight, FLAGS.self_consistency_loss_weight,
        FLAGS.classification_loss_weight, FLAGS.use_color_labels)

    override_generator_fn = None
    # override_generator_fn = network.generator_hack
    # override_generator_fn = network.generator_smooth

    train_lib.train(hparams, override_generator_fn=override_generator_fn)