Ejemplo n.º 1
0
    def gan_train(real_img, skeleton):

        B = real_img.shape[0]
        C = 512

        requires_grad(generator, True)
        requires_grad(decoder, True)
        condition = skeleton.detach().requires_grad_(True)

        noise = mixing_noise(B, C, 0.9, real_img.device)

        fake, fake_latent = generator(condition, noise, return_latents=True)

        model.discriminator_train([real_img], [fake], [condition])

        WR.writable("Generator brule_loss", model.generator_loss)([real_img], [fake], [condition]) \
            .minimize_step(model.optimizer.opt_min)

        fake = fake.detach()

        fake_latent_pred = style_encoder(fake)
        restored = decoder(condition, style_encoder(real_img))
        fake_latent = torch.cat([f[:, None, :] for f in fake_latent],
                                dim=1).detach()

        coefs = json.load(open("../parameters/gan_loss.json"))

        (WR.L1("L1 restored")(restored, real_img) * coefs["L1 restored"] +
         WR.L1("L1 style gan")(fake_latent_pred, fake_latent) *
         coefs["L1 style gan"]).minimize_step(model.optimizer.opt_min,
                                              style_opt)
Ejemplo n.º 2
0
    def do_train(real_img):

        B = real_img.shape[0]

        coefs = json.load(open("../parameters/content_loss.json"))

        requires_grad(encoder_HG, True)
        requires_grad(decoder, False)
        requires_grad(generator, False)

        encoded = encoder_HG(real_img)
        pred_measures: UniformMeasure2D01 = encoded["mes"]

        heatmap_content = heatmapper.forward(pred_measures.coord).detach()

        restored = decoder(encoded["skeleton"], style_encoder(real_img))

        noise = mixing_noise(B, C, 0.9, real_img.device)
        fake, _ = generator(encoded["skeleton"], noise)
        fake_content = encoder_HG(fake.detach())["mes"]

        ll = (WR.L1("L1 image")(restored, real_img) * coefs["L1 image"] +
              WR.writable("fake_content brule_loss", coord_hm_loss)
              (fake_content, heatmap_content) *
              coefs["fake_content brule_loss"] +
              WR.writable("Fake-content D", model.loss.generator_loss)
              (real=None, fake=[fake, encoded["skeleton"].detach()]) *
              coefs["Fake-content D"])

        ll.minimize_step(model.optimizer.opt_min)