def test_works_with_get_collection(self): """Tests that gradient penalty works inside other scopes.""" # We ran the discriminator once in the setup, so there should be an op # already in the collection. self.assertEqual( 1, len( ops.get_collection('fake_update_ops', self._kwargs['discriminator_scope'].name))) # Make sure the op is added to the collection even if it's in a name scope. with ops.name_scope('loss'): tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) self.assertEqual( 2, len( ops.get_collection('fake_update_ops', self._kwargs['discriminator_scope'].name))) # Make sure the op is added to the collection even if it's in a variable # scope. with variable_scope.variable_scope('loss_vscope'): tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) self.assertEqual( 3, len( ops.get_collection('fake_update_ops', self._kwargs['discriminator_scope'].name)))
def test_works_with_get_collection(self): """Tests that gradient penalty works inside other scopes.""" # We ran the discriminator once in the setup, so there should be an op # already in the collection. self.assertEqual(1, len(ops.get_collection( 'fake_update_ops', self._kwargs['discriminator_scope'].name))) # Make sure the op is added to the collection even if it's in a name scope. with ops.name_scope('loss'): tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) self.assertEqual(2, len(ops.get_collection( 'fake_update_ops', self._kwargs['discriminator_scope'].name))) # Make sure the op is added to the collection even if it's in a variable # scope. with variable_scope.variable_scope('loss_vscope'): tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) self.assertEqual(3, len(ops.get_collection( 'fake_update_ops', self._kwargs['discriminator_scope'].name)))
def grad_loss(self, real, fake): def tfgan_dis(gen_data, not_used): return self.dis.build(gen_data) loss = gan_losses.wasserstein_gradient_penalty( real_data=real, generated_data=fake, generator_inputs=real, discriminator_fn=tfgan_dis, discriminator_scope=self.dis.name, one_sided=self.one_sided, add_summaries=self.tb_verbose) return self.grad_lambda * loss
def test_loss_with_placeholder(self): generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) loss = tfgan_losses.wasserstein_gradient_penalty( generated_data, real_data, self._kwargs['generator_inputs'], self._kwargs['discriminator_fn'], self._kwargs['discriminator_scope']) self.assertEqual(generated_data.dtype, loss.dtype) with self.test_session() as sess: variables.global_variables_initializer().run() loss = sess.run(loss, feed_dict={ generated_data: self._generated_data_np, real_data: self._real_data_np, }) self.assertAlmostEqual(self._expected_loss, loss, 5)
def test_loss_with_placeholder(self): generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) loss = tfgan_losses.wasserstein_gradient_penalty( generated_data, real_data, self._kwargs['generator_inputs'], self._kwargs['discriminator_fn'], self._kwargs['discriminator_scope']) self.assertEqual(generated_data.dtype, loss.dtype) with self.test_session() as sess: variables.global_variables_initializer().run() loss = sess.run(loss, feed_dict={ generated_data: self._generated_data_np, real_data: self._real_data_np, }) self.assertAlmostEqual(self._expected_loss, loss, 5)
def test_loss_with_gradient_norm_target(self): """Test loss value with non default gradient norm target.""" generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) loss = tfgan_losses.wasserstein_gradient_penalty( generated_data, real_data, self._kwargs['generator_inputs'], self._kwargs['discriminator_fn'], self._kwargs['discriminator_scope'], target=2.0) with self.test_session() as sess: variables.global_variables_initializer().run() loss = sess.run(loss, feed_dict={ generated_data: self._generated_data_np, real_data: self._real_data_np, }) self.assertAlmostEqual(1.0, loss, 5)
def test_loss_with_gradient_norm_target(self): """Test loss value with non default gradient norm target.""" generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) loss = tfgan_losses.wasserstein_gradient_penalty( generated_data, real_data, self._kwargs['generator_inputs'], self._kwargs['discriminator_fn'], self._kwargs['discriminator_scope'], target=2.0) with self.cached_session() as sess: variables.global_variables_initializer().run() loss = sess.run( loss, feed_dict={ generated_data: self._generated_data_np, real_data: self._real_data_np, }) self.assertAlmostEqual(1.0, loss, 5)
def test_reuses_scope(self): """Test that gradient penalty reuses discriminator scope.""" num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) self.assertEqual( num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
def test_reuses_scope(self): """Test that gradient penalty reuses discriminator scope.""" num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) self.assertEqual( num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))