def update_state(self, context: GANContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.GANContext`): An AshPy Context Object that carries all the information the Metric needs. """ for real_xy, noise in context.dataset: real_x, real_y = real_xy g_inputs = noise if len(context.generator_model.inputs) == 2: g_inputs = [noise, real_y] fake = context.generator_model( g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN) loss = context.discriminator_loss( context, fake=fake, real=real_x, condition=real_y, training=context.log_eval_mode == LogEvalMode.TRAIN, ) self._distribute_strategy.experimental_run_v2( lambda: self._metric.update_state(loss))
def get_discriminator_inputs( context: GANContext, fake_or_real: tf.Tensor, condition: tf.Tensor, training: bool, ) -> Union[tf.Tensor, List[tf.Tensor]]: """ Return the discriminator inputs. If needed it uses the encoder. The current implementation uses the number of inputs to determine whether the discriminator is conditioned or not. Args: context (:py:class:`ashpy.contexts.gan.GANContext`): Context for GAN models. fake_or_real (:py:class:`tf.Tensor`): Discriminator input tensor, it can be fake (generated) or real. condition (:py:class:`tf.Tensor`): Discriminator condition (it can also be generator noise). training (:py:class:`bool`): whether is training phase or not Returns: The discriminator inputs. """ num_inputs = len(context.discriminator_model.inputs) # Handle Encoder if isinstance(context, GANEncoderContext): if num_inputs == 2: d_inputs = [ fake_or_real, context.encoder_model(fake_or_real, training=training), ] elif num_inputs == 3: d_inputs = [ fake_or_real, context.encoder_model(fake_or_real, training=training), condition, ] else: raise ValueError( f"Context has encoder_model, but generator has only {num_inputs} inputs" ) else: if num_inputs == 2: d_inputs = [fake_or_real, condition] else: d_inputs = fake_or_real return d_inputs
def call( self, context: GANContext, *, fake: tf.Tensor, real: tf.Tensor, condition: tf.Tensor, training: bool, **kwargs, ) -> tf.Tensor: """ Configure the discriminator inputs and calls `loss_fn`. Args: context (:py:class:`ashpy.contexts.GANContext`): GAN Context. fake (): Fake data. real (): Real data. condition (): Generator conditioning. training (bool): If training or evaluation. Returns: :py:class:`tf.Tensor`: The loss for each example. """ fake_inputs = self.get_discriminator_inputs(context, fake_or_real=fake, condition=condition, training=training) real_inputs = self.get_discriminator_inputs(context, fake_or_real=real, condition=condition, training=training) _, features_fake = context.discriminator_model(fake_inputs, training=training, return_features=True) _, features_real = context.discriminator_model(real_inputs, training=training, return_features=True) # for each feature the L1 between the real and the fake # every call to fn should return [batch_size, 1] that is the mean L1 feature_loss = [ self._fn(feat_real_i, feat_fake_i) for feat_real_i, feat_fake_i in zip(features_real, features_fake) ] mae = tf.add_n(feature_loss) return mae
def update_state(self, context: GANContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.gan.GANContext`): An AshPy Context Object that carries all the information the Metric needs. """ updater = lambda value: lambda: self._metric.update_state(value) for real_xy, noise in context.dataset: _, real_y = real_xy g_inputs = noise if len(context.generator_model.inputs) == 2: g_inputs = [noise, real_y] fake = context.generator_model( g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN) img1, img2 = self.split_batch(fake) ssim_multiscale = tf.image.ssim_multiscale( img1, img2, max_val=self.max_val, power_factors=self.power_factors, filter_sigma=self.filter_sigma, filter_size=self.filter_size, k1=self.k1, k2=self.k2, ) self._distribute_strategy.experimental_run_v2( updater(ssim_multiscale))
def call( self, context: GANContext, *, fake: tf.Tensor, real: tf.Tensor, condition: tf.Tensor, training: bool, **kwargs, ): r""" Call: setup the discriminator inputs and calls `loss_fn`. Args: context (:py:class:`ashpy.contexts.GANContext`): GAN Context. fake (:py:class:`tf.Tensor`): Fake images corresponding to the condition G(c). real (:py:class:`tf.Tensor`): Real images corresponding to the condition x(c). condition (:py:class:`tf.Tensor`): Condition for the generator and discriminator. training (bool): if training or evaluation Returns: :py:class:`tf.Tensor`: The loss for each example. """ fake_inputs = self.get_discriminator_inputs(context, fake_or_real=fake, condition=condition, training=training) real_inputs = self.get_discriminator_inputs(context, fake_or_real=real, condition=condition, training=training) d_fake = context.discriminator_model(fake_inputs, training=training) d_real = context.discriminator_model(real_inputs, training=training) if isinstance(d_fake, list): value = tf.add_n([ tf.reduce_mean(self._fn(d_real_i, d_fake_i), axis=[1, 2]) for d_real_i, d_fake_i in zip(d_real, d_fake) ]) return value value = self._fn(d_real, d_fake) value = tf.cond( tf.equal(tf.rank(d_fake), tf.constant(4)), lambda: value, lambda: tf.expand_dims(tf.expand_dims(value, axis=-1), axis=-1), ) return tf.reduce_mean(value, axis=[1, 2])
def update_state(self, context: GANContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.gan.GANContext`): An AshPy Context Object that carries all the information the Metric needs. """ updater = lambda value: lambda: self._metric.update_state(value) for real_xy, noise in context.dataset: real_x, real_y = real_xy g_inputs = noise if len(context.generator_model.inputs) == 2: g_inputs = [noise, real_y] fake = context.generator_model( g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN) # check the resolution is the same as the one passed as input resolution = real_x.shape[1] if resolution != self.resolution: raise ValueError( "Image resolution is not the same as the input resolution." ) scores = sliced_wasserstein_distance( real_x, fake, resolution_min=self.resolution_min, patches_per_image=self.patches_per_image, use_svd=self.use_svd, patch_size=self.patch_size, random_projection_dim=self.random_projection_dim, random_sampling_count=self.random_sampling_count, ) fake_scores = [] for i, couple in enumerate(scores): self.children_real_fake[i][0].update_state(context, couple[0]) self.children_real_fake[i][1].update_state(context, couple[1]) fake_scores.append(tf.expand_dims(couple[1], axis=0)) fake_scores = tf.concat(fake_scores, axis=0) self._distribute_strategy.experimental_run_v2(updater(fake_scores))
def call( self, context: GANContext, *, fake: tf.Tensor, condition: tf.Tensor, training: bool, **kwargs, ) -> tf.Tensor: r""" Configure the discriminator inputs and calls `loss_fn`. Args: context (:py:class:`ashpy.contexts.GANContext`): GAN Context. fake (:py:class:`tf.Tensor`): Fake images. condition (:py:class:`tf.Tensor`): Generator conditioning. training (bool): If training or evaluation. Returns: :py:class:`tf.Tensor`: The loss for each example. """ fake_inputs = self.get_discriminator_inputs( context=context, fake_or_real=fake, condition=condition, training=training ) d_fake = context.discriminator_model(fake_inputs, training=training) # Support for Multiscale discriminator # TODO: Improve if isinstance(d_fake, list): value = tf.add_n( [ tf.reduce_mean( self._fn(tf.ones_like(d_fake_i), d_fake_i), axis=[1, 2] ) for d_fake_i in d_fake ] ) return value value = self._fn(tf.ones_like(d_fake), d_fake) value = tf.cond( tf.equal(tf.rank(d_fake), tf.constant(4)), lambda: value, lambda: tf.expand_dims(tf.expand_dims(value, axis=-1), axis=-1), ) return tf.reduce_mean(value, axis=[1, 2])
def update_state(self, context: GANContext) -> None: """ Update the internal state of the metric, using the information from the context object. Args: context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context holding all the information the Metric needs. """ updater = lambda value: lambda: self._metric.update_state(value) # Generate the images created with the AshPy Context's generator for real_xy, noise in context.dataset: _, real_y = real_xy g_inputs = noise if len(context.generator_model.inputs) == 2: g_inputs = [noise, real_y] fake = context.generator_model( g_inputs, training=context.log_eval_mode == LogEvalMode.TRAIN) # rescale images between 0 and 1 fake = (fake + 1.0) / 2.0 # Resize images to 299x299 fake = tf.image.resize(fake, (299, 299)) try: fake = tf.image.grayscale_to_rgb(fake) except ValueError: # Images are already RGB pass # Calculate the inception score inception_score_per_batch = self.inception_score(fake) # Update the Mean metric created for this context # self._metric.update_state(mean) self._distribute_strategy.experimental_run_v2( updater(inception_score_per_batch))
def _log_fn(self, context: GANContext) -> None: """ Log output of the generator to Tensorboard. Args: context (:py:class:`ashpy.contexts.gan.GANContext`): current context. """ if context.log_eval_mode == LogEvalMode.TEST: out = context.generator_model(context.generator_inputs, training=False) elif context.log_eval_mode == LogEvalMode.TRAIN: out = context.fake_samples else: raise ValueError("Invalid LogEvalMode") # tensorboard 2.0 does not support float images in [-1, 1] # only in [0,1] if out.dtype == tf.float32: # The hypothesis is that image are in [-1,1] how to check? out = (out + 1.0) / 2 log("generator", out, context.global_step)