def train(avg_tensor=None, coefs=0, tensor_writer=None): Gs = Generator(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) # 32->512 layer_count=8 / 64->256 layer_count=7 #Gs.load_state_dict(torch.load('./pre-model/cat/cat256_Gs_dict.pth')) #256*256 cat Gs.load_state_dict( torch.load( './pre-model/bedroom/bedrooms256_Gs_dict.pth')) # 256*256 bedroom Gm = Mapping(num_layers=14, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 #Gm.load_state_dict(torch.load('./pre-model/cat/cat256_Gm_dict.pth')) Gm.load_state_dict( torch.load( './pre-model/bedroom/bedrooms256_Gm_dict.pth')) # 256*256 bedroom Gm.buffer1 = avg_tensor E = BE.BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) #E.load_state_dict(torch.load('/_yucheng/myStyle/myStyle-v1/EAE-car-cat/result/EB_cat_cosine_v2/E_model_ep80000.pth')) Gs.cuda() #Gm.cuda() E.cuda() const_ = Gs.const writer = tensor_writer E_optimizer = LREQAdam([ { 'params': E.parameters() }, ], lr=0.0015, betas=(0.0, 0.99), weight_decay=0) #用这个adam不会报错:RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation loss_lpips = lpips.LPIPS(net='vgg').to('cuda') batch_size = 3 const1 = const_.repeat(batch_size, 1, 1, 1) it_d = 0 for epoch in range(0, 250001): set_seed(epoch % 30000) latents = torch.randn(batch_size, 512) #[32, 512] with torch.no_grad(): #这里需要生成图片和变量 w1 = Gm(latents, coefs_m=coefs).to('cuda') #[batch_size,18,512] imgs1 = Gs.forward(w1, 6) # 7->512 / 6->256 const2, w2 = E(imgs1.cuda()) imgs2 = Gs.forward(w2, 6) E_optimizer.zero_grad() #Latent_space ## c loss_c, loss_c_info = space_loss(const1, const2, image_space=False) E_optimizer.zero_grad() loss_c.backward(retain_graph=True) E_optimizer.step() ## w loss_w, loss_w_info = space_loss(w1, w2, image_space=False) E_optimizer.zero_grad() loss_w.backward(retain_graph=True) E_optimizer.step() #Image Space ##loss1 最小区域 imgs_small_1 = imgs1[:, :, imgs1.shape[2] // 20:-imgs1.shape[2] // 20, imgs1.shape[3] // 20:-imgs1.shape[3] // 20].clone() # w,h imgs_small_2 = imgs2[:, :, imgs2.shape[2] // 20:-imgs2.shape[2] // 20, imgs2.shape[3] // 20:-imgs2.shape[3] // 20].clone() loss_small, loss_small_info = space_loss(imgs_small_1, imgs_small_2, lpips_model=loss_lpips) E_optimizer.zero_grad() loss_small.backward(retain_graph=True) E_optimizer.step() #loss2 中等区域 imgs_medium_1 = imgs1[:, :, imgs1.shape[2] // 10:-imgs1.shape[2] // 10, imgs1.shape[3] // 10:-imgs1.shape[3] // 10].clone() imgs_medium_2 = imgs2[:, :, imgs2.shape[2] // 10:-imgs2.shape[2] // 10, imgs2.shape[3] // 10:-imgs2.shape[3] // 10].clone() loss_medium, loss_medium_info = space_loss(imgs_medium_1, imgs_medium_2, lpips_model=loss_lpips) E_optimizer.zero_grad() loss_medium.backward(retain_graph=True) E_optimizer.step() #loss3 原图区域 loss_imgs, loss_imgs_info = space_loss(imgs1, imgs2, lpips_model=loss_lpips) E_optimizer.zero_grad() loss_imgs.backward(retain_graph=True) E_optimizer.step() print('i_' + str(epoch)) print( '[loss_imgs_mse[img,img_mean,img_std], loss_imgs_ssim, loss_imgs_cosine, loss_kl_imgs, loss_imgs_lpips]' ) print('---------ImageSpace--------') print('loss_small_info: %s' % loss_small_info) print('loss_medium_info: %s' % loss_medium_info) print('loss_imgs_info: %s' % loss_imgs_info) print('---------LatentSpace--------') print('loss_w_info: %s' % loss_w_info) print('loss_c_info: %s' % loss_c_info) it_d += 1 writer.add_scalar('loss_small_mse', loss_small_info[0][0], global_step=it_d) writer.add_scalar('loss_samll_mse_mean', loss_small_info[0][1], global_step=it_d) writer.add_scalar('loss_samll_mse_std', loss_small_info[0][2], global_step=it_d) writer.add_scalar('loss_samll_kl', loss_small_info[1], global_step=it_d) writer.add_scalar('loss_samll_cosine', loss_small_info[2], global_step=it_d) writer.add_scalar('loss_samll_ssim', loss_small_info[3], global_step=it_d) writer.add_scalar('loss_samll_lpips', loss_small_info[4], global_step=it_d) writer.add_scalar('loss_medium_mse', loss_medium_info[0][0], global_step=it_d) writer.add_scalar('loss_medium_mse_mean', loss_medium_info[0][1], global_step=it_d) writer.add_scalar('loss_medium_mse_std', loss_medium_info[0][2], global_step=it_d) writer.add_scalar('loss_medium_kl', loss_medium_info[1], global_step=it_d) writer.add_scalar('loss_medium_cosine', loss_medium_info[2], global_step=it_d) writer.add_scalar('loss_medium_ssim', loss_medium_info[3], global_step=it_d) writer.add_scalar('loss_medium_lpips', loss_medium_info[4], global_step=it_d) writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d) writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d) writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d) writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d) writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d) writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d) writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d) writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d) writer.add_scalar('loss_w_mse_mean', loss_w_info[0][1], global_step=it_d) writer.add_scalar('loss_w_mse_std', loss_w_info[0][2], global_step=it_d) writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d) writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d) writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d) writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d) writer.add_scalar('loss_c_mse', loss_c_info[0][0], global_step=it_d) writer.add_scalar('loss_c_mse_mean', loss_c_info[0][1], global_step=it_d) writer.add_scalar('loss_c_mse_std', loss_c_info[0][2], global_step=it_d) writer.add_scalar('loss_c_kl', loss_c_info[1], global_step=it_d) writer.add_scalar('loss_c_cosine', loss_c_info[2], global_step=it_d) writer.add_scalar('loss_c_ssim', loss_c_info[3], global_step=it_d) writer.add_scalar('loss_c_lpips', loss_c_info[4], global_step=it_d) writer.add_scalars('Image_Space_MSE', { 'loss_small_mse': loss_small_info[0][0], 'loss_medium_mse': loss_medium_info[0][0], 'loss_img_mse': loss_imgs_info[0][0] }, global_step=it_d) writer.add_scalars('Image_Space_KL', { 'loss_small_kl': loss_small_info[1], 'loss_medium_kl': loss_medium_info[1], 'loss_imgs_kl': loss_imgs_info[1] }, global_step=it_d) writer.add_scalars('Image_Space_Cosine', { 'loss_samll_cosine': loss_small_info[2], 'loss_medium_cosine': loss_medium_info[2], 'loss_imgs_cosine': loss_imgs_info[2] }, global_step=it_d) writer.add_scalars('Image_Space_SSIM', { 'loss_small_ssim': loss_small_info[3], 'loss_medium_ssim': loss_medium_info[3], 'loss_img_ssim': loss_imgs_info[3] }, global_step=it_d) writer.add_scalars('Image_Space_Lpips', { 'loss_small_lpips': loss_small_info[4], 'loss_medium_lpips': loss_medium_info[4], 'loss_img_lpips': loss_imgs_info[4] }, global_step=it_d) writer.add_scalars('Latent Space W', { 'loss_w_mse': loss_w_info[0][0], 'loss_w_mse_mean': loss_w_info[0][1], 'loss_w_mse_std': loss_w_info[0][2], 'loss_w_kl': loss_w_info[1], 'loss_w_cosine': loss_w_info[2] }, global_step=it_d) writer.add_scalars('Latent Space C', { 'loss_c_mse': loss_c_info[0][0], 'loss_c_mse_mean': loss_c_info[0][1], 'loss_c_mse_std': loss_c_info[0][2], 'loss_c_kl': loss_c_info[1], 'loss_c_cosine': loss_c_info[2] }, global_step=it_d) if epoch % 100 == 0: n_row = batch_size test_img = torch.cat((imgs1[:n_row], imgs2[:n_row])) * 0.5 + 0.5 torchvision.utils.save_image(test_img, resultPath1_1 + '/ep%d.jpg' % (epoch), nrow=n_row) # nrow=3 with open(resultPath + '/Loss.txt', 'a+') as f: print('i_' + str(epoch), file=f) print( '[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]', file=f) print('---------ImageSpace--------', file=f) print('loss_small_info: %s' % loss_small_info, file=f) print('loss_medium_info: %s' % loss_medium_info, file=f) print('loss_imgs_info: %s' % loss_imgs_info, file=f) print('---------LatentSpace--------', file=f) print('loss_w_info: %s' % loss_w_info, file=f) print('loss_c_info: %s' % loss_c_info, file=f) if epoch % 5000 == 0: torch.save(E.state_dict(), resultPath1_2 + '/E_model_ep%d.pth' % epoch)
def train(avg_tensor=None, coefs=0, tensor_writer=None): Gs = Generator( startf=32, maxf=512, layer_count=8, latent_size=512, channels=3 ) # cats: stratf 32->512 layer_count=8 / cat: startf 64->256 layer_count=7 #Gs.load_state_dict(torch.load('./pre-model/cat/cat256_Gs_dict.pth')) Gs.load_state_dict(torch.load('./pre-model/cars/cars512_Gs_dict.pth')) Gm = Mapping(num_layers=16, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 #Gm.load_state_dict(torch.load('./pre-model/cat/cat256_Gm_dict.pth')) Gm.load_state_dict(torch.load('./pre-model/cars/cars512_Gm_dict.pth')) Gm.buffer1 = avg_tensor E = BE.BE(startf=32, maxf=512, layer_count=8, latent_size=512, channels=3) #E.load_state_dict(torch.load('/_yucheng/myStyle/myStyle-v1/EAE-car-cat/pre-model/E_cat_v2_1_ep100000.pth')) E.load_state_dict( torch.load( '/_wmwang/mystyle/myStyle1/EAE/pre-model/E_cars512_ep90000_v1.pth') ) # cars Gs.cuda() #Gm.cuda() E.cuda() const_ = Gs.const writer = tensor_writer E_optimizer = LREQAdam([ { 'params': E.parameters() }, ], lr=0.0015, betas=(0.0, 0.99), weight_decay=0) loss_lpips = lpips.LPIPS(net='vgg').to('cuda') batch_size = 4 const1 = const_.repeat(batch_size, 1, 1, 1) vgg16 = torchvision.models.vgg16(pretrained=True).cuda() final_layer = None for name, m in vgg16.named_modules(): if isinstance(m, nn.Conv2d): final_layer = name grad_cam_plus_plus = GradCamPlusPlus(vgg16, final_layer) gbp = GuidedBackPropagation(vgg16) it_d = 0 for epoch in range(0, 250001): set_seed(epoch % 30000) latents = torch.randn(batch_size, 512) #[32, 512] with torch.no_grad(): #这里需要生成图片和变量 w1 = Gm(latents, coefs_m=coefs).to('cuda') #[batch_size,18,512] imgs1 = Gs.forward(w1, 7) # 7->512 / 6->256 const2, w2 = E(imgs1.cuda()) imgs2 = Gs.forward(w2, 7) E_optimizer.zero_grad() #Latent Space ##--C loss_c, loss_c_info = space_loss(const1, const2, image_space=False) E_optimizer.zero_grad() loss_c.backward(retain_graph=True) E_optimizer.step() ##--W loss_w, loss_w_info = space_loss(w1, w2, image_space=False) E_optimizer.zero_grad() loss_w.backward(retain_graph=True) E_optimizer.step() #Image Space mask_1 = grad_cam_plus_plus(imgs1, None) #[c,1,h,w] mask_2 = grad_cam_plus_plus(imgs2, None) #imgs1.retain_grad() #imgs2.retain_grad() imgs1_ = imgs1.detach().clone() imgs1_.requires_grad = True imgs2_ = imgs2.detach().clone() imgs2_.requires_grad = True grad_1 = gbp(imgs1_) # [n,c,h,w] grad_2 = gbp(imgs2_) heatmap_1, cam_1 = mask2cam(mask_1, imgs1) heatmap_2, cam_2 = mask2cam(mask_2, imgs2) ##--Mask mask_1 = mask_1.cuda().float() mask_1.requires_grad = True mask_2 = mask_2.cuda().float() mask_2.requires_grad = True loss_mask, loss_mask_info = space_loss(mask_1, mask_2, lpips_model=loss_lpips) E_optimizer.zero_grad() loss_mask.backward(retain_graph=True) E_optimizer.step() ##--Grad grad_1 = grad_1.cuda().float() grad_1.requires_grad = True grad_2 = grad_2.cuda().float() grad_2.requires_grad = True loss_grad, loss_grad_info = space_loss(grad_1, grad_2, lpips_model=loss_lpips) E_optimizer.zero_grad() loss_grad.backward(retain_graph=True) E_optimizer.step() ##--Image loss_imgs, loss_imgs_info = space_loss(imgs1, imgs2, lpips_model=loss_lpips) E_optimizer.zero_grad() loss_imgs.backward(retain_graph=True) E_optimizer.step() ##--Grad_CAM from mask cam_1 = cam_1.cuda().float() cam_1.requires_grad = True cam_2 = cam_2.cuda().float() cam_2.requires_grad = True loss_Gcam, loss_Gcam_info = space_loss(cam_1, cam_2, lpips_model=loss_lpips) E_optimizer.zero_grad() loss_Gcam.backward(retain_graph=True) E_optimizer.step() print('i_' + str(epoch)) print( '[loss_imgs_mse[img,img_mean,img_std], loss_imgs_ssim, loss_imgs_cosine, loss_kl_imgs, loss_imgs_lpips]' ) print('---------ImageSpace--------') print('loss_mask_info: %s' % loss_mask_info) print('loss_grad_info: %s' % loss_grad_info) print('loss_imgs_info: %s' % loss_imgs_info) print('loss_Gcam_info: %s' % loss_Gcam_info) print('---------LatentSpace--------') print('loss_w_info: %s' % loss_w_info) print('loss_c_info: %s' % loss_c_info) it_d += 1 writer.add_scalar('loss_mask_mse', loss_mask_info[0][0], global_step=it_d) writer.add_scalar('loss_mask_mse_mean', loss_mask_info[0][1], global_step=it_d) writer.add_scalar('loss_mask_mse_std', loss_mask_info[0][2], global_step=it_d) writer.add_scalar('loss_mask_kl', loss_mask_info[1], global_step=it_d) writer.add_scalar('loss_mask_cosine', loss_mask_info[2], global_step=it_d) writer.add_scalar('loss_mask_ssim', loss_mask_info[3], global_step=it_d) writer.add_scalar('loss_mask_lpips', loss_mask_info[4], global_step=it_d) writer.add_scalar('loss_grad_mse', loss_grad_info[0][0], global_step=it_d) writer.add_scalar('loss_grad_mse_mean', loss_grad_info[0][1], global_step=it_d) writer.add_scalar('loss_grad_mse_std', loss_grad_info[0][2], global_step=it_d) writer.add_scalar('loss_grad_kl', loss_grad_info[1], global_step=it_d) writer.add_scalar('loss_grad_cosine', loss_grad_info[2], global_step=it_d) writer.add_scalar('loss_grad_ssim', loss_grad_info[3], global_step=it_d) writer.add_scalar('loss_grad_lpips', loss_grad_info[4], global_step=it_d) writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d) writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d) writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d) writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d) writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d) writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d) writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d) writer.add_scalar('loss_Gcam', loss_Gcam_info[0][0], global_step=it_d) writer.add_scalar('loss_Gcam_mean', loss_Gcam_info[0][1], global_step=it_d) writer.add_scalar('loss_Gcam_std', loss_Gcam_info[0][2], global_step=it_d) writer.add_scalar('loss_Gcam_kl', loss_Gcam_info[1], global_step=it_d) writer.add_scalar('loss_Gcam_cosine', loss_Gcam_info[2], global_step=it_d) writer.add_scalar('loss_Gcam_ssim', loss_Gcam_info[3], global_step=it_d) writer.add_scalar('loss_Gcam_lpips', loss_Gcam_info[4], global_step=it_d) writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d) writer.add_scalar('loss_w_mse_mean', loss_w_info[0][1], global_step=it_d) writer.add_scalar('loss_w_mse_std', loss_w_info[0][2], global_step=it_d) writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d) writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d) writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d) writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d) writer.add_scalar('loss_c_mse', loss_c_info[0][0], global_step=it_d) writer.add_scalar('loss_c_mse_mean', loss_c_info[0][1], global_step=it_d) writer.add_scalar('loss_c_mse_std', loss_c_info[0][2], global_step=it_d) writer.add_scalar('loss_c_kl', loss_c_info[1], global_step=it_d) writer.add_scalar('loss_c_cosine', loss_c_info[2], global_step=it_d) writer.add_scalar('loss_c_ssim', loss_c_info[3], global_step=it_d) writer.add_scalar('loss_c_lpips', loss_c_info[4], global_step=it_d) writer.add_scalars('Image_Space_MSE', { 'loss_mask_mse': loss_mask_info[0][0], 'loss_grad_mse': loss_grad_info[0][0], 'loss_img_mse': loss_imgs_info[0][0] }, global_step=it_d) writer.add_scalars('Image_Space_KL', { 'loss_mask_kl': loss_mask_info[1], 'loss_grad_cosine': loss_grad_info[1], 'loss_imgs_cosine': loss_imgs_info[1] }, global_step=it_d) writer.add_scalars('Image_Space_Cosine', { 'loss_mask_cosine': loss_mask_info[2], 'loss_grad_cosine': loss_grad_info[2], 'loss_imgs_cosine': loss_imgs_info[2] }, global_step=it_d) writer.add_scalars('Image_Space_SSIM', { 'loss_mask_ssim': loss_mask_info[3], 'loss_grad_ssim': loss_grad_info[3], 'loss_img_ssim': loss_imgs_info[3] }, global_step=it_d) writer.add_scalars('Image_Space_Cosine', { 'loss_mask_cosine': loss_mask_info[4], 'loss_grad_cosine': loss_grad_info[4], 'loss_imgs_cosine': loss_imgs_info[4] }, global_step=it_d) writer.add_scalars('Image_Space_Lpips', { 'loss_mask_lpips': loss_mask_info[4], 'loss_grad_lpips': loss_grad_info[4], 'loss_img_lpips': loss_imgs_info[4] }, global_step=it_d) writer.add_scalars('Latent Space W', { 'loss_w_mse': loss_w_info[0][0], 'loss_w_mse_mean': loss_w_info[0][1], 'loss_w_mse_std': loss_w_info[0][2], 'loss_w_kl': loss_w_info[1], 'loss_w_cosine': loss_w_info[2] }, global_step=it_d) writer.add_scalars('Latent Space C', { 'loss_c_mse': loss_c_info[0][0], 'loss_c_mse_mean': loss_c_info[0][1], 'loss_c_mse_std': loss_c_info[0][2], 'loss_c_kl': loss_w_info[1], 'loss_c_cosine': loss_w_info[2] }, global_step=it_d) if epoch % 100 == 0: n_row = batch_size test_img = torch.cat((imgs1[:n_row], imgs2[:n_row])) * 0.5 + 0.5 torchvision.utils.save_image(test_img, resultPath1_1 + '/ep%d.png' % (epoch), nrow=n_row) # nrow=3 heatmap = torch.cat((heatmap_1, heatmap_2)) cam = torch.cat((cam_1, cam_2)) grads = torch.cat((grad_1, grad_2)) grads = grads.data.cpu().numpy() # [n,c,h,w] grads -= np.max(np.min(grads), 0) grads /= np.max(grads) torchvision.utils.save_image(torch.tensor(heatmap), resultPath_grad_cam + '/heatmap_%d.png' % (epoch), nrow=n_row) torchvision.utils.save_image(torch.tensor(cam), resultPath_grad_cam + '/cam_%d.png' % (epoch), nrow=n_row) torchvision.utils.save_image(torch.tensor(grads), resultPath_grad_cam + '/gb_%d.png' % (epoch), nrow=n_row) with open(resultPath + '/Loss.txt', 'a+') as f: print('i_' + str(epoch), file=f) print( '[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]', file=f) print('---------ImageSpace--------', file=f) print('loss_mask_info: %s' % loss_mask_info, file=f) print('loss_grad_info: %s' % loss_grad_info, file=f) print('loss_imgs_info: %s' % loss_imgs_info, file=f) print('loss_Gcam_info: %s' % loss_Gcam_info, file=f) print('---------LatentSpace--------', file=f) print('loss_w_info: %s' % loss_w_info, file=f) print('loss_c_info: %s' % loss_c_info, file=f) if epoch % 5000 == 0: torch.save(E.state_dict(), resultPath1_2 + '/E_model_ep%d.pth' % epoch)
def train(avg_tensor=None, coefs=0): Gs = Generator(startf=16, maxf=512, layer_count=9, latent_size=512, channels=3) Gs.load_state_dict(torch.load('./pre-model/Gs_dict.pth')) Gm = Mapping(num_layers=18, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) Gm.load_state_dict(torch.load('./pre-model/Gm_dict.pth')) Gm.buffer1 = avg_tensor Gm1 = Mapping3() Gm1.load_state_dict( torch.load( '/_yucheng/myStyle/myStyle-v1/result/Gm_1&2_V10_6/models/Gm1_model_ep85000.pth' )) Gm2 = Mapping4() #Gm1.load_state_dict(torch.load('/_yucheng/myStyle/myStyle-v1/result/Gm_1&2_V10_3/models/Gm1_model_ep10000.pth')) #Gm2 = Mapping2(num_layers=18, mapping_layers=8, latent_size=512, inverse=True) #Gm1.load_state_dict(torch.load('./pre-model/Gm1.pth')) #Gm2.load_state_dict(torch.load('./pre-model/Gm2.pth')) E = BE.BE() E.load_state_dict(torch.load( '/_yucheng/myStyle/myStyle-v1/result/EB_V10_blob_mse/models/E_model_ep15000.pth' ), strict=False) Gs.cuda() E.cuda() Gm.cuda() Gm1.cuda() Gm2.cuda() #Gm_optimizer = LREQAdam([{'params': Gm1.parameters()},], lr=0.0015, betas=(0.0, 0.99), weight_decay=0) Gm_optimizer = LREQAdam([ { 'params': Gm1.parameters() }, { 'params': Gm2.parameters() }, ], lr=0.0015, betas=(0.0, 0.99), weight_decay=0) loss_mse = torch.nn.MSELoss() loss_kl = torch.nn.KLDivLoss() loss_lpips = lpips.LPIPS(net='vgg').to('cuda') batch_size = 3 for epoch in range(100000): set_seed(epoch % 20000) z = torch.randn(batch_size, 512).to('cuda') #[32, 512] w1 = Gm(z, coefs_m=coefs) #[batch_size,18,512] imgs1 = Gs.forward(w1, 8) const2, w2 = E(imgs1.cuda()) z2 = Gm2(w2) w3 = Gm1(z2) imgs2 = Gs.forward(w2, 8) imgs3 = Gs.forward(w3, 8) #loss1 Gm_optimizer.zero_grad() #Gm1_optimizer.zero_grad() loss_m1_mse = loss_mse(w2, w3) loss_m1_mse_mean = loss_mse(w2.mean(), w3.mean()) loss_m1_mse_std = loss_mse(w2.std(), w3.std()) y1_w, y2_w = torch.nn.functional.softmax( w2), torch.nn.functional.softmax(w3) loss_kl_w = loss_kl(torch.log(y2_w), y1_w) #D_kl(True=y1_w||Fake=y2_w) loss_kl_w = torch.where(torch.isnan(loss_kl_w), torch.full_like(loss_kl_w, 0), loss_kl_w) loss_kl_w = torch.where(torch.isinf(loss_kl_w), torch.full_like(loss_kl_w, 1), loss_kl_w) loss_1 = loss_m1_mse + loss_m1_mse_mean + loss_m1_mse_std + loss_kl_w #loss2 loss_m2_mse = loss_mse(z, z2) loss_m2_mse_mean = loss_mse(z.mean(), z2.mean()) loss_m2_mse_std = loss_mse(z.std(), z2.std()) y1_z, y2_z = torch.nn.functional.softmax( z), torch.nn.functional.softmax(z2) loss_kl_z = loss_kl(torch.log(y2_z), y1_z) #D_kl(True=y1_z||Fake=y2_z) loss_kl_z = torch.where(torch.isnan(loss_kl_z), torch.full_like(loss_kl_z, 0), loss_kl_z) loss_kl_z = torch.where(torch.isinf(loss_kl_z), torch.full_like(loss_kl_z, 1), loss_kl_z) loss_2 = loss_m2_mse + loss_m2_mse_mean + loss_m2_mse_std + loss_kl_z #loss3 loss_m1_mse_img = loss_mse(imgs2, imgs3) imgs2_ = F.avg_pool2d(imgs2, 2, 2) imgs3_ = F.avg_pool2d(imgs3, 2, 2) loss_img_lpips = loss_lpips(imgs2_, imgs3_).mean() loss_3 = loss_img_lpips + loss_m1_mse_img * 3 loss_all = loss_2 + loss_1 * 3 + loss_3 * 15 # z -> w -> x loss_all.backward() Gm_optimizer.step() print('i_' + str(epoch) + '--loss_all__:' + str(loss_all.item()) + '--loss_m1_mse:' + str(loss_m1_mse.item()) + '--loss_m1_mse_mean:' + str(loss_m1_mse_mean.item()) + '--loss_m1_mse_std:' + str(loss_m1_mse_std.item()) + '--loss_kl_w:' + str(loss_kl_w.item())) print('--loss_m2_mse:' + str(loss_m2_mse.item()) + '--loss_m2_mse_mean:' + str(loss_m2_mse_mean.item()) + '--loss_m2_mse_std:' + str(loss_m2_mse_std.item()) + '--loss_kl_z:' + str(loss_kl_z.item())) #print('--loss_m1_mse_img:'+str(loss_m1_mse_img.item())+'--loss_m2_mse_img:'+str(loss_m2_mse_img.item())+'--loss_m3_mse_img:'+str(loss_m3_mse_img.item())) print('loss_img_lpips' + str(loss_img_lpips.item()) + '--loss_m1_mse_img:' + str(loss_m1_mse_img.item())) print('-') if epoch % 100 == 0: with torch.no_grad(): #这里需要生成图片和变量 test_img = torch.cat((imgs1[:3], imgs2[:3])) test_img = torch.cat((test_img, imgs3[:3])) test_img = test_img * 0.5 + 0.5 torchvision.utils.save_image(test_img, resultPath1_1 + '/ep%d.jpg' % (epoch), nrow=3) # nrow=3 with open(resultPath + '/Loss.txt', 'a+') as f: print('i_' + str(epoch) + '--loss_all__:' + str(loss_all.item()) + '--loss_m1_mse:' + str(loss_m1_mse.item()) + '--loss_m1_mse_mean:' + str(loss_m1_mse_mean.item()) + '--loss_m1_mse_std:' + str(loss_m1_mse_std.item()) + '--loss_kl_w:' + str(loss_kl_w.item()), file=f) print('--loss_m2_mse:' + str(loss_m2_mse.item()) + '--loss_m2_mse_mean:' + str(loss_m2_mse_mean.item()) + '--loss_m2_mse_std:' + str(loss_m2_mse_std.item()) + '--loss_kl_z:' + str(loss_kl_z.item()), file=f) #print('--loss_m1_mse_img:'+str(loss_m1_mse_img.item())+'--loss_m2_mse_img:'+str(loss_m2_mse_img.item())+'--loss_m3_mse_img:'+str(loss_m3_mse_img.item()),file=f) print('loss_img_lpips' + str(loss_img_lpips.item()) + '--loss_m1_mse_img:' + str(loss_m1_mse_img.item()), file=f) if epoch % 5000 == 0: torch.save(Gm1.state_dict(), resultPath1_2 + '/Gm1_model_ep%d.pth' % epoch) torch.save(Gm2.state_dict(), resultPath1_2 + '/Gm2_model_ep%d.pth' % epoch)
def train(cfg, logger, gpu_id=0): torch.cuda.set_device(gpu_id) model = Model( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=3) model.cuda(gpu_id) model.train() if gpu_id == 0: model_s = Model( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=3) del model_s.discriminator model_s.cuda(gpu_id) model_s.eval() model_s.requires_grad_(False) generator = model.generator discriminator = model.discriminator mapping = model.mapping dlatent_avg = model.dlatent_avg count_param_override.print = lambda a: logger.info(a) logger.info("Trainable parameters generator:") count_parameters(generator) logger.info("Trainable parameters discriminator:") count_parameters(discriminator) generator_optimizer = LREQAdam([ {'params': generator.parameters()}, {'params': mapping.parameters()}], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) discriminator_optimizer = LREQAdam([ {'params': discriminator.parameters()}, ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) scheduler = ComboMultiStepLR(optimizers={'generator': generator_optimizer,'discriminator': discriminator_optimizer}, milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, # [] gamma=cfg.TRAIN.LEARNING_DECAY_RATE, # 0.1 reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) # 0.002 model_dict = { 'discriminator': discriminator, 'generator': generator, 'mapping': mapping, 'dlatent_avg': dlatent_avg } if gpu_id == 0: model_dict['generator_s'] = model_s.generator model_dict['mapping_s'] = model_s.mapping tracker = LossTracker(cfg.OUTPUT_DIR) checkpointer = Checkpointer(cfg, model_dict, { 'generator_optimizer': generator_optimizer, 'discriminator_optimizer': discriminator_optimizer, 'scheduler': scheduler, 'tracker': tracker }, logger=logger, save=gpu_id == 0) checkpointer.load() logger.info("Starting from epoch: %d" % (scheduler.start_epoch())) layer_to_resolution = generator.layer_to_resolution #[4, 8, 16, 32, 64, 128] dataset = TFRecordsDataset(cfg, logger, buffer_size_mb=1024) rnd = np.random.RandomState(3456) latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE) sample = torch.tensor(latents).float().cuda() lod2batch = lod_driver.LODDriver(cfg, logger, gpu_num=1, dataset_size=len(dataset)) #一个可以返回各类训练参数(param)的对象 for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS): model.train() lod2batch.set_epoch(epoch, [generator_optimizer, discriminator_optimizer]) logger.info("Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" % ( lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(), lod2batch.lod, 2 ** lod2batch.get_lod_power2(), 2 ** lod2batch.get_lod_power2(), lod2batch.get_blend_factor(), len(dataset))) dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size()) batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), gpu_id) # 一个数据集分为多个batch,一个batch有n长图片 scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod) #报错! need_permute = False with torch.autograd.profiler.profile(use_cuda=True, enabled=False) as prof: for x_orig in tqdm(batches): # x_orig:[-1,c,w,h] with torch.no_grad(): if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size(): continue if need_permute: x_orig = x_orig.permute(0, 3, 1, 2) x_orig = (x_orig / 127.5 - 1.) blend_factor = lod2batch.get_blend_factor() needed_resolution = layer_to_resolution[lod2batch.lod] x = x_orig if lod2batch.in_transition: needed_resolution_prev = layer_to_resolution[lod2batch.lod - 1] x_prev = F.avg_pool2d(x_orig, 2, 2) x_prev_2x = F.interpolate(x_prev, needed_resolution) x = x * blend_factor + x_prev_2x * (1.0 - blend_factor) x.requires_grad = True discriminator_optimizer.zero_grad() loss_d = model(x, lod2batch.lod, blend_factor, d_train=True) tracker.update(dict(loss_d=loss_d)) loss_d.backward() discriminator_optimizer.step() if gpu_id == 0: betta = 0.5 ** (lod2batch.get_batch_size() / (10 * 1000.0)) model_s.lerp(model, betta) generator_optimizer.zero_grad() loss_g = model(x, lod2batch.lod, blend_factor, d_train=False) tracker.update(dict(loss_g=loss_g)) loss_g.backward() generator_optimizer.step() lod2batch.step() if gpu_id == 0: if lod2batch.is_time_to_save(): checkpointer.save("model_tmp_intermediate") if lod2batch.is_time_to_report(): save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, discriminator_optimizer, generator_optimizer) scheduler.step() if gpu_id == 0: checkpointer.save("model_tmp") save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, discriminator_optimizer, generator_optimizer) logger.info("Training finish!... save training results") if gpu_id == 0: checkpointer.save("model_final").wait()
def train(avg_tensor=None, coefs=0, tensor_writer=None): Gs = Generator(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) # 32->512 layer_count=8 / 64->256 layer_count=7 Gs.load_state_dict(torch.load('./pre-model/cat/cat256_Gs_dict.pth')) Gm = Mapping(num_layers=14, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 Gm.load_state_dict(torch.load('./pre-model/cat/cat256_Gm_dict.pth')) Gm.buffer1 = avg_tensor E = BE.BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) E.load_state_dict( torch.load( '/_yucheng/myStyle/myStyle-v1/EAE-car-cat/result/EB_cat_cosine_v2/E_model_ep80000.pth' )) Gs.cuda() #Gm.cuda() E.cuda() const_ = Gs.const writer = tensor_writer E_optimizer = LREQAdam([ { 'params': E.parameters() }, ], lr=0.0015, betas=(0.0, 0.99), weight_decay=0) loss_mse = torch.nn.MSELoss() loss_lpips = lpips.LPIPS(net='vgg').to('cuda') loss_kl = torch.nn.KLDivLoss() ssim_loss = pytorch_ssim.SSIM() batch_size = 3 const1 = const_.repeat(batch_size, 1, 1, 1) vgg16 = torchvision.models.vgg16(pretrained=True).cuda() final_layer = None for name, m in vgg16.named_modules(): if isinstance(m, nn.Conv2d): final_layer = name grad_cam_plus_plus = GradCamPlusPlus(vgg16, final_layer) gbp = GuidedBackPropagation(vgg16) it_d = 0 for epoch in range(0, 250001): set_seed(epoch % 30000) latents = torch.randn(batch_size, 512) #[32, 512] with torch.no_grad(): #这里需要生成图片和变量 w1 = Gm(latents, coefs_m=coefs).to('cuda') #[batch_size,18,512] imgs1 = Gs.forward(w1, 6) # 7->512 / 6->256 const2, w2 = E(imgs1.cuda()) imgs2 = Gs.forward(w2, 6) E_optimizer.zero_grad() #Image Space mask_1 = grad_cam_plus_plus(imgs1, None) #[c,1,h,w] mask_2 = grad_cam_plus_plus(imgs2, None) #imgs1.retain_grad() #imgs2.retain_grad() imgs1_ = imgs1.detach().clone() imgs1_.requires_grad = True imgs2_ = imgs2.detach().clone() imgs2_.requires_grad = True grad1 = gbp(imgs1_) # [n,c,h,w] grad2 = gbp(imgs2_) #Mask_Cam mask_1 = mask_1.cuda().float() mask_1.requires_grad = True mask_2 = mask_2.cuda().float() mask_2.requires_grad = True loss_mask_mse_1 = loss_mse(mask_1, mask_2) loss_mask_mse_2 = loss_mse(mask_1.mean(), mask_2.mean()) loss_mask_mse_3 = loss_mse(mask_1.std(), mask_2.std()) loss_mask_mse = loss_mask_mse_1 + loss_mask_mse_2 + loss_mask_mse_3 ssim_value = pytorch_ssim.ssim(mask_1, mask_2) # while ssim_value<0.999: loss_mask_ssim = 1 - ssim_loss(mask_1, mask_2) loss_mask_lpips = loss_lpips(mask_1, mask_2).mean() mask1_kl, mask2_kl = torch.nn.functional.softmax( mask_1), torch.nn.functional.softmax(mask_2) loss_kl_mask = loss_kl(torch.log(mask2_kl), mask1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_kl_mask = torch.where(torch.isnan(loss_kl_w), torch.full_like(loss_kl_w, 0), loss_kl_w) loss_kl_mask = torch.where(torch.isinf(loss_kl_w), torch.full_like(loss_kl_w, 1), loss_kl_w) mask1_cos = mask1.view(-1) mask2_cos = mask2.view(-1) loss_cosine_w = 1 - mask1_cos.dot(mask2_cos) / ( torch.sqrt(mask1_cos.dot(mask1_cos)) * torch.sqrt(mask2_cos.dot(mask2_cos))) #[-1,1],-1:反向相反,1:方向相同 loss_mask = loss_mask_mse + loss_mask_ssim + loss_mask_lpips + loss_kl_mask + loss_cosine_w E_optimizer.zero_grad() loss_mask.backward(retain_graph=True) E_optimizer.step() #Grad grad1 = grad1.cuda().float() grad1.requires_grad = True grad2 = grad2.cuda().float() grad2.requires_grad = True loss_grad_mse = loss_mse(grad1, grad2) E_optimizer.zero_grad() loss_grad_mse.backward(retain_graph=True) E_optimizer.step() ssim_value = pytorch_ssim.ssim(grad1, grad2) # while ssim_value<0.999: loss_grad_ssim = 1 - ssim_loss(grad1, grad2) E_optimizer.zero_grad() loss_grad_ssim.backward(retain_graph=True) E_optimizer.step() loss_grad_lpips = loss_lpips(grad1, grad2).mean() E_optimizer.zero_grad() loss_grad_lpips.backward(retain_graph=True) E_optimizer.step() grad1 = grad1.cuda().float() grad1.requires_grad = True grad2 = grad2.cuda().float() grad2.requires_grad = True loss_grad_mse_1 = loss_mse(grad1, grad2) loss_grad_mse_2 = loss_mse(grad1.mean(), grad2.mean()) loss_grad_mse_3 = loss_mse(mask_1.std(), mask_2.std()) loss_mask_mse = loss_mask_mse_1 + loss_mask_mse_2 + loss_mask_mse_3 ssim_value = pytorch_ssim.ssim(mask_1, mask_2) # while ssim_value<0.999: loss_mask_ssim = 1 - ssim_loss(mask_1, mask_2) loss_mask_lpips = loss_lpips(mask_1, mask_2).mean() mask1_kl, mask2_kl = torch.nn.functional.softmax( mask_1), torch.nn.functional.softmax(mask_2) loss_kl_mask = loss_kl(torch.log(mask2_kl), mask1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_kl_mask = torch.where(torch.isnan(loss_kl_w), torch.full_like(loss_kl_w, 0), loss_kl_w) loss_kl_mask = torch.where(torch.isinf(loss_kl_w), torch.full_like(loss_kl_w, 1), loss_kl_w) mask1_cos = mask1.view(-1) mask2_cos = mask2.view(-1) loss_cosine_w = 1 - mask1_cos.dot(mask2_cos) / ( torch.sqrt(mask1_cos.dot(mask1_cos)) * torch.sqrt(mask2_cos.dot(mask2_cos))) #[-1,1],-1:反向相反,1:方向相同 loss_mask = loss_mask_mse + loss_mask_ssim + loss_mask_lpips + loss_kl_mask + loss_cosine_w E_optimizer.zero_grad() loss_mask.backward(retain_graph=True) E_optimizer.step() #Image loss_img_mse = loss_mse(imgs1, imgs2) E_optimizer.zero_grad() loss_img_mse.backward(retain_graph=True) E_optimizer.step() ssim_value = pytorch_ssim.ssim(imgs1, imgs2) # while ssim_value<0.999: loss_img_ssim = 1 - ssim_loss(imgs1, imgs2) E_optimizer.zero_grad() loss_img_ssim.backward(retain_graph=True) E_optimizer.step() loss_img_lpips = loss_lpips(imgs1, imgs2).mean() E_optimizer.zero_grad() loss_img_lpips.backward(retain_graph=True) E_optimizer.step() #Latent Space # W loss_w = loss_mse(w1, w2) loss_w_m = loss_mse(w1.mean(), w2.mean()) #初期一会很大10,一会很小0.0001 loss_w_s = loss_mse(w1.std(), w2.std()) #后期一会很大,一会很小 w1_kl, w2_kl = torch.nn.functional.softmax( w1), torch.nn.functional.softmax(w2) loss_kl_w = loss_kl(torch.log(w2_kl), w1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_kl_w = torch.where(torch.isnan(loss_kl_w), torch.full_like(loss_kl_w, 0), loss_kl_w) loss_kl_w = torch.where(torch.isinf(loss_kl_w), torch.full_like(loss_kl_w, 1), loss_kl_w) w1_cos = w1.view(-1) w2_cos = w2.view(-1) loss_cosine_w = 1 - w1_cos.dot(w2_cos) / ( torch.sqrt(w1_cos.dot(w1_cos)) * torch.sqrt(w2_cos.dot(w2_cos)) ) #[-1,1],-1:反向相反,1:方向相同 # C loss_c = loss_mse( const1, const2) #没有这个const,梯度起初没法快速下降,很可能无法收敛, 这个惩罚即乘0.1后,效果大幅提升! loss_c_m = loss_mse(const1.mean(), const2.mean()) loss_c_s = loss_mse(const1.std(), const2.std()) y1, y2 = torch.nn.functional.softmax( const1), torch.nn.functional.softmax(const2) loss_kl_c = loss_kl(torch.log(y2), y1) loss_kl_c = torch.where(torch.isnan(loss_kl_c), torch.full_like(loss_kl_c, 0), loss_kl_c) loss_kl_c = torch.where(torch.isinf(loss_kl_c), torch.full_like(loss_kl_c, 1), loss_kl_c) c_cos1 = const1.view(-1) c_cos2 = const2.view(-1) loss_cosine_c = 1 - c_cos1.dot(c_cos2) / ( torch.sqrt(c_cos1.dot(c_cos1)) * torch.sqrt(c_cos2.dot(c_cos2))) loss_ls_all = loss_w+loss_w_m+loss_w_s+loss_kl_w+loss_cosine_w+\ loss_c+loss_c_m+loss_c_s+loss_kl_c+loss_cosine_c loss_ls_all.backward(retain_graph=True) E_optimizer.step() print('i_' + str(epoch)) print('---------ImageSpace--------') print('loss_mask_mse:' + str(loss_mask_mse.item()) + '--loss_mask_ssim:' + str(loss_mask_ssim.item()) + '--loss_mask_lpips:' + str(loss_mask_lpips.item())) print('loss_grad_mse:' + str(loss_grad_mse.item()) + '--loss_grad_ssim:' + str(loss_grad_ssim.item()) + '--loss_grad_lpips:' + str(loss_grad_lpips.item())) print('loss_img_mse:' + str(loss_img_mse.item()) + '--loss_img_ssim:' + str(loss_img_ssim.item()) + '--loss_img_lpips:' + str(loss_img_lpips.item())) print('---------LatentSpace--------') print('loss_w:' + str(loss_w.item()) + '--loss_w_m:' + str(loss_w_m.item()) + '--loss_w_s:' + str(loss_w_s.item())) print('loss_kl_w:' + str(loss_kl_w.item()) + '--loss_cosine_w:' + str(loss_cosine_w.item())) print('loss_c:' + str(loss_c.item()) + '--loss_c_m:' + str(loss_c_m.item()) + '--loss_c_s:' + str(loss_c_s.item())) print('loss_kl_c:' + str(loss_kl_c.item()) + '--loss_cosine_c:' + str(loss_cosine_c.item())) it_d += 1 writer.add_scalar('loss_mask_mse', loss_mask_mse, global_step=it_d) writer.add_scalar('loss_mask_ssim', loss_mask_ssim, global_step=it_d) writer.add_scalar('loss_mask_lpips', loss_mask_lpips, global_step=it_d) writer.add_scalar('loss_grad_mse', loss_grad_mse, global_step=it_d) writer.add_scalar('loss_grad_ssim', loss_grad_ssim, global_step=it_d) writer.add_scalar('loss_grad_lpips', loss_grad_lpips, global_step=it_d) writer.add_scalar('loss_img_mse', loss_img_mse, global_step=it_d) writer.add_scalar('loss_img_ssim', loss_img_ssim, global_step=it_d) writer.add_scalar('loss_img_lpips', loss_img_lpips, global_step=it_d) writer.add_scalar('loss_w', loss_w, global_step=it_d) writer.add_scalar('loss_w_m', loss_w_m, global_step=it_d) writer.add_scalar('loss_w_s', loss_w_s, global_step=it_d) writer.add_scalar('loss_kl_w', loss_kl_w, global_step=it_d) writer.add_scalar('loss_cosine_w', loss_cosine_w, global_step=it_d) writer.add_scalar('loss_c', loss_c, global_step=it_d) writer.add_scalar('loss_c_m', loss_c_m, global_step=it_d) writer.add_scalar('loss_c_s', loss_c_s, global_step=it_d) writer.add_scalar('loss_kl_c', loss_kl_c, global_step=it_d) writer.add_scalar('loss_cosine_c', loss_cosine_c, global_step=it_d) writer.add_scalars('Image_Space', { 'loss_mask_mse': loss_mask_mse, 'loss_grad_mse': loss_grad_mse, 'loss_img_mse': loss_img_mse }, global_step=it_d) writer.add_scalars('Image_Space', { 'loss_mask_ssim': loss_mask_mse, 'loss_grad_ssim': loss_grad_ssim, 'loss_img_ssim': loss_img_ssim }, global_step=it_d) writer.add_scalars('Image_Space', { 'loss_mask_lpips': loss_mask_lpips, 'loss_grad_lpips': loss_grad_lpips, 'loss_img_lpips': loss_img_lpips }, global_step=it_d) writer.add_scalars('Latent Space W', { 'loss_w': loss_w, 'loss_w_m': loss_w_m, 'loss_w_s': loss_w_s, 'loss_kl_w': loss_kl_w, 'loss_cosine_w': loss_cosine_w }, global_step=it_d) writer.add_scalars('Latent Space C', { 'loss_c': loss_c, 'loss_c_m': loss_c_m, 'loss_c_s': loss_c_s, 'loss_kl_c': loss_kl_c, 'loss_cosine_c': loss_cosine_c }, global_step=it_d) if epoch % 100 == 0: n_row = batch_size test_img = torch.cat((imgs1[:n_row], imgs2[:n_row])) * 0.5 + 0.5 torchvision.utils.save_image(test_img, resultPath1_1 + '/ep%d.png' % (epoch), nrow=n_row) # nrow=3 heatmap1, cam1 = mask2cam(mask_1, imgs1) heatmap2, cam2 = mask2cam(mask_2, imgs2) heatmap = torch.cat((heatmap1, heatmap1)) cam = torch.cat((cam1, cam2)) grads = torch.cat((grad1, grad2)) grads = grads.data.cpu().numpy() # [n,c,h,w] grads -= np.max(np.min(grads), 0) grads /= np.max(grads) torchvision.utils.save_image( torch.tensor(heatmap), resultPath_grad_cam + '/heatmap_%d.png' % (epoch)) torchvision.utils.save_image( torch.tensor(cam), resultPath_grad_cam + '/cam_%d.png' % (epoch)) torchvision.utils.save_image( torch.tensor(grads), resultPath_grad_cam + '/gb_%d.png' % (epoch)) with open(resultPath + '/Loss.txt', 'a+') as f: print('i_' + str(epoch), file=f) print('---------ImageSpace--------', file=f) print('loss_mask_mse:' + str(loss_mask_mse.item()) + '--loss_mask_ssim:' + str(loss_mask_ssim.item()) + '--loss_mask_lpips:' + str(loss_mask_lpips.item()), file=f) print('loss_grad_mse:' + str(loss_grad_mse.item()) + '--loss_grad_ssim:' + str(loss_grad_ssim.item()) + '--loss_grad_lpips:' + str(loss_grad_lpips.item()), file=f) print('loss_img_mse:' + str(loss_img_mse.item()) + '--loss_img_ssim:' + str(loss_img_ssim.item()) + '--loss_img_lpips:' + str(loss_img_lpips.item()), file=f) print('---------LatentSpace--------', file=f) print('loss_w:' + str(loss_w.item()) + '--loss_w_m:' + str(loss_w_m.item()) + '--loss_w_s:' + str(loss_w_s.item()), file=f) print('loss_kl_w:' + str(loss_kl_w.item()) + '--loss_cosine_w:' + str(loss_cosine_w.item()), file=f) print('loss_c:' + str(loss_c.item()) + '--loss_c_m:' + str(loss_c_m.item()) + '--loss_c_s:' + str(loss_c_s.item()), file=f) print('loss_kl_c:' + str(loss_kl_c.item()) + '--loss_cosine_c:' + str(loss_cosine_c.item()), file=f) if epoch % 5000 == 0: torch.save(E.state_dict(), resultPath1_2 + '/E_model_ep%d.pth' % epoch)
def train(avg_tensor = None, coefs=0): Gs = Generator(startf=16, maxf=512, layer_count=9, latent_size=512, channels=3) Gs.load_state_dict(torch.load('./pre-model/Gs_dict.pth')) Gm = Mapping(num_layers=18, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) Gm.load_state_dict(torch.load('./pre-model/Gm_dict.pth')) #Gm.requires_grad_(False) #Gs.requires_grad_(False) Gm.buffer1 = avg_tensor E = BE.BE() E.load_state_dict(torch.load('/_yucheng/myStyle/myStyle-v1/result/EB_V8_newE_noEw_Ebias_2/models/E_model_ep65000.pth'),strict=False) # model_dict = E.state_dict() # pretrained_dict = torch.load('/_yucheng/myStyle/myStyle-v1/result/EB_V6_3ImgLoss_Res0618_truncW_noUpgradeW/models/E_model_ep15000.pth') # for k,v in model_dict.items(): # if 'decode_block.0.noise_weight_2' in k: # pretrained_dict.pop(k) # if 'decode_block.1.noise_weight_2' in k: # pretrained_dict.pop(k) # if 'decode_block.2.noise_weight_2' in k: # pretrained_dict.pop(k) # if 'decode_block.3.noise_weight_2' in k: # pretrained_dict.pop(k) # if 'decode_block.4.noise_weight_2' in k: # pretrained_dict.pop(k) # if 'decode_block.0.bias_2' in k: # pretrained_dict.pop(k) # if 'decode_block.1.bias_2' in k: # pretrained_dict.pop(k) # if 'decode_block.2.bias_2' in k: # pretrained_dict.pop(k) # if 'decode_block.3.bias_2' in k: # pretrained_dict.pop(k) # if 'decode_block.4.bias_2' in k: # pretrained_dict.pop(k) # model_dict.update(pretrained_dict) # E.load_state_dict(model_dict,strict=False) Gs.cuda() #Gm.cuda() E.cuda() const_ = Gs.const E_optimizer = LREQAdam([{'params': E.parameters()},], lr=0.0015, betas=(0.0, 0.99), weight_decay=0) loss_all=0 loss_mse = torch.nn.MSELoss() loss_lpips = lpips.LPIPS(net='vgg').to('cuda') loss_kl = torch.nn.KLDivLoss() batch_size = 3 const1 = const_.repeat(batch_size,1,1,1) for epoch in range(0,250001): set_seed(epoch%25000) latents = torch.randn(batch_size, 512) #[32, 512] w1 = Gm(latents,coefs_m=coefs).to('cuda') #[batch_size,18,512] with torch.no_grad(): #这里需要生成图片和变量 imgs1 = Gs.forward(w1,8) const2,w2 = E(imgs1.cuda()) imgs2=Gs.forward(w2,8) E_optimizer.zero_grad() #loss1 #loss_img_mse = loss_mse(imgs1,imgs2) loss_img_mse_c1 = loss_mse(imgs1[:,0],imgs2[:,0]) loss_img_mse_c2 = loss_mse(imgs1[:,1],imgs2[:,1]) loss_img_mse_c3 = loss_mse(imgs1[:,2],imgs2[:,2]) loss_img_mse = max(loss_img_mse_c1,loss_img_mse_c2,loss_img_mse_c3) imgs1_ = F.avg_pool2d(imgs1,2,2) imgs2_ = F.avg_pool2d(imgs2,2,2) imgs1__ = F.avg_pool2d(imgs1_,2,2) imgs2__ = F.avg_pool2d(imgs2_,2,2) loss_img_lpips = loss_lpips(imgs1__,imgs2__).mean() y1_imgs, y2_imgs = torch.nn.functional.softmax(imgs1_),torch.nn.functional.softmax(imgs2_) loss_kl_img = loss_kl(torch.log(y2_imgs),y1_imgs) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_kl_img = torch.where(torch.isnan(loss_kl_img),torch.full_like(loss_kl_img,0), loss_kl_img) loss_kl_img = torch.where(torch.isinf(loss_kl_img),torch.full_like(loss_kl_img,1), loss_kl_img) loss_1 = 13*loss_img_mse+ 5*loss_img_lpips + loss_kl_img loss_1.backward(retain_graph=True) E_optimizer.step() #loss2 imgs_column1 = imgs1[:,:,:,128:-128] imgs_column2 = imgs2[:,:,:,128:-128] loss_img_mse_column = loss_mse(imgs_column1,imgs_column2) imgs1_column_down = F.avg_pool2d(imgs_column1,2,2) imgs2_column_down = F.avg_pool2d(imgs_column2,2,2) loss_img_lpips_column = loss_lpips(imgs1_column_down,imgs2_column_down).mean() loss_2 = 19*loss_img_mse_column +7*loss_img_lpips_column loss_2.backward(retain_graph=True) E_optimizer.step() #loss3 imgs_center1 = imgs1[:,:,128:640,256:-256] imgs_center2 = imgs2[:,:,128:640,256:-256] loss_img_mse_center = loss_mse(imgs_center1,imgs_center2) imgs1_c_down = F.avg_pool2d(imgs_center1,2,2) imgs2_c_down = F.avg_pool2d(imgs_center2,2,2) loss_img_lpips_center = loss_lpips(imgs1_c_down,imgs2_c_down).mean() imgs_blob1 = imgs1[:,:,924:,924:] imgs_blob2 = imgs2[:,:,924:,924:] loss_img_mse_blob = loss_mse(imgs_blob1,imgs_blob2) loss_3 = 23*loss_img_mse_center +11*loss_img_lpips_center + loss_img_mse_blob loss_3.backward(retain_graph=True) #loss_x = loss_1+loss_2+loss_3 #loss_x.backward(retain_graph=True) E_optimizer.step() #loss4 loss_c = loss_mse(const1,const2) #没有这个const,梯度起初没法快速下降,很可能无法收敛, 这个惩罚即乘0.1后,效果大幅提升! loss_c_m = loss_mse(const1.mean(),const2.mean()) loss_c_s = loss_mse(const1.std(),const2.std()) loss_w = loss_mse(w1,w2) loss_w_m = loss_mse(w1.mean(),w2.mean()) #初期一会很大10,一会很小0.0001 loss_w_s = loss_mse(w1.std(),w2.std()) #后期一会很大,一会很小 y1, y2 = torch.nn.functional.softmax(const1),torch.nn.functional.softmax(const2) loss_kl_c = loss_kl(torch.log(y2),y1) loss_kl_c = torch.where(torch.isnan(loss_kl_c),torch.full_like(loss_kl_c,0), loss_kl_c) loss_kl_c = torch.where(torch.isinf(loss_kl_c),torch.full_like(loss_kl_c,1), loss_kl_c) w1_kl, w2_kl = torch.nn.functional.softmax(w1),torch.nn.functional.softmax(w2) loss_kl_w = loss_kl(torch.log(w2_kl),w1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_kl_w = torch.where(torch.isnan(loss_kl_w),torch.full_like(loss_kl_w,0), loss_kl_w) loss_kl_w = torch.where(torch.isinf(loss_kl_w),torch.full_like(loss_kl_w,1), loss_kl_w) loss_4 = 0.02*loss_c+0.03*loss_c_m+0.03*loss_c_s+0.02*loss_w+0.03*loss_w_m+0.03*loss_w_s+ loss_kl_w + loss_kl_c loss_4.backward(retain_graph=True) E_optimizer.step() loss_all = loss_1 + loss_4 + loss_2 + loss_3 print('i_'+str(epoch)+'--loss_all__:'+str(loss_all.item())+'--loss_mse:'+str(loss_img_mse.item())+'--loss_lpips:'+str(loss_img_lpips.item())+'--loss_kl_img:'+str(loss_kl_img.item())) print('loss_img_mse_column:'+str(loss_img_mse_column.item())+'loss_img_lpips_column:'+str(loss_img_lpips_column.item())\ +'--loss_img_mse_center:'+str(loss_img_mse_center.item())+'--loss_lpips_center:'+str(loss_img_lpips_center.item())) print('loss_w:'+str(loss_w.item())+'--loss_w_m:'+str(loss_w_m.item())+'--loss_w_s:'+str(loss_w_s.item())+'--loss_kl_w:'+str(loss_kl_w.item())\ +'--loss_c:'+str(loss_c.item())+'--loss_c_m:'+str(loss_c_m.item())+'--loss_c_s:'+str(loss_c_s.item())+'--loss_kl_c:'+str(loss_kl_c.item())+'--loss_img_blob_mse:'+str(loss_img_mse_blob.item())) print('-') if epoch % 100 == 0: test_img = torch.cat((imgs1[:3],imgs2[:3]))*0.5+0.5 torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d.jpg'%(epoch),nrow=3) # nrow=3 with open(resultPath+'/Loss.txt', 'a+') as f: print('i_'+str(epoch)+'--loss_all__:'+str(loss_all.item())+'--loss_mse:'+str(loss_img_mse.item())+'--loss_lpips:'+str(loss_img_lpips.item())+'--loss_kl_img:'+str(loss_kl_img.item()),file=f) print('loss_img_mse_column:'+str(loss_img_mse_column.item())+'loss_img_lpips_column:'+str(loss_img_lpips_column.item())\ +'--loss_img_mse_center:'+str(loss_img_mse_center.item())+'--loss_lpips_center:'+str(loss_img_lpips_center.item()),file=f) print('loss_w:'+str(loss_w.item())+'--loss_w_m:'+str(loss_w_m.item())+'--loss_w_s:'+str(loss_w_s.item())+'--loss_kl_w:'+str(loss_kl_w.item())+'--loss_c:'+str(loss_c.item())\ +'--loss_c_m:'+str(loss_c_m.item())+'--loss_c_s:'+str(loss_c_s.item())+'--loss_kl_c:'+str(loss_kl_c.item())+'--loss_img_blob_mse:'+str(loss_img_mse_blob.item()),file=f) if epoch % 5000 == 0: torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d.pth'%epoch)
def train(avg_tensor = None, coefs=0): Gs = Generator(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) # 32->512 layer_count=8 / 64->256 layer_count=7 Gs.load_state_dict(torch.load('./pre-model/cat/cat256_Gs_dict.pth')) Gm = Mapping(num_layers=14, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 Gm.load_state_dict(torch.load('./pre-model/cat/cat256_Gm_dict.pth')) Gm.buffer1 = avg_tensor E = BE.BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) #E.load_state_dict(torch.load('/_yucheng/myStyle/EAE/result/EB_cars_v1/models/E_model_ep135000.pth')) #E.load_state_dict(torch.load('/_yucheng/myStyle/EAE/result/EB_cat_v1/models/E_model_ep165000.pth')) Gs.cuda() #Gm.cuda() E.cuda() const_ = Gs.const E_optimizer = LREQAdam([{'params': E.parameters()},], lr=0.0015, betas=(0.0, 0.99), weight_decay=0) loss_all=0 loss_mse = torch.nn.MSELoss() loss_lpips = lpips.LPIPS(net='vgg').to('cuda') loss_kl = torch.nn.KLDivLoss() batch_size = 5 const1 = const_.repeat(batch_size,1,1,1) for epoch in range(0,250001): set_seed(epoch%30000) latents = torch.randn(batch_size, 512) #[32, 512] with torch.no_grad(): #这里需要生成图片和变量 w1 = Gm(latents,coefs_m=coefs).to('cuda') #[batch_size,18,512] imgs1 = Gs.forward(w1,6) # 7->512 / 6->256 const2,w2 = E(imgs1.cuda()) imgs2=Gs.forward(w2,6) E_optimizer.zero_grad() #loss1 loss_img_mse = loss_mse(imgs1,imgs2) # loss_img_mse_c1 = loss_mse(imgs1[:,0],imgs2[:,0]) # loss_img_mse_c2 = loss_mse(imgs1[:,1],imgs2[:,1]) # loss_img_mse_c3 = loss_mse(imgs1[:,2],imgs2[:,2]) # loss_img_mse = max(loss_img_mse_c1,loss_img_mse_c2,loss_img_mse_c3) loss_img_lpips = loss_lpips(imgs1,imgs2).mean() y1_imgs, y2_imgs = torch.nn.functional.softmax(imgs1),torch.nn.functional.softmax(imgs2) loss_kl_img = loss_kl(torch.log(y2_imgs),y1_imgs) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_kl_img = torch.where(torch.isnan(loss_kl_img),torch.full_like(loss_kl_img,0), loss_kl_img) loss_kl_img = torch.where(torch.isinf(loss_kl_img),torch.full_like(loss_kl_img,1), loss_kl_img) loss_1 = 17*loss_img_mse + 5*loss_img_lpips + loss_kl_img loss_1.backward(retain_graph=True) E_optimizer.step() #loss2 中等区域 #imgs_column1 = imgs1[:,:,imgs1.shape[2]//20:-imgs1.shape[2]//20,imgs1.shape[3]//20:-imgs1.shape[3]//20] # w,h #imgs_column2 = imgs2[:,:,imgs2.shape[2]//20:-imgs2.shape[2]//20,imgs2.shape[3]//20:-imgs2.shape[3]//20] #loss_img_mse_column = loss_mse(imgs_column1,imgs_column2) #loss_img_lpips_column = loss_lpips(imgs_column1,imgs_column2).mean() # loss_2 = 5*loss_img_mse_column + 3*loss_img_lpips_column # loss_2.backward(retain_graph=True) # E_optimizer.step() #loss3 最小区域 #imgs_center1 = imgs1[:,:,imgs1.shape[2]//10:-imgs1.shape[2]//10,imgs1.shape[3]//10:-imgs1.shape[3]//10] #imgs_center2 = imgs2[:,:,imgs2.shape[2]//10:-imgs2.shape[2]//10,imgs2.shape[3]//10:-imgs2.shape[3]//10] #loss_img_mse_center = loss_mse(imgs_center1,imgs_center2) #loss_img_lpips_center = loss_lpips(imgs_center1,imgs_center2).mean() # imgs_blob1 = imgs1[:,:,924:,924:] # imgs_blob2 = imgs2[:,:,924:,924:] # loss_img_mse_blob = loss_mse(imgs_blob1,imgs_blob2) #loss_3 = 3*loss_img_mse_center + loss_img_lpips_center #+ loss_img_mse_blob #loss_3.backward(retain_graph=True) #loss_x = loss_1+loss_2+loss_3 #loss_x.backward(retain_graph=True) #E_optimizer.step() #loss3_v2, cosine相似性 i1 = imgs1.view(-1) i2 = imgs2.view(-1) loss_cosine_i = i1.dot(i2)/(torch.sqrt(i1.dot(i1))*torch.sqrt(i2.dot(i2))) #loss_cosine_w = w1.dot(w2)/(torch.sqrt(w1.dot(w1))*torch.sqrt(w2.dot(w2))) #loss4 loss_c = loss_mse(const1,const2) #没有这个const,梯度起初没法快速下降,很可能无法收敛, 这个惩罚即乘0.1后,效果大幅提升! loss_c_m = loss_mse(const1.mean(),const2.mean()) loss_c_s = loss_mse(const1.std(),const2.std()) loss_w = loss_mse(w1,w2) loss_w_m = loss_mse(w1.mean(),w2.mean()) #初期一会很大10,一会很小0.0001 loss_w_s = loss_mse(w1.std(),w2.std()) #后期一会很大,一会很小 y1, y2 = torch.nn.functional.softmax(const1),torch.nn.functional.softmax(const2) loss_kl_c = loss_kl(torch.log(y2),y1) loss_kl_c = torch.where(torch.isnan(loss_kl_c),torch.full_like(loss_kl_c,0), loss_kl_c) loss_kl_c = torch.where(torch.isinf(loss_kl_c),torch.full_like(loss_kl_c,1), loss_kl_c) w1_kl, w2_kl = torch.nn.functional.softmax(w1),torch.nn.functional.softmax(w2) loss_kl_w = loss_kl(torch.log(w2_kl),w1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs) loss_kl_w = torch.where(torch.isnan(loss_kl_w),torch.full_like(loss_kl_w,0), loss_kl_w) loss_kl_w = torch.where(torch.isinf(loss_kl_w),torch.full_like(loss_kl_w,1), loss_kl_w) w1_cos = w1.view(-1) w2_cos = w2.view(-1) loss_cosine_w = w1_cos.dot(w2_cos)/(torch.sqrt(w1_cos.dot(w1_cos))*torch.sqrt(w1_cos.dot(w1_cos))) loss_4 = 0.02*loss_c+0.03*loss_c_m+0.03*loss_c_s+0.02*loss_w+0.03*loss_w_m+0.03*loss_w_s+ loss_kl_w + loss_kl_c+loss_cosine_i loss_4.backward(retain_graph=True) E_optimizer.step() loss_all = loss_1 + loss_4 + loss_cosine_i #loss_2 + loss_3 print('i_'+str(epoch)+'--loss_all__:'+str(loss_all.item())+'--loss_mse:'+str(loss_img_mse.item())+'--loss_lpips:'+str(loss_img_lpips.item())+'--loss_kl_img:'+str(loss_kl_img.item())+'--loss_cosine_i:'+str(loss_cosine_i.item())) #print('loss_img_mse_column:'+str(loss_img_mse_column.item())+'loss_img_lpips_column:'+str(loss_img_lpips_column.item())+'--loss_img_mse_center:'+str(loss_img_mse_center.item())+'--loss_lpips_center:'+str(loss_img_lpips_center.item())) print('loss_w:'+str(loss_w.item())+'--loss_w_m:'+str(loss_w_m.item())+'--loss_w_s:'+str(loss_w_s.item())+'--loss_kl_w:'+str(loss_kl_w.item())+'--loss_c:'+str(loss_c.item())+'--loss_c_m:'+str(loss_c_m.item())+'--loss_c_s:'+str(loss_c_s.item())+'--loss_kl_c:'+str(loss_kl_c.item())+'--loss_cosine_w:'+str(loss_cosine_w.item())) print('-') if epoch % 100 == 0: n_row = batch_size test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5 torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d.jpg'%(epoch),nrow=n_row) # nrow=3 with open(resultPath+'/Loss.txt', 'a+') as f: print('i_'+str(epoch)+'--loss_all__:'+str(loss_all.item())+'--loss_mse:'+str(loss_img_mse.item())+'--loss_lpips:'+str(loss_img_lpips.item())+'--loss_kl_img:'+str(loss_kl_img.item())+'--loss_cosine_i:'+str(loss_cosine_i.item()),file=f) #print('loss_img_mse_column:'+str(loss_img_mse_column.item())+'loss_img_lpips_column:'+str(loss_img_lpips_column.item())+'--loss_img_mse_center:'+str(loss_img_mse_center.item())+'--loss_lpips_center:'+str(loss_img_lpips_center.item()),file=f) print('loss_w:'+str(loss_w.item())+'--loss_w_m:'+str(loss_w_m.item())+'--loss_w_s:'+str(loss_w_s.item())+'--loss_kl_w:'+str(loss_kl_w.item())+'--loss_c:'+str(loss_c.item())+'--loss_c_m:'+str(loss_c_m.item())+'--loss_c_s:'+str(loss_c_s.item())+'--loss_kl_c:'+str(loss_kl_c.item())+'--loss_cosine_w:'+str(loss_cosine_w.item()),file=f) if epoch % 5000 == 0: torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d.pth'%epoch)