def _make_new_model_and_check(self, model, pool_size): pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size) new_model = train._tensor_pool_adjusted_model(model, pool_fn) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) return new_model
def _make_new_model_and_check(self, model, pool_size): pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size) new_model = train._tensor_pool_adjusted_model(model, pool_fn) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) return new_model
def test_tensor_pool(self, create_gan_model_fn): """Test tensor pool option.""" model = create_gan_model_fn() tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5) loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) self.assertIsInstance(loss, namedtuples.GANLoss) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() for _ in range(10): sess.run([loss.generator_loss, loss.discriminator_loss])
def test_tensor_pool(self, create_gan_model_fn): """Test tensor pool option.""" model = create_gan_model_fn() tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5) loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) self.assertIsInstance(loss, namedtuples.GANLoss) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() for _ in range(10): sess.run([loss.generator_loss, loss.discriminator_loss])
def tensor_pool_fn_impl(input_values): generated_data, generator_inputs = input_values output_values = random_tensor_pool.tensor_pool( [generated_data] + generator_inputs, pool_size=pool_size) return output_values[0], output_values[1:]
def tensor_pool_fn_impl(input_values): return random_tensor_pool.tensor_pool(input_values, pool_size=pool_size)
def tensor_pool_fn_impl(input_values): generated_data, generator_inputs = input_values output_values = random_tensor_pool.tensor_pool([generated_data] + generator_inputs, pool_size=pool_size) return output_values[0], output_values[1:]
def tensor_pool_fn_impl(input_values): return random_tensor_pool.tensor_pool(input_values, pool_size=pool_size)