Ejemplo n.º 1
0
def train(hparams):
  """Trains a StarGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """

  # Create the log_dir if not exist.
  if not tf.io.gfile.exists(hparams.train_log_dir):
    tf.io.gfile.makedirs(hparams.train_log_dir)

  # Shard the model to different parameter servers.
  with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):

    # Create the input dataset.
    with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
      images, labels = data_provider.provide_data('train', hparams.batch_size,
                                                  hparams.patch_size)

    # Define the model.
    with tf.compat.v1.name_scope('model'):
      model = _define_model(images, labels)

    # Add image summary.
    tfgan.eval.add_stargan_image_summaries(
        model, num_images=3 * hparams.batch_size, display_diffs=True)

    # Define the model loss.
    loss = tfgan.stargan_loss(model)

    # Define the train ops.
    with tf.compat.v1.name_scope('train_ops'):
      train_ops = _define_train_ops(model, loss, hparams.generator_lr,
                                    hparams.discriminator_lr,
                                    hparams.adam_beta1, hparams.adam_beta2,
                                    hparams.max_number_of_steps)

    # Define the train steps.
    train_steps = _define_train_step(hparams.gen_disc_step_ratio)

    # Define a status message.
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.compat.v1.train.get_or_create_global_step())
    ],
                                     name='status_message')

    # Train the model.
    tfgan.gan_train(
        train_ops,
        hparams.train_log_dir,
        get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook([status_message], every_n_iter=10)
        ],
        master=hparams.tf_master,
        is_chief=hparams.task == 0)
Ejemplo n.º 2
0
  def test_define_train_ops(self):
    hparams = self.hparams._replace(
        batch_size=2, generator_lr=0.1, discriminator_lr=0.01)

    images_shape = [hparams.batch_size, 4, 4, 3]
    images = tf.zeros(images_shape, dtype=tf.float32)
    labels = tf.one_hot([0] * hparams.batch_size, 2)

    model = train_lib._define_model(images, labels)
    loss = tfgan.stargan_loss(model)
    train_ops = train_lib._define_train_ops(model, loss, hparams.generator_lr,
                                            hparams.discriminator_lr,
                                            hparams.adam_beta1,
                                            hparams.adam_beta2,
                                            hparams.max_number_of_steps)

    self.assertIsInstance(train_ops, tfgan.GANTrainOps)
Ejemplo n.º 3
0
  def test_stargan(self, create_gan_model_fn):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    model = create_gan_model_fn()
    model_loss = tfgan.stargan_loss(model)

    self.assertIsInstance(model_loss, tfgan.GANLoss)

    with self.cached_session() as sess:

      sess.run(tf.compat.v1.global_variables_initializer())

      gen_loss, disc_loss = sess.run(
          [model_loss.generator_loss, model_loss.discriminator_loss])

      self.assertTrue(np.isscalar(gen_loss))
      self.assertTrue(np.isscalar(disc_loss))