def build_generator_model(generator, critic, generator_lr, latent_dim):
    utils.set_model_trainable(generator, True)
    utils.set_model_trainable(critic, False)

    noise_samples = Input((latent_dim,))

    generated_samples = generator(noise_samples)
    generated_criticized = critic(generated_samples)

    generator_model = Model(noise_samples, generated_criticized, 'generator_model')
    generator_model.compile(loss=utils.wasserstein_loss, optimizer=RMSprop(generator_lr))
    return generator_model
def build_critic_model(generator, critic, critic_lr, latent_dim, timesteps):
    utils.set_model_trainable(generator, False)
    utils.set_model_trainable(critic, True)

    noise_samples = Input((latent_dim,))
    real_samples = Input((timesteps,))

    generated_samples = generator(noise_samples)
    generated_criticized = critic(generated_samples)
    real_criticized = critic(real_samples)

    critic_model = Model([real_samples, noise_samples],
                         [real_criticized, generated_criticized], 'critic_model')
    critic_model.compile(loss=[utils.wasserstein_loss, utils.wasserstein_loss], optimizer=RMSprop(critic_lr),
                         loss_weights=[1 / 2, 1 / 2])
    return critic_model
예제 #3
0
def build_generator_model(generator, critic, latent_dim, timesteps,
                          use_packing, packing_degree, batch_size,
                          generator_lr):
    utils.set_model_trainable(generator, True)
    utils.set_model_trainable(critic, False)

    noise_samples = Input((latent_dim, ))
    generated_samples = generator(noise_samples)

    if use_packing:
        generated_samples = Lambda(lambda x: K.reshape(x, (
            batch_size, timesteps, 1)))(generated_samples)
        supporting_noise_samples = Input((latent_dim, packing_degree))

        reshaped_supporting_noise_samples = Lambda(
            lambda x: K.reshape(x, (batch_size * packing_degree, latent_dim)))(
                supporting_noise_samples)

        supporting_generated_samples = generator(
            reshaped_supporting_noise_samples)
        supporting_generated_samples = Lambda(
            lambda x: K.reshape(x, (batch_size, timesteps, packing_degree)))(
                supporting_generated_samples)
        merged_generated_samples = Concatenate(-1)(
            [generated_samples, supporting_generated_samples])

        generated_criticized = critic(merged_generated_samples)

        generator_model = Model([noise_samples, supporting_noise_samples],
                                generated_criticized, 'generator_model')
        generator_model.compile(optimizer=Adam(generator_lr,
                                               beta_1=0,
                                               beta_2=0.9),
                                loss=utils.wasserstein_loss)
    else:
        generated_criticized = critic(generated_samples)

        generator_model = Model([noise_samples], generated_criticized,
                                'generator_model')
        generator_model.compile(optimizer=Adam(generator_lr,
                                               beta_1=0,
                                               beta_2=0.9),
                                loss=utils.wasserstein_loss)
    return generator_model
def build_critic_model(encoder, decoder_generator, critic, latent_dim,
                       timesteps, batch_size, critic_lr,
                       gradient_penality_weight):
    utils.set_model_trainable(encoder, False)
    utils.set_model_trainable(decoder_generator, False)
    utils.set_model_trainable(critic, True)

    noise_samples = Input((latent_dim, ))
    real_samples = Input((timesteps, ))

    generated_samples = decoder_generator(noise_samples)
    generated_criticized = critic(generated_samples)
    real_criticized = critic(real_samples)

    averaged_samples = RandomWeightedAverage(batch_size)(
        [real_samples, generated_samples])
    averaged_criticized = critic(averaged_samples)

    partial_gp_loss = partial(gradient_penalty_loss,
                              averaged_samples=averaged_samples,
                              gradient_penalty_weight=gradient_penality_weight)
    partial_gp_loss.__name__ = 'gradient_penalty'

    critic_model = Model(
        [real_samples, noise_samples],
        [real_criticized, generated_criticized, averaged_criticized],
        'critic_model')

    critic_model.compile(
        optimizer=Adam(critic_lr, beta_1=0, beta_2=0.9),
        loss=[utils.wasserstein_loss, utils.wasserstein_loss, partial_gp_loss],
        loss_weights=[1 / 3, 1 / 3, 1 / 3])
    return critic_model
def build_vae_model(encoder, decoder_generator, critic, latent_dim, timesteps,
                    gamma, vae_lr):
    utils.set_model_trainable(encoder, True)
    utils.set_model_trainable(decoder_generator, True)
    utils.set_model_trainable(critic, False)

    real_samples = Input((timesteps, ))
    noise_samples = Input((latent_dim, ))

    generated_samples = decoder_generator(noise_samples)
    generated_criticized = critic(generated_samples)

    z_mean, z_log_var = encoder(real_samples)

    sampled_z = Lambda(sampling)([z_mean, z_log_var])
    decoded_inputs = decoder_generator(sampled_z)

    real_criticized = critic(real_samples)
    decoded_criticized = critic(decoded_inputs)

    vae_model = Model([real_samples, noise_samples],
                      [generated_criticized, generated_criticized])
    vae_model.compile(optimizer=Adam(lr=vae_lr, beta_1=0, beta_2=0.9),
                      loss=[
                          utils.wasserstein_loss,
                          vae_loss(z_mean, z_log_var, real_criticized,
                                   decoded_criticized)
                      ],
                      loss_weights=[gamma, (1 - gamma)])

    generator_model = Model(noise_samples, generated_samples)
    return vae_model, generator_model
예제 #6
0
def build_critic_model(generator, critic, latent_dim, timesteps, use_packing,
                       packing_degree, batch_size, critic_lr,
                       gradient_penality_weight):
    utils.set_model_trainable(generator, False)
    utils.set_model_trainable(critic, True)

    noise_samples = Input((latent_dim, ))
    real_samples = Input((timesteps, ))

    if use_packing:
        supporting_noise_samples = Input((latent_dim, packing_degree))
        supporting_real_samples = Input((timesteps, packing_degree))

        reshaped_supporting_noise_samples = Lambda(
            lambda x: K.reshape(x, (batch_size * packing_degree, latent_dim)))(
                supporting_noise_samples)
        generated_samples = generator(noise_samples)
        supporting_generated_samples = generator(
            reshaped_supporting_noise_samples)

        expanded_generated_samples = Lambda(lambda x: K.reshape(
            x, (batch_size, timesteps, 1)))(generated_samples)
        expanded_generated_supporting_samples = Lambda(
            lambda x: K.reshape(x, (batch_size, timesteps, packing_degree)))(
                supporting_generated_samples)

        merged_generated_samples = Concatenate(-1)([
            expanded_generated_samples, expanded_generated_supporting_samples
        ])

        generated_criticized = critic(merged_generated_samples)

        expanded_real_samples = Lambda(
            lambda x: K.reshape(x, (batch_size, timesteps, 1)))(real_samples)
        merged_real_samples = Lambda(lambda x: K.concatenate(x, -1))(
            [expanded_real_samples, supporting_real_samples])

        real_criticized = critic(merged_real_samples)

        averaged_samples = RandomWeightedAverage(batch_size)(
            [real_samples, generated_samples])

        expanded_averaged_samples = Lambda(lambda x: K.reshape(
            x, (batch_size, timesteps, 1)))(averaged_samples)

        expanded_supporting_real_samples = Lambda(lambda x: K.reshape(
            x, ((batch_size * packing_degree), timesteps)))(
                supporting_real_samples)
        averaged_support_samples = RandomWeightedAverage(
            (batch_size * packing_degree))([
                expanded_supporting_real_samples, supporting_generated_samples
            ])

        averaged_support_samples = Lambda(
            lambda x: K.reshape(x, (batch_size, timesteps, packing_degree)))(
                averaged_support_samples)

        merged_averaged_samples = Concatenate(-1)(
            [expanded_averaged_samples, averaged_support_samples])

        averaged_criticized = critic(merged_averaged_samples)

        partial_gp_loss = partial(
            gradient_penalty_loss,
            averaged_samples=merged_averaged_samples,
            gradient_penalty_weight=gradient_penality_weight)
        partial_gp_loss.__name__ = 'gradient_penalty'

        critic_model = Model([
            real_samples, noise_samples, supporting_real_samples,
            supporting_noise_samples
        ], [real_criticized, generated_criticized, averaged_criticized],
                             'critic_model')

        critic_model.compile(optimizer=Adam(critic_lr, beta_1=0, beta_2=0.9),
                             loss=[
                                 utils.wasserstein_loss,
                                 utils.wasserstein_loss, partial_gp_loss
                             ],
                             loss_weights=[1 / 3, 1 / 3, 1 / 3])
    else:
        generated_samples = generator(noise_samples)
        generated_criticized = critic(generated_samples)
        real_criticized = critic(real_samples)

        averaged_samples = RandomWeightedAverage(batch_size)(
            [real_samples, generated_samples])
        averaged_criticized = critic(averaged_samples)

        partial_gp_loss = partial(
            gradient_penalty_loss,
            averaged_samples=averaged_samples,
            gradient_penalty_weight=gradient_penality_weight)
        partial_gp_loss.__name__ = 'gradient_penalty'

        critic_model = Model(
            [real_samples, noise_samples],
            [real_criticized, generated_criticized, averaged_criticized],
            'critic_model')

        critic_model.compile(optimizer=Adam(critic_lr, beta_1=0, beta_2=0.9),
                             loss=[
                                 utils.wasserstein_loss,
                                 utils.wasserstein_loss, partial_gp_loss
                             ],
                             loss_weights=[1 / 3, 1 / 3, 1 / 3])
    return critic_model