def generator(inputs, targets): """Generator module. Piece everything together for the Generator. PyTorch Version: https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L22 Args: inputs: Tensor of shape (batch_size, h, w, c) representing the images/information that we want to transform. targets: Tensor of shape (batch_size, num_domains) representing the target domain the generator should transform the image/information to. Returns: Tensor of shape (batch_size, h, w, c) as the inputs. """ with tf.variable_scope('generator'): input_with_condition = ops.condition_input_with_pixel_padding( inputs, targets) down_sample = layers.generator_down_sample(input_with_condition) bottleneck = layers.generator_bottleneck(down_sample) up_sample = layers.generator_up_sample(bottleneck, inputs.shape[-1]) return up_sample
def test_condition_input_with_pixel_padding(self): n = 2 h = 128 w = h c = 3 num_label = 5 input_tensor = tf.random.uniform((n, h, w, c)) label_tensor = tf.random.uniform((n, num_label)) output_tensor = ops.condition_input_with_pixel_padding( input_tensor, label_tensor) with self.cached_session() as sess: labels, outputs = sess.run([label_tensor, output_tensor]) self.assertTupleEqual((n, h, w, c + num_label), outputs.shape) for label, output in zip(labels, outputs): for i in range(output.shape[0]): for j in range(output.shape[1]): self.assertListEqual(label.tolist(), output[i, j, c:].tolist())