示例#1
0
    def test_gradient_penalty_loss_with_wrong_input_types_raises(self):
        discriminator = tf.keras.Sequential()

        with self.assertRaisesRegex(
                TypeError, 'should either both be a tf.Tensor '
                'or both a sequence of tf.Tensor'):
            losses.gradient_penalty_loss(real_data=(tf.ones((1, )), ),
                                         generated_data=tf.ones((1, )),
                                         discriminator=discriminator)
示例#2
0
    def test_gradient_penalty_loss_with_unequal_number_of_elements_raises(
            self):
        discriminator = tf.keras.Sequential()

        with self.assertRaisesRegex(
                ValueError,
                'number of elements in real_data and generated_data are '
                'expected to be equal'):
            losses.gradient_penalty_loss(real_data=(tf.ones((1, )), ),
                                         generated_data=(tf.ones(
                                             (1, )), tf.ones((1, ))),
                                         discriminator=discriminator)
示例#3
0
    def test_gradient_penalty_loss_positive(self):
        discriminator = tf.keras.Sequential()
        discriminator.add(tf.keras.layers.Reshape((25, )))
        discriminator.add(tf.keras.layers.Dense(units=1))
        real_data = tf.ones(shape=(1, 5, 5))
        generated_data = tf.ones(shape=(1, 5, 5))

        gradient_penalty = losses.gradient_penalty_loss(
            real_data=real_data,
            generated_data=generated_data,
            discriminator=discriminator)

        self.assertAllGreaterEqual(gradient_penalty, 0.0)
示例#4
0
    def test_gradient_penalty_shape_correct(self):
        discriminator = tf.keras.Sequential()
        discriminator.add(tf.keras.layers.Reshape((25, )))
        discriminator.add(tf.keras.layers.Dense(units=1))
        real_data = tf.ones(shape=(3, 5, 5))
        generated_data = tf.ones(shape=(3, 5, 5))

        gradient_penalty = losses.gradient_penalty_loss(
            real_data=real_data,
            generated_data=generated_data,
            discriminator=discriminator)

        self.assertAllEqual(tf.shape(gradient_penalty), (3, ))
示例#5
0
        def gradient_penalty_fn(weights):
            def multiply(input_tensor):
                return tf.linalg.matmul(input_tensor, weights)

            discriminator = tf.keras.Sequential()
            discriminator.add(tf.keras.layers.Reshape((25, )))
            # To simulate a dense layer a lambda layer is used, such that we are able
            # to feed the weights in as numpy array to the assert_jacobian_fn.
            discriminator.add(tf.keras.layers.Lambda(multiply))

            return losses.gradient_penalty_loss(
                real_data=tf.convert_to_tensor(real_data),
                generated_data=tf.convert_to_tensor(generated_data),
                discriminator=discriminator)
示例#6
0
    def test_gradient_penalty_loss_lambda_for_zero_gradient(self):
        discriminator = tf.keras.Sequential()
        discriminator.add(tf.keras.layers.Reshape((4, )))
        # Generates a dense layer that is initialized with all zeros.
        # This leads to a network that has zero gradient for any input.
        discriminator.add(
            tf.keras.layers.Dense(units=1,
                                  kernel_initializer='zeros',
                                  bias_initializer='zeros'))
        real_data = tf.ones(shape=(1, 2, 2))
        generated_data = tf.ones(shape=(1, 2, 2))
        weight = 1.0

        gradient_penalty = losses.gradient_penalty_loss(
            real_data=real_data,
            generated_data=generated_data,
            discriminator=discriminator,
            weight=weight)

        # Tolerance is large due to eps that is added in the gradient pentaly loss
        # for numerical stability at 0.
        self.assertAllClose(gradient_penalty, (weight, ), atol=0.001)