Example #1
0
    def test_stargan_model_generator_output(self):
        batch_size = 2
        img_size = 16
        c_size = 3
        num_domains = 5

        input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
            batch_size, img_size, c_size, num_domains)
        model = train.stargan_model(
            generator_fn=stargan_generator_model,
            discriminator_fn=stargan_discriminator_model,
            input_data=input_tensor,
            input_data_domain_label=label_tensor)

        with self.test_session(use_gpu=True) as sess:

            sess.run(variables.global_variables_initializer())

            input_data, generated_data, reconstructed_data = sess.run([
                model.input_data, model.generated_data,
                model.reconstructed_data
            ])
            self.assertTupleEqual(
                (batch_size * num_domains, img_size, img_size, c_size),
                input_data.shape)
            self.assertTupleEqual(
                (batch_size * num_domains, img_size, img_size, c_size),
                generated_data.shape)
            self.assertTupleEqual(
                (batch_size * num_domains, img_size, img_size, c_size),
                reconstructed_data.shape)
Example #2
0
  def test_stargan_model_generator_output(self):
    batch_size = 2
    img_size = 16
    c_size = 3
    num_domains = 5

    input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
        batch_size, img_size, c_size, num_domains)
    model = train.stargan_model(
        generator_fn=stargan_generator_model,
        discriminator_fn=stargan_discriminator_model,
        input_data=input_tensor,
        input_data_domain_label=label_tensor)

    with self.test_session(use_gpu=True) as sess:

      sess.run(variables.global_variables_initializer())

      input_data, generated_data, reconstructed_data = sess.run(
          [model.input_data, model.generated_data, model.reconstructed_data])
      self.assertTupleEqual(
          (batch_size * num_domains, img_size, img_size, c_size),
          input_data.shape)
      self.assertTupleEqual(
          (batch_size * num_domains, img_size, img_size, c_size),
          generated_data.shape)
      self.assertTupleEqual(
          (batch_size * num_domains, img_size, img_size, c_size),
          reconstructed_data.shape)
Example #3
0
    def test_stargan_model_discriminator_output(self):

        batch_size = 2
        img_size = 16
        c_size = 3
        num_domains = 5

        input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
            batch_size, img_size, c_size, num_domains)
        model = train.stargan_model(
            generator_fn=stargan_generator_model,
            discriminator_fn=stargan_discriminator_model,
            input_data=input_tensor,
            input_data_domain_label=label_tensor)

        with self.test_session(use_gpu=True) as sess:

            sess.run(variables.global_variables_initializer())

            disc_input_data_source_pred, disc_gen_data_source_pred = sess.run([
                model.discriminator_input_data_source_predication,
                model.discriminator_generated_data_source_predication
            ])
            self.assertEqual(1, len(disc_input_data_source_pred.shape))
            self.assertEqual(batch_size * num_domains,
                             disc_input_data_source_pred.shape[0])
            self.assertEqual(1, len(disc_gen_data_source_pred.shape))
            self.assertEqual(batch_size * num_domains,
                             disc_gen_data_source_pred.shape[0])

            input_label, disc_input_label, gen_label, disc_gen_label = sess.run(
                [
                    model.input_data_domain_label,
                    model.discriminator_input_data_domain_predication,
                    model.generated_data_domain_target,
                    model.discriminator_generated_data_domain_predication
                ])
            self.assertTupleEqual((batch_size * num_domains, num_domains),
                                  input_label.shape)
            self.assertTupleEqual((batch_size * num_domains, num_domains),
                                  disc_input_label.shape)
            self.assertTupleEqual((batch_size * num_domains, num_domains),
                                  gen_label.shape)
            self.assertTupleEqual((batch_size * num_domains, num_domains),
                                  disc_gen_label.shape)
Example #4
0
  def test_stargan_model_discriminator_output(self):

    batch_size = 2
    img_size = 16
    c_size = 3
    num_domains = 5

    input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
        batch_size, img_size, c_size, num_domains)
    model = train.stargan_model(
        generator_fn=stargan_generator_model,
        discriminator_fn=stargan_discriminator_model,
        input_data=input_tensor,
        input_data_domain_label=label_tensor)

    with self.test_session(use_gpu=True) as sess:

      sess.run(variables.global_variables_initializer())

      disc_input_data_source_pred, disc_gen_data_source_pred = sess.run([
          model.discriminator_input_data_source_predication,
          model.discriminator_generated_data_source_predication
      ])
      self.assertEqual(1, len(disc_input_data_source_pred.shape))
      self.assertEqual(batch_size * num_domains,
                       disc_input_data_source_pred.shape[0])
      self.assertEqual(1, len(disc_gen_data_source_pred.shape))
      self.assertEqual(batch_size * num_domains,
                       disc_gen_data_source_pred.shape[0])

      input_label, disc_input_label, gen_label, disc_gen_label = sess.run([
          model.input_data_domain_label,
          model.discriminator_input_data_domain_predication,
          model.generated_data_domain_target,
          model.discriminator_generated_data_domain_predication
      ])
      self.assertTupleEqual((batch_size * num_domains, num_domains),
                            input_label.shape)
      self.assertTupleEqual((batch_size * num_domains, num_domains),
                            disc_input_label.shape)
      self.assertTupleEqual((batch_size * num_domains, num_domains),
                            gen_label.shape)
      self.assertTupleEqual((batch_size * num_domains, num_domains),
                            disc_gen_label.shape)
Example #5
0
def _make_gan_model(generator_fn, discriminator_fn, input_data,
                    input_data_domain_label, generator_scope, add_summaries,
                    mode):
    """Construct a `StarGANModel`, and optionally pass in `mode`."""
    # If network functions have an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn, mode=mode)
    if 'mode' in inspect.getargspec(discriminator_fn).args:
        discriminator_fn = functools.partial(discriminator_fn, mode=mode)
    gan_model = tfgan_train.stargan_model(generator_fn,
                                          discriminator_fn,
                                          input_data,
                                          input_data_domain_label,
                                          generator_scope=generator_scope)
    if add_summaries:
        if not isinstance(add_summaries, (tuple, list)):
            add_summaries = [add_summaries]
        with ops.name_scope(None):
            for summary_type in add_summaries:
                _summary_type_map[summary_type](gan_model)

    return gan_model
Example #6
0
    def test_stargan_model_output_type(self):
        batch_size = 2
        img_size = 16
        c_size = 3
        num_domains = 5

        input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
            batch_size, img_size, c_size, num_domains)
        model = train.stargan_model(
            generator_fn=stargan_generator_model,
            discriminator_fn=stargan_discriminator_model,
            input_data=input_tensor,
            input_data_domain_label=label_tensor)

        self.assertIsInstance(model, namedtuples.StarGANModel)
        self.assertTrue(isinstance(model.discriminator_variables, list))
        self.assertTrue(isinstance(model.generator_variables, list))
        self.assertIsInstance(model.discriminator_scope,
                              variable_scope.VariableScope)
        self.assertTrue(model.generator_scope, variable_scope.VariableScope)
        self.assertTrue(callable(model.discriminator_fn))
        self.assertTrue(callable(model.generator_fn))
Example #7
0
  def test_stargan_model_output_type(self):
    batch_size = 2
    img_size = 16
    c_size = 3
    num_domains = 5

    input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
        batch_size, img_size, c_size, num_domains)
    model = train.stargan_model(
        generator_fn=stargan_generator_model,
        discriminator_fn=stargan_discriminator_model,
        input_data=input_tensor,
        input_data_domain_label=label_tensor)

    self.assertIsInstance(model, namedtuples.StarGANModel)
    self.assertTrue(isinstance(model.discriminator_variables, list))
    self.assertTrue(isinstance(model.generator_variables, list))
    self.assertIsInstance(model.discriminator_scope,
                          variable_scope.VariableScope)
    self.assertTrue(model.generator_scope, variable_scope.VariableScope)
    self.assertTrue(callable(model.discriminator_fn))
    self.assertTrue(callable(model.generator_fn))
def _make_gan_model(generator_fn, discriminator_fn, input_data,
                    input_data_domain_label, generator_scope, add_summaries,
                    mode):
  """Construct a `StarGANModel`, and optionally pass in `mode`."""
  # If network functions have an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(generator_fn, mode=mode)
  if 'mode' in inspect.getargspec(discriminator_fn).args:
    discriminator_fn = functools.partial(discriminator_fn, mode=mode)
  gan_model = tfgan_train.stargan_model(
      generator_fn,
      discriminator_fn,
      input_data,
      input_data_domain_label,
      generator_scope=generator_scope)
  if add_summaries:
    if not isinstance(add_summaries, (tuple, list)):
      add_summaries = [add_summaries]
    with ops.name_scope(None):
      for summary_type in add_summaries:
        _summary_type_map[summary_type](gan_model)

  return gan_model
Example #9
0
def create_callable_stargan_model():
    return train.stargan_model(StarGANGenerator(), StarGANDiscriminator(),
                               array_ops.ones([1, 2, 2, 3]),
                               array_ops.ones([1, 2]))
Example #10
0
def create_stargan_model():
    return train.stargan_model(stargan_generator_model,
                               stargan_discriminator_model,
                               array_ops.ones([1, 2, 2, 3]),
                               array_ops.ones([1, 2]))
Example #11
0
def create_callable_stargan_model():
  return train.stargan_model(StarGANGenerator(), StarGANDiscriminator(),
                             array_ops.ones([1, 2, 2, 3]),
                             array_ops.ones([1, 2]))
Example #12
0
def create_stargan_model():
  return train.stargan_model(
      stargan_generator_model, stargan_discriminator_model,
      array_ops.ones([1, 2, 2, 3]), array_ops.ones([1, 2]))