示例#1
0
    def train_step(lr, hr, generator, discriminator, content):

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            ## re-scale
            ## lr: 0 ~ 1
            ## hr: -1 ~ 1
            lr = tf.cast(lr, tf.float32)
            hr = tf.cast(hr, tf.float32)
            lr = lr / 255
            hr = hr / 127.5 - 1

            sr = generator(lr, training=True)

            sr_output = discriminator(sr, training=True)
            hr_output = discriminator(hr, training=True)

            disc_loss = discriminator_loss(sr_output, hr_output)

            mse_loss = mse_based_loss(sr, hr)
            gen_loss = generator_loss(sr_output)
            cont_loss = content_loss(content, sr, hr)
            perceptual_loss = mse_loss + cont_loss + 1e-3 * gen_loss

        gradients_of_generator = gen_tape.gradient(
            perceptual_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(
            disc_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(
            zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(
            zip(gradients_of_discriminator, discriminator.trainable_variables))

        return perceptual_loss, disc_loss
示例#2
0
def gen_step(input_image, target,LAMBDA):
    with tf.GradientTape() as gen_tape:
        gen_output = generator(input_image, training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target,LAMBDA)
        
    generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
    return gen_total_loss,gen_gan_loss, gen_l1_loss
示例#3
0
def train_step(molecules_A, molecules_X) :
    """
    First generates molecules with the generator, then pass a batch of data and the generated molecules through the discriminator, then applies backpropagation 
    """
    z = np.random.randn(molecules_A.shape[0], 32)

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape :
        generated_A, generated_X = generator(z)

        real_logits = discriminator(molecules_A, molecules_X)
        fake_logits = discriminator(generated_A[0], generated_X[0])
        
        # backpropagation
        # gradient penalty : WGAN, discriminator loss
        with gen_tape.stop_recording() :
            with disc_tape.stop_recording() :
                # cf. equation (3) in the paper : penalty on the gradient of the discriminator wrt a linear combination of real and generated data

                epsilon_adj = tf.random.uniform(tf.shape(molecules_A), 0.0, 1.0, dtype=molecules_A.dtype)
                epsilon_features = tf.random.uniform(tf.shape(molecules_X), 0.0, 1.0, dtype=generated_X.dtype)

                with tf.GradientTape() as penalty_tape :
                    m1 = epsilon_adj * molecules_A
                    m2 = (1 - epsilon_adj)*generated_A[0]
                    x_hat_adj = m1 + m2
                    x_hat_features = epsilon_features * molecules_X + (1 - epsilon_features) * generated_X[0]
                    penalty_tape.watch([x_hat_adj, x_hat_features])
                    disc_penalty = discriminator(x_hat_adj, x_hat_features)

                # get the gradient, again eq (3) in the paper
                grad_adj = penalty_tape.gradient(disc_penalty, [x_hat_adj, x_hat_features])
         


        disc_loss = loss.discriminator_loss(real_logits, fake_logits, grad_adj) 
        gen_loss = loss.generator_loss(fake_logits)


    gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)

    disciminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))

    gen_train_loss(gen_loss)
    disc_train_loss(disc_loss)
    return gen_train_loss, disc_train_loss
示例#4
0
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = (gen_anime_loss + total_cycle_loss +
                                    identity_loss(real_anime, same_anime))
            total_gen_anime_loss = generator_to_anime_optimizer.get_scaled_loss(
                total_gen_anime_loss)

            total_gen_human_loss = (gen_human_loss + total_cycle_loss +
                                    identity_loss(real_human, same_human))
            total_gen_human_loss = generator_to_human_optimizer.get_scaled_loss(
                total_gen_human_loss)

            disc_human_loss = discriminator_loss(disc_real_human,
                                                 disc_fake_human)
            disc_human_loss = discriminator_human_optimizer.get_scaled_loss(
                disc_human_loss)

            disc_anime_loss = discriminator_loss(disc_real_anime,
                                                 disc_fake_anime)
            disc_anime_loss = discriminator_anime_optimizer.get_scaled_loss(
                disc_anime_loss)

            # My part

            fake_anime_upscale = generator_anime_upscale(fake_anime)
            cycle_anime_upscale = generator_anime_upscale(real_anime)
            same_anime_upscale = generator_anime_upscale(same_anime)

            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale)
            disc_cycle_upscale = discriminator_anime_upscale(
                cycle_anime_upscale)
            disc_same_upscale = discriminator_anime_upscale(same_anime_upscale)

            disc_real_big = discriminator_anime_upscale(big_anime)

            gen_upscale_loss = (
                generator_loss(disc_fake_upscale) * 3 +
                generator_loss(disc_cycle_upscale) +
                generator_loss(disc_same_upscale)
                # + mse_loss(big_anime, cycle_anime_upscale)
                + identity_loss(big_anime, cycle_anime_upscale) * 0.5 +
                identity_loss(big_anime, same_anime_upscale) * 0.5)

            gen_upscale_loss = generator_anime_upscale_optimizer.get_scaled_loss(
                gen_upscale_loss)

            disc_upscale_loss = discriminator_upscale_loss(
                disc_real_big, disc_fake_upscale, disc_cycle_upscale,
                disc_same_upscale)

            disc_upscale_loss = discriminator_anime_upscale_optimizer.get_scaled_loss(
                disc_upscale_loss)

        # Calculate the gradients for generator and discriminator
        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)

        discriminator_human_gradients = tape.gradient(
            disc_human_loss, discriminator_human.trainable_variables)
        discriminator_anime_gradients = tape.gradient(
            disc_anime_loss, discriminator_anime.trainable_variables)

        generator_upscale_gradients = tape.gradient(
            gen_upscale_loss, generator_anime_upscale.trainable_variables)

        discriminator_upscale_gradients = tape.gradient(
            disc_upscale_loss, discriminator_anime_upscale.trainable_variables)

        # Apply the gradients to the optimizer
        generator_to_anime_gradients = generator_to_anime_optimizer.get_unscaled_gradients(
            generator_to_anime_gradients)
        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))
        generator_to_human_gradients = generator_to_human_optimizer.get_unscaled_gradients(
            generator_to_human_gradients)
        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        discriminator_human_gradients = discriminator_human_optimizer.get_unscaled_gradients(
            discriminator_human_gradients)
        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))

        discriminator_anime_gradients = discriminator_anime_optimizer.get_unscaled_gradients(
            discriminator_anime_gradients)
        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))

        generator_upscale_gradients = generator_anime_upscale_optimizer.get_unscaled_gradients(
            generator_upscale_gradients)
        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        discriminator_upscale_gradients = discriminator_anime_upscale_optimizer.get_unscaled_gradients(
            discriminator_upscale_gradients)
        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))

        return (
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_upscale,
            same_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
            disc_upscale_loss,
        )
示例#5
0
    def trainstep(real_human, real_anime):
        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_x(real_human, training=True)
            disc_real_anime = discriminator_y(real_anime, training=True)

            disc_fake_human = discriminator_x(fake_human, training=True)
            disc_fake_anime = discriminator_y(fake_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = gen_anime_loss + total_cycle_loss + identity_loss(
                real_anime, same_anime)
            total_gen_human_loss = gen_human_loss + total_cycle_loss + identity_loss(
                real_human, same_human)

            disc_x_loss = discriminator_loss(disc_real_human, disc_fake_human)
            disc_y_loss = discriminator_loss(disc_real_anime, disc_fake_anime)

        # Calculate the gradients for generator and discriminator
        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)

        discriminator_x_gradients = tape.gradient(
            disc_x_loss, discriminator_x.trainable_variables)
        discriminator_y_gradients = tape.gradient(
            disc_y_loss, discriminator_y.trainable_variables)

        # Apply the gradients to the optimizer
        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))

        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        discriminator_x_optimizer.apply_gradients(
            zip(discriminator_x_gradients,
                discriminator_x.trainable_variables))

        discriminator_y_optimizer.apply_gradients(
            zip(discriminator_y_gradients,
                discriminator_y.trainable_variables))

        return fake_anime, cycled_human, fake_human, cycled_anime , same_human , same_anime, \
            gen_anime_loss, gen_human_loss, disc_x_loss, disc_y_loss, total_gen_anime_loss, total_gen_human_loss
示例#6
0
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            ones = tf.ones_like(real_human)
            neg_ones = tf.ones_like(real_human) * -1

            def get_domain_anime(img):
                return tf.concat([img, ones], 3)

            def get_domain_human(img):
                return tf.concat([img, neg_ones], 3)

            fake_anime = generator(get_domain_anime(real_human), training=True)
            cycled_human = generator(get_domain_human(fake_anime), training=True)

            fake_human = generator(get_domain_human(real_anime), training=True)
            cycled_anime = generator(get_domain_anime(fake_human), training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator(get_domain_human(real_human), training=True)
            same_anime = generator(get_domain_anime(real_anime), training=True)

            disc_real_human, label_real_human = discriminator(real_human, training=True)
            disc_real_anime, label_real_anime = discriminator(real_anime, training=True)

            disc_fake_human, label_fake_human = discriminator(fake_human, training=True)
            disc_fake_anime, label_fake_anime = discriminator(fake_anime, training=True)

            _, label_cycled_human = discriminator(cycled_human, training=True)
            _, label_cycled_anime = discriminator(cycled_anime, training=True)

            _, label_same_human = discriminator(same_human, training=True)
            _, label_same_anime = discriminator(same_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human, cycled_human) + cycle_loss(
                real_anime, cycled_anime
            )

            gen_class_loss = (
                discriminator_loss(label_fake_human, label_fake_anime)
                + discriminator_loss(label_cycled_human, label_cycled_anime)
                + discriminator_loss(label_same_human, label_same_anime)
            )

            # Total generator loss = adversarial loss + cycle loss
            total_gen_loss = (
                gen_anime_loss
                + gen_human_loss 
                + gen_class_loss
                + total_cycle_loss * 0.1
                + identity_loss(real_anime, same_anime)
                + identity_loss(real_human, same_human)
            )

            tf.print("gen_anime_loss",gen_anime_loss)
            tf.print("gen_human_loss",gen_human_loss)
            tf.print("gen_class_loss",gen_class_loss)
            tf.print("total_cycle_loss",total_cycle_loss)
            tf.print("identity_loss(real_anime, same_anime)",identity_loss(real_anime, same_anime))
            tf.print("identity_loss(real_human, same_human)",identity_loss(real_human, same_human))

            scaled_total_gen_anime_loss = generator_optimizer.get_scaled_loss(
                total_gen_loss
            )

            disc_human_loss = discriminator_loss(disc_real_human, disc_fake_human)
            disc_anime_loss = discriminator_loss(disc_real_anime, disc_fake_anime)

            # disc_gp_anime = gradient_penalty_star(partial(discriminator, training=True), real_anime,fake_anime )
            # disc_gp_human = gradient_penalty_star(partial(discriminator, training=True), real_human,fake_human )

            disc_loss = disc_human_loss + disc_anime_loss + discriminator_loss(label_real_human,label_real_anime)
            # +disc_gp_anime+disc_gp_human

            scaled_disc_loss = discriminator_optimizer.get_scaled_loss(
                disc_loss
            )

        # Calculate the gradients for generator and discriminator
        generator_gradients =generator_optimizer.get_unscaled_gradients( tape.gradient(
            scaled_total_gen_anime_loss, generator.trainable_variables
        ))
        discriminator_gradients = discriminator_optimizer.get_unscaled_gradients( tape.gradient(
            scaled_disc_loss, discriminator.trainable_variables
        ))

        generator_optimizer.apply_gradients(
            zip(generator_gradients, generator.trainable_variables)
        )

        discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, discriminator.trainable_variables)
        )

        with tf.GradientTape(persistent=True) as tape:
            real_anime_up = up_G(real_anime)
            fake_anime_up = up_G(fake_anime)

            dis_fake_anime_up = up_D(fake_anime_up)
            dis_real_anime_up = up_D(real_anime_up)
            dis_ori_anime = up_D(big_anime)
            gen_up_loss =  generator_loss(fake_anime_up) + generator_loss(dis_real_anime_up)*0.1
            dis_up_loss = discriminator_loss(dis_ori_anime,dis_fake_anime_up)+discriminator_loss(dis_ori_anime,dis_real_anime_up)*0.1
            scaled_gen_up_loss = up_G_optim.get_scaled_loss(gen_up_loss)
            scaled_disc_loss = up_D_optim.get_scaled_loss(dis_up_loss)

        up_G_gradients =up_G_optim.get_unscaled_gradients( tape.gradient(
            scaled_gen_up_loss, up_G.trainable_variables
        ))
        up_D_gradients = up_D_optim.get_unscaled_gradients( tape.gradient(
            scaled_disc_loss, up_D.trainable_variables
        ))

        up_G_optim.apply_gradients(
            zip(up_G_gradients, up_G.trainable_variables)
        )

        up_D_optim.apply_gradients(
            zip(up_D_gradients, up_D.trainable_variables)
        )
            

        return (
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_up,
            real_anime_up,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            gen_up_loss,
            dis_up_loss
        )
示例#7
0
        sent_embs = sent_embs.detach()
        real_labels = real_labels.detach()
        fake_labels = fake_labels.detach()
        

        fake_images = g_net(noise, sent_embs)
        
        for i in range(len(d_nets)):
            d_nets[i].zero_grad()
            errD = discriminator_loss(d_nets[i], images[i].detach(), fake_images[i].detach(), sent_embs, real_labels.detach(), fake_labels.detach())
            errD.backward()
            optimizersD[i].step()
            total_error_d += errD
        
        g_net.zero_grad()
        errG_total = generator_loss(d_nets, fake_images, sent_embs, real_labels)
        errG_total.backward()
        optimizerG.step()
        total_error_g += errG_total
    print('total error: g: ', total_error_g, ' d: ', total_error_d)
    save_single_image(fake_images[2], 'fake' + str(epoch) + '.png')
    

torch.save(g_net, 'G_NET.pth')
for i in range(len(d_nets)):
    torch.save(d_nets[i].state_dict(), 'D_NET' + str(i) + '.pth')



    '''
    i = 2
示例#8
0
            real_logits_rot = discriminator(images_rot,
                                            training=True,
                                            predict_rotation=True)
            fake_logits_rot = discriminator(generated_images_rot,
                                            training=True,
                                            predict_rotation=True)

            if second_unpaired is True:
                generated_images_2 = generator(noise_2, training=True)
                fake_logits_2 = discriminator(generated_images_2,
                                              training=True)
                disc_loss_2 = discriminator_loss(real_logits, fake_logits_2,
                                                 rotation_n,
                                                 real_logits_rot)  # [] CHECK

            gen_loss = generator_loss(fake_logits, rotation_n, fake_logits_rot)
            disc_loss = discriminator_loss(real_logits, fake_logits,
                                           rotation_n, real_logits_rot)

        gradients_of_generator = gen_tape.gradient(gen_loss,
                                                   generator.variables)
        gradients_of_discriminator = disc_tape.gradient(
            disc_loss, discriminator.variables)
        # gradients_of_discriminator_rot = disc_rot_tape.gradient(disc_loss_rot, discriminator.variables)

        generator_optimizer.apply_gradients(
            zip(gradients_of_generator, generator.variables))
        discriminator_optimizer.apply_gradients(
            zip(gradients_of_discriminator, discriminator.variables))
        # discriminator_rot_optimizer.apply_gradients(zip(gradients_of_discriminator_rot, discriminator.variables))