コード例 #1
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
  def test_stargan_model_generator_output(self):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return

    batch_size = 2
    img_size = 16
    c_size = 3
    num_domains = 5

    input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
        batch_size, img_size, c_size, num_domains)
    model = tfgan.stargan_model(
        generator_fn=stargan_generator_model,
        discriminator_fn=stargan_discriminator_model,
        input_data=input_tensor,
        input_data_domain_label=label_tensor)

    with self.cached_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())

      input_data, generated_data, reconstructed_data = sess.run(
          [model.input_data, model.generated_data, model.reconstructed_data])
      self.assertTupleEqual(
          (batch_size * num_domains, img_size, img_size, c_size),
          input_data.shape)
      self.assertTupleEqual(
          (batch_size * num_domains, img_size, img_size, c_size),
          generated_data.shape)
      self.assertTupleEqual(
          (batch_size * num_domains, img_size, img_size, c_size),
          reconstructed_data.shape)
コード例 #2
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
  def test_stargan_model_output_type(self):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return

    batch_size = 2
    img_size = 16
    c_size = 3
    num_domains = 5

    input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
        batch_size, img_size, c_size, num_domains)
    model = tfgan.stargan_model(
        generator_fn=stargan_generator_model,
        discriminator_fn=stargan_discriminator_model,
        input_data=input_tensor,
        input_data_domain_label=label_tensor)

    self.assertIsInstance(model, tfgan.StarGANModel)
    self.assertTrue(isinstance(model.discriminator_variables, list))
    self.assertTrue(isinstance(model.generator_variables, list))
    self.assertIsInstance(model.discriminator_scope, tf.compat.v1.VariableScope)
    self.assertTrue(model.generator_scope, tf.compat.v1.VariableScope)
    self.assertTrue(callable(model.discriminator_fn))
    self.assertTrue(callable(model.generator_fn))
コード例 #3
0
ファイル: train_lib.py プロジェクト: zhouyonglong/gan
def _define_model(images, labels):
    """Create the StarGAN Model.

  Args:
    images: `Tensor` or list of `Tensor` of shape (N, H, W, C).
    labels: `Tensor` or list of `Tensor` of shape (N, num_domains).

  Returns:
    `StarGANModel` namedtuple.
  """

    return tfgan.stargan_model(generator_fn=network.generator,
                               discriminator_fn=network.discriminator,
                               input_data=images,
                               input_data_domain_label=labels)
コード例 #4
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
  def test_stargan_model_discriminator_output(self):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return

    batch_size = 2
    img_size = 16
    c_size = 3
    num_domains = 5

    input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
        batch_size, img_size, c_size, num_domains)
    model = tfgan.stargan_model(
        generator_fn=stargan_generator_model,
        discriminator_fn=stargan_discriminator_model,
        input_data=input_tensor,
        input_data_domain_label=label_tensor)

    with self.cached_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())

      disc_input_data_source_pred, disc_gen_data_source_pred = sess.run([
          model.discriminator_input_data_source_predication,
          model.discriminator_generated_data_source_predication
      ])
      self.assertEqual(1, len(disc_input_data_source_pred.shape))
      self.assertEqual(batch_size * num_domains,
                       disc_input_data_source_pred.shape[0])
      self.assertEqual(1, len(disc_gen_data_source_pred.shape))
      self.assertEqual(batch_size * num_domains,
                       disc_gen_data_source_pred.shape[0])

      input_label, disc_input_label, gen_label, disc_gen_label = sess.run([
          model.input_data_domain_label,
          model.discriminator_input_data_domain_predication,
          model.generated_data_domain_target,
          model.discriminator_generated_data_domain_predication
      ])
      self.assertTupleEqual((batch_size * num_domains, num_domains),
                            input_label.shape)
      self.assertTupleEqual((batch_size * num_domains, num_domains),
                            disc_input_label.shape)
      self.assertTupleEqual((batch_size * num_domains, num_domains),
                            gen_label.shape)
      self.assertTupleEqual((batch_size * num_domains, num_domains),
                            disc_gen_label.shape)
コード例 #5
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
def create_callable_stargan_model():
  return tfgan.stargan_model(StarGANGenerator(), StarGANDiscriminator(),
                             tf.ones([1, 2, 2, 3]), tf.ones([1, 2]))
コード例 #6
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
def create_stargan_model():
  return tfgan.stargan_model(stargan_generator_model,
                             stargan_discriminator_model, tf.ones([1, 2, 2, 3]),
                             tf.ones([1, 2]))