def define_global_discriminator(self, generator_raw, global_discriminator_raw): generator_inputs = Input(shape=(self.img_height, self.img_width, self.num_channels)) generator_masks = Input(shape=(self.img_height, self.img_width, self.num_channels)) real_samples = Input(shape=(self.img_height, self.img_width, self.num_channels)) fake_samples = generator_raw.model([generator_inputs, generator_masks]) # fake_samples = generator_inputs * (1 - generator_masks) + fake_samples * generator_masks fake_samples = Lambda(make_comp_sample)( [generator_inputs, fake_samples, generator_masks]) discriminator_output_from_fake_samples = global_discriminator_raw.model( fake_samples) discriminator_output_from_real_samples = global_discriminator_raw.model( real_samples) averaged_samples = custom_layers.RandomWeightedAverage()( [real_samples, fake_samples]) # We then run these samples through the discriminator as well. Note that we never # really use the discriminator output for these samples - we're only running them to # get the gradient norm for the gradient penalty loss. averaged_samples_outputs = global_discriminator_raw.model( averaged_samples) # The gradient penalty loss function requires the input averaged samples to get # gradients. However, Keras loss functions can only have two arguments, y_true and # y_pred. We get around this by making a partial() of the function with the averaged # samples here. partial_gp_loss = partial( gradient_penalty_loss, averaged_samples=averaged_samples, gradient_penalty_weight=self.gradient_penalty_loss_weight) # Functions need names or Keras will throw an error partial_gp_loss.__name__ = 'gradient_penalty' global_discriminator_model = Model( inputs=[real_samples, generator_inputs, generator_masks], outputs=[ discriminator_output_from_real_samples, discriminator_output_from_fake_samples, averaged_samples_outputs ]) # We use the Adam paramaters from Gulrajani et al. We use the Wasserstein loss for both # the real and generated samples, and the gradient penalty loss for the averaged samples global_discriminator_model.compile( optimizer=self.discriminator_optimizer, loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss]) return global_discriminator_model
def define_local_discriminator(self, generator_raw, local_discriminator_raw): generator_inputs = Input(shape=(self.img_height, self.img_width, self.num_channels)) generator_masks = Input(shape=(self.img_height, self.img_width, self.num_channels)) real_samples = Input(shape=(self.img_height, self.img_width, self.num_channels)) fake_samples = generator_raw.model([generator_inputs, generator_masks]) # fake_samples = generator_inputs * (1 - generator_masks) + fake_samples * generator_masks # fake_samples = Lambda(make_comp_sample)([generator_inputs, fake_samples, generator_masks]) discriminator_output_from_fake_samples = local_discriminator_raw.model( [fake_samples, generator_masks]) discriminator_output_from_real_samples = local_discriminator_raw.model( [real_samples, generator_masks]) averaged_samples = custom_layers.RandomWeightedAverage()( [real_samples, fake_samples]) averaged_samples_output = local_discriminator_raw.model( [averaged_samples, generator_masks]) partial_gp_loss = partial( gradient_penalty_loss, averaged_samples=averaged_samples, gradient_penalty_weight=self.gradient_penalty_loss_weight) partial_gp_loss.__name__ = 'gradient_penalty' local_discriminator_model = Model( inputs=[real_samples, generator_inputs, generator_masks], outputs=[ discriminator_output_from_real_samples, discriminator_output_from_fake_samples, averaged_samples_output ]) local_discriminator_model.compile( optimizer=self.discriminator_optimizer, loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss]) return local_discriminator_model