Esempio n. 1
0
def build_discriminator_model(discriminator: Model,
                              resolution: int,
                              discriminator_lr: float,
                              loss_exponent: int,
                              channels: int = 3) -> Model:
    set_model_trainable(discriminator, True)

    real_samples = Input((resolution, resolution, channels))
    generated_samples = Input((resolution, resolution, channels))
    k = Input((1, ))

    discriminated_real_samples = discriminator(real_samples)
    discriminated_generated_samples = discriminator(generated_samples)

    ln_real = ln_loss(real_samples, discriminated_real_samples, loss_exponent)
    ln_generated = ln_loss(generated_samples, discriminated_generated_samples,
                           loss_exponent)
    discriminator_model = Model([real_samples, generated_samples, k],
                                [ln_real, ln_generated],
                                name='discriminator_model')

    discriminator_model.compile(optimizer=Adam(discriminator_lr),
                                loss=discriminator_loss(
                                    k, ln_real, ln_generated))
    return discriminator_model
Esempio n. 2
0
def build_generator_model(generator: Model, discriminator: Model,
                          latent_dim: int, generator_lr: float,
                          loss_exponent: int) -> Model:
    set_model_trainable(discriminator, False)

    input_noise = Input((latent_dim, ))
    generated_samples = generator(input_noise)
    discriminated_generated_samples = discriminator(generated_samples)

    ln_generated = ln_loss(generated_samples, discriminated_generated_samples,
                           loss_exponent)
    generator_model = Model(input_noise, ln_generated, name='generator_model')
    generator_model.compile(optimizer=Adam(generator_lr),
                            loss=generator_loss(ln_generated))
    return generator_model
Esempio n. 3
0
def build_generator_model(generator: Model, critic: Model, latent_dim: int,
                          generator_lr: float) -> Model:
    set_model_trainable(generator, True)
    set_model_trainable(critic, False)

    noise_samples = Input((latent_dim, ))
    generated_samples = generator(noise_samples)

    generated_criticized = critic(generated_samples)

    generator_model = Model(noise_samples,
                            generated_criticized,
                            name='generator_model')
    generator_model.compile(optimizer=Adam(generator_lr, beta_1=0, beta_2=0.9),
                            loss=wasserstein_loss)
    return generator_model
Esempio n. 4
0
def build_encoder_decoder_models(encoder: Model, decoder_generator: Model,
                                 critic: Model, resolution: int,
                                 latent_dim: int, channels: int, gamma: float,
                                 vae_lr: float) -> Tuple[Model, Model]:
    set_model_trainable(encoder, True)
    set_model_trainable(decoder_generator, True)
    set_model_trainable(critic, False)

    real_samples = Input((resolution, resolution, channels))
    noise_samples = Input((latent_dim, ))

    generated_samples = decoder_generator(noise_samples)
    generated_criticized = critic(generated_samples)

    real_criticized = critic(real_samples)
    z_mean, z_log_var = encoder(real_samples)

    sampled_z = Lambda(sampling)([z_mean, z_log_var])
    decoded_samples = decoder_generator(sampled_z)
    decoded_criticized = critic(decoded_samples)

    vae_model = Model(
        [real_samples, noise_samples],
        [generated_criticized, generated_criticized, generated_criticized])
    vae_model.compile(
        optimizer=Adam(lr=vae_lr, beta_1=0.5, beta_2=0.9),
        loss=[
            kl_loss(z_mean, z_log_var), wasserstein_loss,
            mse_loss(real_criticized, decoded_criticized)
        ],
        loss_weights=[1 / 3.0, gamma * 1 / 3.0, (1 - gamma) * 1 / 3.0])

    generator = Model(noise_samples, generated_samples)
    return vae_model, generator
Esempio n. 5
0
def build_critic_model(generator: Model,
                       critic: Model,
                       latent_dim: int,
                       resolution: int,
                       classes_n: int,
                       batch_size: int,
                       critic_lr: float,
                       gradient_penalty_weight: int,
                       channels: int = 3) -> Model:
    set_model_trainable(generator, False)
    set_model_trainable(critic, True)

    noise_samples = Input((latent_dim, ))
    class_samples = Input((classes_n, ))

    real_samples = Input((resolution, resolution, channels))

    generated_samples = generator([noise_samples, class_samples])
    generated_criticized = critic([generated_samples, class_samples])
    real_criticized = critic([real_samples, class_samples])

    averaged_samples = RandomWeightedAverage(batch_size)(
        [real_samples, generated_samples])
    averaged_criticized = critic([averaged_samples, class_samples])

    partial_gp_loss = partial(gradient_penalty_loss,
                              averaged_samples=averaged_samples,
                              gradient_penalty_weight=gradient_penalty_weight)
    partial_gp_loss.__name__ = 'gradient_penalty'

    critic_model = Model(
        [real_samples, noise_samples, class_samples],
        [real_criticized, generated_criticized, averaged_criticized],
        name='critic_model')

    critic_model.compile(
        optimizer=Adam(critic_lr, beta_1=0.5, beta_2=0.9),
        loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss])
    return critic_model