Exemple #1
0
  def test_tensor_pool_adjusted_model_no_pool(self):
    """Test `_tensor_pool_adjusted_model` for no pool size."""
    model = create_gan_model()
    new_model = train._tensor_pool_adjusted_model(model, None)

    # Check values.
    self.assertIs(new_model.discriminator_gen_outputs,
                  model.discriminator_gen_outputs)
  def test_tensor_pool_adjusted_model_gan(self):
    model = create_gan_model()

    new_model = train._tensor_pool_adjusted_model(model, None)
    # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
    self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
    self.assertIs(new_model.discriminator_gen_outputs,
                  model.discriminator_gen_outputs)

    pool_size = 5
    new_model = train._tensor_pool_adjusted_model(
        model, get_tensor_pool_fn(pool_size=pool_size))
    self.assertIsNot(new_model.discriminator_gen_outputs,
                     model.discriminator_gen_outputs)
    # Check values.
    self._check_tensor_pool_adjusted_model_outputs(
        model.discriminator_gen_outputs, new_model.discriminator_gen_outputs,
        pool_size)
Exemple #3
0
    def test_tensor_pool_adjusted_model_gan(self):
        model = create_gan_model()

        new_model = train._tensor_pool_adjusted_model(model, None)
        # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
        self.assertLen(ops.get_collection(ops.GraphKeys.VARIABLES), 2)
        self.assertIs(new_model.discriminator_gen_outputs,
                      model.discriminator_gen_outputs)

        pool_size = 5
        new_model = train._tensor_pool_adjusted_model(
            model, get_tensor_pool_fn(pool_size=pool_size))
        self.assertIsNot(new_model.discriminator_gen_outputs,
                         model.discriminator_gen_outputs)
        # Check values.
        self._check_tensor_pool_adjusted_model_outputs(
            model.discriminator_gen_outputs,
            new_model.discriminator_gen_outputs, pool_size)
Exemple #4
0
  def _make_new_model_and_check(self, model, pool_size):
    pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size)
    new_model = train._tensor_pool_adjusted_model(model, pool_fn)
    # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
    self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
    self.assertIsNot(new_model.discriminator_gen_outputs,
                     model.discriminator_gen_outputs)

    return new_model
Exemple #5
0
  def _make_new_model_and_check(self, model, pool_size):
    pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size)
    new_model = train._tensor_pool_adjusted_model(model, pool_fn)
    # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
    self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
    self.assertIsNot(new_model.discriminator_gen_outputs,
                     model.discriminator_gen_outputs)

    return new_model
Exemple #6
0
  def test_tensor_pool_adjusted_model_acgan(self):
    model = create_acgan_model()

    pool_size = 5
    new_model = train._tensor_pool_adjusted_model(
        model, get_tensor_pool_fn(pool_size=pool_size))
    # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
    self.assertLen(ops.get_collection(ops.GraphKeys.VARIABLES), 2)
    self.assertIsNot(new_model.discriminator_gen_outputs,
                     model.discriminator_gen_outputs)
    self.assertIsNot(new_model.discriminator_gen_classification_logits,
                     model.discriminator_gen_classification_logits)
    # Check values.
    self._check_tensor_pool_adjusted_model_outputs(
        model.discriminator_gen_outputs, new_model.discriminator_gen_outputs,
        pool_size)
Exemple #7
0
    def test_tensor_pool_adjusted_model_acgan(self):
        model = create_acgan_model()

        pool_size = 5
        new_model = train._tensor_pool_adjusted_model(
            model, get_tensor_pool_fn(pool_size=pool_size))
        # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
        self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
        self.assertIsNot(new_model.discriminator_gen_outputs,
                         model.discriminator_gen_outputs)
        self.assertIsNot(new_model.discriminator_gen_classification_logits,
                         model.discriminator_gen_classification_logits)
        # Check values.
        self._check_tensor_pool_adjusted_model_outputs(
            model.discriminator_gen_outputs,
            new_model.discriminator_gen_outputs, pool_size)