Example #1
0
    def update_state(self, context: ClassifierContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context
                holding all the information the Metric needs.

        """
        # Generate the images created with the AshPy Context's generator
        generated_images = [
            context.generator_model(
                noise, training=context.log_eval_mode == LogEvalMode.TRAIN)
            for noise in context.noise_dataset  # FIXME: ?
        ]

        rescaled_images = [((generate_image * 0.5) + 0.5)
                           for generate_image in generated_images]

        # Resize images to 299x299
        resized_images = [
            tf.image.resize(rescaled_image, (299, 299))
            for rescaled_image in rescaled_images
        ]

        try:
            resized_images[:] = [
                tf.image.grayscale_to_rgb(images) for images in resized_images
            ]
        except ValueError:
            # Images are already RGB
            pass

        # Instead of using multiple batches of 'batch_size' each (that causes OOM).
        # Unravel the dataset and then create small batches, each with 2 images at most.
        dataset = tf.unstack(
            tf.reshape(tf.stack(resized_images), (-1, 1, 299, 299, 3)))

        # Calculate the inception score
        mean, _ = self.inception_score(dataset)

        # Update the Mean metric created for this context
        # self._metric.update_state(mean)
        self._distribute_strategy.experimental_run(
            lambda: self._metric.update_state(mean))
Example #2
0
    def update_state(self, context: ClassifierContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context
                holding all the information the Metric needs.

        """
        updater = lambda value: lambda: self._metric.update_state(value)
        for features, labels in context.dataset:
            loss = context.loss(
                context,
                features=features,
                labels=labels,
                training=context.log_eval_mode == LogEvalMode.TRAIN,
            )
            self._distribute_strategy.experimental_run_v2(updater(loss))
Example #3
0
    def update_state(self, context: ClassifierContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.

        Args:
            context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context holding
                all the information the Metric needs.

        """
        for features, labels in context.dataset:
            predictions = context.classifier_model(
                features, training=context.log_eval_mode == LogEvalMode.TRAIN)

            self._distribute_strategy.experimental_run(
                lambda: self._metric.update_state(
                    labels,
                    self._processing_predictions["fn"]
                    (predictions, **self._processing_predictions["kwargs"]),
                ))
Example #4
0
    def call(self, context: ClassifierContext, *, features: tf.Tensor,
             labels: tf.Tensor, training: bool, **kwargs) -> tf.Tensor:
        r"""
        Compute the classifier loss.

        Args:
            context (:py:class:`ashpy.ClassifierContext`): Context for classification.
            features (:py:class:`tf.Tensor`): Inputs for the classifier model.
            labels (:py:class:`tf.Tensor`): Target for the classifier model.
            training (bool): Whether is training or not.
            **kwargs:

        Returns:
            :py:class:`tf.Tensor`: Loss value.

        """
        predictions = context.classifier_model(features, training=training)
        loss = self._fn(labels, predictions)
        loss = tf.cond(
            tf.equal(tf.rank(loss), tf.constant(4)),
            lambda: loss,
            lambda: tf.expand_dims(tf.expand_dims(loss, axis=-1), axis=-1),
        )
        return tf.reduce_mean(loss, axis=[1, 2])