def test_stargan_model_generator_output(self): batch_size = 2 img_size = 16 c_size = 3 num_domains = 5 input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( batch_size, img_size, c_size, num_domains) model = train.stargan_model( generator_fn=stargan_generator_model, discriminator_fn=stargan_discriminator_model, input_data=input_tensor, input_data_domain_label=label_tensor) with self.test_session(use_gpu=True) as sess: sess.run(variables.global_variables_initializer()) input_data, generated_data, reconstructed_data = sess.run([ model.input_data, model.generated_data, model.reconstructed_data ]) self.assertTupleEqual( (batch_size * num_domains, img_size, img_size, c_size), input_data.shape) self.assertTupleEqual( (batch_size * num_domains, img_size, img_size, c_size), generated_data.shape) self.assertTupleEqual( (batch_size * num_domains, img_size, img_size, c_size), reconstructed_data.shape)
def test_stargan_model_discriminator_output(self): batch_size = 2 img_size = 16 c_size = 3 num_domains = 5 input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( batch_size, img_size, c_size, num_domains) model = train.stargan_model( generator_fn=stargan_generator_model, discriminator_fn=stargan_discriminator_model, input_data=input_tensor, input_data_domain_label=label_tensor) with self.test_session(use_gpu=True) as sess: sess.run(variables.global_variables_initializer()) disc_input_data_source_pred, disc_gen_data_source_pred = sess.run([ model.discriminator_input_data_source_predication, model.discriminator_generated_data_source_predication ]) self.assertEqual(1, len(disc_input_data_source_pred.shape)) self.assertEqual(batch_size * num_domains, disc_input_data_source_pred.shape[0]) self.assertEqual(1, len(disc_gen_data_source_pred.shape)) self.assertEqual(batch_size * num_domains, disc_gen_data_source_pred.shape[0]) input_label, disc_input_label, gen_label, disc_gen_label = sess.run( [ model.input_data_domain_label, model.discriminator_input_data_domain_predication, model.generated_data_domain_target, model.discriminator_generated_data_domain_predication ]) self.assertTupleEqual((batch_size * num_domains, num_domains), input_label.shape) self.assertTupleEqual((batch_size * num_domains, num_domains), disc_input_label.shape) self.assertTupleEqual((batch_size * num_domains, num_domains), gen_label.shape) self.assertTupleEqual((batch_size * num_domains, num_domains), disc_gen_label.shape)
def test_stargan_model_output_type(self): batch_size = 2 img_size = 16 c_size = 3 num_domains = 5 input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( batch_size, img_size, c_size, num_domains) model = train.stargan_model( generator_fn=stargan_generator_model, discriminator_fn=stargan_discriminator_model, input_data=input_tensor, input_data_domain_label=label_tensor) self.assertIsInstance(model, namedtuples.StarGANModel) self.assertTrue(isinstance(model.discriminator_variables, list)) self.assertTrue(isinstance(model.generator_variables, list)) self.assertIsInstance(model.discriminator_scope, variable_scope.VariableScope) self.assertTrue(model.generator_scope, variable_scope.VariableScope) self.assertTrue(callable(model.discriminator_fn)) self.assertTrue(callable(model.generator_fn))
def _make_gan_model(generator_fn, discriminator_fn, input_data, input_data_domain_label, generator_scope, add_summaries, mode): """Construct a `StarGANModel`, and optionally pass in `mode`.""" # If network functions have an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=mode) if 'mode' in inspect.getargspec(discriminator_fn).args: discriminator_fn = functools.partial(discriminator_fn, mode=mode) gan_model = tfgan_train.stargan_model(generator_fn, discriminator_fn, input_data, input_data_domain_label, generator_scope=generator_scope) if add_summaries: if not isinstance(add_summaries, (tuple, list)): add_summaries = [add_summaries] with ops.name_scope(None): for summary_type in add_summaries: _summary_type_map[summary_type](gan_model) return gan_model
def create_callable_stargan_model(): return train.stargan_model(StarGANGenerator(), StarGANDiscriminator(), array_ops.ones([1, 2, 2, 3]), array_ops.ones([1, 2]))
def create_stargan_model(): return train.stargan_model(stargan_generator_model, stargan_discriminator_model, array_ops.ones([1, 2, 2, 3]), array_ops.ones([1, 2]))