Ejemplo n.º 1
0
    def update_state(self, context: GANContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.GANContext`): An AshPy Context Object that carries
                all the information the Metric needs.

        """
        for real_xy, noise in context.dataset:
            real_x, real_y = real_xy

            g_inputs = noise
            if len(context.generator_model.inputs) == 2:
                g_inputs = [noise, real_y]

            fake = context.generator_model(
                g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN)
            loss = context.discriminator_loss(
                context,
                fake=fake,
                real=real_x,
                condition=real_y,
                training=context.log_eval_mode == LogEvalMode.TRAIN,
            )

            self._distribute_strategy.experimental_run_v2(
                lambda: self._metric.update_state(loss))
Ejemplo n.º 2
0
    def get_discriminator_inputs(
        context: GANContext,
        fake_or_real: tf.Tensor,
        condition: tf.Tensor,
        training: bool,
    ) -> Union[tf.Tensor, List[tf.Tensor]]:
        """
        Return the discriminator inputs. If needed it uses the encoder.

        The current implementation uses the number of inputs to determine
        whether the discriminator is conditioned or not.

        Args:
            context (:py:class:`ashpy.contexts.gan.GANContext`): Context for GAN models.
            fake_or_real (:py:class:`tf.Tensor`): Discriminator input tensor,
                it can be fake (generated) or real.
            condition (:py:class:`tf.Tensor`): Discriminator condition
                (it can also be generator noise).
            training (:py:class:`bool`): whether is training phase or not

        Returns:
            The discriminator inputs.

        """
        num_inputs = len(context.discriminator_model.inputs)

        # Handle Encoder
        if isinstance(context, GANEncoderContext):
            if num_inputs == 2:
                d_inputs = [
                    fake_or_real,
                    context.encoder_model(fake_or_real, training=training),
                ]
            elif num_inputs == 3:
                d_inputs = [
                    fake_or_real,
                    context.encoder_model(fake_or_real, training=training),
                    condition,
                ]
            else:
                raise ValueError(
                    f"Context has encoder_model, but generator has only {num_inputs} inputs"
                )
        else:
            if num_inputs == 2:
                d_inputs = [fake_or_real, condition]
            else:
                d_inputs = fake_or_real

        return d_inputs
Ejemplo n.º 3
0
    def call(
        self,
        context: GANContext,
        *,
        fake: tf.Tensor,
        real: tf.Tensor,
        condition: tf.Tensor,
        training: bool,
        **kwargs,
    ) -> tf.Tensor:
        """
        Configure the discriminator inputs and calls `loss_fn`.

        Args:
            context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
            fake (): Fake data.
            real (): Real data.
            condition (): Generator conditioning.
            training (bool): If training or evaluation.

        Returns:
            :py:class:`tf.Tensor`: The loss for each example.

        """
        fake_inputs = self.get_discriminator_inputs(context,
                                                    fake_or_real=fake,
                                                    condition=condition,
                                                    training=training)

        real_inputs = self.get_discriminator_inputs(context,
                                                    fake_or_real=real,
                                                    condition=condition,
                                                    training=training)

        _, features_fake = context.discriminator_model(fake_inputs,
                                                       training=training,
                                                       return_features=True)
        _, features_real = context.discriminator_model(real_inputs,
                                                       training=training,
                                                       return_features=True)

        # for each feature the L1 between the real and the fake
        # every call to fn should return [batch_size, 1] that is the mean L1
        feature_loss = [
            self._fn(feat_real_i, feat_fake_i)
            for feat_real_i, feat_fake_i in zip(features_real, features_fake)
        ]
        mae = tf.add_n(feature_loss)
        return mae
Ejemplo n.º 4
0
    def update_state(self, context: GANContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.gan.GANContext`): An AshPy Context Object
                that carries all the information the Metric needs.

        """
        updater = lambda value: lambda: self._metric.update_state(value)
        for real_xy, noise in context.dataset:
            _, real_y = real_xy

            g_inputs = noise
            if len(context.generator_model.inputs) == 2:
                g_inputs = [noise, real_y]

            fake = context.generator_model(
                g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN)

            img1, img2 = self.split_batch(fake)

            ssim_multiscale = tf.image.ssim_multiscale(
                img1,
                img2,
                max_val=self.max_val,
                power_factors=self.power_factors,
                filter_sigma=self.filter_sigma,
                filter_size=self.filter_size,
                k1=self.k1,
                k2=self.k2,
            )

            self._distribute_strategy.experimental_run_v2(
                updater(ssim_multiscale))
Ejemplo n.º 5
0
Archivo: gan.py Proyecto: softdzx/ashpy
    def call(
        self,
        context: GANContext,
        *,
        fake: tf.Tensor,
        real: tf.Tensor,
        condition: tf.Tensor,
        training: bool,
        **kwargs,
    ):
        r"""
        Call: setup the discriminator inputs and calls `loss_fn`.

        Args:
            context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
            fake (:py:class:`tf.Tensor`): Fake images corresponding to the condition G(c).
            real (:py:class:`tf.Tensor`): Real images corresponding to the condition x(c).
            condition (:py:class:`tf.Tensor`): Condition for the generator and discriminator.
            training (bool): if training or evaluation

        Returns:
            :py:class:`tf.Tensor`: The loss for each example.

        """
        fake_inputs = self.get_discriminator_inputs(context,
                                                    fake_or_real=fake,
                                                    condition=condition,
                                                    training=training)

        real_inputs = self.get_discriminator_inputs(context,
                                                    fake_or_real=real,
                                                    condition=condition,
                                                    training=training)

        d_fake = context.discriminator_model(fake_inputs, training=training)
        d_real = context.discriminator_model(real_inputs, training=training)

        if isinstance(d_fake, list):
            value = tf.add_n([
                tf.reduce_mean(self._fn(d_real_i, d_fake_i), axis=[1, 2])
                for d_real_i, d_fake_i in zip(d_real, d_fake)
            ])
            return value
        value = self._fn(d_real, d_fake)
        value = tf.cond(
            tf.equal(tf.rank(d_fake), tf.constant(4)),
            lambda: value,
            lambda: tf.expand_dims(tf.expand_dims(value, axis=-1), axis=-1),
        )
        return tf.reduce_mean(value, axis=[1, 2])
Ejemplo n.º 6
0
    def update_state(self, context: GANContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.gan.GANContext`): An AshPy Context Object
                that carries all the information the Metric needs.

        """
        updater = lambda value: lambda: self._metric.update_state(value)
        for real_xy, noise in context.dataset:
            real_x, real_y = real_xy

            g_inputs = noise
            if len(context.generator_model.inputs) == 2:
                g_inputs = [noise, real_y]

            fake = context.generator_model(
                g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN)

            # check the resolution is the same as the one passed as input
            resolution = real_x.shape[1]
            if resolution != self.resolution:
                raise ValueError(
                    "Image resolution is not the same as the input resolution."
                )

            scores = sliced_wasserstein_distance(
                real_x,
                fake,
                resolution_min=self.resolution_min,
                patches_per_image=self.patches_per_image,
                use_svd=self.use_svd,
                patch_size=self.patch_size,
                random_projection_dim=self.random_projection_dim,
                random_sampling_count=self.random_sampling_count,
            )

            fake_scores = []

            for i, couple in enumerate(scores):
                self.children_real_fake[i][0].update_state(context, couple[0])
                self.children_real_fake[i][1].update_state(context, couple[1])
                fake_scores.append(tf.expand_dims(couple[1], axis=0))

            fake_scores = tf.concat(fake_scores, axis=0)

            self._distribute_strategy.experimental_run_v2(updater(fake_scores))
Ejemplo n.º 7
0
    def call(
        self,
        context: GANContext,
        *,
        fake: tf.Tensor,
        condition: tf.Tensor,
        training: bool,
        **kwargs,
    ) -> tf.Tensor:
        r"""
        Configure the discriminator inputs and calls `loss_fn`.

        Args:
            context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
            fake (:py:class:`tf.Tensor`): Fake images.
            condition (:py:class:`tf.Tensor`): Generator conditioning.
            training (bool): If training or evaluation.

        Returns:
            :py:class:`tf.Tensor`: The loss for each example.

        """
        fake_inputs = self.get_discriminator_inputs(
            context=context, fake_or_real=fake, condition=condition, training=training
        )

        d_fake = context.discriminator_model(fake_inputs, training=training)

        # Support for Multiscale discriminator
        # TODO: Improve
        if isinstance(d_fake, list):
            value = tf.add_n(
                [
                    tf.reduce_mean(
                        self._fn(tf.ones_like(d_fake_i), d_fake_i), axis=[1, 2]
                    )
                    for d_fake_i in d_fake
                ]
            )
            return value

        value = self._fn(tf.ones_like(d_fake), d_fake)
        value = tf.cond(
            tf.equal(tf.rank(d_fake), tf.constant(4)),
            lambda: value,
            lambda: tf.expand_dims(tf.expand_dims(value, axis=-1), axis=-1),
        )
        return tf.reduce_mean(value, axis=[1, 2])
Ejemplo n.º 8
0
Archivo: gan.py Proyecto: softdzx/ashpy
    def update_state(self, context: GANContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context
                holding all the information the Metric needs.

        """
        updater = lambda value: lambda: self._metric.update_state(value)

        # Generate the images created with the AshPy Context's generator
        for real_xy, noise in context.dataset:
            _, real_y = real_xy

            g_inputs = noise
            if len(context.generator_model.inputs) == 2:
                g_inputs = [noise, real_y]

            fake = context.generator_model(
                g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN)

            # rescale images between 0 and 1
            fake = (fake + 1.0) / 2.0

            # Resize images to 299x299
            fake = tf.image.resize(fake, (299, 299))

            try:
                fake = tf.image.grayscale_to_rgb(fake)
            except ValueError:
                # Images are already RGB
                pass

            # Calculate the inception score
            inception_score_per_batch = self.inception_score(fake)

            # Update the Mean metric created for this context
            # self._metric.update_state(mean)
            self._distribute_strategy.experimental_run_v2(
                updater(inception_score_per_batch))
Ejemplo n.º 9
0
Archivo: gan.py Proyecto: softdzx/ashpy
    def _log_fn(self, context: GANContext) -> None:
        """
        Log output of the generator to Tensorboard.

        Args:
            context (:py:class:`ashpy.contexts.gan.GANContext`): current context.

        """
        if context.log_eval_mode == LogEvalMode.TEST:
            out = context.generator_model(context.generator_inputs,
                                          training=False)
        elif context.log_eval_mode == LogEvalMode.TRAIN:
            out = context.fake_samples
        else:
            raise ValueError("Invalid LogEvalMode")

        # tensorboard 2.0 does not support float images in [-1, 1]
        # only in [0,1]
        if out.dtype == tf.float32:
            # The hypothesis is that image are in [-1,1] how to check?
            out = (out + 1.0) / 2

        log("generator", out, context.global_step)