def test_define_model(self, mock_eval): self.hparams = self.hparams._replace(batch_size=2) images_shape = [self.hparams.batch_size, 4, 4, 3] images_x_np = np.zeros(shape=images_shape) images_y_np = np.zeros(shape=images_shape) images_x = tf.constant(images_x_np, dtype=tf.float32) images_y = tf.constant(images_y_np, dtype=tf.float32) cyclegan_model = train_lib._define_model(images_x, images_y) self.assertIsInstance(cyclegan_model, tfgan.CycleGANModel) self.assertShapeEqual(images_x_np, cyclegan_model.reconstructed_x) self.assertShapeEqual(images_y_np, cyclegan_model.reconstructed_y)
def test_define_train_ops(self): self.hparams = self.hparams._replace(batch_size=2, generator_lr=0.1, discriminator_lr=0.01) images_shape = [self.hparams.batch_size, 4, 4, 3] images_x = tf.zeros(images_shape, dtype=tf.float32) images_y = tf.zeros(images_shape, dtype=tf.float32) cyclegan_model = train_lib._define_model(images_x, images_y) cyclegan_loss = tfgan.cyclegan_loss(cyclegan_model, cycle_consistency_loss_weight=10.0) train_ops = train_lib._define_train_ops(cyclegan_model, cyclegan_loss, self.hparams) self.assertIsInstance(train_ops, tfgan.GANTrainOps)