def train(generator, discriminator, encoder, style_encoder, device, starting_model_number): batch = 32 Celeba.batch_size = batch latent_size = 512 model = CondStyleGanModel(generator, StyleGANLoss(discriminator), (0.001, 0.0015)) style_opt = optim.Adam(style_encoder.parameters(), lr=5e-4, betas=(0.5, 0.97)) g_transforms: albumentations.DualTransform = albumentations.Compose([ ToNumpy(), NumpyBatch( albumentations.ElasticTransform(p=0.8, alpha=150, alpha_affine=1, sigma=10)), NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)), ToTensor(device) ]) R_s = UnoTransformRegularizer.__call__( g_transforms, lambda trans_dict, img, ltnt: L1("R_s") (ltnt, style_encoder(trans_dict['image']))) sample_z = torch.randn(batch, latent_size, device=device) test_img = next(LazyLoader.celeba().loader).to(device) print(test_img.shape) test_cond = encoder(test_img) requires_grad(encoder, False) # REMOVE BEFORE TRAINING t_start = time.time() for i in range(100000): counter.update(i) real_img = next(LazyLoader.celeba().loader).to(device) img_content = encoder(real_img).detach() noise = mixing_noise(batch, latent_size, 0.9, device) fake, _ = generator(img_content, noise) model.discriminator_train([real_img], [fake.detach()], img_content) writable("Generator loss", model.generator_loss)([real_img], [fake], [], img_content)\ .minimize_step(model.optimizer.opt_min) # print("gen train", time.time() - t1) if i % 5 == 0 and i > 0: noise = mixing_noise(batch, latent_size, 0.9, device) img_content = encoder(real_img).detach() fake, fake_latent = generator(img_content, noise, return_latents=True) fake_latent_test = fake_latent[:, [0, 13], :].detach() fake_latent_pred = style_encoder(fake) fake_content_pred = encoder(fake) restored = generator.module.decode( img_content[:batch // 2], style_encoder(real_img[:batch // 2])) (HMLoss("BCE content gan", 5000)(fake_content_pred, img_content) + L1("L1 restored")(restored, real_img[:batch // 2]) * 50 + L1("L1 style gan")(fake_latent_pred, fake_latent_test) * 30 + R_s(fake.detach(), fake_latent_pred) * 50).minimize_step( model.optimizer.opt_min, style_opt) if i % 100 == 0: t_100 = time.time() print(i, t_100 - t_start) t_start = time.time() with torch.no_grad(): fake_img, _ = generator(test_cond, [sample_z]) coords, p = heatmap_to_measure(test_cond) test_mes = ProbabilityMeasure(p, coords) iwm = imgs_with_mask(fake_img, test_mes.toImage(256)) send_images_to_tensorboard(writer, iwm, "FAKE", i) iwm = imgs_with_mask(test_img, test_mes.toImage(256)) send_images_to_tensorboard(writer, iwm, "REAL", i) restored = generator.module.decode(test_cond, style_encoder(test_img)) send_images_to_tensorboard(writer, restored, "RESTORED", i) if i % 10000 == 0 and i > 0: torch.save( { 'g': generator.state_dict(), 'd': discriminator.state_dict(), 'style': style_encoder.state_dict() # 'enc': cont_style_encoder.state_dict(), }, f'/trinity/home/n.buzun/PycharmProjects/saved/stylegan2_w300_{str(starting_model_number + i).zfill(6)}.pt', )
def apply_to_mask(self, img: ProbabilityMeasure, **params): return img.toImage(self.size)
# HMLoss("BCE content gan", 5000)(fake_content_pred, img_content.detach()) + # Loss(nn.L1Loss()(restored, real_img[:W300DatasetLoader.batch_size//2]) * 50) + # Loss(nn.L1Loss()(fake_latent_pred, fake_latent_test) * 25) + # R_s(fake.detach(), fake_latent_pred) * 50 # ).minimize_step( # model.optimizer.opt_min, # style_opt, # ) # img_content = encoder_HG(real_img) # fake, fake_latent = generator(img_content, noise, return_latents=True) # fake_content_pred = encoder_HG(fake) # # # disc_influence = model.loss.generator_loss(real=None, fake=[real_img, img_content]) * 2 # (HMLoss("BCE content gan", 1)(fake_content_pred, img_content.detach()) + # disc_influence).minimize_step(enc_opt) if i % 50 == 0 and i > 0: with torch.no_grad(): test_loss = test(encoder_HG) print(test_loss) # tuner.update(test_loss) coord, p = heatmap_to_measure(encoder_HG(w300_test_image)) pred_measure = ProbabilityMeasure(p, coord) iwm = imgs_with_mask(w300_test_image, pred_measure.toImage(256)) send_images_to_tensorboard(writer, iwm, "W300_test_image", i) writer.add_scalar("test_loss", test_loss, i) # torch.save(enc.state_dict(), f"/home/ibespalov/pomoika/hg2_e{epoch}.pt")