def space_loss(imgs1, imgs2, image_space=True, lpips_model=None): loss_mse = torch.nn.MSELoss() loss_kl = torch.nn.KLDivLoss() ssim_loss = pytorch_ssim.SSIM() loss_lpips = lpips_model loss_imgs_mse_1 = loss_mse(imgs1, imgs2) loss_imgs_mse_2 = loss_mse(imgs1.mean(), imgs2.mean()) loss_imgs_mse_3 = loss_mse(imgs1.std(), imgs2.std()) loss_imgs_mse = loss_imgs_mse_1 + loss_imgs_mse_2 + loss_imgs_mse_3 imgs1_kl, imgs2_kl = torch.nn.functional.softmax( imgs1), torch.nn.functional.softmax(imgs2) loss_imgs_kl = loss_kl(torch.log(imgs2_kl), imgs1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_imgs_kl = torch.where(torch.isnan(loss_imgs_kl), torch.full_like(loss_imgs_kl, 0), loss_imgs_kl) loss_imgs_kl = torch.where(torch.isinf(loss_imgs_kl), torch.full_like(loss_imgs_kl, 1), loss_imgs_kl) imgs1_cos = imgs1.view(-1) imgs2_cos = imgs2.view(-1) loss_imgs_cosine = 1 - imgs1_cos.dot(imgs2_cos) / ( torch.sqrt(imgs1_cos.dot(imgs1_cos)) * torch.sqrt(imgs2_cos.dot(imgs2_cos))) #[-1,1],-1:反向相反,1:方向相同 if image_space: ssim_value = pytorch_ssim.ssim(imgs1, imgs2) # while ssim_value<0.999: loss_imgs_ssim = 1 - ssim_loss(imgs1, imgs2) else: loss_imgs_ssim = torch.tensor(0) if image_space: loss_imgs_lpips = loss_lpips(imgs1, imgs2).mean() else: loss_imgs_lpips = torch.tensor(0) loss_imgs = loss_imgs_mse + loss_imgs_kl + loss_imgs_cosine + loss_imgs_ssim + loss_imgs_lpips loss_info = [[ loss_imgs_mse_1.item(), loss_imgs_mse_2.item(), loss_imgs_mse_3.item() ], loss_imgs_kl.item(), loss_imgs_cosine.item(), loss_imgs_ssim.item(), loss_imgs_lpips.item()] return loss_imgs, loss_info
def train(avg_tensor=None, coefs=0, tensor_writer=None): Gs = Generator(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) # 32->512 layer_count=8 / 64->256 layer_count=7 Gs.load_state_dict(torch.load('./pre-model/cat/cat256_Gs_dict.pth')) Gm = Mapping(num_layers=14, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 Gm.load_state_dict(torch.load('./pre-model/cat/cat256_Gm_dict.pth')) Gm.buffer1 = avg_tensor E = BE.BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) E.load_state_dict( torch.load( '/_yucheng/myStyle/myStyle-v1/EAE-car-cat/result/EB_cat_cosine_v2/E_model_ep80000.pth' )) Gs.cuda() #Gm.cuda() E.cuda() const_ = Gs.const writer = tensor_writer E_optimizer = LREQAdam([ { 'params': E.parameters() }, ], lr=0.0015, betas=(0.0, 0.99), weight_decay=0) loss_mse = torch.nn.MSELoss() loss_lpips = lpips.LPIPS(net='vgg').to('cuda') loss_kl = torch.nn.KLDivLoss() ssim_loss = pytorch_ssim.SSIM() batch_size = 3 const1 = const_.repeat(batch_size, 1, 1, 1) vgg16 = torchvision.models.vgg16(pretrained=True).cuda() final_layer = None for name, m in vgg16.named_modules(): if isinstance(m, nn.Conv2d): final_layer = name grad_cam_plus_plus = GradCamPlusPlus(vgg16, final_layer) gbp = GuidedBackPropagation(vgg16) it_d = 0 for epoch in range(0, 250001): set_seed(epoch % 30000) latents = torch.randn(batch_size, 512) #[32, 512] with torch.no_grad(): #这里需要生成图片和变量 w1 = Gm(latents, coefs_m=coefs).to('cuda') #[batch_size,18,512] imgs1 = Gs.forward(w1, 6) # 7->512 / 6->256 const2, w2 = E(imgs1.cuda()) imgs2 = Gs.forward(w2, 6) E_optimizer.zero_grad() #Image Space mask_1 = grad_cam_plus_plus(imgs1, None) #[c,1,h,w] mask_2 = grad_cam_plus_plus(imgs2, None) #imgs1.retain_grad() #imgs2.retain_grad() imgs1_ = imgs1.detach().clone() imgs1_.requires_grad = True imgs2_ = imgs2.detach().clone() imgs2_.requires_grad = True grad1 = gbp(imgs1_) # [n,c,h,w] grad2 = gbp(imgs2_) #Mask_Cam mask_1 = mask_1.cuda().float() mask_1.requires_grad = True mask_2 = mask_2.cuda().float() mask_2.requires_grad = True loss_mask_mse_1 = loss_mse(mask_1, mask_2) loss_mask_mse_2 = loss_mse(mask_1.mean(), mask_2.mean()) loss_mask_mse_3 = loss_mse(mask_1.std(), mask_2.std()) loss_mask_mse = loss_mask_mse_1 + loss_mask_mse_2 + loss_mask_mse_3 ssim_value = pytorch_ssim.ssim(mask_1, mask_2) # while ssim_value<0.999: loss_mask_ssim = 1 - ssim_loss(mask_1, mask_2) loss_mask_lpips = loss_lpips(mask_1, mask_2).mean() mask1_kl, mask2_kl = torch.nn.functional.softmax( mask_1), torch.nn.functional.softmax(mask_2) loss_kl_mask = loss_kl(torch.log(mask2_kl), mask1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_kl_mask = torch.where(torch.isnan(loss_kl_w), torch.full_like(loss_kl_w, 0), loss_kl_w) loss_kl_mask = torch.where(torch.isinf(loss_kl_w), torch.full_like(loss_kl_w, 1), loss_kl_w) mask1_cos = mask1.view(-1) mask2_cos = mask2.view(-1) loss_cosine_w = 1 - mask1_cos.dot(mask2_cos) / ( torch.sqrt(mask1_cos.dot(mask1_cos)) * torch.sqrt(mask2_cos.dot(mask2_cos))) #[-1,1],-1:反向相反,1:方向相同 loss_mask = loss_mask_mse + loss_mask_ssim + loss_mask_lpips + loss_kl_mask + loss_cosine_w E_optimizer.zero_grad() loss_mask.backward(retain_graph=True) E_optimizer.step() #Grad grad1 = grad1.cuda().float() grad1.requires_grad = True grad2 = grad2.cuda().float() grad2.requires_grad = True loss_grad_mse = loss_mse(grad1, grad2) E_optimizer.zero_grad() loss_grad_mse.backward(retain_graph=True) E_optimizer.step() ssim_value = pytorch_ssim.ssim(grad1, grad2) # while ssim_value<0.999: loss_grad_ssim = 1 - ssim_loss(grad1, grad2) E_optimizer.zero_grad() loss_grad_ssim.backward(retain_graph=True) E_optimizer.step() loss_grad_lpips = loss_lpips(grad1, grad2).mean() E_optimizer.zero_grad() loss_grad_lpips.backward(retain_graph=True) E_optimizer.step() grad1 = grad1.cuda().float() grad1.requires_grad = True grad2 = grad2.cuda().float() grad2.requires_grad = True loss_grad_mse_1 = loss_mse(grad1, grad2) loss_grad_mse_2 = loss_mse(grad1.mean(), grad2.mean()) loss_grad_mse_3 = loss_mse(mask_1.std(), mask_2.std()) loss_mask_mse = loss_mask_mse_1 + loss_mask_mse_2 + loss_mask_mse_3 ssim_value = pytorch_ssim.ssim(mask_1, mask_2) # while ssim_value<0.999: loss_mask_ssim = 1 - ssim_loss(mask_1, mask_2) loss_mask_lpips = loss_lpips(mask_1, mask_2).mean() mask1_kl, mask2_kl = torch.nn.functional.softmax( mask_1), torch.nn.functional.softmax(mask_2) loss_kl_mask = loss_kl(torch.log(mask2_kl), mask1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_kl_mask = torch.where(torch.isnan(loss_kl_w), torch.full_like(loss_kl_w, 0), loss_kl_w) loss_kl_mask = torch.where(torch.isinf(loss_kl_w), torch.full_like(loss_kl_w, 1), loss_kl_w) mask1_cos = mask1.view(-1) mask2_cos = mask2.view(-1) loss_cosine_w = 1 - mask1_cos.dot(mask2_cos) / ( torch.sqrt(mask1_cos.dot(mask1_cos)) * torch.sqrt(mask2_cos.dot(mask2_cos))) #[-1,1],-1:反向相反,1:方向相同 loss_mask = loss_mask_mse + loss_mask_ssim + loss_mask_lpips + loss_kl_mask + loss_cosine_w E_optimizer.zero_grad() loss_mask.backward(retain_graph=True) E_optimizer.step() #Image loss_img_mse = loss_mse(imgs1, imgs2) E_optimizer.zero_grad() loss_img_mse.backward(retain_graph=True) E_optimizer.step() ssim_value = pytorch_ssim.ssim(imgs1, imgs2) # while ssim_value<0.999: loss_img_ssim = 1 - ssim_loss(imgs1, imgs2) E_optimizer.zero_grad() loss_img_ssim.backward(retain_graph=True) E_optimizer.step() loss_img_lpips = loss_lpips(imgs1, imgs2).mean() E_optimizer.zero_grad() loss_img_lpips.backward(retain_graph=True) E_optimizer.step() #Latent Space # W loss_w = loss_mse(w1, w2) loss_w_m = loss_mse(w1.mean(), w2.mean()) #初期一会很大10,一会很小0.0001 loss_w_s = loss_mse(w1.std(), w2.std()) #后期一会很大,一会很小 w1_kl, w2_kl = torch.nn.functional.softmax( w1), torch.nn.functional.softmax(w2) loss_kl_w = loss_kl(torch.log(w2_kl), w1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_kl_w = torch.where(torch.isnan(loss_kl_w), torch.full_like(loss_kl_w, 0), loss_kl_w) loss_kl_w = torch.where(torch.isinf(loss_kl_w), torch.full_like(loss_kl_w, 1), loss_kl_w) w1_cos = w1.view(-1) w2_cos = w2.view(-1) loss_cosine_w = 1 - w1_cos.dot(w2_cos) / ( torch.sqrt(w1_cos.dot(w1_cos)) * torch.sqrt(w2_cos.dot(w2_cos)) ) #[-1,1],-1:反向相反,1:方向相同 # C loss_c = loss_mse( const1, const2) #没有这个const,梯度起初没法快速下降,很可能无法收敛, 这个惩罚即乘0.1后,效果大幅提升! loss_c_m = loss_mse(const1.mean(), const2.mean()) loss_c_s = loss_mse(const1.std(), const2.std()) y1, y2 = torch.nn.functional.softmax( const1), torch.nn.functional.softmax(const2) loss_kl_c = loss_kl(torch.log(y2), y1) loss_kl_c = torch.where(torch.isnan(loss_kl_c), torch.full_like(loss_kl_c, 0), loss_kl_c) loss_kl_c = torch.where(torch.isinf(loss_kl_c), torch.full_like(loss_kl_c, 1), loss_kl_c) c_cos1 = const1.view(-1) c_cos2 = const2.view(-1) loss_cosine_c = 1 - c_cos1.dot(c_cos2) / ( torch.sqrt(c_cos1.dot(c_cos1)) * torch.sqrt(c_cos2.dot(c_cos2))) loss_ls_all = loss_w+loss_w_m+loss_w_s+loss_kl_w+loss_cosine_w+\ loss_c+loss_c_m+loss_c_s+loss_kl_c+loss_cosine_c loss_ls_all.backward(retain_graph=True) E_optimizer.step() print('i_' + str(epoch)) print('---------ImageSpace--------') print('loss_mask_mse:' + str(loss_mask_mse.item()) + '--loss_mask_ssim:' + str(loss_mask_ssim.item()) + '--loss_mask_lpips:' + str(loss_mask_lpips.item())) print('loss_grad_mse:' + str(loss_grad_mse.item()) + '--loss_grad_ssim:' + str(loss_grad_ssim.item()) + '--loss_grad_lpips:' + str(loss_grad_lpips.item())) print('loss_img_mse:' + str(loss_img_mse.item()) + '--loss_img_ssim:' + str(loss_img_ssim.item()) + '--loss_img_lpips:' + str(loss_img_lpips.item())) print('---------LatentSpace--------') print('loss_w:' + str(loss_w.item()) + '--loss_w_m:' + str(loss_w_m.item()) + '--loss_w_s:' + str(loss_w_s.item())) print('loss_kl_w:' + str(loss_kl_w.item()) + '--loss_cosine_w:' + str(loss_cosine_w.item())) print('loss_c:' + str(loss_c.item()) + '--loss_c_m:' + str(loss_c_m.item()) + '--loss_c_s:' + str(loss_c_s.item())) print('loss_kl_c:' + str(loss_kl_c.item()) + '--loss_cosine_c:' + str(loss_cosine_c.item())) it_d += 1 writer.add_scalar('loss_mask_mse', loss_mask_mse, global_step=it_d) writer.add_scalar('loss_mask_ssim', loss_mask_ssim, global_step=it_d) writer.add_scalar('loss_mask_lpips', loss_mask_lpips, global_step=it_d) writer.add_scalar('loss_grad_mse', loss_grad_mse, global_step=it_d) writer.add_scalar('loss_grad_ssim', loss_grad_ssim, global_step=it_d) writer.add_scalar('loss_grad_lpips', loss_grad_lpips, global_step=it_d) writer.add_scalar('loss_img_mse', loss_img_mse, global_step=it_d) writer.add_scalar('loss_img_ssim', loss_img_ssim, global_step=it_d) writer.add_scalar('loss_img_lpips', loss_img_lpips, global_step=it_d) writer.add_scalar('loss_w', loss_w, global_step=it_d) writer.add_scalar('loss_w_m', loss_w_m, global_step=it_d) writer.add_scalar('loss_w_s', loss_w_s, global_step=it_d) writer.add_scalar('loss_kl_w', loss_kl_w, global_step=it_d) writer.add_scalar('loss_cosine_w', loss_cosine_w, global_step=it_d) writer.add_scalar('loss_c', loss_c, global_step=it_d) writer.add_scalar('loss_c_m', loss_c_m, global_step=it_d) writer.add_scalar('loss_c_s', loss_c_s, global_step=it_d) writer.add_scalar('loss_kl_c', loss_kl_c, global_step=it_d) writer.add_scalar('loss_cosine_c', loss_cosine_c, global_step=it_d) writer.add_scalars('Image_Space', { 'loss_mask_mse': loss_mask_mse, 'loss_grad_mse': loss_grad_mse, 'loss_img_mse': loss_img_mse }, global_step=it_d) writer.add_scalars('Image_Space', { 'loss_mask_ssim': loss_mask_mse, 'loss_grad_ssim': loss_grad_ssim, 'loss_img_ssim': loss_img_ssim }, global_step=it_d) writer.add_scalars('Image_Space', { 'loss_mask_lpips': loss_mask_lpips, 'loss_grad_lpips': loss_grad_lpips, 'loss_img_lpips': loss_img_lpips }, global_step=it_d) writer.add_scalars('Latent Space W', { 'loss_w': loss_w, 'loss_w_m': loss_w_m, 'loss_w_s': loss_w_s, 'loss_kl_w': loss_kl_w, 'loss_cosine_w': loss_cosine_w }, global_step=it_d) writer.add_scalars('Latent Space C', { 'loss_c': loss_c, 'loss_c_m': loss_c_m, 'loss_c_s': loss_c_s, 'loss_kl_c': loss_kl_c, 'loss_cosine_c': loss_cosine_c }, global_step=it_d) if epoch % 100 == 0: n_row = batch_size test_img = torch.cat((imgs1[:n_row], imgs2[:n_row])) * 0.5 + 0.5 torchvision.utils.save_image(test_img, resultPath1_1 + '/ep%d.png' % (epoch), nrow=n_row) # nrow=3 heatmap1, cam1 = mask2cam(mask_1, imgs1) heatmap2, cam2 = mask2cam(mask_2, imgs2) heatmap = torch.cat((heatmap1, heatmap1)) cam = torch.cat((cam1, cam2)) grads = torch.cat((grad1, grad2)) grads = grads.data.cpu().numpy() # [n,c,h,w] grads -= np.max(np.min(grads), 0) grads /= np.max(grads) torchvision.utils.save_image( torch.tensor(heatmap), resultPath_grad_cam + '/heatmap_%d.png' % (epoch)) torchvision.utils.save_image( torch.tensor(cam), resultPath_grad_cam + '/cam_%d.png' % (epoch)) torchvision.utils.save_image( torch.tensor(grads), resultPath_grad_cam + '/gb_%d.png' % (epoch)) with open(resultPath + '/Loss.txt', 'a+') as f: print('i_' + str(epoch), file=f) print('---------ImageSpace--------', file=f) print('loss_mask_mse:' + str(loss_mask_mse.item()) + '--loss_mask_ssim:' + str(loss_mask_ssim.item()) + '--loss_mask_lpips:' + str(loss_mask_lpips.item()), file=f) print('loss_grad_mse:' + str(loss_grad_mse.item()) + '--loss_grad_ssim:' + str(loss_grad_ssim.item()) + '--loss_grad_lpips:' + str(loss_grad_lpips.item()), file=f) print('loss_img_mse:' + str(loss_img_mse.item()) + '--loss_img_ssim:' + str(loss_img_ssim.item()) + '--loss_img_lpips:' + str(loss_img_lpips.item()), file=f) print('---------LatentSpace--------', file=f) print('loss_w:' + str(loss_w.item()) + '--loss_w_m:' + str(loss_w_m.item()) + '--loss_w_s:' + str(loss_w_s.item()), file=f) print('loss_kl_w:' + str(loss_kl_w.item()) + '--loss_cosine_w:' + str(loss_cosine_w.item()), file=f) print('loss_c:' + str(loss_c.item()) + '--loss_c_m:' + str(loss_c_m.item()) + '--loss_c_s:' + str(loss_c_s.item()), file=f) print('loss_kl_c:' + str(loss_kl_c.item()) + '--loss_cosine_c:' + str(loss_cosine_c.item()), file=f) if epoch % 5000 == 0: torch.save(E.state_dict(), resultPath1_2 + '/E_model_ep%d.pth' % epoch)