예제 #1
0
파일: train_test.py 프로젝트: srkm009/gan
    def test_define_model(self, mock_eval):
        self.hparams = self.hparams._replace(batch_size=2)
        images_shape = [self.hparams.batch_size, 4, 4, 3]
        images_x_np = np.zeros(shape=images_shape)
        images_y_np = np.zeros(shape=images_shape)
        images_x = tf.constant(images_x_np, dtype=tf.float32)
        images_y = tf.constant(images_y_np, dtype=tf.float32)

        cyclegan_model = train_lib._define_model(images_x, images_y)
        self.assertIsInstance(cyclegan_model, tfgan.CycleGANModel)
        self.assertShapeEqual(images_x_np, cyclegan_model.reconstructed_x)
        self.assertShapeEqual(images_y_np, cyclegan_model.reconstructed_y)
예제 #2
0
파일: train_test.py 프로젝트: srkm009/gan
    def test_define_train_ops(self):
        self.hparams = self.hparams._replace(batch_size=2,
                                             generator_lr=0.1,
                                             discriminator_lr=0.01)

        images_shape = [self.hparams.batch_size, 4, 4, 3]
        images_x = tf.zeros(images_shape, dtype=tf.float32)
        images_y = tf.zeros(images_shape, dtype=tf.float32)

        cyclegan_model = train_lib._define_model(images_x, images_y)
        cyclegan_loss = tfgan.cyclegan_loss(cyclegan_model,
                                            cycle_consistency_loss_weight=10.0)

        train_ops = train_lib._define_train_ops(cyclegan_model, cyclegan_loss,
                                                self.hparams)
        self.assertIsInstance(train_ops, tfgan.GANTrainOps)