def update_state(self, context: GANEncoderContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.GANEncoderContext`): An AshPy Context Object that carries all the information the Metric needs. """ for real_xy, noise in context.dataset: real_x, real_y = real_xy g_inputs = noise if len(context.generator_model.inputs) == 2: g_inputs = [noise, real_y] fake = context.generator_model( g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN) loss = context.encoder_loss( context, fake=fake, real=real_x, condition=real_y, training=context.log_eval_mode == LogEvalMode.TRAIN, ) self._distribute_strategy.experimental_run( lambda: self._metric.update_state(loss))
def _log_fn(self, context: GANEncoderContext): """ Log output of the generator to Tensorboard. Logs G(E(x)). Args: context (:py:class:`ashpy.contexts.gan.GanEncoderContext`): current context. """ if context.log_eval_mode == LogEvalMode.TEST: generator_of_encoder = context.generator_model( context.encoder_model(context.encoder_inputs, training=False), training=False, ) elif context.log_eval_mode == LogEvalMode.TRAIN: generator_of_encoder = context.generator_of_encoder else: raise ValueError("Invalid LogEvalMode") # Tensorboard 2.0 does not support float images in [-1, 1] # Only in [0,1] if generator_of_encoder.dtype == tf.float32: # The hypothesis is that image are in [-1,1] how to check? generator_of_encoder = (generator_of_encoder + 1.0) / 2 log("generator_of_encoder", generator_of_encoder, context.global_step)