示例#1
0
 def test_reduction_cyclegan(self):
   if tf.executing_eagerly():
     # None of the usual utilities work in eager.
     return
   loss = tfgan.cyclegan_loss(
       create_cyclegan_model(), reduction=tf.compat.v1.losses.Reduction.NONE)
   self.assertIsInstance(loss, tfgan.CycleGANLoss)
   self.assertEqual(2, loss.loss_x2y.discriminator_loss.shape.ndims)
   self.assertEqual(2, loss.loss_x2y.generator_loss.shape.ndims)
   self.assertEqual(2, loss.loss_y2x.discriminator_loss.shape.ndims)
   self.assertEqual(2, loss.loss_y2x.generator_loss.shape.ndims)
示例#2
0
 def test_cyclegan_output_type(self, get_gan_model_fn):
   if tf.executing_eagerly():
     # None of the usual utilities work in eager.
     return
   loss = tfgan.cyclegan_loss(get_gan_model_fn(), add_summaries=True)
   self.assertIsInstance(loss, tfgan.CycleGANLoss)
   self.assertEqual(0, loss.loss_x2y.discriminator_loss.shape.ndims)
   self.assertEqual(0, loss.loss_x2y.generator_loss.shape.ndims)
   self.assertEqual(0, loss.loss_y2x.discriminator_loss.shape.ndims)
   self.assertEqual(0, loss.loss_y2x.generator_loss.shape.ndims)
   if not tf.executing_eagerly():  # Collections don't work in eager.
     self.assertNotEmpty(
         tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))
示例#3
0
文件: train_lib.py 项目: yyht/gan
def train(hparams):
    """Trains a CycleGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    with tf.device(
            tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
        with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
            images_x, images_y = _get_data(hparams.image_set_x_file_pattern,
                                           hparams.image_set_y_file_pattern,
                                           hparams.batch_size,
                                           hparams.patch_size)

        # Define CycleGAN model.
        cyclegan_model = _define_model(images_x, images_y)

        # Define CycleGAN loss.
        cyclegan_loss = tfgan.cyclegan_loss(
            cyclegan_model,
            cycle_consistency_loss_weight=hparams.
            cycle_consistency_loss_weight,
            tensor_pool_fn=tfgan.features.tensor_pool)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.compat.v1.train.get_or_create_global_step())
        ],
                                         name='status_message')
        if not hparams.max_number_of_steps:
            return
        tfgan.gan_train(
            train_ops,
            hparams.train_log_dir,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                tf.estimator.StopAtStepHook(
                    num_steps=hparams.max_number_of_steps),
                tf.estimator.LoggingTensorHook(
                    {'status_message': status_message}, every_n_iter=10)
            ],
            master=hparams.master,
            is_chief=hparams.task == 0)
示例#4
0
文件: train_test.py 项目: srkm009/gan
    def test_define_train_ops(self):
        self.hparams = self.hparams._replace(batch_size=2,
                                             generator_lr=0.1,
                                             discriminator_lr=0.01)

        images_shape = [self.hparams.batch_size, 4, 4, 3]
        images_x = tf.zeros(images_shape, dtype=tf.float32)
        images_y = tf.zeros(images_shape, dtype=tf.float32)

        cyclegan_model = train_lib._define_model(images_x, images_y)
        cyclegan_loss = tfgan.cyclegan_loss(cyclegan_model,
                                            cycle_consistency_loss_weight=10.0)

        train_ops = train_lib._define_train_ops(cyclegan_model, cyclegan_loss,
                                                self.hparams)
        self.assertIsInstance(train_ops, tfgan.GANTrainOps)
示例#5
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
        with tf.name_scope('inputs'):
            initializer_hook = load_op(FLAGS.batch_size, FLAGS.max_number_of_steps)
            training_input_iter = initializer_hook.input_itr
            images_x, images_y = training_input_iter.get_next()
            # Set batch size for summaries.
            # images_x.set_shape([FLAGS.batch_size, None, None, None])
            # images_y.set_shape([FLAGS.batch_size, None, None, None])

        # Define CycleGAN model.
        cyclegan_model = _define_model(images_x, images_y)

        # Define CycleGAN loss.
        cyclegan_loss = tfgan.cyclegan_loss(
            cyclegan_model,
            cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight,
            tensor_pool_fn=tfgan.features.tensor_pool)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(cyclegan_model, cyclegan_loss)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.string_join(
            [
                'Starting train step: ',
                tf.as_string(tf.train.get_or_create_global_step())
            ],
            name='status_message')
        if not FLAGS.max_number_of_steps:
            return
        tfgan.gan_train(
            train_ops,
            FLAGS.train_log_dir,
            save_checkpoint_secs=60*10,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                initializer_hook,
                tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
                tf.train.LoggingTensorHook([status_message], every_n_iter=10)
            ],
            master=FLAGS.master,
            is_chief=FLAGS.task == 0)
 def define_loss(self, model):
     if self._use_identity_loss:
         cyclegan_loss = cyclegan_loss_with_identity(
             model,
             # generator_loss_fn=wasserstein_generator_loss,
             # discriminator_loss_fn=wasserstein_discriminator_loss,
             cycle_consistency_loss_weight=self.
             _cycle_consistency_loss_weight,
             identity_loss_weight=self._identity_loss_weight,
             tensor_pool_fn=tfgan.features.tensor_pool)
     else:
         # Define CycleGAN loss.
         cyclegan_loss = tfgan.cyclegan_loss(
             model,
             # generator_loss_fn=wasserstein_generator_loss,
             # discriminator_loss_fn=wasserstein_discriminator_loss,
             cycle_consistency_loss_weight=self.
             _cycle_consistency_loss_weight,
             tensor_pool_fn=tfgan.features.tensor_pool)
     return cyclegan_loss
示例#7
0
  def test_cyclegan(self, create_gan_model_fn):
    """Test that CycleGan models work."""
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    model = create_gan_model_fn()
    loss = tfgan.cyclegan_loss(model)
    self.assertIsInstance(loss, tfgan.CycleGANLoss)

    # Check values.
    with self.cached_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      (loss_x2y_gen_np, loss_x2y_dis_np, loss_y2x_gen_np,
       loss_y2x_dis_np) = sess.run([
           loss.loss_x2y.generator_loss, loss.loss_x2y.discriminator_loss,
           loss.loss_y2x.generator_loss, loss.loss_y2x.discriminator_loss
       ])

    self.assertGreater(loss_x2y_gen_np, loss_x2y_dis_np)
    self.assertGreater(loss_y2x_gen_np, loss_y2x_dis_np)
    self.assertTrue(np.isscalar(loss_x2y_gen_np))
    self.assertTrue(np.isscalar(loss_x2y_dis_np))
    self.assertTrue(np.isscalar(loss_y2x_gen_np))
    self.assertTrue(np.isscalar(loss_y2x_dis_np))
示例#8
0
def train(hparams):
  """Trains a CycleGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
  if not tf.io.gfile.exists(hparams.train_log_dir):
    tf.io.gfile.makedirs(hparams.train_log_dir)
    
  with open(hparams.train_log_dir + 'train_result.json', 'w') as fp:
    json.dump(hparams._asdict(), fp, indent=4)

  with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
    with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
      images_x, images_y = _get_data(hparams.image_set_x_file_pattern,
                                     hparams.image_set_y_file_pattern,
                                     hparams.batch_size, hparams.patch_size, hparams.tfdata_source)

    # Define CycleGAN model.
    cyclegan_model = _define_model(images_x, images_y)

    # Define CycleGAN loss.
    cyclegan_loss = tfgan.cyclegan_loss(
        cyclegan_model,
        cycle_consistency_loss_weight=hparams.cycle_consistency_loss_weight,
        tensor_pool_fn=tfgan.features.tensor_pool)

    # Define CycleGAN train ops.
    train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams)

    # Training
    train_steps = tfgan.GANTrainSteps(1, 1)
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.compat.v1.train.get_or_create_global_step())
    ],
                                     name='status_message')
    if not hparams.max_number_of_steps:
      return

    additional_params = {}
    if hparams.save_checkpoint_steps:
        max_to_keep = hparams.max_number_of_steps // hparams.save_checkpoint_steps + 1
        additional_params = {
            'scaffold': tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=max_to_keep)),
            'save_checkpoint_secs': None,
            'save_checkpoint_steps': hparams.save_checkpoint_steps,
        }

    tfgan.gan_train(
        train_ops,
        hparams.train_log_dir,
        get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook({'status_message': status_message},
                                           every_n_iter=10)
        ],
        master=hparams.master,
        is_chief=hparams.task == 0,
        **additional_params,
    )