def liuboff(encoder: nn.Module): sum_loss = 0 for i, batch in enumerate(LazyLoader.mafl().test_loader): data = batch['data'].to(device) landmarks = batch["meta"]["keypts_normalized"].cuda().type(dtype=torch.float32) landmarks[landmarks > 1] = 0.99999 # content = heatmap_to_measure(encoder(data))[0] pred_measure = UniformMeasure2DFactory.from_heatmap(encoder(data)) target = UniformMeasure2D01(torch.clamp(landmarks, max=1)) eye_dist = landmarks[:, 1] - landmarks[:, 0] eye_dist = eye_dist.pow(2).sum(dim=1).sqrt() # w1_loss = (handmadew1(pred_measure, target) / eye_dist).sum().item() # l1_loss = ((pred_measure.coord - target.coord).pow(2).sum(dim=2).sqrt().mean(dim=1) / eye_dist).sum().item() # print(w1_loss, l1_loss) sum_loss += ((pred_measure.coord - target.coord).pow(2).sum(dim=2).sqrt().mean(dim=1) / eye_dist).sum().item() return sum_loss / len(LazyLoader.mafl().test_dataset)
def liuboffMAFL(encoder: nn.Module): sum_loss = 0 for i, batch in enumerate(LazyLoader.mafl().test_loader): data = batch['data'].cuda() landmarks = batch["meta"]["keypts_normalized"].cuda() landmarks[landmarks > 1] = 0.99999 pred_measure = UniformMeasure2D01(encoder(data)["coords"]) target = UniformMeasure2D01(torch.clamp(landmarks, max=1)) eye_dist = landmarks[:, 1] - landmarks[:, 0] eye_dist = eye_dist.pow(2).sum(dim=1).sqrt() sum_loss += (handmadew1(pred_measure, target, 0.005) / eye_dist).sum().item() return sum_loss / len(LazyLoader.mafl().test_dataset)
def train(generator, decoder, discriminator, encoder_HG, style_encoder, device, starting_model_number): latent_size = 512 batch_size = 12 sample_z = torch.randn(8, latent_size, device=device) MAFL.batch_size = batch_size MAFL.test_batch_size = 64 Celeba.batch_size = batch_size test_img = next(LazyLoader.mafl().loader_train_inf)["data"][:8].cuda() loss_st: StyleGANLoss = StyleGANLoss(discriminator) model = CondStyleGanModel(generator, loss_st, (0.001, 0.0015)) style_opt = optim.Adam(style_encoder.parameters(), lr=5e-4, betas=(0.9, 0.99)) cont_opt = optim.Adam(encoder_HG.parameters(), lr=2e-5, betas=(0.5, 0.97)) g_transforms: albumentations.DualTransform = albumentations.Compose([ ToNumpy(), NumpyBatch(albumentations.Compose([ ResizeMask(h=256, w=256), albumentations.ElasticTransform(p=0.7, alpha=150, alpha_affine=1, sigma=10), albumentations.ShiftScaleRotate(p=0.7, rotate_limit=15), ResizeMask(h=64, w=64), NormalizeMask(dim=(0, 1, 2)) ])), ToTensor(device), ]) R_t = DualTransformRegularizer.__call__( g_transforms, lambda trans_dict, img: # rt_loss(encoder_HG(trans_dict['image']), trans_dict['mask']) stariy_hm_loss(encoder_HG(trans_dict['image']), trans_dict['mask']) ) R_s = UnoTransformRegularizer.__call__( g_transforms, lambda trans_dict, img, ltnt: L1("R_s")(ltnt, style_encoder(trans_dict['image'])) ) barycenter: UniformMeasure2D01 = UniformMeasure2DFactory.load( f"{Paths.default.models()}/face_barycenter_5").cuda().batch_repeat(batch_size) R_b = BarycenterRegularizer.__call__(barycenter, 1.0, 2.0, 4.0) tuner = GoldTuner([0.37, 1.55, 0.9393, 0.1264, 1.7687, 0.8648, 1.8609], device=device, rule_eps=0.01 / 2, radius=0.1, active=True) heatmaper = ToGaussHeatMap(64, 1.0) sparse_bc = heatmaper.forward(barycenter.coord * 63) sparse_bc = nn.Upsample(scale_factor=4)(sparse_bc).sum(dim=1, keepdim=True).repeat(1, 3, 1, 1) * \ torch.tensor([1.0, 1.0, 0.0], device=device).view(1, 3, 1, 1) sparse_bc = (sparse_bc - sparse_bc.min()) / sparse_bc.max() send_images_to_tensorboard(writer, sparse_bc, "BC", 0, normalize=False, range=(0, 1)) trainer_gan = gan_trainer(model, generator, decoder, encoder_HG, style_encoder, R_s, style_opt, heatmaper, g_transforms) content_trainer = content_trainer_with_gan(cont_opt, tuner, heatmaper, encoder_HG, R_b, R_t, model, generator, g_transforms) supervise_trainer = content_trainer_supervised(cont_opt, encoder_HG, LazyLoader.mafl().loader_train_inf) for i in range(100000): counter.update(i) requires_grad(encoder_HG, False) # REMOVE BEFORE TRAINING real_img = next(LazyLoader.mafl().loader_train_inf)["data"].to(device) \ if i % 5 == 0 else next(LazyLoader.celeba().loader).to(device) img_content = encoder_HG(real_img) pred_measures: UniformMeasure2D01 = UniformMeasure2DFactory.from_heatmap(img_content) sparse_hm = heatmaper.forward(pred_measures.coord * 63).detach() trainer_gan(i, real_img, pred_measures.detach(), sparse_hm.detach(), apply_g=False) supervise_trainer() if i % 4 == 0: # real_img = next(LazyLoader.mafl().loader_train_inf)["data"].to(device) trainer_gan(i, real_img, pred_measures.detach(), sparse_hm.detach(), apply_g=True) content_trainer(real_img) if i % 100 == 0: coefs = json.load(open("../parameters/content_loss.json")) print(i, coefs) with torch.no_grad(): # pred_measures_test, sparse_hm_test = encoder_HG(test_img) content_test = encoder_HG(test_img) pred_measures_test: UniformMeasure2D01 = UniformMeasure2DFactory.from_heatmap(content_test) heatmaper_256 = ToGaussHeatMap(256, 2.0) sparse_hm_test = heatmaper.forward(pred_measures_test.coord * 63) sparse_hm_test_1 = heatmaper_256.forward(pred_measures_test.coord * 255) latent_test = style_encoder(test_img) sparce_mask = sparse_hm_test_1.sum(dim=1, keepdim=True) sparce_mask[sparce_mask < 0.0003] = 0 iwm = imgs_with_mask(test_img, sparce_mask) send_images_to_tensorboard(writer, iwm, "REAL", i) fake_img, _ = generator(sparse_hm_test, [sample_z]) iwm = imgs_with_mask(fake_img, pred_measures_test.toImage(256)) send_images_to_tensorboard(writer, iwm, "FAKE", i) restored = decoder(sparse_hm_test, latent_test) iwm = imgs_with_mask(restored, pred_measures_test.toImage(256)) send_images_to_tensorboard(writer, iwm, "RESTORED", i) content_test_256 = nn.Upsample(scale_factor=4)(sparse_hm_test).sum(dim=1, keepdim=True).repeat(1, 3, 1, 1) * \ torch.tensor([1.0, 1.0, 0.0], device=device).view(1, 3, 1, 1) content_test_256 = (content_test_256 - content_test_256.min()) / content_test_256.max() send_images_to_tensorboard(writer, content_test_256, "HM", i, normalize=False, range=(0, 1)) if i % 50 == 0 and i >= 0: test_loss = liuboff(encoder_HG) # test_loss = nadbka(encoder_HG) tuner.update(test_loss) writer.add_scalar("liuboff", test_loss, i) if i % 10000 == 0 and i > 0: torch.save( { 'g': generator.module.state_dict(), 'd': discriminator.module.state_dict(), 'c': encoder_HG.module.state_dict(), "s": style_encoder.state_dict() }, f'{Paths.default.models()}/stylegan2_MAFL_{str(i + starting_model_number).zfill(6)}.pt', )
def train(generator, decoder, discriminator, encoder_HG, style_encoder, device, starting_model_number): latent_size = 512 batch_size = 8 sample_z = torch.randn(8, latent_size, device=device) Celeba.batch_size = batch_size W300DatasetLoader.batch_size = batch_size W300DatasetLoader.test_batch_size = 32 test_img = next(LazyLoader.mafl().loader_train_inf)["data"][:8].cuda() model = CondStyleGanModel(generator, StyleGANLoss(discriminator), (0.001 / 4, 0.0015 / 4)) style_opt = optim.Adam(style_encoder.parameters(), lr=5e-4, betas=(0.9, 0.99)) cont_opt = optim.Adam(encoder_HG.parameters(), lr=3e-5, betas=(0.5, 0.97)) g_transforms: albumentations.DualTransform = albumentations.Compose([ ToNumpy(), NumpyBatch( albumentations.Compose([ albumentations.ElasticTransform(p=0.7, alpha=150, alpha_affine=1, sigma=10), albumentations.ShiftScaleRotate(p=0.9, rotate_limit=15), ])), ToTensor(device), ]) g_transforms_without_norm: albumentations.DualTransform = albumentations.Compose( [ ToNumpy(), NumpyBatch( albumentations.Compose([ albumentations.ElasticTransform(p=0.3, alpha=150, alpha_affine=1, sigma=10), albumentations.ShiftScaleRotate(p=0.7, rotate_limit=15), ])), ToTensor(device), ]) R_t = DualTransformRegularizer.__call__( g_transforms, lambda trans_dict, img: coord_hm_loss( encoder_HG(trans_dict['image'])["coords"], trans_dict['mask'])) R_s = UnoTransformRegularizer.__call__( g_transforms, lambda trans_dict, img, ltnt: WR.L1("R_s") (ltnt, style_encoder(trans_dict['image']))) barycenter: UniformMeasure2D01 = UniformMeasure2DFactory.load( f"{Paths.default.models()}/face_barycenter_5").cuda().batch_repeat( batch_size) R_b = BarycenterRegularizer.__call__(barycenter, 1.0, 2.0, 4.0) tuner = GoldTuner([0.37, 2.78, 0.58, 1.43, 3.23], device=device, rule_eps=0.001, radius=0.3, active=False) trainer_gan = gan_trainer(model, generator, decoder, encoder_HG, style_encoder, R_s, style_opt, g_transforms) content_trainer = content_trainer_with_gan(cont_opt, tuner, encoder_HG, R_b, R_t, model, generator, g_transforms, decoder, style_encoder) # supervise_trainer = content_trainer_supervised(cont_opt, encoder_HG, LazyLoader.w300().loader_train_inf) for i in range(100000): WR.counter.update(i) requires_grad(encoder_HG, False) real_img = next(LazyLoader.mafl().loader_train_inf)["data"].to(device) encoded = encoder_HG(real_img) internal_content = encoded["skeleton"].detach() # trainer_gan(i, real_img, internal_content) # content_trainer(real_img) train_content(cont_opt, R_b, R_t, real_img, model, encoder_HG, decoder, generator, style_encoder) # supervise_trainer() encoder_ema.accumulate(encoder_HG.module, i, 0.97) if i % 50 == 0 and i > 0: encoder_ema.write_to(encoder_HG.module) if i % 100 == 0: coefs = json.load(open("../parameters/content_loss.json")) print(i, coefs) with torch.no_grad(): # pred_measures_test, sparse_hm_test = encoder_HG(test_img) encoded_test = encoder_HG(test_img) pred_measures_test: UniformMeasure2D01 = UniformMeasure2D01( encoded_test["coords"]) heatmaper_256 = ToGaussHeatMap(256, 1.0) sparse_hm_test_1 = heatmaper_256.forward( pred_measures_test.coord) latent_test = style_encoder(test_img) sparce_mask = sparse_hm_test_1.sum(dim=1, keepdim=True) sparce_mask[sparce_mask < 0.0003] = 0 iwm = imgs_with_mask(test_img, sparce_mask) send_images_to_tensorboard(WR.writer, iwm, "REAL", i) fake_img, _ = generator(encoded_test["skeleton"], [sample_z]) iwm = imgs_with_mask(fake_img, pred_measures_test.toImage(256)) send_images_to_tensorboard(WR.writer, iwm, "FAKE", i) restored = decoder(encoded_test["skeleton"], latent_test) iwm = imgs_with_mask(restored, pred_measures_test.toImage(256)) send_images_to_tensorboard(WR.writer, iwm, "RESTORED", i) content_test_256 = (encoded_test["skeleton"]).repeat(1, 3, 1, 1) * \ torch.tensor([1.0, 1.0, 0.0], device=device).view(1, 3, 1, 1) content_test_256 = (content_test_256 - content_test_256.min() ) / content_test_256.max() send_images_to_tensorboard(WR.writer, content_test_256, "HM", i, normalize=False, range=(0, 1)) if i % 50 == 0 and i >= 0: test_loss = liuboffMAFL(encoder_HG) print("liuboff", test_loss) # test_loss = nadbka(encoder_HG) tuner.update(test_loss) WR.writer.add_scalar("liuboff", test_loss, i)