Пример #1
0
    def gan_train(real_img, skeleton):

        B = real_img.shape[0]
        C = 512

        requires_grad(generator, True)
        requires_grad(decoder, True)
        condition = skeleton.detach().requires_grad_(True)

        noise = mixing_noise(B, C, 0.9, real_img.device)

        fake, fake_latent = generator(condition, noise, return_latents=True)

        model.discriminator_train([real_img], [fake], [condition])

        WR.writable("Generator brule_loss", model.generator_loss)([real_img], [fake], [condition]) \
            .minimize_step(model.optimizer.opt_min)

        fake = fake.detach()

        fake_latent_pred = style_encoder(fake)
        restored = decoder(condition, style_encoder(real_img))
        fake_latent = torch.cat([f[:, None, :] for f in fake_latent],
                                dim=1).detach()

        coefs = json.load(open("../parameters/gan_loss.json"))

        (WR.L1("L1 restored")(restored, real_img) * coefs["L1 restored"] +
         WR.L1("L1 style gan")(fake_latent_pred, fake_latent) *
         coefs["L1 style gan"]).minimize_step(model.optimizer.opt_min,
                                              style_opt)
Пример #2
0
    def do_train(real_img):

        B = real_img.shape[0]

        coefs = json.load(open("../parameters/content_loss.json"))

        requires_grad(encoder_HG, True)
        requires_grad(decoder, False)
        requires_grad(generator, False)

        encoded = encoder_HG(real_img)
        pred_measures: UniformMeasure2D01 = encoded["mes"]

        heatmap_content = heatmapper.forward(pred_measures.coord).detach()

        restored = decoder(encoded["skeleton"], style_encoder(real_img))

        noise = mixing_noise(B, C, 0.9, real_img.device)
        fake, _ = generator(encoded["skeleton"], noise)
        fake_content = encoder_HG(fake.detach())["mes"]

        ll = (WR.L1("L1 image")(restored, real_img) * coefs["L1 image"] +
              WR.writable("fake_content brule_loss", coord_hm_loss)
              (fake_content, heatmap_content) *
              coefs["fake_content brule_loss"] +
              WR.writable("Fake-content D", model.loss.generator_loss)
              (real=None, fake=[fake, encoded["skeleton"].detach()]) *
              coefs["Fake-content D"])

        ll.minimize_step(model.optimizer.opt_min)
Пример #3
0
def train_content(cont_opt, R_b, R_t, real_img, encoder_HG):

    # heatmapper = ToGaussHeatMap(256, 4)
    requires_grad(encoder_HG, True)

    coefs = json.load(open(os.path.join(sys.path[0], "../parameters/content_loss.json")))
    encoded = encoder_HG(real_img)
    pred_measures: UniformMeasure2D01 = encoded["mes"]

    heatmap_content = encoded["hm"]

    ll = (
        WR.writable("R_b", R_b.__call__)(real_img, pred_measures) * coefs["R_b"] +
        WR.writable("R_t", R_t.__call__)(real_img, heatmap_content) * coefs["R_t"]
    )

    ll.minimize_step(cont_opt)
Пример #4
0
    fake_latent_pred = enc_dec.encode_latent(fake)

    gan_model_tuda.discriminator_train([real_img], [fake.detach()])
    (
            gan_model_tuda.generator_loss([real_img], [fake]) +
            l1_loss(fake_latent_pred, fake_latent) * coefs["style"]
    ).minimize_step(gan_model_tuda.optimizer.opt_min, style_opt)

    hm_pred = hg.forward(real_img)["hm_sum"]
    hm_ref = heatmapper.forward(landmarks).detach().sum(1, keepdim=True)
    gan_model_obratno.discriminator_train([hm_ref], [hm_pred.detach()])
    gan_model_obratno.generator_loss([hm_ref], [hm_pred]).__mul__(coefs["obratno"]) \
        .minimize_step(gan_model_obratno.optimizer.opt_min)

    fake2, _ = enc_dec.generate(heatmap_sum)
    WR.writable("cycle", mes_loss.forward)(hg.forward(fake2)["mes"], UniformMeasure2D01(landmarks)).__mul__(coefs["hm"]) \
        .minimize_step(gan_model_tuda.optimizer.opt_min, gan_model_obratno.optimizer.opt_min)

    latent = enc_dec.encode_latent(g_transforms(image=real_img)["image"])
    restored = enc_dec.decode(hg.forward(real_img)["hm_sum"], latent)
    WR.writable("cycle2", psp_loss.forward)(real_img, real_img, restored, latent).__mul__(coefs["img"]) \
        .minimize_step(gan_model_tuda.optimizer.opt_min, gan_model_obratno.optimizer.opt_min, style_opt)

    image_accumulator.step(i)
    hm_accumulator.step(i)

    if i % 10000 == 0 and i > 0:
        torch.save(
            {
                'gi': enc_dec.generator.state_dict(),
                'gh': hg.state_dict(),
                'di': discriminator_img.state_dict(),
Пример #5
0
        pred = hg.forward(real_img)
        hm_pred = pred["hm_sum"].detach()
        mes_pred = pred["mes"].detach()

    fake, fake_latent = enc_dec.generate(hm_pred)
    fake_latent_pred = enc_dec.encode_latent(fake)

    gan_model_tuda.discriminator_train([real_img], [fake.detach()])
    (gan_model_tuda.generator_loss([real_img], [fake]) +
     l1_loss(fake_latent_pred, fake_latent) * coefs["style"]).minimize_step(
         gan_model_tuda.optimizer.opt_min, style_opt)

    train_content(cont_opt, R_b, R_t, real_img, hg)

    fake2, _ = enc_dec.generate(hm_pred)
    WR.writable("cycle", mes_loss.forward)(hg.forward(fake2)["mes"], mes_pred).__mul__(coefs["hm"]) \
        .minimize_step(gan_model_tuda.optimizer.opt_min, cont_opt)

    latent = enc_dec.encode_latent(g_transforms(image=real_img)["image"])
    restored = enc_dec.decode(hg.forward(real_img)["hm_sum"], latent)
    WR.writable("cycle2", psp_loss.forward)(real_img, real_img, restored, latent).__mul__(coefs["img"])\
        .minimize_step(gan_model_tuda.optimizer.opt_min, cont_opt, style_opt)

    # requires_grad(discriminator_img, False)
    # requires_grad(enc_dec.generator, False)
    # fake3, _ = enc_dec.generate(hg.forward(real_img)["hm_sum"])
    # gan_model_tuda.generator_loss([real_img], [fake3]).__mul__(coefs["ganhg"]).minimize_step(cont_opt)
    # requires_grad(discriminator_img, True)
    # requires_grad(enc_dec.generator, True)

    image_accumulator.step(i)
    hm_accumulator.step(i)
Пример #6
0
    fake, fake_latent = enc_dec.generate(real_seg)
    fake_latent_pred = enc_dec.encode_latent(fake)

    gan_model_tuda.discriminator_train([real_img], [fake.detach()])
    (
        gan_model_tuda.generator_loss([real_img], [fake]) +
        l1_loss(fake_latent_pred, fake_latent) * coefs["style"]
    ).minimize_step(gan_model_tuda.optimizer.opt_min, style_opt)

    seg_pred = hg.forward(real_img)
    gan_model_obratno.discriminator_train([real_seg], [seg_pred.detach()])
    gan_model_obratno.generator_loss([real_seg], [seg_pred]).__mul__(coefs["obratno"])\
        .minimize_step(gan_model_obratno.optimizer.opt_min)

    fake2, _ = enc_dec.generate(real_seg)
    WR.writable("cycle", our_loss.forward)(hg.forward(fake2), real_seg).__mul__(coefs["hm"])\
        .minimize_step(gan_model_tuda.optimizer.opt_min, gan_model_obratno.optimizer.opt_min)

    latent = enc_dec.encode_latent(real_img)
    # latent = enc_dec.encode_latent(g_transforms(image=real_img)["image"])
    restored = enc_dec.decode(hg.forward(real_img), latent)
    WR.writable("cycle2", psp_loss.forward)(real_img, real_img, restored, latent).__mul__(coefs["img"])\
        .minimize_step(gan_model_tuda.optimizer.opt_min, gan_model_obratno.optimizer.opt_min, style_opt)

    image_accumulator.step(i)
    hm_accumulator.step(i)

    if i % 10000 == 0 and i > 0:
        torch.save(
            {
                'gi': enc_dec.generator.state_dict(),
                'gh': hg.state_dict(),
Пример #7
0
def hm_svoego_roda_loss(pred, target):

    pred_xy, _ = heatmap_to_measure(pred)
    t_xy, _ = heatmap_to_measure(target)

    return Loss(nn.BCELoss()(pred, target) * 10 +
                nn.MSELoss()(pred_xy, t_xy) * 0.005 +
                (pred - target).abs().mean() * 3)


for i in range(100000):

    WR.counter.update(i)

    batch = next(LazyLoader.w300().loader_train_inf)
    real_img = batch["data"].cuda()
    landmarks = torch.clamp(batch["meta"]['keypts_normalized'].cuda(), max=1)

    WR.writable("cycle", hm_svoego_roda_loss)(hg.forward(real_img)["hm"], heatmapper.forward(landmarks))\
        .minimize_step(hg_opt)

    if i % 100 == 0:
        print(i)
        with torch.no_grad():

            tl2 = verka_300w(hg)
            writer.add_scalar("verka", tl2, i)

            # sk_pred = hg.forward(test_img)["hm_sum"]
            # send_images_to_tensorboard(writer, test_img + sk_pred, "REAL", i)
Пример #8
0
    heatmap_sum = heatmapper.forward(landmarks).sum(1, keepdim=True).detach()

    fake, fake_latent = enc_dec.generate(heatmap_sum)
    fake_latent_pred = enc_dec.encode_latent(fake)

    real_gan_img = real_img if i % 2 == 0 else next(
        LazyLoader.celeba().loader).cuda()

    gan_model_tuda.discriminator_train([real_gan_img], [fake.detach()])
    (gan_model_tuda.generator_loss([real_gan_img], [fake]) +
     l1_loss(fake_latent_pred, fake_latent)).minimize_step(
         gan_model_tuda.optimizer.opt_min, style_opt)

    latent = enc_dec.encode_latent(real_img)
    restored = enc_dec.decode(heatmap_sum, latent)
    WR.writable("cycle2", psp_loss.forward)(real_img, real_img, restored, latent).__mul__(20)\
        .minimize_step(gan_model_tuda.optimizer.opt_min, style_opt)

    image_accumulator.step(i)
    # enc_accumulator.step(i)

    if i % 10000 == 0 and i > 0:
        torch.save(
            {
                'gi': enc_dec.generator.state_dict(),
                'di': discriminator_img.state_dict(),
                's': enc_dec.style_encoder.state_dict()
            },
            f'{Paths.default.models()}/300w_encoder_{str(i + starting_model_number).zfill(6)}.pt',
        )

    if i % 100 == 0:
Пример #9
0
# image_accumulator = Accumulator(enc_dec.generator, decay=0.99, write_every=100)
hm_accumulator = Accumulator(hg, decay=0.99, write_every=100)


for i in range(100000):

    WR.counter.update(i)

    batch = next(LazyLoader.cardio().loader_train_inf)
    real_img = batch["image"].cuda()
    train_landmarks = batch["keypoints"].cuda()

    coefs = json.load(open(os.path.join(sys.path[0], "../parameters/cycle_loss_2.json")))

    WR.writable("sup", mes_loss.forward)(hg.forward(real_img)["mes"], UniformMeasure2D01(train_landmarks)).__mul__(coefs["sup"]) \
        .minimize_step(cont_opt)

    hm_accumulator.step(i)

    if i % 1000 == 0 and i > 0:
        torch.save(
            {
                'gh': hg.state_dict(),
            },
            f'{Paths.default.models()}/cardio_brule_sup_{str(i + starting_model_number).zfill(6)}.pt',
        )


    if i % 100 == 0:
        print(i)
        with torch.no_grad():
Пример #10
0
    batch = next(LazyLoader.human36(use_mask=True).loader_train_inf)
    real_img = batch["A"].cuda()
    landmarks = torch.clamp(batch["paired_B"].cuda(), min=0, max=1)
    heatmap = heatmapper.forward(landmarks).detach()

    coefs = json.load(
        open(os.path.join(sys.path[0], "../parameters/cycle_loss.json")))

    fake, fake_latent = enc_dec.generate(heatmap)

    gan_model_tuda.discriminator_train([real_img], [fake.detach()])
    (gan_model_tuda.generator_loss([real_img], [fake])).minimize_step(
        gan_model_tuda.optimizer.opt_min)

    fake2, _ = enc_dec.generate(heatmap)
    WR.writable("cycle", mes_loss.forward)(hg.forward(fake2)["mes"], UniformMeasure2D01(landmarks)).__mul__(coefs["hm"]) \
        .minimize_step(gan_model_tuda.optimizer.opt_min)

    image_accumulator.step(i)

    if i % 10000 == 0 and i > 0:
        torch.save(
            {
                'gi': enc_dec.generator.state_dict(),
                'di': discriminator_img.state_dict(),
                's': enc_dec.style_encoder.state_dict()
            },
            f'{Paths.default.models()}/human_gan_{str(i + starting_model_number).zfill(6)}.pt',
        )

    if i % 100 == 0:
        print(i)