Exemplo n.º 1
0
    def test_works_with_get_collection(self):
        """Tests that gradient penalty works inside other scopes."""
        # We ran the discriminator once in the setup, so there should be an op
        # already in the collection.
        self.assertEqual(
            1,
            len(
                ops.get_collection('fake_update_ops',
                                   self._kwargs['discriminator_scope'].name)))

        # Make sure the op is added to the collection even if it's in a name scope.
        with ops.name_scope('loss'):
            tfgan_losses.wasserstein_gradient_penalty(**self._kwargs)
        self.assertEqual(
            2,
            len(
                ops.get_collection('fake_update_ops',
                                   self._kwargs['discriminator_scope'].name)))

        # Make sure the op is added to the collection even if it's in a variable
        # scope.
        with variable_scope.variable_scope('loss_vscope'):
            tfgan_losses.wasserstein_gradient_penalty(**self._kwargs)
        self.assertEqual(
            3,
            len(
                ops.get_collection('fake_update_ops',
                                   self._kwargs['discriminator_scope'].name)))
Exemplo n.º 2
0
  def test_works_with_get_collection(self):
    """Tests that gradient penalty works inside other scopes."""
    # We ran the discriminator once in the setup, so there should be an op
    # already in the collection.
    self.assertEqual(1, len(ops.get_collection(
        'fake_update_ops', self._kwargs['discriminator_scope'].name)))

    # Make sure the op is added to the collection even if it's in a name scope.
    with ops.name_scope('loss'):
      tfgan_losses.wasserstein_gradient_penalty(**self._kwargs)
    self.assertEqual(2, len(ops.get_collection(
        'fake_update_ops', self._kwargs['discriminator_scope'].name)))

    # Make sure the op is added to the collection even if it's in a variable
    # scope.
    with variable_scope.variable_scope('loss_vscope'):
      tfgan_losses.wasserstein_gradient_penalty(**self._kwargs)
    self.assertEqual(3, len(ops.get_collection(
        'fake_update_ops', self._kwargs['discriminator_scope'].name)))
Exemplo n.º 3
0
    def grad_loss(self, real, fake):
        def tfgan_dis(gen_data, not_used):
            return self.dis.build(gen_data)

        loss = gan_losses.wasserstein_gradient_penalty(
            real_data=real,
            generated_data=fake,
            generator_inputs=real,
            discriminator_fn=tfgan_dis,
            discriminator_scope=self.dis.name,
            one_sided=self.one_sided,
            add_summaries=self.tb_verbose)
        return self.grad_lambda * loss
Exemplo n.º 4
0
    def test_loss_with_placeholder(self):
        generated_data = array_ops.placeholder(dtypes.float32,
                                               shape=(None, None))
        real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))

        loss = tfgan_losses.wasserstein_gradient_penalty(
            generated_data, real_data, self._kwargs['generator_inputs'],
            self._kwargs['discriminator_fn'],
            self._kwargs['discriminator_scope'])
        self.assertEqual(generated_data.dtype, loss.dtype)

        with self.test_session() as sess:
            variables.global_variables_initializer().run()
            loss = sess.run(loss,
                            feed_dict={
                                generated_data: self._generated_data_np,
                                real_data: self._real_data_np,
                            })
            self.assertAlmostEqual(self._expected_loss, loss, 5)
Exemplo n.º 5
0
  def test_loss_with_placeholder(self):
    generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
    real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))

    loss = tfgan_losses.wasserstein_gradient_penalty(
        generated_data,
        real_data,
        self._kwargs['generator_inputs'],
        self._kwargs['discriminator_fn'],
        self._kwargs['discriminator_scope'])
    self.assertEqual(generated_data.dtype, loss.dtype)

    with self.test_session() as sess:
      variables.global_variables_initializer().run()
      loss = sess.run(loss,
                      feed_dict={
                          generated_data: self._generated_data_np,
                          real_data: self._real_data_np,
                      })
      self.assertAlmostEqual(self._expected_loss, loss, 5)
Exemplo n.º 6
0
    def test_loss_with_gradient_norm_target(self):
        """Test loss value with non default gradient norm target."""
        generated_data = array_ops.placeholder(dtypes.float32,
                                               shape=(None, None))
        real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))

        loss = tfgan_losses.wasserstein_gradient_penalty(
            generated_data,
            real_data,
            self._kwargs['generator_inputs'],
            self._kwargs['discriminator_fn'],
            self._kwargs['discriminator_scope'],
            target=2.0)

        with self.test_session() as sess:
            variables.global_variables_initializer().run()
            loss = sess.run(loss,
                            feed_dict={
                                generated_data: self._generated_data_np,
                                real_data: self._real_data_np,
                            })
            self.assertAlmostEqual(1.0, loss, 5)
Exemplo n.º 7
0
  def test_loss_with_gradient_norm_target(self):
    """Test loss value with non default gradient norm target."""
    generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
    real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))

    loss = tfgan_losses.wasserstein_gradient_penalty(
        generated_data,
        real_data,
        self._kwargs['generator_inputs'],
        self._kwargs['discriminator_fn'],
        self._kwargs['discriminator_scope'],
        target=2.0)

    with self.cached_session() as sess:
      variables.global_variables_initializer().run()
      loss = sess.run(
          loss,
          feed_dict={
              generated_data: self._generated_data_np,
              real_data: self._real_data_np,
          })
      self.assertAlmostEqual(1.0, loss, 5)
Exemplo n.º 8
0
 def test_reuses_scope(self):
     """Test that gradient penalty reuses discriminator scope."""
     num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
     tfgan_losses.wasserstein_gradient_penalty(**self._kwargs)
     self.assertEqual(
         num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
Exemplo n.º 9
0
 def test_reuses_scope(self):
   """Test that gradient penalty reuses discriminator scope."""
   num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
   tfgan_losses.wasserstein_gradient_penalty(**self._kwargs)
   self.assertEqual(
       num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))