예제 #1
0
파일: network.py 프로젝트: zhouyonglong/gan
def generator(inputs, targets):
    """Generator module.

  Piece everything together for the Generator.

  PyTorch Version:
  https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L22

  Args:
    inputs: Tensor of shape (batch_size, h, w, c) representing the
      images/information that we want to transform.
    targets: Tensor of shape (batch_size, num_domains) representing the target
      domain the generator should transform the image/information to.

  Returns:
    Tensor of shape (batch_size, h, w, c) as the inputs.
  """

    with tf.variable_scope('generator'):

        input_with_condition = ops.condition_input_with_pixel_padding(
            inputs, targets)

        down_sample = layers.generator_down_sample(input_with_condition)

        bottleneck = layers.generator_bottleneck(down_sample)

        up_sample = layers.generator_up_sample(bottleneck, inputs.shape[-1])

    return up_sample
예제 #2
0
    def test_generator_down_sample(self):

        n = 2
        h = 128
        w = h
        c = 3 + 3

        input_tensor = tf.random.uniform((n, h, w, c))
        output_tensor = layers.generator_down_sample(input_tensor)

        with self.cached_session() as sess:
            sess.run(tf.global_variables_initializer())
            output = sess.run(output_tensor)
            self.assertTupleEqual((n, h // 4, w // 4, 256), output.shape)