def test_tensor_pool_adjusted_model_no_pool(self): """Test `_tensor_pool_adjusted_model` for no pool size.""" model = create_gan_model() new_model = train._tensor_pool_adjusted_model(model, None) # Check values. self.assertIs(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs)
def test_tensor_pool_adjusted_model_gan(self): model = create_gan_model() new_model = train._tensor_pool_adjusted_model(model, None) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) self.assertIs(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) pool_size = 5 new_model = train._tensor_pool_adjusted_model( model, get_tensor_pool_fn(pool_size=pool_size)) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) # Check values. self._check_tensor_pool_adjusted_model_outputs( model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, pool_size)
def test_tensor_pool_adjusted_model_gan(self): model = create_gan_model() new_model = train._tensor_pool_adjusted_model(model, None) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' self.assertLen(ops.get_collection(ops.GraphKeys.VARIABLES), 2) self.assertIs(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) pool_size = 5 new_model = train._tensor_pool_adjusted_model( model, get_tensor_pool_fn(pool_size=pool_size)) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) # Check values. self._check_tensor_pool_adjusted_model_outputs( model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, pool_size)
def _make_new_model_and_check(self, model, pool_size): pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size) new_model = train._tensor_pool_adjusted_model(model, pool_fn) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) return new_model
def test_tensor_pool_adjusted_model_acgan(self): model = create_acgan_model() pool_size = 5 new_model = train._tensor_pool_adjusted_model( model, get_tensor_pool_fn(pool_size=pool_size)) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' self.assertLen(ops.get_collection(ops.GraphKeys.VARIABLES), 2) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) self.assertIsNot(new_model.discriminator_gen_classification_logits, model.discriminator_gen_classification_logits) # Check values. self._check_tensor_pool_adjusted_model_outputs( model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, pool_size)
def test_tensor_pool_adjusted_model_acgan(self): model = create_acgan_model() pool_size = 5 new_model = train._tensor_pool_adjusted_model( model, get_tensor_pool_fn(pool_size=pool_size)) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) self.assertIsNot(new_model.discriminator_gen_classification_logits, model.discriminator_gen_classification_logits) # Check values. self._check_tensor_pool_adjusted_model_outputs( model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, pool_size)