Example #1
0
    def update_state(self, context: GANEncoderContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.GANEncoderContext`): An AshPy Context Object
                that carries all the information the Metric needs.

        """
        for real_xy, noise in context.dataset:
            real_x, real_y = real_xy
            g_inputs = noise
            if len(context.generator_model.inputs) == 2:
                g_inputs = [noise, real_y]
            fake = context.generator_model(
                g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN)

            loss = context.encoder_loss(
                context,
                fake=fake,
                real=real_x,
                condition=real_y,
                training=context.log_eval_mode == LogEvalMode.TRAIN,
            )

            self._distribute_strategy.experimental_run(
                lambda: self._metric.update_state(loss))
Example #2
0
File: gan.py Project: softdzx/ashpy
    def _log_fn(self, context: GANEncoderContext):
        """
        Log output of the generator to Tensorboard.

        Logs G(E(x)).

        Args:
            context (:py:class:`ashpy.contexts.gan.GanEncoderContext`): current context.

        """
        if context.log_eval_mode == LogEvalMode.TEST:
            generator_of_encoder = context.generator_model(
                context.encoder_model(context.encoder_inputs, training=False),
                training=False,
            )
        elif context.log_eval_mode == LogEvalMode.TRAIN:
            generator_of_encoder = context.generator_of_encoder
        else:
            raise ValueError("Invalid LogEvalMode")

        # Tensorboard 2.0 does not support float images in [-1, 1]
        # Only in [0,1]
        if generator_of_encoder.dtype == tf.float32:
            # The hypothesis is that image are in [-1,1] how to check?
            generator_of_encoder = (generator_of_encoder + 1.0) / 2

        log("generator_of_encoder", generator_of_encoder, context.global_step)
Example #3
0
    def call(self, context: GANEncoderContext, *, real: tf.Tensor,
             training: bool, **kwargs):
        """
        Compute the Encoder BCE.

        Args:
            context (:py:class:`ashpy.contexts.GANEncoderContext`): GAN Context
                with Encoder support.
            real (:py:class:`tf.Tensor`): Real images.
            training (bool): If training or evaluation.

        Returns:
            :py:class:`tf.Tensor`: The loss for each example.

        """
        encode = context.encoder_model(real, training=training)
        d_real = context.discriminator_model([real, encode], training=training)
        return self._fn(tf.zeros_like(d_real), d_real)