def discriminator(input_net, class_num): """Discriminator Module. Piece everything together and reshape the output source tensor PyTorch Version: https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L63 Notes: The PyTorch Version run the reduce_mean operation later in their solver: https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/solver.py#L245 Args: input_net: Tensor of shape (batch_size, h, w, c) as batch of images. class_num: (int) number of domain to be predicted Returns: output_src: Tensor of shape (batch_size) where each value is a logit representing whether the image is real of fake. output_cls: Tensor of shape (batch_size, class_um) where each value is a logit representing whether the image is in the associated domain. """ with tf.variable_scope('discriminator'): hidden = layers.discriminator_input_hidden(input_net) output_src = layers.discriminator_output_source(hidden) output_src = tf.layers.flatten(output_src) output_src = tf.reduce_mean(input_tensor=output_src, axis=1) output_cls = layers.discriminator_output_class(hidden, class_num) return output_src, output_cls
def __call__(self, input_net, class_num): with tf.compat.v1.variable_scope('discriminator'): hidden_src = layers.discriminator_input_hidden( input_net, scope='discriminator_input_hidden_source') output_src = layers.discriminator_output_source(hidden_src) output_src = tf.compat.v1.layers.flatten(output_src) output_src = tf.reduce_mean(input_tensor=output_src, axis=1) output_cls = self.keras_model((input_net + 1.0) / 2.0) return output_src, output_cls
def test_discriminator_output_source(self): n = 2 h = 2 w = 2 c = 2048 input_tensor = tf.random.uniform((n, h, w, c)) output_tensor = layers.discriminator_output_source(input_tensor) with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) output = sess.run(output_tensor) self.assertTupleEqual((n, h, w, 1), output.shape)
def _custom_discriminator(input_net, class_num): with tf.compat.v1.variable_scope('discriminator'): hidden = layers.discriminator_input_hidden(input_net, trainable=False) output_src = layers.discriminator_output_source(hidden) output_src = tf.compat.v1.layers.flatten(output_src) output_src = tf.reduce_mean(input_tensor=output_src, axis=1) output_cls = layers.discriminator_output_class(hidden, class_num, trainable=False) return output_src, output_cls