def generator(inputs, targets): """Generator module. Piece everything together for the Generator. PyTorch Version: https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L22 Args: inputs: Tensor of shape (batch_size, h, w, c) representing the images/information that we want to transform. targets: Tensor of shape (batch_size, num_domains) representing the target domain the generator should transform the image/information to. Returns: Tensor of shape (batch_size, h, w, c) as the inputs. """ with tf.variable_scope('generator'): input_with_condition = ops.condition_input_with_pixel_padding( inputs, targets) down_sample = layers.generator_down_sample(input_with_condition) bottleneck = layers.generator_bottleneck(down_sample) up_sample = layers.generator_up_sample(bottleneck, inputs.shape[-1]) return up_sample
def test_generator_down_sample(self): n = 2 h = 128 w = h c = 3 + 3 input_tensor = tf.random.uniform((n, h, w, c)) output_tensor = layers.generator_down_sample(input_tensor) with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) output = sess.run(output_tensor) self.assertTupleEqual((n, h // 4, w // 4, 256), output.shape)