def update_state(self, context: ClassifierContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context holding all the information the Metric needs. """ # Generate the images created with the AshPy Context's generator generated_images = [ context.generator_model( noise, training=context.log_eval_mode == LogEvalMode.TRAIN) for noise in context.noise_dataset # FIXME: ? ] rescaled_images = [((generate_image * 0.5) + 0.5) for generate_image in generated_images] # Resize images to 299x299 resized_images = [ tf.image.resize(rescaled_image, (299, 299)) for rescaled_image in rescaled_images ] try: resized_images[:] = [ tf.image.grayscale_to_rgb(images) for images in resized_images ] except ValueError: # Images are already RGB pass # Instead of using multiple batches of 'batch_size' each (that causes OOM). # Unravel the dataset and then create small batches, each with 2 images at most. dataset = tf.unstack( tf.reshape(tf.stack(resized_images), (-1, 1, 299, 299, 3))) # Calculate the inception score mean, _ = self.inception_score(dataset) # Update the Mean metric created for this context # self._metric.update_state(mean) self._distribute_strategy.experimental_run( lambda: self._metric.update_state(mean))
def update_state(self, context: ClassifierContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context holding all the information the Metric needs. """ updater = lambda value: lambda: self._metric.update_state(value) for features, labels in context.dataset: loss = context.loss( context, features=features, labels=labels, training=context.log_eval_mode == LogEvalMode.TRAIN, ) self._distribute_strategy.experimental_run_v2(updater(loss))
def update_state(self, context: ClassifierContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context holding all the information the Metric needs. """ for features, labels in context.dataset: predictions = context.classifier_model( features, training=context.log_eval_mode == LogEvalMode.TRAIN) self._distribute_strategy.experimental_run( lambda: self._metric.update_state( labels, self._processing_predictions["fn"] (predictions, **self._processing_predictions["kwargs"]), ))
def call(self, context: ClassifierContext, *, features: tf.Tensor, labels: tf.Tensor, training: bool, **kwargs) -> tf.Tensor: r""" Compute the classifier loss. Args: context (:py:class:`ashpy.ClassifierContext`): Context for classification. features (:py:class:`tf.Tensor`): Inputs for the classifier model. labels (:py:class:`tf.Tensor`): Target for the classifier model. training (bool): Whether is training or not. **kwargs: Returns: :py:class:`tf.Tensor`: Loss value. """ predictions = context.classifier_model(features, training=training) loss = self._fn(labels, predictions) loss = tf.cond( tf.equal(tf.rank(loss), tf.constant(4)), lambda: loss, lambda: tf.expand_dims(tf.expand_dims(loss, axis=-1), axis=-1), ) return tf.reduce_mean(loss, axis=[1, 2])