Exemplo n.º 1
0
def generator(inputs, targets):
    """Generator module.

  Piece everything together for the Generator.

  PyTorch Version:
  https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L22

  Args:
    inputs: Tensor of shape (batch_size, h, w, c) representing the
      images/information that we want to transform.
    targets: Tensor of shape (batch_size, num_domains) representing the target
      domain the generator should transform the image/information to.

  Returns:
    Tensor of shape (batch_size, h, w, c) as the inputs.
  """

    with tf.variable_scope('generator'):

        input_with_condition = ops.condition_input_with_pixel_padding(
            inputs, targets)

        down_sample = layers.generator_down_sample(input_with_condition)

        bottleneck = layers.generator_bottleneck(down_sample)

        up_sample = layers.generator_up_sample(bottleneck, inputs.shape[-1])

    return up_sample
Exemplo n.º 2
0
Arquivo: ops_test.py Projeto: yyht/gan
  def test_condition_input_with_pixel_padding(self):

    n = 2
    h = 128
    w = h
    c = 3
    num_label = 5

    input_tensor = tf.random.uniform((n, h, w, c))
    label_tensor = tf.random.uniform((n, num_label))
    output_tensor = ops.condition_input_with_pixel_padding(
        input_tensor, label_tensor)

    with self.cached_session() as sess:
      labels, outputs = sess.run([label_tensor, output_tensor])
      self.assertTupleEqual((n, h, w, c + num_label), outputs.shape)
      for label, output in zip(labels, outputs):
        for i in range(output.shape[0]):
          for j in range(output.shape[1]):
            self.assertListEqual(label.tolist(), output[i, j, c:].tolist())