def train(generator, decoder, discriminator, encoder_HG, style_encoder, device, starting_model_number):
    latent_size = 512
    batch_size = 12
    sample_z = torch.randn(8, latent_size, device=device)
    MAFL.batch_size = batch_size
    MAFL.test_batch_size = 64
    Celeba.batch_size = batch_size

    test_img = next(LazyLoader.mafl().loader_train_inf)["data"][:8].cuda()

    loss_st: StyleGANLoss = StyleGANLoss(discriminator)
    model = CondStyleGanModel(generator, loss_st, (0.001, 0.0015))

    style_opt = optim.Adam(style_encoder.parameters(), lr=5e-4, betas=(0.9, 0.99))
    cont_opt = optim.Adam(encoder_HG.parameters(), lr=2e-5, betas=(0.5, 0.97))

    g_transforms: albumentations.DualTransform = albumentations.Compose([
        ToNumpy(),
        NumpyBatch(albumentations.Compose([
            ResizeMask(h=256, w=256),
            albumentations.ElasticTransform(p=0.7, alpha=150, alpha_affine=1, sigma=10),
            albumentations.ShiftScaleRotate(p=0.7, rotate_limit=15),
            ResizeMask(h=64, w=64),
            NormalizeMask(dim=(0, 1, 2))
        ])),
        ToTensor(device),
    ])

    R_t = DualTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img:
        # rt_loss(encoder_HG(trans_dict['image']), trans_dict['mask'])
        stariy_hm_loss(encoder_HG(trans_dict['image']), trans_dict['mask'])
    )

    R_s = UnoTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img, ltnt:
        L1("R_s")(ltnt, style_encoder(trans_dict['image']))
    )

    barycenter: UniformMeasure2D01 = UniformMeasure2DFactory.load(
        f"{Paths.default.models()}/face_barycenter_5").cuda().batch_repeat(batch_size)

    R_b = BarycenterRegularizer.__call__(barycenter, 1.0, 2.0, 4.0)
    tuner = GoldTuner([0.37, 1.55, 0.9393, 0.1264, 1.7687, 0.8648, 1.8609], device=device, rule_eps=0.01 / 2,
                      radius=0.1, active=True)

    heatmaper = ToGaussHeatMap(64, 1.0)
    sparse_bc = heatmaper.forward(barycenter.coord * 63)
    sparse_bc = nn.Upsample(scale_factor=4)(sparse_bc).sum(dim=1, keepdim=True).repeat(1, 3, 1, 1) * \
                torch.tensor([1.0, 1.0, 0.0], device=device).view(1, 3, 1, 1)
    sparse_bc = (sparse_bc - sparse_bc.min()) / sparse_bc.max()
    send_images_to_tensorboard(writer, sparse_bc, "BC", 0, normalize=False, range=(0, 1))

    trainer_gan = gan_trainer(model, generator, decoder, encoder_HG, style_encoder, R_s, style_opt, heatmaper,
                              g_transforms)
    content_trainer = content_trainer_with_gan(cont_opt, tuner, heatmaper, encoder_HG, R_b, R_t, model, generator,
                                               g_transforms)
    supervise_trainer = content_trainer_supervised(cont_opt, encoder_HG, LazyLoader.mafl().loader_train_inf)

    for i in range(100000):
        counter.update(i)

        requires_grad(encoder_HG, False)  # REMOVE BEFORE TRAINING
        real_img = next(LazyLoader.mafl().loader_train_inf)["data"].to(device) \
            if i % 5 == 0 else next(LazyLoader.celeba().loader).to(device)

        img_content = encoder_HG(real_img)
        pred_measures: UniformMeasure2D01 = UniformMeasure2DFactory.from_heatmap(img_content)
        sparse_hm = heatmaper.forward(pred_measures.coord * 63).detach()
        trainer_gan(i, real_img, pred_measures.detach(), sparse_hm.detach(), apply_g=False)
        supervise_trainer()

        if i % 4 == 0:
            # real_img = next(LazyLoader.mafl().loader_train_inf)["data"].to(device)
            trainer_gan(i, real_img, pred_measures.detach(), sparse_hm.detach(), apply_g=True)
            content_trainer(real_img)

        if i % 100 == 0:
            coefs = json.load(open("../parameters/content_loss.json"))
            print(i, coefs)
            with torch.no_grad():
                # pred_measures_test, sparse_hm_test = encoder_HG(test_img)
                content_test = encoder_HG(test_img)
                pred_measures_test: UniformMeasure2D01 = UniformMeasure2DFactory.from_heatmap(content_test)
                heatmaper_256 = ToGaussHeatMap(256, 2.0)
                sparse_hm_test = heatmaper.forward(pred_measures_test.coord * 63)
                sparse_hm_test_1 = heatmaper_256.forward(pred_measures_test.coord * 255)

                latent_test = style_encoder(test_img)

                sparce_mask = sparse_hm_test_1.sum(dim=1, keepdim=True)
                sparce_mask[sparce_mask < 0.0003] = 0
                iwm = imgs_with_mask(test_img, sparce_mask)
                send_images_to_tensorboard(writer, iwm, "REAL", i)

                fake_img, _ = generator(sparse_hm_test, [sample_z])
                iwm = imgs_with_mask(fake_img, pred_measures_test.toImage(256))
                send_images_to_tensorboard(writer, iwm, "FAKE", i)

                restored = decoder(sparse_hm_test, latent_test)
                iwm = imgs_with_mask(restored, pred_measures_test.toImage(256))
                send_images_to_tensorboard(writer, iwm, "RESTORED", i)

                content_test_256 = nn.Upsample(scale_factor=4)(sparse_hm_test).sum(dim=1, keepdim=True).repeat(1, 3, 1,
                                                                                                               1) * \
                                   torch.tensor([1.0, 1.0, 0.0], device=device).view(1, 3, 1, 1)

                content_test_256 = (content_test_256 - content_test_256.min()) / content_test_256.max()
                send_images_to_tensorboard(writer, content_test_256, "HM", i, normalize=False, range=(0, 1))

        if i % 50 == 0 and i >= 0:
            test_loss = liuboff(encoder_HG)
            # test_loss = nadbka(encoder_HG)
            tuner.update(test_loss)
            writer.add_scalar("liuboff", test_loss, i)

        if i % 10000 == 0 and i > 0:
            torch.save(
                {
                    'g': generator.module.state_dict(),
                    'd': discriminator.module.state_dict(),
                    'c': encoder_HG.module.state_dict(),
                    "s": style_encoder.state_dict()
                },
                f'{Paths.default.models()}/stylegan2_MAFL_{str(i + starting_model_number).zfill(6)}.pt',
            )
Exemplo n.º 2
0
def train(args, loader, generator, discriminator, device, cont_style_encoder,
          starting_model_number):
    loader = sample_data(loader)

    pbar = range(args.iter)

    sample_z = torch.randn(8, args.latent, device=device)
    test_img = next(loader)[:8]
    test_img = test_img.cuda()

    # test_pairs = [next(loader) for _ in range(50)]

    loss_st: StyleGANLoss = StyleGANLoss(discriminator)
    model = CondStyleGanModel(generator, loss_st, (0.001, 0.0015))

    style_opt = optim.Adam(cont_style_encoder.enc_style.parameters(),
                           lr=5e-4,
                           betas=(0.5, 0.9))
    cont_opt = optim.Adam(cont_style_encoder.enc_content.parameters(),
                          lr=2e-5,
                          betas=(0.5, 0.9))

    g_transforms: albumentations.DualTransform = albumentations.Compose([
        MeasureToMask(size=256),
        ToNumpy(),
        NumpyBatch(
            albumentations.ElasticTransform(p=0.8,
                                            alpha=150,
                                            alpha_affine=1,
                                            sigma=10)),
        NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)),
        ToTensor(device),
        MaskToMeasure(size=256, padding=140),
    ])

    W1 = Samples_Loss(scaling=0.85, p=1)
    # W2 = Samples_Loss(scaling=0.85, p=2)

    # g_trans_res_dict = g_transforms(image=test_img, mask=MaskToMeasure(size=256, padding=140).apply_to_mask(test_mask))
    # g_trans_img = g_trans_res_dict['image']
    # g_trans_mask = g_trans_res_dict['mask']
    # iwm = imgs_with_mask(g_trans_img, g_trans_mask.toImage(256), color=[1, 1, 1])
    # send_images_to_tensorboard(writer, iwm, "RT", 0)

    R_t = DualTransformRegularizer.__call__(
        g_transforms,
        lambda trans_dict, img: W1(
            content_to_measure(
                cont_style_encoder.get_content(trans_dict['image'])),
            trans_dict['mask'])  # +
        # W2(content_to_measure(cont_style_encoder.get_content(trans_dict['image'])), trans_dict['mask'])
    )

    R_s = UnoTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img, ltnt: L1("R_s")
        (ltnt, cont_style_encoder.enc_style(trans_dict['image'])))

    fabric = ProbabilityMeasureFabric(256)
    barycenter = fabric.load(f"{Paths.default.models()}/face_barycenter").cuda(
    ).padding(70).transpose().batch_repeat(16)

    R_b = BarycenterRegularizer.__call__(barycenter)

    # tuner = CoefTuner([4.5, 10.5, 2.5, 0.7, 0.5], device=device)
    #                 [6.5, 7.9, 2.7, 2.06, 5.4, 0.7, 2.04]
    #                  3.3, 10.5,  6.2,  1.14, 10.88,  0.93,  2.6
    #                  4.3, 10.3, 5.9, 0.85, 10.1, 0.27, 4.5
    #                  [4.53, 9.97, 5.5, 0.01, 9.44, 1.05, 4.9
    tuner = GoldTuner([2.53, 40.97, 5.5, 0.01, 5.44, 1.05, 4.9],
                      device=device,
                      rule_eps=0.05,
                      radius=1,
                      active=False)
    gan_tuner = GoldTuner([20, 25, 25],
                          device=device,
                          rule_eps=1,
                          radius=20,
                          active=False)

    # rb_tuner = GoldTuner([0.7, 1.5, 10], device=device, rule_eps=0.02, radius=0.5)

    best_igor = 100

    for idx in pbar:
        i = idx + args.start_iter
        counter.update(i)

        if i > args.iter:
            print('Done!')
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        img_content = cont_style_encoder.get_content(real_img)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        img_content_variable = img_content.detach().requires_grad_(True)
        fake, fake_latent = generator(img_content_variable,
                                      noise,
                                      return_latents=True)

        model.discriminator_train([real_img], [fake], img_content)

        # fake_detach = fake.detach()
        fake_latent_test = fake_latent[:, [0, 13], :].detach()
        fake_content_pred = cont_style_encoder.get_content(fake)

        fake_latent_pred = cont_style_encoder.enc_style(fake)

        (writable("Generator loss", model.generator_loss)(
            [real_img], [fake], [fake_latent], img_content_variable) +  # 3e-5
         gan_tuner.sum_losses([
             L1("L1 content gan")(fake_content_pred, img_content.detach()),
             L1("L1 style gan")(fake_latent_pred, fake_latent_test),
             R_s(fake.detach(), fake_latent_pred),
         ])
         # L1("L1 content gan")(fake_content_pred, img_content.detach()) * 50 +  # 3e-7
         # L1("L1 style gan")(fake_latent_pred, fake_latent_test) * 10 +  # 8e-7
         # R_s(fake, barycenter) * 20
         ).minimize_step(model.optimizer.opt_min, style_opt)

        if i % 5 == 0:
            # fake_latent_pred = cont_style_encoder.enc_style(fake_detach)
            # (L1("L1 style gan")(fake_latent_pred, fake_latent_test)).__mul__(2).minimize_step(style_opt)
            img_latent = cont_style_encoder.enc_style(real_img[:16])
            restored = model.generator.module.decode(img_content[:16],
                                                     img_latent[:16])
            pred_measures: ProbabilityMeasure = content_to_measure(
                img_content[:16])

            noise1 = mixing_noise(16, args.latent, args.mixing, device)
            noise2 = mixing_noise(16, args.latent, args.mixing, device)
            fake1, _ = generator(img_content[:16], noise1)
            fake2, _ = generator(img_content[:16], noise2)

            cont_fake1 = cont_style_encoder.get_content(fake1)
            cont_fake2 = cont_style_encoder.get_content(fake2)

            # rb_coefs = rb_tuner.get_coef()
            # R_b = BarycenterRegularizer.__call__(barycenter, rb_coefs[0], rb_coefs[1], rb_coefs[2])
            #TUNER PART
            tuner.sum_losses([
                # writable("Fake-content D", model.loss.generator_loss)(real=None, fake=[fake1, img_content.detach()]),  # 1e-3
                writable("Real-content D", model.loss.generator_loss)
                (real=None, fake=[real_img, img_content]),  # 3e-5
                writable("R_b", R_b.__call__)(real_img[:16],
                                              pred_measures),  # 7e-5
                writable("R_t", R_t.__call__)(real_img[:16],
                                              pred_measures),  # -
                L1("L1 content between fake")(cont_fake1, cont_fake2),  # 1e-6
                L1("L1 image")(restored, real_img[:16]),  # 4e-5
                R_s(real_img[:16], img_latent),
                L1("L1 style restored")(cont_style_encoder.enc_style(restored),
                                        img_latent.detach())
            ]).minimize_step(cont_opt, model.optimizer.opt_min, style_opt)

            ##Without tuner part

            # (
            #         model.loss.generator_loss(real=None, fake=[real_img, img_content]) * 5 +
            #         (R_b + R_t * 0.4)(real_img, pred_measures) * 10 +
            #         L1("L1 content between fake")(cont_fake1, cont_fake2) * 1 +
            #         L1("L1 image")(restored, real_img) * 1
            #         # L1("L1 style gan")(fake_latent_pred, fake_latent_test) * 1
            # ).minimize_step(
            #     cont_opt,
            #     model.optimizer.opt_min
            # )

        if i % 100 == 0:
            print(i)
            with torch.no_grad():

                content, latent = cont_style_encoder(test_img)
                pred_measures: ProbabilityMeasure = content_to_measure(content)
                # ref_measures: ProbabilityMeasure = MaskToMeasure(size=256, padding=140).apply_to_mask(test_mask)
                # iwm = imgs_with_mask(test_img, ref_measures.toImage(256), color=[0, 0, 1])
                iwm = imgs_with_mask(test_img,
                                     pred_measures.toImage(256),
                                     color=[1, 1, 1])
                send_images_to_tensorboard(writer, iwm, "REAL", i)

                fake_img, _ = generator(content, [sample_z])
                iwm = imgs_with_mask(fake_img, pred_measures.toImage(256))
                send_images_to_tensorboard(writer, iwm, "FAKE", i)
                restored = model.generator.module.decode(content, latent)
                send_images_to_tensorboard(writer, restored, "RESTORED", i)

        if i % 100 == 0 and i > 0:
            pass
            # with torch.no_grad():
            #     igor = test(cont_style_encoder, test_pairs)
            #     writer.add_scalar("test error", igor, i)
            #     tuner.update(igor)
            #     gan_tuner.update(igor)
            #     # rb_tuner.update(igor)
            #
            # if igor < best_igor:
            #     best_igor = igor
            #     print("best igor")
            #     torch.save(
            #         {
            #             'g': generator.state_dict(),
            #             'd': discriminator.state_dict(),
            #             'enc': cont_style_encoder.state_dict(),
            #         },
            #         f'{Paths.default.nn()}/stylegan2_igor_3.pt',
            #     )

        if i % 10000 == 0 and i > 0:
            torch.save(
                {
                    'g': generator.module.state_dict(),
                    'd': discriminator.module.state_dict(),
                    'enc': cont_style_encoder.state_dict(),
                    # 'g_ema': g_ema.state_dict(),
                    # 'g_optim': g_optim.state_dict(),
                    # 'd_optim': d_optim.state_dict(),
                },
                f'{Paths.default.models()}/stylegan2_invertable_{str(i + starting_model_number).zfill(6)}.pt',
            )
Exemplo n.º 3
0
def train(generator, discriminator, encoder, style_encoder, device,
          starting_model_number):

    batch = 32
    Celeba.batch_size = batch

    latent_size = 512
    model = CondStyleGanModel(generator, StyleGANLoss(discriminator),
                              (0.001, 0.0015))

    style_opt = optim.Adam(style_encoder.parameters(),
                           lr=5e-4,
                           betas=(0.5, 0.97))

    g_transforms: albumentations.DualTransform = albumentations.Compose([
        ToNumpy(),
        NumpyBatch(
            albumentations.ElasticTransform(p=0.8,
                                            alpha=150,
                                            alpha_affine=1,
                                            sigma=10)),
        NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)),
        ToTensor(device)
    ])

    R_s = UnoTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img, ltnt: L1("R_s")
        (ltnt, style_encoder(trans_dict['image'])))

    sample_z = torch.randn(batch, latent_size, device=device)
    test_img = next(LazyLoader.celeba().loader).to(device)
    print(test_img.shape)
    test_cond = encoder(test_img)

    requires_grad(encoder, False)  # REMOVE BEFORE TRAINING

    t_start = time.time()

    for i in range(100000):
        counter.update(i)
        real_img = next(LazyLoader.celeba().loader).to(device)

        img_content = encoder(real_img).detach()

        noise = mixing_noise(batch, latent_size, 0.9, device)
        fake, _ = generator(img_content, noise)

        model.discriminator_train([real_img], [fake.detach()], img_content)

        writable("Generator loss", model.generator_loss)([real_img], [fake], [], img_content)\
            .minimize_step(model.optimizer.opt_min)

        # print("gen train", time.time() - t1)

        if i % 5 == 0 and i > 0:
            noise = mixing_noise(batch, latent_size, 0.9, device)

            img_content = encoder(real_img).detach()
            fake, fake_latent = generator(img_content,
                                          noise,
                                          return_latents=True)

            fake_latent_test = fake_latent[:, [0, 13], :].detach()
            fake_latent_pred = style_encoder(fake)
            fake_content_pred = encoder(fake)

            restored = generator.module.decode(
                img_content[:batch // 2], style_encoder(real_img[:batch // 2]))
            (HMLoss("BCE content gan", 5000)(fake_content_pred, img_content) +
             L1("L1 restored")(restored, real_img[:batch // 2]) * 50 +
             L1("L1 style gan")(fake_latent_pred, fake_latent_test) * 30 +
             R_s(fake.detach(), fake_latent_pred) * 50).minimize_step(
                 model.optimizer.opt_min, style_opt)

        if i % 100 == 0:
            t_100 = time.time()
            print(i, t_100 - t_start)
            t_start = time.time()
            with torch.no_grad():

                fake_img, _ = generator(test_cond, [sample_z])
                coords, p = heatmap_to_measure(test_cond)
                test_mes = ProbabilityMeasure(p, coords)
                iwm = imgs_with_mask(fake_img, test_mes.toImage(256))
                send_images_to_tensorboard(writer, iwm, "FAKE", i)

                iwm = imgs_with_mask(test_img, test_mes.toImage(256))
                send_images_to_tensorboard(writer, iwm, "REAL", i)

                restored = generator.module.decode(test_cond,
                                                   style_encoder(test_img))
                send_images_to_tensorboard(writer, restored, "RESTORED", i)

        if i % 10000 == 0 and i > 0:
            torch.save(
                {
                    'g': generator.state_dict(),
                    'd': discriminator.state_dict(),
                    'style': style_encoder.state_dict()
                    # 'enc': cont_style_encoder.state_dict(),
                },
                f'/trinity/home/n.buzun/PycharmProjects/saved/stylegan2_w300_{str(starting_model_number + i).zfill(6)}.pt',
            )
def train(generator, decoder, discriminator, encoder_HG, style_encoder, device,
          starting_model_number):
    latent_size = 512
    batch_size = 8
    sample_z = torch.randn(8, latent_size, device=device)
    Celeba.batch_size = batch_size
    W300DatasetLoader.batch_size = batch_size
    W300DatasetLoader.test_batch_size = 16

    test_img = next(LazyLoader.w300().loader_train_inf)["data"][:8].cuda()

    model = CondStyleGanModel(generator, StyleGANLoss(discriminator),
                              (0.001 / 4, 0.0015 / 4))

    style_opt = optim.Adam(style_encoder.parameters(),
                           lr=5e-4,
                           betas=(0.9, 0.99))
    cont_opt = optim.Adam(encoder_HG.parameters(), lr=3e-5, betas=(0.5, 0.97))

    g_transforms: albumentations.DualTransform = albumentations.Compose([
        ToNumpy(),
        NumpyBatch(
            albumentations.Compose([
                albumentations.ElasticTransform(p=0.7,
                                                alpha=150,
                                                alpha_affine=1,
                                                sigma=10),
                albumentations.ShiftScaleRotate(p=0.9, rotate_limit=15),
            ])),
        ToTensor(device),
    ])

    g_transforms_without_norm: albumentations.DualTransform = albumentations.Compose(
        [
            ToNumpy(),
            NumpyBatch(
                albumentations.Compose([
                    albumentations.ElasticTransform(p=0.3,
                                                    alpha=150,
                                                    alpha_affine=1,
                                                    sigma=10),
                    albumentations.ShiftScaleRotate(p=0.7, rotate_limit=15),
                ])),
            ToTensor(device),
        ])

    R_t = DualTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img: coord_hm_loss(
            encoder_HG(trans_dict['image'])["coords"], trans_dict['mask']))

    R_s = UnoTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img, ltnt: WR.L1("R_s")
        (ltnt, style_encoder(trans_dict['image'])))

    barycenter: UniformMeasure2D01 = UniformMeasure2DFactory.load(
        f"{Paths.default.models()}/face_barycenter_68").cuda().batch_repeat(
            batch_size)

    R_b = BarycenterRegularizer.__call__(barycenter, 1.0, 2.0, 4.0)

    tuner = GoldTuner([0.37, 2.78, 0.58, 1.43, 3.23],
                      device=device,
                      rule_eps=0.001,
                      radius=0.3,
                      active=False)

    trainer_gan = gan_trainer(model, generator, decoder, encoder_HG,
                              style_encoder, R_s, style_opt, g_transforms)
    content_trainer = content_trainer_with_gan(cont_opt, tuner, encoder_HG,
                                               R_b, R_t, model, generator,
                                               g_transforms, decoder,
                                               style_encoder)
    supervise_trainer = content_trainer_supervised(
        cont_opt, encoder_HG,
        LazyLoader.w300().loader_train_inf)

    for i in range(11000):
        WR.counter.update(i)

        requires_grad(encoder_HG, False)
        real_img = next(LazyLoader.celeba().loader).to(device)

        encoded = encoder_HG(real_img)
        internal_content = encoded["skeleton"].detach()

        trainer_gan(i, real_img, internal_content)
        # content_trainer(real_img)
        train_content(cont_opt, R_b, R_t, real_img, model, encoder_HG, decoder,
                      generator, style_encoder)
        supervise_trainer()

        encoder_ema.accumulate(encoder_HG.module, i, 0.98)
        if i % 50 == 0 and i > 0:
            encoder_ema.write_to(encoder_HG.module)

        if i % 100 == 0:
            coefs = json.load(open("../parameters/content_loss.json"))
            print(i, coefs)
            with torch.no_grad():

                # pred_measures_test, sparse_hm_test = encoder_HG(test_img)
                encoded_test = encoder_HG(test_img)
                pred_measures_test: UniformMeasure2D01 = UniformMeasure2D01(
                    encoded_test["coords"])
                heatmaper_256 = ToGaussHeatMap(256, 1.0)
                sparse_hm_test_1 = heatmaper_256.forward(
                    pred_measures_test.coord)

                latent_test = style_encoder(test_img)

                sparce_mask = sparse_hm_test_1.sum(dim=1, keepdim=True)
                sparce_mask[sparce_mask < 0.0003] = 0
                iwm = imgs_with_mask(test_img, sparce_mask)
                send_images_to_tensorboard(WR.writer, iwm, "REAL", i)

                fake_img, _ = generator(encoded_test["skeleton"], [sample_z])
                iwm = imgs_with_mask(fake_img, pred_measures_test.toImage(256))
                send_images_to_tensorboard(WR.writer, iwm, "FAKE", i)

                restored = decoder(encoded_test["skeleton"], latent_test)
                iwm = imgs_with_mask(restored, pred_measures_test.toImage(256))
                send_images_to_tensorboard(WR.writer, iwm, "RESTORED", i)

                content_test_256 = (encoded_test["skeleton"]).repeat(1, 3, 1, 1) * \
                    torch.tensor([1.0, 1.0, 0.0], device=device).view(1, 3, 1, 1)

                content_test_256 = (content_test_256 - content_test_256.min()
                                    ) / content_test_256.max()
                send_images_to_tensorboard(WR.writer,
                                           content_test_256,
                                           "HM",
                                           i,
                                           normalize=False,
                                           range=(0, 1))

        if i % 50 == 0 and i >= 0:
            test_loss = liuboff(encoder_HG)
            print("liuboff", test_loss)
            # test_loss = nadbka(encoder_HG)
            tuner.update(test_loss)
            WR.writer.add_scalar("liuboff", test_loss, i)

        if i % 10000 == 0 and i > 0:
            torch.save(
                {
                    'g': generator.module.state_dict(),
                    'd': discriminator.module.state_dict(),
                    'c': encoder_HG.module.state_dict(),
                    "s": style_encoder.state_dict(),
                    "e": encoder_ema.storage_model.state_dict()
                },
                f'{Paths.default.models()}/stylegan2_new_{str(i + starting_model_number).zfill(6)}.pt',
            )
full_dataset = ImageMeasureDataset(
    "/raid/data/celeba",
    "/raid/data/celeba_masks",
    img_transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize((image_size, image_size)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
)


g_transforms: albumentations.DualTransform = albumentations.Compose([
    MeasureToMask(size=256),
    ToNumpy(),
    NumpyBatch(albumentations.ShiftScaleRotate(p=1, rotate_limit=20)),
    ToTensor(device),
    MaskToMeasure(size=256, padding=args.measure_size),
])


train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [len(full_dataset) - 1000, 1000])

dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=20)

noise = NormalNoise(args.noise_size, device)
#
# cond_gan_model = stylegan2_cond_transfer(
#     "/home/ibespalov/stylegan2/stylegan2-pytorch/checkpoint/790000.pt",
#     "hinge",
#     0.002,
#     args.measure_size * 3,
def train(generator, decoder, discriminator, encoder_HG, style_encoder, device,
          starting_model_number):
    latent_size = 512
    batch_size = 24
    sample_z = torch.randn(8, latent_size, device=device)
    Celeba.batch_size = batch_size
    test_img = next(LazyLoader.celeba().loader)[:8].cuda()

    loss_st: StyleGANLoss = StyleGANLoss(discriminator)
    model = CondStyleGanModel(generator, loss_st, (0.001, 0.0015))

    style_opt = optim.Adam(style_encoder.parameters(),
                           lr=5e-4,
                           betas=(0.9, 0.99))
    cont_opt = optim.Adam(encoder_HG.parameters(), lr=4e-5, betas=(0.5, 0.97))

    g_transforms: albumentations.DualTransform = albumentations.Compose([
        ToNumpy(),
        NumpyBatch(
            albumentations.Compose([
                ResizeMask(h=256, w=256),
                albumentations.ElasticTransform(p=0.7,
                                                alpha=150,
                                                alpha_affine=1,
                                                sigma=10),
                albumentations.ShiftScaleRotate(p=0.7, rotate_limit=15),
                ResizeMask(h=64, w=64),
                NormalizeMask(dim=(0, 1, 2))
            ])),
        ToTensor(device),
    ])

    R_t = DualTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img: hm_svoego_roda_loss(
            encoder_HG(trans_dict['image']), trans_dict['mask'], 1000, 0.3))

    R_s = UnoTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img, ltnt: L1("R_s")
        (ltnt, style_encoder(trans_dict['image'])))

    barycenter: UniformMeasure2D01 = UniformMeasure2DFactory.load(
        f"{Paths.default.models()}/face_barycenter_68").cuda().batch_repeat(
            batch_size)
    # plt.imshow(barycenter.toImage(256)[0][0].detach().cpu().numpy())
    # plt.show()

    R_b = BarycenterRegularizer.__call__(barycenter, 1.0, 2.0, 3.0)

    #                  4.5, 1.2, 1.12, 1.4, 0.07, 2.2
    #                  1.27, 3.55, 5.88, 3.83, 2.17, 0.22, 1.72
    tuner = GoldTuner([2.2112, 2.3467, 3.8438, 3.2202, 2.0494, 0.0260, 5.8378],
                      device=device,
                      rule_eps=0.03,
                      radius=1,
                      active=True)
    # tuner_verka = GoldTuner([3.0, 1.2, 2.0], device=device, rule_eps=0.05, radius=1, active=True)

    best_igor = 100
    heatmaper = ToGaussHeatMap(64, 1.5)

    trainer_gan = gan_trainer(model, generator, decoder, encoder_HG,
                              style_encoder, R_s, style_opt, heatmaper,
                              g_transforms)
    content_trainer = content_trainer_with_gan(cont_opt, tuner, heatmaper,
                                               encoder_HG, R_b, R_t, model,
                                               generator)

    for i in range(100000):
        counter.update(i)

        requires_grad(encoder_HG, False)  # REMOVE BEFORE TRAINING
        real_img = next(LazyLoader.celeba().loader).to(device)

        img_content = encoder_HG(real_img).detach()
        pred_measures: UniformMeasure2D01 = UniformMeasure2DFactory.from_heatmap(
            img_content)
        sparce_hm = heatmaper.forward(pred_measures.coord * 63).detach()

        trainer_gan(i, real_img, img_content, sparce_hm)

        if i % 3 == 0:
            real_img = next(LazyLoader.celeba().loader).to(device)
            content_trainer(real_img)

        if i % 100 == 0:
            coefs = json.load(open("../parameters/content_loss_sup.json"))
            print(i, coefs)
            with torch.no_grad():

                content_test = encoder_HG(test_img)
                latent_test = style_encoder(test_img)
                pred_measures = UniformMeasure2DFactory.from_heatmap(
                    content_test)
                sparce_hm = heatmaper.forward(pred_measures.coord *
                                              63).detach()

                iwm = imgs_with_mask(test_img, pred_measures.toImage(256))
                send_images_to_tensorboard(writer, iwm, "REAL", i)

                fake_img, _ = generator(sparce_hm, [sample_z])
                iwm = imgs_with_mask(fake_img, pred_measures.toImage(256))
                send_images_to_tensorboard(writer, iwm, "FAKE", i)

                restored = decoder(sparce_hm, latent_test)
                iwm = imgs_with_mask(restored, pred_measures.toImage(256))
                send_images_to_tensorboard(writer, iwm, "RESTORED", i)

                content_test_256 = nn.Upsample(
                    scale_factor=4)(content_test).sum(dim=1, keepdim=True)
                content_test_256 = content_test_256 / content_test_256.max()
                send_images_to_tensorboard(writer, content_test_256, "HM", i)

        if i % 50 == 0 and i > 0:
            test_loss = verka(encoder_HG)
            tuner.update(test_loss)
            writer.add_scalar("verka", test_loss, i)

        if i % 10000 == 0 and i > 0:
            torch.save(
                {
                    'g': generator.module.state_dict(),
                    'd': discriminator.module.state_dict(),
                    'c': encoder_HG.module.state_dict(),
                    "s": style_encoder.state_dict()
                },
                f'{Paths.default.models()}/stylegan2_new_{str(i + starting_model_number).zfill(6)}.pt',
            )