Example #1
0
File: gan.py Project: softdzx/ashpy
    def _log_fn(self, context: GANEncoderContext):
        """
        Log output of the generator to Tensorboard.

        Logs G(E(x)).

        Args:
            context (:py:class:`ashpy.contexts.gan.GanEncoderContext`): current context.

        """
        if context.log_eval_mode == LogEvalMode.TEST:
            generator_of_encoder = context.generator_model(
                context.encoder_model(context.encoder_inputs, training=False),
                training=False,
            )
        elif context.log_eval_mode == LogEvalMode.TRAIN:
            generator_of_encoder = context.generator_of_encoder
        else:
            raise ValueError("Invalid LogEvalMode")

        # Tensorboard 2.0 does not support float images in [-1, 1]
        # Only in [0,1]
        if generator_of_encoder.dtype == tf.float32:
            # The hypothesis is that image are in [-1,1] how to check?
            generator_of_encoder = (generator_of_encoder + 1.0) / 2

        log("generator_of_encoder", generator_of_encoder, context.global_step)
Example #2
0
    def call(self, context: GANEncoderContext, *, real: tf.Tensor,
             training: bool, **kwargs):
        """
        Compute the Encoder BCE.

        Args:
            context (:py:class:`ashpy.contexts.GANEncoderContext`): GAN Context
                with Encoder support.
            real (:py:class:`tf.Tensor`): Real images.
            training (bool): If training or evaluation.

        Returns:
            :py:class:`tf.Tensor`: The loss for each example.

        """
        encode = context.encoder_model(real, training=training)
        d_real = context.discriminator_model([real, encode], training=training)
        return self._fn(tf.zeros_like(d_real), d_real)