示例#1
0
enc_dec = StyleGanAutoEncoder().load_state_dict(weights, style=False).cuda()

discriminator_img = Discriminator(image_size)
discriminator_img.load_state_dict(weights['di'])
discriminator_img = discriminator_img.cuda()

heatmapper = ToGaussHeatMap(256, 4)
hg = HG_heatmap(heatmapper, num_classes=200)
# hg.load_state_dict(weights['gh'])
hg = hg.cuda()
hm_discriminator = Discriminator(image_size, input_nc=1, channel_multiplier=1)
hm_discriminator.load_state_dict(weights["dh"])
hm_discriminator = hm_discriminator.cuda()

gan_model_tuda = StyleGanModel[HeatmapToImage](enc_dec.generator, StyleGANLoss(discriminator_img), (0.001/4, 0.0015/4))
gan_model_obratno = StyleGanModel[HG_skeleton](hg, StyleGANLoss(hm_discriminator, r1=3), (2e-5, 0.0015/4))

style_opt = optim.Adam(enc_dec.style_encoder.parameters(), lr=1e-5)

print(f"board path: {Paths.default.board()}/cardio{int(time.time())}")
writer = SummaryWriter(f"{Paths.default.board()}/cardio{int(time.time())}")
WR.writer = writer

#%%

test_batch = next(LazyLoader.cardio().loader_train_inf)
test_img = test_batch["image"].cuda()
test_landmarks = next(LazyLoader.cardio_landmarks(args.data_path).loader_train_inf).cuda()
test_measure = UniformMeasure2D01(torch.clamp(test_landmarks, max=1))
test_hm = heatmapper.forward(test_measure.coord).sum(1, keepdim=True).detach()
def train(generator, decoder, discriminator, encoder_HG, style_encoder, device,
          starting_model_number):
    latent_size = 512
    batch_size = 8
    sample_z = torch.randn(8, latent_size, device=device)
    Celeba.batch_size = batch_size
    W300DatasetLoader.batch_size = batch_size
    W300DatasetLoader.test_batch_size = 16

    test_img = next(LazyLoader.w300().loader_train_inf)["data"][:8].cuda()

    model = CondStyleGanModel(generator, StyleGANLoss(discriminator),
                              (0.001 / 4, 0.0015 / 4))

    style_opt = optim.Adam(style_encoder.parameters(),
                           lr=5e-4,
                           betas=(0.9, 0.99))
    cont_opt = optim.Adam(encoder_HG.parameters(), lr=3e-5, betas=(0.5, 0.97))

    g_transforms: albumentations.DualTransform = albumentations.Compose([
        ToNumpy(),
        NumpyBatch(
            albumentations.Compose([
                albumentations.ElasticTransform(p=0.7,
                                                alpha=150,
                                                alpha_affine=1,
                                                sigma=10),
                albumentations.ShiftScaleRotate(p=0.9, rotate_limit=15),
            ])),
        ToTensor(device),
    ])

    g_transforms_without_norm: albumentations.DualTransform = albumentations.Compose(
        [
            ToNumpy(),
            NumpyBatch(
                albumentations.Compose([
                    albumentations.ElasticTransform(p=0.3,
                                                    alpha=150,
                                                    alpha_affine=1,
                                                    sigma=10),
                    albumentations.ShiftScaleRotate(p=0.7, rotate_limit=15),
                ])),
            ToTensor(device),
        ])

    R_t = DualTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img: coord_hm_loss(
            encoder_HG(trans_dict['image'])["coords"], trans_dict['mask']))

    R_s = UnoTransformRegularizer.__call__(
        g_transforms, lambda trans_dict, img, ltnt: WR.L1("R_s")
        (ltnt, style_encoder(trans_dict['image'])))

    barycenter: UniformMeasure2D01 = UniformMeasure2DFactory.load(
        f"{Paths.default.models()}/face_barycenter_68").cuda().batch_repeat(
            batch_size)

    R_b = BarycenterRegularizer.__call__(barycenter, 1.0, 2.0, 4.0)

    tuner = GoldTuner([0.37, 2.78, 0.58, 1.43, 3.23],
                      device=device,
                      rule_eps=0.001,
                      radius=0.3,
                      active=False)

    trainer_gan = gan_trainer(model, generator, decoder, encoder_HG,
                              style_encoder, R_s, style_opt, g_transforms)
    content_trainer = content_trainer_with_gan(cont_opt, tuner, encoder_HG,
                                               R_b, R_t, model, generator,
                                               g_transforms, decoder,
                                               style_encoder)
    supervise_trainer = content_trainer_supervised(
        cont_opt, encoder_HG,
        LazyLoader.w300().loader_train_inf)

    for i in range(11000):
        WR.counter.update(i)

        requires_grad(encoder_HG, False)
        real_img = next(LazyLoader.celeba().loader).to(device)

        encoded = encoder_HG(real_img)
        internal_content = encoded["skeleton"].detach()

        trainer_gan(i, real_img, internal_content)
        # content_trainer(real_img)
        train_content(cont_opt, R_b, R_t, real_img, model, encoder_HG, decoder,
                      generator, style_encoder)
        supervise_trainer()

        encoder_ema.accumulate(encoder_HG.module, i, 0.98)
        if i % 50 == 0 and i > 0:
            encoder_ema.write_to(encoder_HG.module)

        if i % 100 == 0:
            coefs = json.load(open("../parameters/content_loss.json"))
            print(i, coefs)
            with torch.no_grad():

                # pred_measures_test, sparse_hm_test = encoder_HG(test_img)
                encoded_test = encoder_HG(test_img)
                pred_measures_test: UniformMeasure2D01 = UniformMeasure2D01(
                    encoded_test["coords"])
                heatmaper_256 = ToGaussHeatMap(256, 1.0)
                sparse_hm_test_1 = heatmaper_256.forward(
                    pred_measures_test.coord)

                latent_test = style_encoder(test_img)

                sparce_mask = sparse_hm_test_1.sum(dim=1, keepdim=True)
                sparce_mask[sparce_mask < 0.0003] = 0
                iwm = imgs_with_mask(test_img, sparce_mask)
                send_images_to_tensorboard(WR.writer, iwm, "REAL", i)

                fake_img, _ = generator(encoded_test["skeleton"], [sample_z])
                iwm = imgs_with_mask(fake_img, pred_measures_test.toImage(256))
                send_images_to_tensorboard(WR.writer, iwm, "FAKE", i)

                restored = decoder(encoded_test["skeleton"], latent_test)
                iwm = imgs_with_mask(restored, pred_measures_test.toImage(256))
                send_images_to_tensorboard(WR.writer, iwm, "RESTORED", i)

                content_test_256 = (encoded_test["skeleton"]).repeat(1, 3, 1, 1) * \
                    torch.tensor([1.0, 1.0, 0.0], device=device).view(1, 3, 1, 1)

                content_test_256 = (content_test_256 - content_test_256.min()
                                    ) / content_test_256.max()
                send_images_to_tensorboard(WR.writer,
                                           content_test_256,
                                           "HM",
                                           i,
                                           normalize=False,
                                           range=(0, 1))

        if i % 50 == 0 and i >= 0:
            test_loss = liuboff(encoder_HG)
            print("liuboff", test_loss)
            # test_loss = nadbka(encoder_HG)
            tuner.update(test_loss)
            WR.writer.add_scalar("liuboff", test_loss, i)

        if i % 10000 == 0 and i > 0:
            torch.save(
                {
                    'g': generator.module.state_dict(),
                    'd': discriminator.module.state_dict(),
                    'c': encoder_HG.module.state_dict(),
                    "s": style_encoder.state_dict(),
                    "e": encoder_ema.storage_model.state_dict()
                },
                f'{Paths.default.models()}/stylegan2_new_{str(i + starting_model_number).zfill(6)}.pt',
            )
示例#3
0
generator = Generator(FromStyleConditionalGenerator(image_size, noise_size),
                      n_mlp=8)
discriminator = Discriminator(image_size)

starting_model_number = 290000
weights = torch.load(
    f'{Paths.default.models()}/celeba_gan_256_{str(starting_model_number).zfill(6)}.pt',
    map_location="cpu")

discriminator.load_state_dict(weights['d'])
generator.load_state_dict(weights['g'])

generator = generator.cuda()
discriminator = discriminator.cuda()

gan_model = StyleGanModel(generator, StyleGANLoss(discriminator),
                          (0.001, 0.0015))
gan_accumulator = Accumulator(generator, decay=0.99, write_every=100)

writer = SummaryWriter(f"{Paths.default.board()}/celeba{int(time.time())}")

print("Starting Training Loop...")
starting_model_number = 0

for i in range(300000):

    print(i)

    real_img = next(LazyLoader.celeba().loader).to(device)

    noise: List[Tensor] = mixing_noise(batch_size, noise_size, 0.9, device)
示例#4
0
    map_location="cpu")

enc_dec = StyleGanAutoEncoder().load_state_dict(weights, style=False).cuda()

discriminator_img = Discriminator(image_size)
discriminator_img.load_state_dict(weights['di'])
discriminator_img = discriminator_img.cuda()

heatmapper = ToGaussHeatMap(256, 4)
hg = HG_heatmap(heatmapper, num_blocks=1, num_classes=200)
hg.load_state_dict(weights['gh'])
hg = hg.cuda()
cont_opt = optim.Adam(hg.parameters(), lr=2e-5, betas=(0, 0.8))

gan_model_tuda = StyleGanModel[HeatmapToImage](enc_dec.generator,
                                               StyleGANLoss(discriminator_img),
                                               (0.001 / 4, 0.0015 / 4))

style_opt = optim.Adam(enc_dec.style_encoder.parameters(), lr=1e-5)

writer = SummaryWriter(
    f"{Paths.default.board()}/brule1_cardio_{int(time.time())}")
WR.writer = writer

batch = next(LazyLoader.cardio().loader_train_inf)

batch_test = next(iter(LazyLoader.cardio().test_loader))
test_img = batch["image"].cuda()

test_landmarks = batch["keypoints"].cuda()
test_measure = UniformMeasure2D01(torch.clamp(test_landmarks, max=1))
示例#5
0
    f'{Paths.default.models()}/lmgen_{N}_{str(starting_model_number).zfill(6)}.pt',
    map_location="cpu")

heatmapper = ToGaussHeatMap(256, 4)
hg = nn.Sequential(EqualLinear(100, 256, activation='fused_lrelu'),
                   EqualLinear(256, 256, activation='fused_lrelu'),
                   EqualLinear(256, 256, activation='fused_lrelu'),
                   EqualLinear(256, 256, activation='fused_lrelu'),
                   EqualLinear(256, 136), nn.Sigmoid(), View(68, 2))
hg.load_state_dict(weights['gh'])
hg = hg.cuda()
hm_discriminator = Discriminator(image_size, input_nc=1, channel_multiplier=1)
hm_discriminator.load_state_dict(weights["dh"])
hm_discriminator = hm_discriminator.cuda()

gan_model = StyleGanModel[nn.Module](hg, StyleGANLoss(hm_discriminator),
                                     (0.001, 0.0015))

writer = SummaryWriter(f"{Paths.default.board()}/lmgen{int(time.time())}")
WR.writer = writer

test_noise = torch.randn(batch_size, 100, device=device)

hm_accumulator = Accumulator(hg, decay=0.98, write_every=100)

#
# os.mkdir(f"{Paths.default.data()}/w300_gen_{N}")
# os.mkdir(f"{Paths.default.data()}/w300_gen_{N}/lmbc")
#
#
for i in range(7000 // batch_size):
示例#6
0
heatmapper = ToGaussHeatMap(image_size, 2)
hg = HG_heatmap(heatmapper,
                num_classes=measure_size,
                image_size=image_size,
                num_blocks=4)
hg.load_state_dict(weights['gh'])
hg = hg.cuda()
hm_discriminator = Discriminator(image_size,
                                 input_nc=measure_size,
                                 channel_multiplier=1)
hm_discriminator.load_state_dict(weights["dh"])
hm_discriminator = hm_discriminator.cuda()

gan_model_tuda = StyleGanModel[HeatmapToImage](enc_dec.generator,
                                               StyleGANLoss(discriminator_img),
                                               (0.001 / 4, 0.0015 / 4))
gan_model_obratno = StyleGanModel[HG_skeleton](hg,
                                               StyleGANLoss(hm_discriminator),
                                               (2e-5, 0.0015 / 4))

style_opt = optim.Adam(enc_dec.style_encoder.parameters(), lr=1e-5)

writer = SummaryWriter(f"{Paths.default.board()}/human{int(time.time())}")
WR.writer = writer

batch = next(LazyLoader.human36().loader_train_inf)
test_img = batch["A"].cuda()
test_landmarks = torch.clamp(next(
    LazyLoader.human_landmarks(args.data_path).loader_train_inf).cuda(),
                             max=1)
示例#7
0
    map_location="cpu"
)

enc_dec = StyleGanAutoEncoder().load_state_dict(weights).cuda()

discriminator_img = Discriminator(image_size)
discriminator_img.load_state_dict(weights['di'])
discriminator_img = discriminator_img.cuda()

heatmapper = ToGaussHeatMap(256, 4)
hg = HG_heatmap(heatmapper, num_blocks=1)
hg.load_state_dict(weights['gh'])
hg = hg.cuda()
cont_opt = optim.Adam(hg.parameters(), lr=2e-5, betas=(0, 0.8))

gan_model_tuda = StyleGanModel[HeatmapToImage](enc_dec.generator, StyleGANLoss(discriminator_img), (0.001/4, 0.0015/4))

style_opt = optim.Adam(enc_dec.style_encoder.parameters(), lr=1e-5)

writer = SummaryWriter(f"{Paths.default.board()}/brule1_{int(time.time())}")
WR.writer = writer

batch = next(LazyLoader.w300().loader_train_inf)
test_img = batch["data"].cuda()
test_landmarks = torch.clamp(next(LazyLoader.w300_landmarks(args.data_path).loader_train_inf).cuda(), max=1)
test_hm = heatmapper.forward(test_landmarks).sum(1, keepdim=True).detach()
test_noise = mixing_noise(batch_size, 512, 0.9, device)

psp_loss = PSPLoss().cuda()
mes_loss = MesBceWasLoss(heatmapper, bce_coef=10000, was_coef=100)