Example #1
0
  def test_define_train_ops(self):
    hparams = self.hparams._replace(
        batch_size=2, generator_lr=0.1, discriminator_lr=0.01)

    images_shape = [hparams.batch_size, 4, 4, 3]
    images = tf.zeros(images_shape, dtype=tf.float32)
    labels = tf.one_hot([0] * hparams.batch_size, 2)

    model = train_lib._define_model(images, labels)
    loss = tfgan.stargan_loss(model)
    train_ops = train_lib._define_train_ops(model, loss, hparams.generator_lr,
                                            hparams.discriminator_lr,
                                            hparams.adam_beta1,
                                            hparams.adam_beta2,
                                            hparams.max_number_of_steps)

    self.assertIsInstance(train_ops, tfgan.GANTrainOps)
Example #2
0
  def test_define_model(self):
    hparams = self.hparams._replace(batch_size=2)
    images_shape = [hparams.batch_size, 4, 4, 3]
    images_np = np.zeros(shape=images_shape)
    images = tf.constant(images_np, dtype=tf.float32)
    labels = tf.one_hot([0] * hparams.batch_size, 2)

    model = train_lib._define_model(images, labels)
    self.assertIsInstance(model, tfgan.StarGANModel)
    self.assertShapeEqual(images_np, model.generated_data)
    self.assertShapeEqual(images_np, model.reconstructed_data)
    self.assertTrue(isinstance(model.discriminator_variables, list))
    self.assertTrue(isinstance(model.generator_variables, list))
    self.assertIsInstance(model.discriminator_scope, tf.compat.v1.VariableScope)
    self.assertTrue(model.generator_scope, tf.compat.v1.VariableScope)
    self.assertTrue(callable(model.discriminator_fn))
    self.assertTrue(callable(model.generator_fn))