def test_stargan_model_generator_output(self): if tf.executing_eagerly(): # None of the usual utilities work in eager. return batch_size = 2 img_size = 16 c_size = 3 num_domains = 5 input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( batch_size, img_size, c_size, num_domains) model = tfgan.stargan_model( generator_fn=stargan_generator_model, discriminator_fn=stargan_discriminator_model, input_data=input_tensor, input_data_domain_label=label_tensor) with self.cached_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) input_data, generated_data, reconstructed_data = sess.run( [model.input_data, model.generated_data, model.reconstructed_data]) self.assertTupleEqual( (batch_size * num_domains, img_size, img_size, c_size), input_data.shape) self.assertTupleEqual( (batch_size * num_domains, img_size, img_size, c_size), generated_data.shape) self.assertTupleEqual( (batch_size * num_domains, img_size, img_size, c_size), reconstructed_data.shape)
def test_stargan_model_output_type(self): if tf.executing_eagerly(): # None of the usual utilities work in eager. return batch_size = 2 img_size = 16 c_size = 3 num_domains = 5 input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( batch_size, img_size, c_size, num_domains) model = tfgan.stargan_model( generator_fn=stargan_generator_model, discriminator_fn=stargan_discriminator_model, input_data=input_tensor, input_data_domain_label=label_tensor) self.assertIsInstance(model, tfgan.StarGANModel) self.assertTrue(isinstance(model.discriminator_variables, list)) self.assertTrue(isinstance(model.generator_variables, list)) self.assertIsInstance(model.discriminator_scope, tf.compat.v1.VariableScope) self.assertTrue(model.generator_scope, tf.compat.v1.VariableScope) self.assertTrue(callable(model.discriminator_fn)) self.assertTrue(callable(model.generator_fn))
def _define_model(images, labels): """Create the StarGAN Model. Args: images: `Tensor` or list of `Tensor` of shape (N, H, W, C). labels: `Tensor` or list of `Tensor` of shape (N, num_domains). Returns: `StarGANModel` namedtuple. """ return tfgan.stargan_model(generator_fn=network.generator, discriminator_fn=network.discriminator, input_data=images, input_data_domain_label=labels)
def test_stargan_model_discriminator_output(self): if tf.executing_eagerly(): # None of the usual utilities work in eager. return batch_size = 2 img_size = 16 c_size = 3 num_domains = 5 input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( batch_size, img_size, c_size, num_domains) model = tfgan.stargan_model( generator_fn=stargan_generator_model, discriminator_fn=stargan_discriminator_model, input_data=input_tensor, input_data_domain_label=label_tensor) with self.cached_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) disc_input_data_source_pred, disc_gen_data_source_pred = sess.run([ model.discriminator_input_data_source_predication, model.discriminator_generated_data_source_predication ]) self.assertEqual(1, len(disc_input_data_source_pred.shape)) self.assertEqual(batch_size * num_domains, disc_input_data_source_pred.shape[0]) self.assertEqual(1, len(disc_gen_data_source_pred.shape)) self.assertEqual(batch_size * num_domains, disc_gen_data_source_pred.shape[0]) input_label, disc_input_label, gen_label, disc_gen_label = sess.run([ model.input_data_domain_label, model.discriminator_input_data_domain_predication, model.generated_data_domain_target, model.discriminator_generated_data_domain_predication ]) self.assertTupleEqual((batch_size * num_domains, num_domains), input_label.shape) self.assertTupleEqual((batch_size * num_domains, num_domains), disc_input_label.shape) self.assertTupleEqual((batch_size * num_domains, num_domains), gen_label.shape) self.assertTupleEqual((batch_size * num_domains, num_domains), disc_gen_label.shape)
def create_callable_stargan_model(): return tfgan.stargan_model(StarGANGenerator(), StarGANDiscriminator(), tf.ones([1, 2, 2, 3]), tf.ones([1, 2]))
def create_stargan_model(): return tfgan.stargan_model(stargan_generator_model, stargan_discriminator_model, tf.ones([1, 2, 2, 3]), tf.ones([1, 2]))