Ejemplo n.º 1
0
    def update_state(self, context: GANContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.GANContext`): An AshPy Context Object that carries
                all the information the Metric needs.

        """
        for real_xy, noise in context.dataset:
            real_x, real_y = real_xy

            g_inputs = noise
            if len(context.generator_model.inputs) == 2:
                g_inputs = [noise, real_y]

            fake = context.generator_model(
                g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN)
            loss = context.discriminator_loss(
                context,
                fake=fake,
                real=real_x,
                condition=real_y,
                training=context.log_eval_mode == LogEvalMode.TRAIN,
            )

            self._distribute_strategy.experimental_run_v2(
                lambda: self._metric.update_state(loss))
Ejemplo n.º 2
0
    def update_state(self, context: GANContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.gan.GANContext`): An AshPy Context Object
                that carries all the information the Metric needs.

        """
        updater = lambda value: lambda: self._metric.update_state(value)
        for real_xy, noise in context.dataset:
            _, real_y = real_xy

            g_inputs = noise
            if len(context.generator_model.inputs) == 2:
                g_inputs = [noise, real_y]

            fake = context.generator_model(
                g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN)

            img1, img2 = self.split_batch(fake)

            ssim_multiscale = tf.image.ssim_multiscale(
                img1,
                img2,
                max_val=self.max_val,
                power_factors=self.power_factors,
                filter_sigma=self.filter_sigma,
                filter_size=self.filter_size,
                k1=self.k1,
                k2=self.k2,
            )

            self._distribute_strategy.experimental_run_v2(
                updater(ssim_multiscale))
Ejemplo n.º 3
0
    def update_state(self, context: GANContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.gan.GANContext`): An AshPy Context Object
                that carries all the information the Metric needs.

        """
        updater = lambda value: lambda: self._metric.update_state(value)
        for real_xy, noise in context.dataset:
            real_x, real_y = real_xy

            g_inputs = noise
            if len(context.generator_model.inputs) == 2:
                g_inputs = [noise, real_y]

            fake = context.generator_model(
                g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN)

            # check the resolution is the same as the one passed as input
            resolution = real_x.shape[1]
            if resolution != self.resolution:
                raise ValueError(
                    "Image resolution is not the same as the input resolution."
                )

            scores = sliced_wasserstein_distance(
                real_x,
                fake,
                resolution_min=self.resolution_min,
                patches_per_image=self.patches_per_image,
                use_svd=self.use_svd,
                patch_size=self.patch_size,
                random_projection_dim=self.random_projection_dim,
                random_sampling_count=self.random_sampling_count,
            )

            fake_scores = []

            for i, couple in enumerate(scores):
                self.children_real_fake[i][0].update_state(context, couple[0])
                self.children_real_fake[i][1].update_state(context, couple[1])
                fake_scores.append(tf.expand_dims(couple[1], axis=0))

            fake_scores = tf.concat(fake_scores, axis=0)

            self._distribute_strategy.experimental_run_v2(updater(fake_scores))
Ejemplo n.º 4
0
Archivo: gan.py Proyecto: softdzx/ashpy
    def update_state(self, context: GANContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context
                holding all the information the Metric needs.

        """
        updater = lambda value: lambda: self._metric.update_state(value)

        # Generate the images created with the AshPy Context's generator
        for real_xy, noise in context.dataset:
            _, real_y = real_xy

            g_inputs = noise
            if len(context.generator_model.inputs) == 2:
                g_inputs = [noise, real_y]

            fake = context.generator_model(
                g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN)

            # rescale images between 0 and 1
            fake = (fake + 1.0) / 2.0

            # Resize images to 299x299
            fake = tf.image.resize(fake, (299, 299))

            try:
                fake = tf.image.grayscale_to_rgb(fake)
            except ValueError:
                # Images are already RGB
                pass

            # Calculate the inception score
            inception_score_per_batch = self.inception_score(fake)

            # Update the Mean metric created for this context
            # self._metric.update_state(mean)
            self._distribute_strategy.experimental_run_v2(
                updater(inception_score_per_batch))
Ejemplo n.º 5
0
Archivo: gan.py Proyecto: softdzx/ashpy
    def _log_fn(self, context: GANContext) -> None:
        """
        Log output of the generator to Tensorboard.

        Args:
            context (:py:class:`ashpy.contexts.gan.GANContext`): current context.

        """
        if context.log_eval_mode == LogEvalMode.TEST:
            out = context.generator_model(context.generator_inputs,
                                          training=False)
        elif context.log_eval_mode == LogEvalMode.TRAIN:
            out = context.fake_samples
        else:
            raise ValueError("Invalid LogEvalMode")

        # tensorboard 2.0 does not support float images in [-1, 1]
        # only in [0,1]
        if out.dtype == tf.float32:
            # The hypothesis is that image are in [-1,1] how to check?
            out = (out + 1.0) / 2

        log("generator", out, context.global_step)