示例#1
0
文件: train.py 项目: zhouyonglong/gan
def main(_):
    hparams = train_lib.HParams(
        FLAGS.image_set_x_file_pattern, FLAGS.image_set_y_file_pattern,
        FLAGS.batch_size, FLAGS.patch_size, FLAGS.master, FLAGS.train_log_dir,
        FLAGS.generator_lr, FLAGS.discriminator_lr, FLAGS.max_number_of_steps,
        FLAGS.ps_replicas, FLAGS.task, FLAGS.cycle_consistency_loss_weight)
    train_lib.train(hparams)
示例#2
0
    def testTrainingAndInferenceGraphsAreCompatible(self,
                                                    mock_provide_custom_data,
                                                    unused_mock_gan_train):
        if tf.executing_eagerly():
            # `tfgan.cyclegan_model` doesn't work when executing eagerly.
            return
        # Training and inference graphs can get out of sync if changes are made
        # to one but not the other. This test will keep them in sync.

        # Save the training graph
        train_sess = tf.Session()
        hparams = train_lib.HParams(image_set_x_file_pattern='/tmp/x/*.jpg',
                                    image_set_y_file_pattern='/tmp/y/*.jpg',
                                    batch_size=3,
                                    patch_size=128,
                                    master='master',
                                    train_log_dir=self._export_dir,
                                    generator_lr=0.02,
                                    discriminator_lr=0.3,
                                    max_number_of_steps=1,
                                    ps_replicas=0,
                                    task=0,
                                    cycle_consistency_loss_weight=2.0)
        mock_provide_custom_data.return_value = (tf.zeros([
            3,
            4,
            4,
            3,
        ]), tf.zeros([3, 4, 4, 3]))
        train_lib.train(hparams)
        init_op = tf.global_variables_initializer()
        train_sess.run(init_op)
        train_saver = tf.train.Saver()
        train_saver.save(train_sess, save_path=self._ckpt_path)

        # Create inference graph
        tf.reset_default_graph()
        FLAGS.patch_dim = hparams.patch_size
        logging.info('dir_path: %s', os.listdir(self._export_dir))
        FLAGS.checkpoint_path = self._ckpt_path
        FLAGS.image_set_x_glob = self._image_glob
        FLAGS.image_set_y_glob = self._image_glob
        FLAGS.generated_x_dir = self._genx_dir
        FLAGS.generated_y_dir = self._geny_dir

        inference_demo.main(None)
        logging.info('gen x: %s', os.listdir(self._genx_dir))

        # Check that the image names match
        self.assertSetEqual(set(_basenames_from_glob(FLAGS.image_set_x_glob)),
                            set(os.listdir(FLAGS.generated_y_dir)))
        self.assertSetEqual(set(_basenames_from_glob(FLAGS.image_set_y_glob)),
                            set(os.listdir(FLAGS.generated_x_dir)))

        # Check that each image in the directory looks as expected
        for directory in [FLAGS.generated_x_dir, FLAGS.generated_x_dir]:
            for base_name in os.listdir(directory):
                image_path = os.path.join(directory, base_name)
                self.assertRealisticImage(image_path)
示例#3
0
文件: train_test.py 项目: srkm009/gan
 def setUp(self):
     super(TrainTest, self).setUp()
     self._original_generator = train_lib.networks.generator
     self._original_discriminator = train_lib.networks.discriminator
     train_lib.networks.generator = _test_generator
     train_lib.networks.discriminator = _test_discriminator
     self.hparams = train_lib.HParams(
         image_set_x_file_pattern=None,
         image_set_y_file_pattern=None,
         batch_size=1,
         patch_size=64,
         master='',
         train_log_dir='/tmp/tfgan_logdir/cyclegan/',
         generator_lr=0.0002,
         discriminator_lr=0.0001,
         max_number_of_steps=500000,
         ps_replicas=0,
         task=0,
         cycle_consistency_loss_weight=10.0)
示例#4
0
    def setUp(self):
        super(TrainTest, self).setUp()

        # Force the TF lazy loading to kick in before mocking these out below.
        _ = tf.train.get_or_create_global_step
        _ = tf.train.AdamOptimizer

        self._original_generator = train_lib.networks.generator
        self._original_discriminator = train_lib.networks.discriminator
        train_lib.networks.generator = _test_generator
        train_lib.networks.discriminator = _test_discriminator
        self.hparams = train_lib.HParams(
            image_set_x_file_pattern=None,
            image_set_y_file_pattern=None,
            batch_size=1,
            patch_size=64,
            master='',
            train_log_dir='/tmp/tfgan_logdir/cyclegan/',
            generator_lr=0.0002,
            discriminator_lr=0.0001,
            max_number_of_steps=500000,
            ps_replicas=0,
            task=0,
            cycle_consistency_loss_weight=10.0)