Пример #1
0
def discriminator(input_net, class_num):
    """Discriminator Module.

  Piece everything together and reshape the output source tensor

  PyTorch Version:
  https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L63

  Notes:
  The PyTorch Version run the reduce_mean operation later in their solver:
  https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/solver.py#L245

  Args:
    input_net: Tensor of shape (batch_size, h, w, c) as batch of images.
    class_num: (int) number of domain to be predicted

  Returns:
    output_src: Tensor of shape (batch_size) where each value is a logit
    representing whether the image is real of fake.
    output_cls: Tensor of shape (batch_size, class_um) where each value is a
    logit representing whether the image is in the associated domain.
  """

    with tf.variable_scope('discriminator'):

        hidden = layers.discriminator_input_hidden(input_net)

        output_src = layers.discriminator_output_source(hidden)
        output_src = tf.layers.flatten(output_src)
        output_src = tf.reduce_mean(input_tensor=output_src, axis=1)

        output_cls = layers.discriminator_output_class(hidden, class_num)

    return output_src, output_cls
Пример #2
0
    def __call__(self, input_net, class_num):
        with tf.compat.v1.variable_scope('discriminator'):
            hidden_src = layers.discriminator_input_hidden(
                input_net, scope='discriminator_input_hidden_source')
            output_src = layers.discriminator_output_source(hidden_src)
            output_src = tf.compat.v1.layers.flatten(output_src)
            output_src = tf.reduce_mean(input_tensor=output_src, axis=1)

        output_cls = self.keras_model((input_net + 1.0) / 2.0)

        return output_src, output_cls
Пример #3
0
    def test_discriminator_output_source(self):

        n = 2
        h = 2
        w = 2
        c = 2048

        input_tensor = tf.random.uniform((n, h, w, c))
        output_tensor = layers.discriminator_output_source(input_tensor)

        with self.cached_session() as sess:
            sess.run(tf.global_variables_initializer())
            output = sess.run(output_tensor)
            self.assertTupleEqual((n, h, w, 1), output.shape)
Пример #4
0
        def _custom_discriminator(input_net, class_num):
            with tf.compat.v1.variable_scope('discriminator'):
                hidden = layers.discriminator_input_hidden(input_net,
                                                           trainable=False)

                output_src = layers.discriminator_output_source(hidden)
                output_src = tf.compat.v1.layers.flatten(output_src)
                output_src = tf.reduce_mean(input_tensor=output_src, axis=1)

                output_cls = layers.discriminator_output_class(hidden,
                                                               class_num,
                                                               trainable=False)

            return output_src, output_cls