Exemple #1
0
    def call(
        self,
        context: GANContext,
        *,
        fake: tf.Tensor,
        real: tf.Tensor,
        condition: tf.Tensor,
        training: bool,
        **kwargs,
    ) -> tf.Tensor:
        """
        Configure the discriminator inputs and calls `loss_fn`.

        Args:
            context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
            fake (): Fake data.
            real (): Real data.
            condition (): Generator conditioning.
            training (bool): If training or evaluation.

        Returns:
            :py:class:`tf.Tensor`: The loss for each example.

        """
        fake_inputs = self.get_discriminator_inputs(context,
                                                    fake_or_real=fake,
                                                    condition=condition,
                                                    training=training)

        real_inputs = self.get_discriminator_inputs(context,
                                                    fake_or_real=real,
                                                    condition=condition,
                                                    training=training)

        _, features_fake = context.discriminator_model(fake_inputs,
                                                       training=training,
                                                       return_features=True)
        _, features_real = context.discriminator_model(real_inputs,
                                                       training=training,
                                                       return_features=True)

        # for each feature the L1 between the real and the fake
        # every call to fn should return [batch_size, 1] that is the mean L1
        feature_loss = [
            self._fn(feat_real_i, feat_fake_i)
            for feat_real_i, feat_fake_i in zip(features_real, features_fake)
        ]
        mae = tf.add_n(feature_loss)
        return mae
Exemple #2
0
    def call(
        self,
        context: GANContext,
        *,
        fake: tf.Tensor,
        real: tf.Tensor,
        condition: tf.Tensor,
        training: bool,
        **kwargs,
    ):
        r"""
        Call: setup the discriminator inputs and calls `loss_fn`.

        Args:
            context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
            fake (:py:class:`tf.Tensor`): Fake images corresponding to the condition G(c).
            real (:py:class:`tf.Tensor`): Real images corresponding to the condition x(c).
            condition (:py:class:`tf.Tensor`): Condition for the generator and discriminator.
            training (bool): if training or evaluation

        Returns:
            :py:class:`tf.Tensor`: The loss for each example.

        """
        fake_inputs = self.get_discriminator_inputs(context,
                                                    fake_or_real=fake,
                                                    condition=condition,
                                                    training=training)

        real_inputs = self.get_discriminator_inputs(context,
                                                    fake_or_real=real,
                                                    condition=condition,
                                                    training=training)

        d_fake = context.discriminator_model(fake_inputs, training=training)
        d_real = context.discriminator_model(real_inputs, training=training)

        if isinstance(d_fake, list):
            value = tf.add_n([
                tf.reduce_mean(self._fn(d_real_i, d_fake_i), axis=[1, 2])
                for d_real_i, d_fake_i in zip(d_real, d_fake)
            ])
            return value
        value = self._fn(d_real, d_fake)
        value = tf.cond(
            tf.equal(tf.rank(d_fake), tf.constant(4)),
            lambda: value,
            lambda: tf.expand_dims(tf.expand_dims(value, axis=-1), axis=-1),
        )
        return tf.reduce_mean(value, axis=[1, 2])
Exemple #3
0
    def call(
        self,
        context: GANContext,
        *,
        fake: tf.Tensor,
        condition: tf.Tensor,
        training: bool,
        **kwargs,
    ) -> tf.Tensor:
        r"""
        Configure the discriminator inputs and calls `loss_fn`.

        Args:
            context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
            fake (:py:class:`tf.Tensor`): Fake images.
            condition (:py:class:`tf.Tensor`): Generator conditioning.
            training (bool): If training or evaluation.

        Returns:
            :py:class:`tf.Tensor`: The loss for each example.

        """
        fake_inputs = self.get_discriminator_inputs(
            context=context, fake_or_real=fake, condition=condition, training=training
        )

        d_fake = context.discriminator_model(fake_inputs, training=training)

        # Support for Multiscale discriminator
        # TODO: Improve
        if isinstance(d_fake, list):
            value = tf.add_n(
                [
                    tf.reduce_mean(
                        self._fn(tf.ones_like(d_fake_i), d_fake_i), axis=[1, 2]
                    )
                    for d_fake_i in d_fake
                ]
            )
            return value

        value = self._fn(tf.ones_like(d_fake), d_fake)
        value = tf.cond(
            tf.equal(tf.rank(d_fake), tf.constant(4)),
            lambda: value,
            lambda: tf.expand_dims(tf.expand_dims(value, axis=-1), axis=-1),
        )
        return tf.reduce_mean(value, axis=[1, 2])