def test(): dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=40, num_workers=20) sum_loss = 0 for i, (imgs, masks) in enumerate(dataloader_test, 0): imgs = imgs.cuda().type(torch.float32) pred_measures: ProbabilityMeasure = image2measure(imgs) ref_measures: ProbabilityMeasure = fabric.from_mask( masks).cuda().padding(args.measure_size) ref_loss = Samples_Loss()(pred_measures, ref_measures) sum_loss += ref_loss.item() return sum_loss
def verka(encoder: nn.Module): res = [] for i, (image, lm) in enumerate(LazyLoader.celeba_test(64)): content = encoder(image.cuda()) mes = UniformMeasure2D01(lm.cuda()) pred_measures: UniformMeasure2D01 = UniformMeasure2DFactory.from_heatmap(content) res.append(Samples_Loss(p=1)(mes, pred_measures).item() * image.shape[0]) return np.mean(res)/len(LazyLoader.celeba_test(1).dataset)
def loss(image: Tensor, mask: ProbabilityMeasure): # t1 = time.time() with torch.no_grad(): A, T = LinearTransformOT.forward(mask, barycenter, 100) t_loss = Samples_Loss(scaling=0.8, border=0.0001)(mask, mask.detach() + T) a_loss = Samples_Loss(scaling=0.8, border=0.0001)( mask.centered(), mask.centered().multiply(A).detach()) w_loss = Samples_Loss(scaling=0.85, border=0.00001)( mask.centered().multiply(A), barycenter.centered().detach()) # print(time.time() - t1) return a_loss * ca + w_loss * cw + t_loss * ct
def test(cont_style_encoder, pairs): W1 = Samples_Loss(scaling=0.9, p=1) err_list = [] for img, masks in pairs: mes: ProbabilityMeasure = MaskToMeasure( size=256, padding=140).apply_to_mask(masks).cuda() real_img = img.cuda() img_content = cont_style_encoder.get_content(real_img).detach() err_list.append(W1(content_to_measure(img_content), mes).item()) print("test:", sum(err_list) / len(err_list)) return sum(err_list) / len(err_list)
def train(args, loader, generator, discriminator, device, cont_style_encoder, starting_model_number): loader = sample_data(loader) pbar = range(args.iter) sample_z = torch.randn(8, args.latent, device=device) test_img = next(loader)[:8] test_img = test_img.cuda() # test_pairs = [next(loader) for _ in range(50)] loss_st: StyleGANLoss = StyleGANLoss(discriminator) model = CondStyleGanModel(generator, loss_st, (0.001, 0.0015)) style_opt = optim.Adam(cont_style_encoder.enc_style.parameters(), lr=5e-4, betas=(0.5, 0.9)) cont_opt = optim.Adam(cont_style_encoder.enc_content.parameters(), lr=2e-5, betas=(0.5, 0.9)) g_transforms: albumentations.DualTransform = albumentations.Compose([ MeasureToMask(size=256), ToNumpy(), NumpyBatch( albumentations.ElasticTransform(p=0.8, alpha=150, alpha_affine=1, sigma=10)), NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)), ToTensor(device), MaskToMeasure(size=256, padding=140), ]) W1 = Samples_Loss(scaling=0.85, p=1) # W2 = Samples_Loss(scaling=0.85, p=2) # g_trans_res_dict = g_transforms(image=test_img, mask=MaskToMeasure(size=256, padding=140).apply_to_mask(test_mask)) # g_trans_img = g_trans_res_dict['image'] # g_trans_mask = g_trans_res_dict['mask'] # iwm = imgs_with_mask(g_trans_img, g_trans_mask.toImage(256), color=[1, 1, 1]) # send_images_to_tensorboard(writer, iwm, "RT", 0) R_t = DualTransformRegularizer.__call__( g_transforms, lambda trans_dict, img: W1( content_to_measure( cont_style_encoder.get_content(trans_dict['image'])), trans_dict['mask']) # + # W2(content_to_measure(cont_style_encoder.get_content(trans_dict['image'])), trans_dict['mask']) ) R_s = UnoTransformRegularizer.__call__( g_transforms, lambda trans_dict, img, ltnt: L1("R_s") (ltnt, cont_style_encoder.enc_style(trans_dict['image']))) fabric = ProbabilityMeasureFabric(256) barycenter = fabric.load(f"{Paths.default.models()}/face_barycenter").cuda( ).padding(70).transpose().batch_repeat(16) R_b = BarycenterRegularizer.__call__(barycenter) # tuner = CoefTuner([4.5, 10.5, 2.5, 0.7, 0.5], device=device) # [6.5, 7.9, 2.7, 2.06, 5.4, 0.7, 2.04] # 3.3, 10.5, 6.2, 1.14, 10.88, 0.93, 2.6 # 4.3, 10.3, 5.9, 0.85, 10.1, 0.27, 4.5 # [4.53, 9.97, 5.5, 0.01, 9.44, 1.05, 4.9 tuner = GoldTuner([2.53, 40.97, 5.5, 0.01, 5.44, 1.05, 4.9], device=device, rule_eps=0.05, radius=1, active=False) gan_tuner = GoldTuner([20, 25, 25], device=device, rule_eps=1, radius=20, active=False) # rb_tuner = GoldTuner([0.7, 1.5, 10], device=device, rule_eps=0.02, radius=0.5) best_igor = 100 for idx in pbar: i = idx + args.start_iter counter.update(i) if i > args.iter: print('Done!') break real_img = next(loader) real_img = real_img.to(device) img_content = cont_style_encoder.get_content(real_img) noise = mixing_noise(args.batch, args.latent, args.mixing, device) img_content_variable = img_content.detach().requires_grad_(True) fake, fake_latent = generator(img_content_variable, noise, return_latents=True) model.discriminator_train([real_img], [fake], img_content) # fake_detach = fake.detach() fake_latent_test = fake_latent[:, [0, 13], :].detach() fake_content_pred = cont_style_encoder.get_content(fake) fake_latent_pred = cont_style_encoder.enc_style(fake) (writable("Generator loss", model.generator_loss)( [real_img], [fake], [fake_latent], img_content_variable) + # 3e-5 gan_tuner.sum_losses([ L1("L1 content gan")(fake_content_pred, img_content.detach()), L1("L1 style gan")(fake_latent_pred, fake_latent_test), R_s(fake.detach(), fake_latent_pred), ]) # L1("L1 content gan")(fake_content_pred, img_content.detach()) * 50 + # 3e-7 # L1("L1 style gan")(fake_latent_pred, fake_latent_test) * 10 + # 8e-7 # R_s(fake, barycenter) * 20 ).minimize_step(model.optimizer.opt_min, style_opt) if i % 5 == 0: # fake_latent_pred = cont_style_encoder.enc_style(fake_detach) # (L1("L1 style gan")(fake_latent_pred, fake_latent_test)).__mul__(2).minimize_step(style_opt) img_latent = cont_style_encoder.enc_style(real_img[:16]) restored = model.generator.module.decode(img_content[:16], img_latent[:16]) pred_measures: ProbabilityMeasure = content_to_measure( img_content[:16]) noise1 = mixing_noise(16, args.latent, args.mixing, device) noise2 = mixing_noise(16, args.latent, args.mixing, device) fake1, _ = generator(img_content[:16], noise1) fake2, _ = generator(img_content[:16], noise2) cont_fake1 = cont_style_encoder.get_content(fake1) cont_fake2 = cont_style_encoder.get_content(fake2) # rb_coefs = rb_tuner.get_coef() # R_b = BarycenterRegularizer.__call__(barycenter, rb_coefs[0], rb_coefs[1], rb_coefs[2]) #TUNER PART tuner.sum_losses([ # writable("Fake-content D", model.loss.generator_loss)(real=None, fake=[fake1, img_content.detach()]), # 1e-3 writable("Real-content D", model.loss.generator_loss) (real=None, fake=[real_img, img_content]), # 3e-5 writable("R_b", R_b.__call__)(real_img[:16], pred_measures), # 7e-5 writable("R_t", R_t.__call__)(real_img[:16], pred_measures), # - L1("L1 content between fake")(cont_fake1, cont_fake2), # 1e-6 L1("L1 image")(restored, real_img[:16]), # 4e-5 R_s(real_img[:16], img_latent), L1("L1 style restored")(cont_style_encoder.enc_style(restored), img_latent.detach()) ]).minimize_step(cont_opt, model.optimizer.opt_min, style_opt) ##Without tuner part # ( # model.loss.generator_loss(real=None, fake=[real_img, img_content]) * 5 + # (R_b + R_t * 0.4)(real_img, pred_measures) * 10 + # L1("L1 content between fake")(cont_fake1, cont_fake2) * 1 + # L1("L1 image")(restored, real_img) * 1 # # L1("L1 style gan")(fake_latent_pred, fake_latent_test) * 1 # ).minimize_step( # cont_opt, # model.optimizer.opt_min # ) if i % 100 == 0: print(i) with torch.no_grad(): content, latent = cont_style_encoder(test_img) pred_measures: ProbabilityMeasure = content_to_measure(content) # ref_measures: ProbabilityMeasure = MaskToMeasure(size=256, padding=140).apply_to_mask(test_mask) # iwm = imgs_with_mask(test_img, ref_measures.toImage(256), color=[0, 0, 1]) iwm = imgs_with_mask(test_img, pred_measures.toImage(256), color=[1, 1, 1]) send_images_to_tensorboard(writer, iwm, "REAL", i) fake_img, _ = generator(content, [sample_z]) iwm = imgs_with_mask(fake_img, pred_measures.toImage(256)) send_images_to_tensorboard(writer, iwm, "FAKE", i) restored = model.generator.module.decode(content, latent) send_images_to_tensorboard(writer, restored, "RESTORED", i) if i % 100 == 0 and i > 0: pass # with torch.no_grad(): # igor = test(cont_style_encoder, test_pairs) # writer.add_scalar("test error", igor, i) # tuner.update(igor) # gan_tuner.update(igor) # # rb_tuner.update(igor) # # if igor < best_igor: # best_igor = igor # print("best igor") # torch.save( # { # 'g': generator.state_dict(), # 'd': discriminator.state_dict(), # 'enc': cont_style_encoder.state_dict(), # }, # f'{Paths.default.nn()}/stylegan2_igor_3.pt', # ) if i % 10000 == 0 and i > 0: torch.save( { 'g': generator.module.state_dict(), 'd': discriminator.module.state_dict(), 'enc': cont_style_encoder.state_dict(), # 'g_ema': g_ema.state_dict(), # 'g_optim': g_optim.state_dict(), # 'd_optim': d_optim.state_dict(), }, f'{Paths.default.models()}/stylegan2_invertable_{str(i + starting_model_number).zfill(6)}.pt', )
err_pred_list_2 = [] err_bc_list = [] for i in range(30): test_img, test_mask = next(loader) test_img = test_img.cuda() content = cont_style_encoder.get_content(test_img) pred_measures: ProbabilityMeasure = content_to_measure(content) content2 = cont_style_encoder2.get_content(test_img) pred_measures2: ProbabilityMeasure = content_to_measure(content2) ref_measure = MaskToMeasure(size=256, padding=140, clusterize=True)(image=test_img, mask=test_mask)["mask"].cuda() err_pred = Samples_Loss(p=1)(pred_measures, ref_measure).item() err_pred_2 = Samples_Loss(p=1)(pred_measures2, ref_measure).item() err_bc = Samples_Loss(p=1)(barycenter, ref_measure).item() print("pred:", err_pred) print("bc:", err_bc) err_pred_list.append(err_pred) err_pred_list_2.append(err_pred_2) err_bc_list.append(err_bc) # %% print("pred mean:", sum(err_pred_list) / len(err_pred_list)) print("pred mean 2:", sum(err_pred_list_2) / len(err_pred_list_2)) print("bc mean:", sum(err_bc_list) / len(err_bc_list))
fabric = ProbabilityMeasureFabric(image_size) barycenter = fabric.load("../examples/face_barycenter").cuda().crop( measure_size) barycenter = fabric.cat([barycenter for b in range(batch_size)]) for i, (imgs, masks) in enumerate(dataloader, 0): imgs = imgs.cuda() measures: ProbabilityMeasure = fabric.from_coord_tensor( masks).cuda().padding(measure_size) t1 = time.time() with torch.no_grad(): A, T = LinearTransformOT.forward(measures, barycenter) t2 = time.time() dist = Samples_Loss().forward(measures, barycenter) t3 = time.time() print(dist, t2 - t1, t3 - t2) # m_lin = measures.centered().multiply(A) + barycenter.mean() # plt.scatter(m_lin.coord[0, :, 1].cpu().numpy(), m_lin.coord[0, :, 0].cpu().numpy()) # plt.scatter(barycenter.coord[0, :, 1].cpu().numpy(), barycenter.coord[0, :, 0].cpu().numpy()) # plt.show() Atest = torch.tensor([[3, 0.2], [0, 1]], device=device, dtype=torch.float32) Atest = torch.cat([Atest[None, ]] * batch_size) bc_tr = barycenter.random_permute().multiply(Atest) + 0.1
def optimization_step(): noise = NormalNoise(n_noise, device) measure2image = ResMeasureToImage(args.measure_size * 3 + noise.size(), args.image_size, ngf).cuda() netD = DCDiscriminator(ndf=ndf).cuda() gan_model = GANModel(measure2image, HingeLoss(netD).add_generator_loss(nn.L1Loss(), L1), lr=0.0004) fabric = ProbabilityMeasureFabric(args.image_size) barycenter = fabric.load("barycenter").cuda().padding(args.measure_size) print(barycenter.coord.shape) barycenter = fabric.cat([barycenter for b in range(args.batch_size)]) print(barycenter.coord.shape) image2measure = ResImageToMeasure(args.measure_size).cuda() image2measure_opt = optim.Adam(image2measure.parameters(), lr=0.0002) def test(): dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=40, num_workers=20) sum_loss = 0 for i, (imgs, masks) in enumerate(dataloader_test, 0): imgs = imgs.cuda().type(torch.float32) pred_measures: ProbabilityMeasure = image2measure(imgs) ref_measures: ProbabilityMeasure = fabric.from_mask( masks).cuda().padding(args.measure_size) ref_loss = Samples_Loss()(pred_measures, ref_measures) sum_loss += ref_loss.item() return sum_loss for epoch in range(20): ot_iters = 100 print("epoch", epoch) test_imgs = None for i, imgs in enumerate(dataloader, 0): imgs = imgs.cuda().type(torch.float32) test_imgs = imgs pred_measures: ProbabilityMeasure = image2measure(imgs) cond = pred_measures.toChannels() n = cond.shape[0] barycenter_batch = barycenter.slice(0, n) z = noise.sample(n) cond = torch.cat((cond, z), dim=1) gan_model.train(imgs, cond.detach()) with torch.no_grad(): A, T = LinearTransformOT.forward(pred_measures, barycenter_batch, ot_iters) bc_loss_T = Samples_Loss()(pred_measures, pred_measures.detach() + T) bc_loss_A = Samples_Loss()( pred_measures.centered(), pred_measures.centered().multiply(A).detach()) bc_loss_W = Samples_Loss()(pred_measures.centered().multiply(A), barycenter_batch.centered()) bc_loss = bc_loss_W * cw + bc_loss_A * ca + bc_loss_T * ct fake = measure2image(cond) g_loss = gan_model.generator_loss(imgs, fake) (g_loss + bc_loss).minimize_step(image2measure_opt) return test()
cond_gan_model.loss += GANLossObject( lambda dx, dy: Loss.ZERO(), lambda dgz, real, fake: Loss( nn.L1Loss()(image2measure(fake[0]).coord, fabric.from_channels(real[1]).coord.detach()) ) * 10, None ) image2measure = ResImageToMeasure(args.measure_size).cuda() image2measure_opt = optim.Adam(image2measure.parameters(), lr=0.0003) R_b = BarycenterRegularizer.__call__(barycenter) R_t = DualTransformRegularizer.__call__( g_transforms, lambda trans_dict: Samples_Loss()(image2measure(trans_dict['image']), trans_dict['mask']) ) deform_array = list(np.linspace(0, 6, 1000)) Whole_Reg = R_t @ deform_array + R_b for epoch in range(500): # if epoch > 0: # cond_gan_model.optimizer.update_lr(0.5) # for i in image2measure_opt.param_groups: # i['lr'] *= 0.5 print("epoch", epoch) for i, (imgs, masks) in enumerate(dataloader, 0): if imgs.shape[0] != args.batch_size: continue
fabric = ProbabilityMeasureFabric(args.image_size) barycenter = fabric.load("/home/ibespalov/unsupervised_pattern_segmentation/examples/face_barycenter").cuda().padding(args.measure_size).batch_repeat(args.batch_size) g_transforms: albumentations.DualTransform = albumentations.Compose([ MeasureToMask(size=256), ToNumpy(), NumpyBatch(albumentations.ElasticTransform(p=0.5, alpha=150, alpha_affine=1, sigma=10)), NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)), ToTensor(device), MaskToMeasure(size=256, padding=args.measure_size), ]) R_b = BarycenterRegularizer.__call__(barycenter) R_t = DualTransformRegularizer.__call__( g_transforms, lambda trans_dict: Samples_Loss(scaling=0.85, p=1)(content_to_measure(cont_style_encoder(trans_dict['image'])[0]), trans_dict['mask']) ) R_b.forward = send_to_tensorboard("R_b", counter=counter, writer=writer)(R_b.forward) R_t.forward = send_to_tensorboard("R_t", counter=counter, writer=writer)(R_t.forward) deform_array = list(np.linspace(0, 1, 1500)) Whole_Reg = R_t @ deform_array + R_b l1_loss = nn.L1Loss() def L1(name: Optional[str], writer: SummaryWriter = writer) -> Callable[[Tensor, Tensor], Loss]: if name: counter.active[name] = True
MeasureToMask(size=256), ToNumpy(), NumpyBatch( albumentations.ElasticTransform(p=0.5, alpha=150, alpha_affine=1, sigma=10)), NumpyBatch(albumentations.ShiftScaleRotate(p=0.5, rotate_limit=10)), ToTensor(device), MaskToMeasure(size=256, padding=args.measure_size), ]) R_b = BarycenterRegularizer.__call__(barycenter) R_t = DualTransformRegularizer.__call__( g_transforms, lambda trans_dict: Samples_Loss(scaling=0.85, p=1)(content_to_measure( cont_style_encoder(trans_dict['image'])[0]), trans_dict['mask'])) R_b.forward = send_to_tensorboard("R_b", counter=counter, writer=writer)(R_b.forward) R_t.forward = send_to_tensorboard("R_t", counter=counter, writer=writer)(R_t.forward) deform_array = list(np.linspace(0, 1, 1500)) Whole_Reg = R_t @ deform_array + R_b l1_loss = nn.L1Loss() def L1(name: Optional[str], writer: SummaryWriter = writer) -> Callable[[Tensor, Tensor], Loss]: if name:
# cond_gan_model.loss += GANLossObject( # lambda dx, dy: Loss.ZERO(), # lambda dgz, real, fake: Loss( # nn.L1Loss()(image2measure(fake[0]).coord, fabric.from_channels(real[1]).coord.detach()) # ) * 10, # None # ) image2measure = ResImageToMeasure(args.measure_size).cuda() image2measure_opt_strong = optim.Adam(image2measure.parameters(), lr=0.0001) image2measure_opt = optim.Adam(image2measure.parameters(), lr=0.0003) R_b = BarycenterRegularizer.__call__(barycenter) R_t = DualTransformRegularizer.__call__( g_transforms, lambda trans_dict: Samples_Loss() (image2measure(trans_dict['image']), trans_dict['mask'])) deform_array = list(np.linspace(0, 6, 1000)) Whole_Reg = R_t @ deform_array + R_b for epoch in range(500): # if epoch > 0: # cond_gan_model.optimizer.update_lr(0.5) # for i in image2measure_opt.param_groups: # i['lr'] *= 0.5 print("epoch", epoch) for i, (imgs, masks) in enumerate(dataloader, 0): if imgs.shape[0] != args.batch_size: continue
encoder_HG = HG_softmax2020(num_classes=68, heatmap_size=64) encoder_HG.load_state_dict( torch.load(f"{Paths.default.models()}/hg2_e29.pt", map_location="cpu")) encoder_HG = encoder_HG.cuda() for iter in range(3000): img = next(LazyLoader.celeba().loader).cuda() content = encoder_HG(img) coord, p = heatmap_to_measure(content) mes = ProbabilityMeasure(p, coord) barycenter_cat = fabric.cat([barycenter] * batch_size) loss = Samples_Loss()(barycenter_cat, mes) opt.zero_grad() loss.to_tensor().backward() opt.step() barycenter.probability.data = barycenter.probability.relu().data barycenter.probability.data /= barycenter.probability.sum(dim=1, keepdim=True) if iter % 100 == 0: print(iter, loss.item()) plt.imshow(barycenter.toImage(200)[0][0].detach().cpu().numpy()) plt.show()