Example #1
0
def train(generator, discriminator, encoder, style_encoder, device,
          starting_model_number):

    batch = 32
    Celeba.batch_size = batch

    latent_size = 512
    model = CondStyleGanModel(generator, StyleGANLoss(discriminator),
                              (0.001, 0.0015))

    style_opt = optim.Adam(style_encoder.parameters(),
                           lr=5e-4,
                           betas=(0.5, 0.97))

    g_transforms: albumentations.DualTransform = albumentations.Compose([
        ToNumpy(),
        NumpyBatch(
            albumentations.ElasticTransform(p=0.8,
                                            alpha=150,
                                            alpha_affine=1,
                                            sigma=10)),
        NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)),
        ToTensor(device)
    ])

    R_s = UnoTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img, ltnt: L1("R_s")
        (ltnt, style_encoder(trans_dict['image'])))

    sample_z = torch.randn(batch, latent_size, device=device)
    test_img = next(LazyLoader.celeba().loader).to(device)
    print(test_img.shape)
    test_cond = encoder(test_img)

    requires_grad(encoder, False)  # REMOVE BEFORE TRAINING

    t_start = time.time()

    for i in range(100000):
        counter.update(i)
        real_img = next(LazyLoader.celeba().loader).to(device)

        img_content = encoder(real_img).detach()

        noise = mixing_noise(batch, latent_size, 0.9, device)
        fake, _ = generator(img_content, noise)

        model.discriminator_train([real_img], [fake.detach()], img_content)

        writable("Generator loss", model.generator_loss)([real_img], [fake], [], img_content)\
            .minimize_step(model.optimizer.opt_min)

        # print("gen train", time.time() - t1)

        if i % 5 == 0 and i > 0:
            noise = mixing_noise(batch, latent_size, 0.9, device)

            img_content = encoder(real_img).detach()
            fake, fake_latent = generator(img_content,
                                          noise,
                                          return_latents=True)

            fake_latent_test = fake_latent[:, [0, 13], :].detach()
            fake_latent_pred = style_encoder(fake)
            fake_content_pred = encoder(fake)

            restored = generator.module.decode(
                img_content[:batch // 2], style_encoder(real_img[:batch // 2]))
            (HMLoss("BCE content gan", 5000)(fake_content_pred, img_content) +
             L1("L1 restored")(restored, real_img[:batch // 2]) * 50 +
             L1("L1 style gan")(fake_latent_pred, fake_latent_test) * 30 +
             R_s(fake.detach(), fake_latent_pred) * 50).minimize_step(
                 model.optimizer.opt_min, style_opt)

        if i % 100 == 0:
            t_100 = time.time()
            print(i, t_100 - t_start)
            t_start = time.time()
            with torch.no_grad():

                fake_img, _ = generator(test_cond, [sample_z])
                coords, p = heatmap_to_measure(test_cond)
                test_mes = ProbabilityMeasure(p, coords)
                iwm = imgs_with_mask(fake_img, test_mes.toImage(256))
                send_images_to_tensorboard(writer, iwm, "FAKE", i)

                iwm = imgs_with_mask(test_img, test_mes.toImage(256))
                send_images_to_tensorboard(writer, iwm, "REAL", i)

                restored = generator.module.decode(test_cond,
                                                   style_encoder(test_img))
                send_images_to_tensorboard(writer, restored, "RESTORED", i)

        if i % 10000 == 0 and i > 0:
            torch.save(
                {
                    'g': generator.state_dict(),
                    'd': discriminator.state_dict(),
                    'style': style_encoder.state_dict()
                    # 'enc': cont_style_encoder.state_dict(),
                },
                f'/trinity/home/n.buzun/PycharmProjects/saved/stylegan2_w300_{str(starting_model_number + i).zfill(6)}.pt',
            )
Example #2
0
 def apply_to_mask(self, img: ProbabilityMeasure, **params):
     return img.toImage(self.size)
    #     HMLoss("BCE content gan", 5000)(fake_content_pred, img_content.detach()) +
    #     Loss(nn.L1Loss()(restored, real_img[:W300DatasetLoader.batch_size//2]) * 50) +
    #     Loss(nn.L1Loss()(fake_latent_pred, fake_latent_test) * 25) +
    #     R_s(fake.detach(), fake_latent_pred) * 50
    # ).minimize_step(
    #     model.optimizer.opt_min,
    #     style_opt,
    # )

    # img_content = encoder_HG(real_img)
    # fake, fake_latent = generator(img_content, noise, return_latents=True)
    # fake_content_pred = encoder_HG(fake)
    #
    #
    # disc_influence = model.loss.generator_loss(real=None, fake=[real_img, img_content]) * 2
    # (HMLoss("BCE content gan", 1)(fake_content_pred, img_content.detach()) +
    # disc_influence).minimize_step(enc_opt)

    if i % 50 == 0 and i > 0:
        with torch.no_grad():
            test_loss = test(encoder_HG)
            print(test_loss)
            # tuner.update(test_loss)
            coord, p = heatmap_to_measure(encoder_HG(w300_test_image))
            pred_measure = ProbabilityMeasure(p, coord)
            iwm = imgs_with_mask(w300_test_image, pred_measure.toImage(256))
            send_images_to_tensorboard(writer, iwm, "W300_test_image", i)
            writer.add_scalar("test_loss", test_loss, i)

    # torch.save(enc.state_dict(), f"/home/ibespalov/pomoika/hg2_e{epoch}.pt")