コード例 #1
0
    def define_global_discriminator(self, generator_raw,
                                    global_discriminator_raw):
        generator_inputs = Input(shape=(self.img_height, self.img_width,
                                        self.num_channels))
        generator_masks = Input(shape=(self.img_height, self.img_width,
                                       self.num_channels))

        real_samples = Input(shape=(self.img_height, self.img_width,
                                    self.num_channels))
        fake_samples = generator_raw.model([generator_inputs, generator_masks])
        # fake_samples = generator_inputs * (1 - generator_masks) + fake_samples * generator_masks
        fake_samples = Lambda(make_comp_sample)(
            [generator_inputs, fake_samples, generator_masks])

        discriminator_output_from_fake_samples = global_discriminator_raw.model(
            fake_samples)
        discriminator_output_from_real_samples = global_discriminator_raw.model(
            real_samples)

        averaged_samples = custom_layers.RandomWeightedAverage()(
            [real_samples, fake_samples])
        # We then run these samples through the discriminator as well. Note that we never
        # really use the discriminator output for these samples - we're only running them to
        # get the gradient norm for the gradient penalty loss.
        averaged_samples_outputs = global_discriminator_raw.model(
            averaged_samples)

        # The gradient penalty loss function requires the input averaged samples to get
        # gradients. However, Keras loss functions can only have two arguments, y_true and
        # y_pred. We get around this by making a partial() of the function with the averaged
        # samples here.
        partial_gp_loss = partial(
            gradient_penalty_loss,
            averaged_samples=averaged_samples,
            gradient_penalty_weight=self.gradient_penalty_loss_weight)
        # Functions need names or Keras will throw an error
        partial_gp_loss.__name__ = 'gradient_penalty'

        global_discriminator_model = Model(
            inputs=[real_samples, generator_inputs, generator_masks],
            outputs=[
                discriminator_output_from_real_samples,
                discriminator_output_from_fake_samples,
                averaged_samples_outputs
            ])
        # We use the Adam paramaters from Gulrajani et al. We use the Wasserstein loss for both
        # the real and generated samples, and the gradient penalty loss for the averaged samples
        global_discriminator_model.compile(
            optimizer=self.discriminator_optimizer,
            loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss])

        return global_discriminator_model
コード例 #2
0
    def define_local_discriminator(self, generator_raw,
                                   local_discriminator_raw):
        generator_inputs = Input(shape=(self.img_height, self.img_width,
                                        self.num_channels))
        generator_masks = Input(shape=(self.img_height, self.img_width,
                                       self.num_channels))

        real_samples = Input(shape=(self.img_height, self.img_width,
                                    self.num_channels))
        fake_samples = generator_raw.model([generator_inputs, generator_masks])
        # fake_samples = generator_inputs * (1 - generator_masks) + fake_samples * generator_masks
        # fake_samples = Lambda(make_comp_sample)([generator_inputs, fake_samples, generator_masks])

        discriminator_output_from_fake_samples = local_discriminator_raw.model(
            [fake_samples, generator_masks])
        discriminator_output_from_real_samples = local_discriminator_raw.model(
            [real_samples, generator_masks])

        averaged_samples = custom_layers.RandomWeightedAverage()(
            [real_samples, fake_samples])
        averaged_samples_output = local_discriminator_raw.model(
            [averaged_samples, generator_masks])

        partial_gp_loss = partial(
            gradient_penalty_loss,
            averaged_samples=averaged_samples,
            gradient_penalty_weight=self.gradient_penalty_loss_weight)
        partial_gp_loss.__name__ = 'gradient_penalty'

        local_discriminator_model = Model(
            inputs=[real_samples, generator_inputs, generator_masks],
            outputs=[
                discriminator_output_from_real_samples,
                discriminator_output_from_fake_samples, averaged_samples_output
            ])

        local_discriminator_model.compile(
            optimizer=self.discriminator_optimizer,
            loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss])
        return local_discriminator_model