def test():
        dataloader_test = torch.utils.data.DataLoader(dataset_test,
                                                      batch_size=40,
                                                      num_workers=20)

        sum_loss = 0

        for i, (imgs, masks) in enumerate(dataloader_test, 0):
            imgs = imgs.cuda().type(torch.float32)
            pred_measures: ProbabilityMeasure = image2measure(imgs)
            ref_measures: ProbabilityMeasure = fabric.from_mask(
                masks).cuda().padding(args.measure_size)
            ref_loss = Samples_Loss()(pred_measures, ref_measures)
            sum_loss += ref_loss.item()

        return sum_loss
Ejemplo n.º 2
0
def verka(encoder: nn.Module):
    res = []
    for i, (image, lm) in enumerate(LazyLoader.celeba_test(64)):
        content = encoder(image.cuda())
        mes = UniformMeasure2D01(lm.cuda())
        pred_measures: UniformMeasure2D01 = UniformMeasure2DFactory.from_heatmap(content)
        res.append(Samples_Loss(p=1)(mes, pred_measures).item() * image.shape[0])
    return np.mean(res)/len(LazyLoader.celeba_test(1).dataset)
        def loss(image: Tensor, mask: ProbabilityMeasure):

            # t1 = time.time()

            with torch.no_grad():
                A, T = LinearTransformOT.forward(mask, barycenter, 100)

            t_loss = Samples_Loss(scaling=0.8,
                                  border=0.0001)(mask, mask.detach() + T)
            a_loss = Samples_Loss(scaling=0.8, border=0.0001)(
                mask.centered(), mask.centered().multiply(A).detach())
            w_loss = Samples_Loss(scaling=0.85, border=0.00001)(
                mask.centered().multiply(A), barycenter.centered().detach())

            # print(time.time() - t1)

            return a_loss * ca + w_loss * cw + t_loss * ct
Ejemplo n.º 4
0
def test(cont_style_encoder, pairs):
    W1 = Samples_Loss(scaling=0.9, p=1)
    err_list = []
    for img, masks in pairs:
        mes: ProbabilityMeasure = MaskToMeasure(
            size=256, padding=140).apply_to_mask(masks).cuda()
        real_img = img.cuda()
        img_content = cont_style_encoder.get_content(real_img).detach()
        err_list.append(W1(content_to_measure(img_content), mes).item())

    print("test:", sum(err_list) / len(err_list))
    return sum(err_list) / len(err_list)
Ejemplo n.º 5
0
def train(args, loader, generator, discriminator, device, cont_style_encoder,
          starting_model_number):
    loader = sample_data(loader)

    pbar = range(args.iter)

    sample_z = torch.randn(8, args.latent, device=device)
    test_img = next(loader)[:8]
    test_img = test_img.cuda()

    # test_pairs = [next(loader) for _ in range(50)]

    loss_st: StyleGANLoss = StyleGANLoss(discriminator)
    model = CondStyleGanModel(generator, loss_st, (0.001, 0.0015))

    style_opt = optim.Adam(cont_style_encoder.enc_style.parameters(),
                           lr=5e-4,
                           betas=(0.5, 0.9))
    cont_opt = optim.Adam(cont_style_encoder.enc_content.parameters(),
                          lr=2e-5,
                          betas=(0.5, 0.9))

    g_transforms: albumentations.DualTransform = albumentations.Compose([
        MeasureToMask(size=256),
        ToNumpy(),
        NumpyBatch(
            albumentations.ElasticTransform(p=0.8,
                                            alpha=150,
                                            alpha_affine=1,
                                            sigma=10)),
        NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)),
        ToTensor(device),
        MaskToMeasure(size=256, padding=140),
    ])

    W1 = Samples_Loss(scaling=0.85, p=1)
    # W2 = Samples_Loss(scaling=0.85, p=2)

    # g_trans_res_dict = g_transforms(image=test_img, mask=MaskToMeasure(size=256, padding=140).apply_to_mask(test_mask))
    # g_trans_img = g_trans_res_dict['image']
    # g_trans_mask = g_trans_res_dict['mask']
    # iwm = imgs_with_mask(g_trans_img, g_trans_mask.toImage(256), color=[1, 1, 1])
    # send_images_to_tensorboard(writer, iwm, "RT", 0)

    R_t = DualTransformRegularizer.__call__(
        g_transforms,
        lambda trans_dict, img: W1(
            content_to_measure(
                cont_style_encoder.get_content(trans_dict['image'])),
            trans_dict['mask'])  # +
        # W2(content_to_measure(cont_style_encoder.get_content(trans_dict['image'])), trans_dict['mask'])
    )

    R_s = UnoTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img, ltnt: L1("R_s")
        (ltnt, cont_style_encoder.enc_style(trans_dict['image'])))

    fabric = ProbabilityMeasureFabric(256)
    barycenter = fabric.load(f"{Paths.default.models()}/face_barycenter").cuda(
    ).padding(70).transpose().batch_repeat(16)

    R_b = BarycenterRegularizer.__call__(barycenter)

    # tuner = CoefTuner([4.5, 10.5, 2.5, 0.7, 0.5], device=device)
    #                 [6.5, 7.9, 2.7, 2.06, 5.4, 0.7, 2.04]
    #                  3.3, 10.5,  6.2,  1.14, 10.88,  0.93,  2.6
    #                  4.3, 10.3, 5.9, 0.85, 10.1, 0.27, 4.5
    #                  [4.53, 9.97, 5.5, 0.01, 9.44, 1.05, 4.9
    tuner = GoldTuner([2.53, 40.97, 5.5, 0.01, 5.44, 1.05, 4.9],
                      device=device,
                      rule_eps=0.05,
                      radius=1,
                      active=False)
    gan_tuner = GoldTuner([20, 25, 25],
                          device=device,
                          rule_eps=1,
                          radius=20,
                          active=False)

    # rb_tuner = GoldTuner([0.7, 1.5, 10], device=device, rule_eps=0.02, radius=0.5)

    best_igor = 100

    for idx in pbar:
        i = idx + args.start_iter
        counter.update(i)

        if i > args.iter:
            print('Done!')
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        img_content = cont_style_encoder.get_content(real_img)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        img_content_variable = img_content.detach().requires_grad_(True)
        fake, fake_latent = generator(img_content_variable,
                                      noise,
                                      return_latents=True)

        model.discriminator_train([real_img], [fake], img_content)

        # fake_detach = fake.detach()
        fake_latent_test = fake_latent[:, [0, 13], :].detach()
        fake_content_pred = cont_style_encoder.get_content(fake)

        fake_latent_pred = cont_style_encoder.enc_style(fake)

        (writable("Generator loss", model.generator_loss)(
            [real_img], [fake], [fake_latent], img_content_variable) +  # 3e-5
         gan_tuner.sum_losses([
             L1("L1 content gan")(fake_content_pred, img_content.detach()),
             L1("L1 style gan")(fake_latent_pred, fake_latent_test),
             R_s(fake.detach(), fake_latent_pred),
         ])
         # L1("L1 content gan")(fake_content_pred, img_content.detach()) * 50 +  # 3e-7
         # L1("L1 style gan")(fake_latent_pred, fake_latent_test) * 10 +  # 8e-7
         # R_s(fake, barycenter) * 20
         ).minimize_step(model.optimizer.opt_min, style_opt)

        if i % 5 == 0:
            # fake_latent_pred = cont_style_encoder.enc_style(fake_detach)
            # (L1("L1 style gan")(fake_latent_pred, fake_latent_test)).__mul__(2).minimize_step(style_opt)
            img_latent = cont_style_encoder.enc_style(real_img[:16])
            restored = model.generator.module.decode(img_content[:16],
                                                     img_latent[:16])
            pred_measures: ProbabilityMeasure = content_to_measure(
                img_content[:16])

            noise1 = mixing_noise(16, args.latent, args.mixing, device)
            noise2 = mixing_noise(16, args.latent, args.mixing, device)
            fake1, _ = generator(img_content[:16], noise1)
            fake2, _ = generator(img_content[:16], noise2)

            cont_fake1 = cont_style_encoder.get_content(fake1)
            cont_fake2 = cont_style_encoder.get_content(fake2)

            # rb_coefs = rb_tuner.get_coef()
            # R_b = BarycenterRegularizer.__call__(barycenter, rb_coefs[0], rb_coefs[1], rb_coefs[2])
            #TUNER PART
            tuner.sum_losses([
                # writable("Fake-content D", model.loss.generator_loss)(real=None, fake=[fake1, img_content.detach()]),  # 1e-3
                writable("Real-content D", model.loss.generator_loss)
                (real=None, fake=[real_img, img_content]),  # 3e-5
                writable("R_b", R_b.__call__)(real_img[:16],
                                              pred_measures),  # 7e-5
                writable("R_t", R_t.__call__)(real_img[:16],
                                              pred_measures),  # -
                L1("L1 content between fake")(cont_fake1, cont_fake2),  # 1e-6
                L1("L1 image")(restored, real_img[:16]),  # 4e-5
                R_s(real_img[:16], img_latent),
                L1("L1 style restored")(cont_style_encoder.enc_style(restored),
                                        img_latent.detach())
            ]).minimize_step(cont_opt, model.optimizer.opt_min, style_opt)

            ##Without tuner part

            # (
            #         model.loss.generator_loss(real=None, fake=[real_img, img_content]) * 5 +
            #         (R_b + R_t * 0.4)(real_img, pred_measures) * 10 +
            #         L1("L1 content between fake")(cont_fake1, cont_fake2) * 1 +
            #         L1("L1 image")(restored, real_img) * 1
            #         # L1("L1 style gan")(fake_latent_pred, fake_latent_test) * 1
            # ).minimize_step(
            #     cont_opt,
            #     model.optimizer.opt_min
            # )

        if i % 100 == 0:
            print(i)
            with torch.no_grad():

                content, latent = cont_style_encoder(test_img)
                pred_measures: ProbabilityMeasure = content_to_measure(content)
                # ref_measures: ProbabilityMeasure = MaskToMeasure(size=256, padding=140).apply_to_mask(test_mask)
                # iwm = imgs_with_mask(test_img, ref_measures.toImage(256), color=[0, 0, 1])
                iwm = imgs_with_mask(test_img,
                                     pred_measures.toImage(256),
                                     color=[1, 1, 1])
                send_images_to_tensorboard(writer, iwm, "REAL", i)

                fake_img, _ = generator(content, [sample_z])
                iwm = imgs_with_mask(fake_img, pred_measures.toImage(256))
                send_images_to_tensorboard(writer, iwm, "FAKE", i)
                restored = model.generator.module.decode(content, latent)
                send_images_to_tensorboard(writer, restored, "RESTORED", i)

        if i % 100 == 0 and i > 0:
            pass
            # with torch.no_grad():
            #     igor = test(cont_style_encoder, test_pairs)
            #     writer.add_scalar("test error", igor, i)
            #     tuner.update(igor)
            #     gan_tuner.update(igor)
            #     # rb_tuner.update(igor)
            #
            # if igor < best_igor:
            #     best_igor = igor
            #     print("best igor")
            #     torch.save(
            #         {
            #             'g': generator.state_dict(),
            #             'd': discriminator.state_dict(),
            #             'enc': cont_style_encoder.state_dict(),
            #         },
            #         f'{Paths.default.nn()}/stylegan2_igor_3.pt',
            #     )

        if i % 10000 == 0 and i > 0:
            torch.save(
                {
                    'g': generator.module.state_dict(),
                    'd': discriminator.module.state_dict(),
                    'enc': cont_style_encoder.state_dict(),
                    # 'g_ema': g_ema.state_dict(),
                    # 'g_optim': g_optim.state_dict(),
                    # 'd_optim': d_optim.state_dict(),
                },
                f'{Paths.default.models()}/stylegan2_invertable_{str(i + starting_model_number).zfill(6)}.pt',
            )
err_pred_list_2 = []
err_bc_list = []

for i in range(30):

    test_img, test_mask = next(loader)
    test_img = test_img.cuda()

    content = cont_style_encoder.get_content(test_img)
    pred_measures: ProbabilityMeasure = content_to_measure(content)
    content2 = cont_style_encoder2.get_content(test_img)
    pred_measures2: ProbabilityMeasure = content_to_measure(content2)

    ref_measure = MaskToMeasure(size=256, padding=140, clusterize=True)(image=test_img, mask=test_mask)["mask"].cuda()

    err_pred = Samples_Loss(p=1)(pred_measures, ref_measure).item()
    err_pred_2 = Samples_Loss(p=1)(pred_measures2, ref_measure).item()
    err_bc = Samples_Loss(p=1)(barycenter, ref_measure).item()
    print("pred:", err_pred)
    print("bc:", err_bc)
    err_pred_list.append(err_pred)
    err_pred_list_2.append(err_pred_2)
    err_bc_list.append(err_bc)

# %%


print("pred mean:", sum(err_pred_list) / len(err_pred_list))
print("pred mean 2:", sum(err_pred_list_2) / len(err_pred_list_2))
print("bc mean:", sum(err_bc_list) / len(err_bc_list))
fabric = ProbabilityMeasureFabric(image_size)
barycenter = fabric.load("../examples/face_barycenter").cuda().crop(
    measure_size)
barycenter = fabric.cat([barycenter for b in range(batch_size)])

for i, (imgs, masks) in enumerate(dataloader, 0):
    imgs = imgs.cuda()
    measures: ProbabilityMeasure = fabric.from_coord_tensor(
        masks).cuda().padding(measure_size)

    t1 = time.time()
    with torch.no_grad():
        A, T = LinearTransformOT.forward(measures, barycenter)
    t2 = time.time()
    dist = Samples_Loss().forward(measures, barycenter)
    t3 = time.time()

    print(dist, t2 - t1, t3 - t2)

    # m_lin = measures.centered().multiply(A) + barycenter.mean()
    # plt.scatter(m_lin.coord[0, :, 1].cpu().numpy(), m_lin.coord[0, :, 0].cpu().numpy())
    # plt.scatter(barycenter.coord[0, :, 1].cpu().numpy(), barycenter.coord[0, :, 0].cpu().numpy())
    # plt.show()

    Atest = torch.tensor([[3, 0.2], [0, 1]],
                         device=device,
                         dtype=torch.float32)
    Atest = torch.cat([Atest[None, ]] * batch_size)
    bc_tr = barycenter.random_permute().multiply(Atest) + 0.1
def optimization_step():
    noise = NormalNoise(n_noise, device)
    measure2image = ResMeasureToImage(args.measure_size * 3 + noise.size(),
                                      args.image_size, ngf).cuda()

    netD = DCDiscriminator(ndf=ndf).cuda()
    gan_model = GANModel(measure2image,
                         HingeLoss(netD).add_generator_loss(nn.L1Loss(), L1),
                         lr=0.0004)

    fabric = ProbabilityMeasureFabric(args.image_size)
    barycenter = fabric.load("barycenter").cuda().padding(args.measure_size)
    print(barycenter.coord.shape)
    barycenter = fabric.cat([barycenter for b in range(args.batch_size)])
    print(barycenter.coord.shape)

    image2measure = ResImageToMeasure(args.measure_size).cuda()
    image2measure_opt = optim.Adam(image2measure.parameters(), lr=0.0002)

    def test():
        dataloader_test = torch.utils.data.DataLoader(dataset_test,
                                                      batch_size=40,
                                                      num_workers=20)

        sum_loss = 0

        for i, (imgs, masks) in enumerate(dataloader_test, 0):
            imgs = imgs.cuda().type(torch.float32)
            pred_measures: ProbabilityMeasure = image2measure(imgs)
            ref_measures: ProbabilityMeasure = fabric.from_mask(
                masks).cuda().padding(args.measure_size)
            ref_loss = Samples_Loss()(pred_measures, ref_measures)
            sum_loss += ref_loss.item()

        return sum_loss

    for epoch in range(20):

        ot_iters = 100
        print("epoch", epoch)
        test_imgs = None

        for i, imgs in enumerate(dataloader, 0):

            imgs = imgs.cuda().type(torch.float32)
            test_imgs = imgs
            pred_measures: ProbabilityMeasure = image2measure(imgs)
            cond = pred_measures.toChannels()
            n = cond.shape[0]
            barycenter_batch = barycenter.slice(0, n)

            z = noise.sample(n)
            cond = torch.cat((cond, z), dim=1)
            gan_model.train(imgs, cond.detach())

            with torch.no_grad():
                A, T = LinearTransformOT.forward(pred_measures,
                                                 barycenter_batch, ot_iters)

            bc_loss_T = Samples_Loss()(pred_measures,
                                       pred_measures.detach() + T)
            bc_loss_A = Samples_Loss()(
                pred_measures.centered(),
                pred_measures.centered().multiply(A).detach())
            bc_loss_W = Samples_Loss()(pred_measures.centered().multiply(A),
                                       barycenter_batch.centered())
            bc_loss = bc_loss_W * cw + bc_loss_A * ca + bc_loss_T * ct

            fake = measure2image(cond)
            g_loss = gan_model.generator_loss(imgs, fake)
            (g_loss + bc_loss).minimize_step(image2measure_opt)

    return test()
cond_gan_model.loss += GANLossObject(
                lambda dx, dy: Loss.ZERO(),
                lambda dgz, real, fake: Loss(
                    nn.L1Loss()(image2measure(fake[0]).coord, fabric.from_channels(real[1]).coord.detach())
                ) * 10,
                None
)

image2measure = ResImageToMeasure(args.measure_size).cuda()
image2measure_opt = optim.Adam(image2measure.parameters(), lr=0.0003)


R_b = BarycenterRegularizer.__call__(barycenter)
R_t = DualTransformRegularizer.__call__(
    g_transforms, lambda trans_dict:
    Samples_Loss()(image2measure(trans_dict['image']), trans_dict['mask'])
)

deform_array = list(np.linspace(0, 6, 1000))
Whole_Reg = R_t @ deform_array + R_b

for epoch in range(500):
    # if epoch > 0:
    #     cond_gan_model.optimizer.update_lr(0.5)
        # for i in image2measure_opt.param_groups:
        #     i['lr'] *= 0.5
    print("epoch", epoch)

    for i, (imgs, masks) in enumerate(dataloader, 0):
        if imgs.shape[0] != args.batch_size:
            continue
Ejemplo n.º 10
0
fabric = ProbabilityMeasureFabric(args.image_size)
barycenter = fabric.load("/home/ibespalov/unsupervised_pattern_segmentation/examples/face_barycenter").cuda().padding(args.measure_size).batch_repeat(args.batch_size)

g_transforms: albumentations.DualTransform = albumentations.Compose([
    MeasureToMask(size=256),
    ToNumpy(),
    NumpyBatch(albumentations.ElasticTransform(p=0.5, alpha=150, alpha_affine=1, sigma=10)),
    NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)),
    ToTensor(device),
    MaskToMeasure(size=256, padding=args.measure_size),
])

R_b = BarycenterRegularizer.__call__(barycenter)
R_t = DualTransformRegularizer.__call__(
    g_transforms, lambda trans_dict:
    Samples_Loss(scaling=0.85, p=1)(content_to_measure(cont_style_encoder(trans_dict['image'])[0]), trans_dict['mask'])
)

R_b.forward = send_to_tensorboard("R_b", counter=counter, writer=writer)(R_b.forward)
R_t.forward = send_to_tensorboard("R_t", counter=counter, writer=writer)(R_t.forward)

deform_array = list(np.linspace(0, 1, 1500))
Whole_Reg = R_t @ deform_array + R_b
l1_loss = nn.L1Loss()


def L1(name: Optional[str], writer: SummaryWriter = writer) -> Callable[[Tensor, Tensor], Loss]:

    if name:
        counter.active[name] = True
Ejemplo n.º 11
0
    MeasureToMask(size=256),
    ToNumpy(),
    NumpyBatch(
        albumentations.ElasticTransform(p=0.5,
                                        alpha=150,
                                        alpha_affine=1,
                                        sigma=10)),
    NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)),
    ToTensor(device),
    MaskToMeasure(size=256, padding=args.measure_size),
])

R_b = BarycenterRegularizer.__call__(barycenter)
R_t = DualTransformRegularizer.__call__(
    g_transforms,
    lambda trans_dict: Samples_Loss(scaling=0.85, p=1)(content_to_measure(
        cont_style_encoder(trans_dict['image'])[0]), trans_dict['mask']))

R_b.forward = send_to_tensorboard("R_b", counter=counter,
                                  writer=writer)(R_b.forward)
R_t.forward = send_to_tensorboard("R_t", counter=counter,
                                  writer=writer)(R_t.forward)

deform_array = list(np.linspace(0, 1, 1500))
Whole_Reg = R_t @ deform_array + R_b
l1_loss = nn.L1Loss()


def L1(name: Optional[str],
       writer: SummaryWriter = writer) -> Callable[[Tensor, Tensor], Loss]:

    if name:
# cond_gan_model.loss += GANLossObject(
#                 lambda dx, dy: Loss.ZERO(),
#                 lambda dgz, real, fake: Loss(
#                     nn.L1Loss()(image2measure(fake[0]).coord, fabric.from_channels(real[1]).coord.detach())
#                 ) * 10,
#                 None
# )

image2measure = ResImageToMeasure(args.measure_size).cuda()
image2measure_opt_strong = optim.Adam(image2measure.parameters(), lr=0.0001)
image2measure_opt = optim.Adam(image2measure.parameters(), lr=0.0003)

R_b = BarycenterRegularizer.__call__(barycenter)
R_t = DualTransformRegularizer.__call__(
    g_transforms, lambda trans_dict: Samples_Loss()
    (image2measure(trans_dict['image']), trans_dict['mask']))

deform_array = list(np.linspace(0, 6, 1000))
Whole_Reg = R_t @ deform_array + R_b

for epoch in range(500):
    # if epoch > 0:
    #     cond_gan_model.optimizer.update_lr(0.5)
    # for i in image2measure_opt.param_groups:
    #     i['lr'] *= 0.5
    print("epoch", epoch)

    for i, (imgs, masks) in enumerate(dataloader, 0):
        if imgs.shape[0] != args.batch_size:
            continue
encoder_HG = HG_softmax2020(num_classes=68, heatmap_size=64)
encoder_HG.load_state_dict(
    torch.load(f"{Paths.default.models()}/hg2_e29.pt", map_location="cpu"))
encoder_HG = encoder_HG.cuda()

for iter in range(3000):

    img = next(LazyLoader.celeba().loader).cuda()
    content = encoder_HG(img)
    coord, p = heatmap_to_measure(content)
    mes = ProbabilityMeasure(p, coord)

    barycenter_cat = fabric.cat([barycenter] * batch_size)

    loss = Samples_Loss()(barycenter_cat, mes)

    opt.zero_grad()
    loss.to_tensor().backward()
    opt.step()

    barycenter.probability.data = barycenter.probability.relu().data
    barycenter.probability.data /= barycenter.probability.sum(dim=1,
                                                              keepdim=True)

    if iter % 100 == 0:
        print(iter, loss.item())

        plt.imshow(barycenter.toImage(200)[0][0].detach().cpu().numpy())
        plt.show()