コード例 #1
0
def train(avg_tensor=None, coefs=0, tensor_writer=None):
    Gs = Generator(startf=64,
                   maxf=512,
                   layer_count=7,
                   latent_size=512,
                   channels=3)  # 32->512 layer_count=8 / 64->256 layer_count=7
    #Gs.load_state_dict(torch.load('./pre-model/cat/cat256_Gs_dict.pth')) #256*256 cat
    Gs.load_state_dict(
        torch.load(
            './pre-model/bedroom/bedrooms256_Gs_dict.pth'))  # 256*256 bedroom
    Gm = Mapping(num_layers=14,
                 mapping_layers=8,
                 latent_size=512,
                 dlatent_size=512,
                 mapping_fmaps=512)  #num_layers: 14->256 / 16->512 / 18->1024
    #Gm.load_state_dict(torch.load('./pre-model/cat/cat256_Gm_dict.pth'))
    Gm.load_state_dict(
        torch.load(
            './pre-model/bedroom/bedrooms256_Gm_dict.pth'))  # 256*256 bedroom
    Gm.buffer1 = avg_tensor
    E = BE.BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3)
    #E.load_state_dict(torch.load('/_yucheng/myStyle/myStyle-v1/EAE-car-cat/result/EB_cat_cosine_v2/E_model_ep80000.pth'))
    Gs.cuda()
    #Gm.cuda()
    E.cuda()
    const_ = Gs.const
    writer = tensor_writer

    E_optimizer = LREQAdam([
        {
            'params': E.parameters()
        },
    ],
                           lr=0.0015,
                           betas=(0.0, 0.99),
                           weight_decay=0)
    #用这个adam不会报错:RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
    loss_lpips = lpips.LPIPS(net='vgg').to('cuda')

    batch_size = 3
    const1 = const_.repeat(batch_size, 1, 1, 1)
    it_d = 0
    for epoch in range(0, 250001):
        set_seed(epoch % 30000)
        latents = torch.randn(batch_size, 512)  #[32, 512]
        with torch.no_grad():  #这里需要生成图片和变量
            w1 = Gm(latents, coefs_m=coefs).to('cuda')  #[batch_size,18,512]
            imgs1 = Gs.forward(w1, 6)  # 7->512 / 6->256

        const2, w2 = E(imgs1.cuda())

        imgs2 = Gs.forward(w2, 6)

        E_optimizer.zero_grad()

        #Latent_space

        ## c
        loss_c, loss_c_info = space_loss(const1, const2, image_space=False)
        E_optimizer.zero_grad()
        loss_c.backward(retain_graph=True)
        E_optimizer.step()

        ## w
        loss_w, loss_w_info = space_loss(w1, w2, image_space=False)
        E_optimizer.zero_grad()
        loss_w.backward(retain_graph=True)
        E_optimizer.step()

        #Image Space

        ##loss1 最小区域
        imgs_small_1 = imgs1[:, :, imgs1.shape[2] // 20:-imgs1.shape[2] // 20,
                             imgs1.shape[3] // 20:-imgs1.shape[3] //
                             20].clone()  # w,h
        imgs_small_2 = imgs2[:, :, imgs2.shape[2] // 20:-imgs2.shape[2] // 20,
                             imgs2.shape[3] // 20:-imgs2.shape[3] //
                             20].clone()
        loss_small, loss_small_info = space_loss(imgs_small_1,
                                                 imgs_small_2,
                                                 lpips_model=loss_lpips)
        E_optimizer.zero_grad()
        loss_small.backward(retain_graph=True)
        E_optimizer.step()

        #loss2 中等区域
        imgs_medium_1 = imgs1[:, :, imgs1.shape[2] // 10:-imgs1.shape[2] // 10,
                              imgs1.shape[3] // 10:-imgs1.shape[3] //
                              10].clone()
        imgs_medium_2 = imgs2[:, :, imgs2.shape[2] // 10:-imgs2.shape[2] // 10,
                              imgs2.shape[3] // 10:-imgs2.shape[3] //
                              10].clone()
        loss_medium, loss_medium_info = space_loss(imgs_medium_1,
                                                   imgs_medium_2,
                                                   lpips_model=loss_lpips)
        E_optimizer.zero_grad()
        loss_medium.backward(retain_graph=True)
        E_optimizer.step()

        #loss3 原图区域
        loss_imgs, loss_imgs_info = space_loss(imgs1,
                                               imgs2,
                                               lpips_model=loss_lpips)
        E_optimizer.zero_grad()
        loss_imgs.backward(retain_graph=True)
        E_optimizer.step()

        print('i_' + str(epoch))
        print(
            '[loss_imgs_mse[img,img_mean,img_std], loss_imgs_ssim, loss_imgs_cosine, loss_kl_imgs, loss_imgs_lpips]'
        )
        print('---------ImageSpace--------')
        print('loss_small_info: %s' % loss_small_info)
        print('loss_medium_info: %s' % loss_medium_info)
        print('loss_imgs_info: %s' % loss_imgs_info)
        print('---------LatentSpace--------')
        print('loss_w_info: %s' % loss_w_info)
        print('loss_c_info: %s' % loss_c_info)

        it_d += 1
        writer.add_scalar('loss_small_mse',
                          loss_small_info[0][0],
                          global_step=it_d)
        writer.add_scalar('loss_samll_mse_mean',
                          loss_small_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_samll_mse_std',
                          loss_small_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_samll_kl',
                          loss_small_info[1],
                          global_step=it_d)
        writer.add_scalar('loss_samll_cosine',
                          loss_small_info[2],
                          global_step=it_d)
        writer.add_scalar('loss_samll_ssim',
                          loss_small_info[3],
                          global_step=it_d)
        writer.add_scalar('loss_samll_lpips',
                          loss_small_info[4],
                          global_step=it_d)

        writer.add_scalar('loss_medium_mse',
                          loss_medium_info[0][0],
                          global_step=it_d)
        writer.add_scalar('loss_medium_mse_mean',
                          loss_medium_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_medium_mse_std',
                          loss_medium_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_medium_kl',
                          loss_medium_info[1],
                          global_step=it_d)
        writer.add_scalar('loss_medium_cosine',
                          loss_medium_info[2],
                          global_step=it_d)
        writer.add_scalar('loss_medium_ssim',
                          loss_medium_info[3],
                          global_step=it_d)
        writer.add_scalar('loss_medium_lpips',
                          loss_medium_info[4],
                          global_step=it_d)

        writer.add_scalar('loss_imgs_mse',
                          loss_imgs_info[0][0],
                          global_step=it_d)
        writer.add_scalar('loss_imgs_mse_mean',
                          loss_imgs_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_imgs_mse_std',
                          loss_imgs_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d)
        writer.add_scalar('loss_imgs_cosine',
                          loss_imgs_info[2],
                          global_step=it_d)
        writer.add_scalar('loss_imgs_ssim',
                          loss_imgs_info[3],
                          global_step=it_d)
        writer.add_scalar('loss_imgs_lpips',
                          loss_imgs_info[4],
                          global_step=it_d)

        writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d)
        writer.add_scalar('loss_w_mse_mean',
                          loss_w_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_w_mse_std',
                          loss_w_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d)
        writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d)
        writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d)
        writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d)

        writer.add_scalar('loss_c_mse', loss_c_info[0][0], global_step=it_d)
        writer.add_scalar('loss_c_mse_mean',
                          loss_c_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_c_mse_std',
                          loss_c_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_c_kl', loss_c_info[1], global_step=it_d)
        writer.add_scalar('loss_c_cosine', loss_c_info[2], global_step=it_d)
        writer.add_scalar('loss_c_ssim', loss_c_info[3], global_step=it_d)
        writer.add_scalar('loss_c_lpips', loss_c_info[4], global_step=it_d)

        writer.add_scalars('Image_Space_MSE', {
            'loss_small_mse': loss_small_info[0][0],
            'loss_medium_mse': loss_medium_info[0][0],
            'loss_img_mse': loss_imgs_info[0][0]
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space_KL', {
            'loss_small_kl': loss_small_info[1],
            'loss_medium_kl': loss_medium_info[1],
            'loss_imgs_kl': loss_imgs_info[1]
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space_Cosine', {
            'loss_samll_cosine': loss_small_info[2],
            'loss_medium_cosine': loss_medium_info[2],
            'loss_imgs_cosine': loss_imgs_info[2]
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space_SSIM', {
            'loss_small_ssim': loss_small_info[3],
            'loss_medium_ssim': loss_medium_info[3],
            'loss_img_ssim': loss_imgs_info[3]
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space_Lpips', {
            'loss_small_lpips': loss_small_info[4],
            'loss_medium_lpips': loss_medium_info[4],
            'loss_img_lpips': loss_imgs_info[4]
        },
                           global_step=it_d)
        writer.add_scalars('Latent Space W', {
            'loss_w_mse': loss_w_info[0][0],
            'loss_w_mse_mean': loss_w_info[0][1],
            'loss_w_mse_std': loss_w_info[0][2],
            'loss_w_kl': loss_w_info[1],
            'loss_w_cosine': loss_w_info[2]
        },
                           global_step=it_d)
        writer.add_scalars('Latent Space C', {
            'loss_c_mse': loss_c_info[0][0],
            'loss_c_mse_mean': loss_c_info[0][1],
            'loss_c_mse_std': loss_c_info[0][2],
            'loss_c_kl': loss_c_info[1],
            'loss_c_cosine': loss_c_info[2]
        },
                           global_step=it_d)

        if epoch % 100 == 0:
            n_row = batch_size
            test_img = torch.cat((imgs1[:n_row], imgs2[:n_row])) * 0.5 + 0.5
            torchvision.utils.save_image(test_img,
                                         resultPath1_1 + '/ep%d.jpg' % (epoch),
                                         nrow=n_row)  # nrow=3
            with open(resultPath + '/Loss.txt', 'a+') as f:
                print('i_' + str(epoch), file=f)
                print(
                    '[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',
                    file=f)
                print('---------ImageSpace--------', file=f)
                print('loss_small_info: %s' % loss_small_info, file=f)
                print('loss_medium_info: %s' % loss_medium_info, file=f)
                print('loss_imgs_info: %s' % loss_imgs_info, file=f)
                print('---------LatentSpace--------', file=f)
                print('loss_w_info: %s' % loss_w_info, file=f)
                print('loss_c_info: %s' % loss_c_info, file=f)
            if epoch % 5000 == 0:
                torch.save(E.state_dict(),
                           resultPath1_2 + '/E_model_ep%d.pth' % epoch)
コード例 #2
0
ファイル: EAE_V2_512.py プロジェクト: disanda/EAE
def train(avg_tensor=None, coefs=0, tensor_writer=None):
    Gs = Generator(
        startf=32, maxf=512, layer_count=8, latent_size=512, channels=3
    )  # cats: stratf 32->512 layer_count=8 / cat: startf 64->256 layer_count=7
    #Gs.load_state_dict(torch.load('./pre-model/cat/cat256_Gs_dict.pth'))
    Gs.load_state_dict(torch.load('./pre-model/cars/cars512_Gs_dict.pth'))
    Gm = Mapping(num_layers=16,
                 mapping_layers=8,
                 latent_size=512,
                 dlatent_size=512,
                 mapping_fmaps=512)  #num_layers: 14->256 / 16->512 / 18->1024
    #Gm.load_state_dict(torch.load('./pre-model/cat/cat256_Gm_dict.pth'))
    Gm.load_state_dict(torch.load('./pre-model/cars/cars512_Gm_dict.pth'))
    Gm.buffer1 = avg_tensor
    E = BE.BE(startf=32, maxf=512, layer_count=8, latent_size=512, channels=3)
    #E.load_state_dict(torch.load('/_yucheng/myStyle/myStyle-v1/EAE-car-cat/pre-model/E_cat_v2_1_ep100000.pth'))
    E.load_state_dict(
        torch.load(
            '/_wmwang/mystyle/myStyle1/EAE/pre-model/E_cars512_ep90000_v1.pth')
    )  # cars
    Gs.cuda()
    #Gm.cuda()
    E.cuda()
    const_ = Gs.const
    writer = tensor_writer

    E_optimizer = LREQAdam([
        {
            'params': E.parameters()
        },
    ],
                           lr=0.0015,
                           betas=(0.0, 0.99),
                           weight_decay=0)
    loss_lpips = lpips.LPIPS(net='vgg').to('cuda')

    batch_size = 4
    const1 = const_.repeat(batch_size, 1, 1, 1)

    vgg16 = torchvision.models.vgg16(pretrained=True).cuda()
    final_layer = None
    for name, m in vgg16.named_modules():
        if isinstance(m, nn.Conv2d):
            final_layer = name
    grad_cam_plus_plus = GradCamPlusPlus(vgg16, final_layer)
    gbp = GuidedBackPropagation(vgg16)

    it_d = 0
    for epoch in range(0, 250001):
        set_seed(epoch % 30000)
        latents = torch.randn(batch_size, 512)  #[32, 512]
        with torch.no_grad():  #这里需要生成图片和变量
            w1 = Gm(latents, coefs_m=coefs).to('cuda')  #[batch_size,18,512]
            imgs1 = Gs.forward(w1, 7)  # 7->512 / 6->256

        const2, w2 = E(imgs1.cuda())

        imgs2 = Gs.forward(w2, 7)

        E_optimizer.zero_grad()

        #Latent Space

        ##--C
        loss_c, loss_c_info = space_loss(const1, const2, image_space=False)
        E_optimizer.zero_grad()
        loss_c.backward(retain_graph=True)
        E_optimizer.step()

        ##--W
        loss_w, loss_w_info = space_loss(w1, w2, image_space=False)
        E_optimizer.zero_grad()
        loss_w.backward(retain_graph=True)
        E_optimizer.step()

        #Image Space
        mask_1 = grad_cam_plus_plus(imgs1, None)  #[c,1,h,w]
        mask_2 = grad_cam_plus_plus(imgs2, None)
        #imgs1.retain_grad()
        #imgs2.retain_grad()
        imgs1_ = imgs1.detach().clone()
        imgs1_.requires_grad = True
        imgs2_ = imgs2.detach().clone()
        imgs2_.requires_grad = True
        grad_1 = gbp(imgs1_)  # [n,c,h,w]
        grad_2 = gbp(imgs2_)
        heatmap_1, cam_1 = mask2cam(mask_1, imgs1)
        heatmap_2, cam_2 = mask2cam(mask_2, imgs2)

        ##--Mask
        mask_1 = mask_1.cuda().float()
        mask_1.requires_grad = True
        mask_2 = mask_2.cuda().float()
        mask_2.requires_grad = True
        loss_mask, loss_mask_info = space_loss(mask_1,
                                               mask_2,
                                               lpips_model=loss_lpips)

        E_optimizer.zero_grad()
        loss_mask.backward(retain_graph=True)
        E_optimizer.step()

        ##--Grad
        grad_1 = grad_1.cuda().float()
        grad_1.requires_grad = True
        grad_2 = grad_2.cuda().float()
        grad_2.requires_grad = True
        loss_grad, loss_grad_info = space_loss(grad_1,
                                               grad_2,
                                               lpips_model=loss_lpips)

        E_optimizer.zero_grad()
        loss_grad.backward(retain_graph=True)
        E_optimizer.step()

        ##--Image
        loss_imgs, loss_imgs_info = space_loss(imgs1,
                                               imgs2,
                                               lpips_model=loss_lpips)
        E_optimizer.zero_grad()
        loss_imgs.backward(retain_graph=True)
        E_optimizer.step()

        ##--Grad_CAM from mask
        cam_1 = cam_1.cuda().float()
        cam_1.requires_grad = True
        cam_2 = cam_2.cuda().float()
        cam_2.requires_grad = True
        loss_Gcam, loss_Gcam_info = space_loss(cam_1,
                                               cam_2,
                                               lpips_model=loss_lpips)

        E_optimizer.zero_grad()
        loss_Gcam.backward(retain_graph=True)
        E_optimizer.step()

        print('i_' + str(epoch))
        print(
            '[loss_imgs_mse[img,img_mean,img_std], loss_imgs_ssim, loss_imgs_cosine, loss_kl_imgs, loss_imgs_lpips]'
        )
        print('---------ImageSpace--------')
        print('loss_mask_info: %s' % loss_mask_info)
        print('loss_grad_info: %s' % loss_grad_info)
        print('loss_imgs_info: %s' % loss_imgs_info)
        print('loss_Gcam_info: %s' % loss_Gcam_info)
        print('---------LatentSpace--------')
        print('loss_w_info: %s' % loss_w_info)
        print('loss_c_info: %s' % loss_c_info)

        it_d += 1
        writer.add_scalar('loss_mask_mse',
                          loss_mask_info[0][0],
                          global_step=it_d)
        writer.add_scalar('loss_mask_mse_mean',
                          loss_mask_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_mask_mse_std',
                          loss_mask_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_mask_kl', loss_mask_info[1], global_step=it_d)
        writer.add_scalar('loss_mask_cosine',
                          loss_mask_info[2],
                          global_step=it_d)
        writer.add_scalar('loss_mask_ssim',
                          loss_mask_info[3],
                          global_step=it_d)
        writer.add_scalar('loss_mask_lpips',
                          loss_mask_info[4],
                          global_step=it_d)

        writer.add_scalar('loss_grad_mse',
                          loss_grad_info[0][0],
                          global_step=it_d)
        writer.add_scalar('loss_grad_mse_mean',
                          loss_grad_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_grad_mse_std',
                          loss_grad_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_grad_kl', loss_grad_info[1], global_step=it_d)
        writer.add_scalar('loss_grad_cosine',
                          loss_grad_info[2],
                          global_step=it_d)
        writer.add_scalar('loss_grad_ssim',
                          loss_grad_info[3],
                          global_step=it_d)
        writer.add_scalar('loss_grad_lpips',
                          loss_grad_info[4],
                          global_step=it_d)

        writer.add_scalar('loss_imgs_mse',
                          loss_imgs_info[0][0],
                          global_step=it_d)
        writer.add_scalar('loss_imgs_mse_mean',
                          loss_imgs_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_imgs_mse_std',
                          loss_imgs_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d)
        writer.add_scalar('loss_imgs_cosine',
                          loss_imgs_info[2],
                          global_step=it_d)
        writer.add_scalar('loss_imgs_ssim',
                          loss_imgs_info[3],
                          global_step=it_d)
        writer.add_scalar('loss_imgs_lpips',
                          loss_imgs_info[4],
                          global_step=it_d)

        writer.add_scalar('loss_Gcam', loss_Gcam_info[0][0], global_step=it_d)
        writer.add_scalar('loss_Gcam_mean',
                          loss_Gcam_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_Gcam_std',
                          loss_Gcam_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_Gcam_kl', loss_Gcam_info[1], global_step=it_d)
        writer.add_scalar('loss_Gcam_cosine',
                          loss_Gcam_info[2],
                          global_step=it_d)
        writer.add_scalar('loss_Gcam_ssim',
                          loss_Gcam_info[3],
                          global_step=it_d)
        writer.add_scalar('loss_Gcam_lpips',
                          loss_Gcam_info[4],
                          global_step=it_d)

        writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d)
        writer.add_scalar('loss_w_mse_mean',
                          loss_w_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_w_mse_std',
                          loss_w_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d)
        writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d)
        writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d)
        writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d)

        writer.add_scalar('loss_c_mse', loss_c_info[0][0], global_step=it_d)
        writer.add_scalar('loss_c_mse_mean',
                          loss_c_info[0][1],
                          global_step=it_d)
        writer.add_scalar('loss_c_mse_std',
                          loss_c_info[0][2],
                          global_step=it_d)
        writer.add_scalar('loss_c_kl', loss_c_info[1], global_step=it_d)
        writer.add_scalar('loss_c_cosine', loss_c_info[2], global_step=it_d)
        writer.add_scalar('loss_c_ssim', loss_c_info[3], global_step=it_d)
        writer.add_scalar('loss_c_lpips', loss_c_info[4], global_step=it_d)

        writer.add_scalars('Image_Space_MSE', {
            'loss_mask_mse': loss_mask_info[0][0],
            'loss_grad_mse': loss_grad_info[0][0],
            'loss_img_mse': loss_imgs_info[0][0]
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space_KL', {
            'loss_mask_kl': loss_mask_info[1],
            'loss_grad_cosine': loss_grad_info[1],
            'loss_imgs_cosine': loss_imgs_info[1]
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space_Cosine', {
            'loss_mask_cosine': loss_mask_info[2],
            'loss_grad_cosine': loss_grad_info[2],
            'loss_imgs_cosine': loss_imgs_info[2]
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space_SSIM', {
            'loss_mask_ssim': loss_mask_info[3],
            'loss_grad_ssim': loss_grad_info[3],
            'loss_img_ssim': loss_imgs_info[3]
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space_Cosine', {
            'loss_mask_cosine': loss_mask_info[4],
            'loss_grad_cosine': loss_grad_info[4],
            'loss_imgs_cosine': loss_imgs_info[4]
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space_Lpips', {
            'loss_mask_lpips': loss_mask_info[4],
            'loss_grad_lpips': loss_grad_info[4],
            'loss_img_lpips': loss_imgs_info[4]
        },
                           global_step=it_d)
        writer.add_scalars('Latent Space W', {
            'loss_w_mse': loss_w_info[0][0],
            'loss_w_mse_mean': loss_w_info[0][1],
            'loss_w_mse_std': loss_w_info[0][2],
            'loss_w_kl': loss_w_info[1],
            'loss_w_cosine': loss_w_info[2]
        },
                           global_step=it_d)
        writer.add_scalars('Latent Space C', {
            'loss_c_mse': loss_c_info[0][0],
            'loss_c_mse_mean': loss_c_info[0][1],
            'loss_c_mse_std': loss_c_info[0][2],
            'loss_c_kl': loss_w_info[1],
            'loss_c_cosine': loss_w_info[2]
        },
                           global_step=it_d)

        if epoch % 100 == 0:
            n_row = batch_size
            test_img = torch.cat((imgs1[:n_row], imgs2[:n_row])) * 0.5 + 0.5
            torchvision.utils.save_image(test_img,
                                         resultPath1_1 + '/ep%d.png' % (epoch),
                                         nrow=n_row)  # nrow=3
            heatmap = torch.cat((heatmap_1, heatmap_2))
            cam = torch.cat((cam_1, cam_2))
            grads = torch.cat((grad_1, grad_2))
            grads = grads.data.cpu().numpy()  # [n,c,h,w]
            grads -= np.max(np.min(grads), 0)
            grads /= np.max(grads)
            torchvision.utils.save_image(torch.tensor(heatmap),
                                         resultPath_grad_cam +
                                         '/heatmap_%d.png' % (epoch),
                                         nrow=n_row)
            torchvision.utils.save_image(torch.tensor(cam),
                                         resultPath_grad_cam + '/cam_%d.png' %
                                         (epoch),
                                         nrow=n_row)
            torchvision.utils.save_image(torch.tensor(grads),
                                         resultPath_grad_cam + '/gb_%d.png' %
                                         (epoch),
                                         nrow=n_row)
            with open(resultPath + '/Loss.txt', 'a+') as f:
                print('i_' + str(epoch), file=f)
                print(
                    '[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',
                    file=f)
                print('---------ImageSpace--------', file=f)
                print('loss_mask_info: %s' % loss_mask_info, file=f)
                print('loss_grad_info: %s' % loss_grad_info, file=f)
                print('loss_imgs_info: %s' % loss_imgs_info, file=f)
                print('loss_Gcam_info: %s' % loss_Gcam_info, file=f)
                print('---------LatentSpace--------', file=f)
                print('loss_w_info: %s' % loss_w_info, file=f)
                print('loss_c_info: %s' % loss_c_info, file=f)
            if epoch % 5000 == 0:
                torch.save(E.state_dict(),
                           resultPath1_2 + '/E_model_ep%d.pth' % epoch)
コード例 #3
0
ファイル: train_mapping.py プロジェクト: disanda/EAE
def train(avg_tensor=None, coefs=0):
    Gs = Generator(startf=16,
                   maxf=512,
                   layer_count=9,
                   latent_size=512,
                   channels=3)
    Gs.load_state_dict(torch.load('./pre-model/Gs_dict.pth'))
    Gm = Mapping(num_layers=18,
                 mapping_layers=8,
                 latent_size=512,
                 dlatent_size=512,
                 mapping_fmaps=512)
    Gm.load_state_dict(torch.load('./pre-model/Gm_dict.pth'))
    Gm.buffer1 = avg_tensor

    Gm1 = Mapping3()
    Gm1.load_state_dict(
        torch.load(
            '/_yucheng/myStyle/myStyle-v1/result/Gm_1&2_V10_6/models/Gm1_model_ep85000.pth'
        ))
    Gm2 = Mapping4()
    #Gm1.load_state_dict(torch.load('/_yucheng/myStyle/myStyle-v1/result/Gm_1&2_V10_3/models/Gm1_model_ep10000.pth'))
    #Gm2 = Mapping2(num_layers=18, mapping_layers=8, latent_size=512, inverse=True)
    #Gm1.load_state_dict(torch.load('./pre-model/Gm1.pth'))
    #Gm2.load_state_dict(torch.load('./pre-model/Gm2.pth'))
    E = BE.BE()
    E.load_state_dict(torch.load(
        '/_yucheng/myStyle/myStyle-v1/result/EB_V10_blob_mse/models/E_model_ep15000.pth'
    ),
                      strict=False)

    Gs.cuda()
    E.cuda()
    Gm.cuda()
    Gm1.cuda()
    Gm2.cuda()

    #Gm_optimizer = LREQAdam([{'params': Gm1.parameters()},], lr=0.0015, betas=(0.0, 0.99), weight_decay=0)
    Gm_optimizer = LREQAdam([
        {
            'params': Gm1.parameters()
        },
        {
            'params': Gm2.parameters()
        },
    ],
                            lr=0.0015,
                            betas=(0.0, 0.99),
                            weight_decay=0)

    loss_mse = torch.nn.MSELoss()
    loss_kl = torch.nn.KLDivLoss()
    loss_lpips = lpips.LPIPS(net='vgg').to('cuda')

    batch_size = 3
    for epoch in range(100000):
        set_seed(epoch % 20000)
        z = torch.randn(batch_size, 512).to('cuda')  #[32, 512]
        w1 = Gm(z, coefs_m=coefs)  #[batch_size,18,512]
        imgs1 = Gs.forward(w1, 8)
        const2, w2 = E(imgs1.cuda())
        z2 = Gm2(w2)
        w3 = Gm1(z2)
        imgs2 = Gs.forward(w2, 8)
        imgs3 = Gs.forward(w3, 8)
        #loss1
        Gm_optimizer.zero_grad()
        #Gm1_optimizer.zero_grad()
        loss_m1_mse = loss_mse(w2, w3)
        loss_m1_mse_mean = loss_mse(w2.mean(), w3.mean())
        loss_m1_mse_std = loss_mse(w2.std(), w3.std())

        y1_w, y2_w = torch.nn.functional.softmax(
            w2), torch.nn.functional.softmax(w3)
        loss_kl_w = loss_kl(torch.log(y2_w), y1_w)  #D_kl(True=y1_w||Fake=y2_w)
        loss_kl_w = torch.where(torch.isnan(loss_kl_w),
                                torch.full_like(loss_kl_w, 0), loss_kl_w)
        loss_kl_w = torch.where(torch.isinf(loss_kl_w),
                                torch.full_like(loss_kl_w, 1), loss_kl_w)

        loss_1 = loss_m1_mse + loss_m1_mse_mean + loss_m1_mse_std + loss_kl_w

        #loss2
        loss_m2_mse = loss_mse(z, z2)
        loss_m2_mse_mean = loss_mse(z.mean(), z2.mean())
        loss_m2_mse_std = loss_mse(z.std(), z2.std())

        y1_z, y2_z = torch.nn.functional.softmax(
            z), torch.nn.functional.softmax(z2)
        loss_kl_z = loss_kl(torch.log(y2_z), y1_z)  #D_kl(True=y1_z||Fake=y2_z)
        loss_kl_z = torch.where(torch.isnan(loss_kl_z),
                                torch.full_like(loss_kl_z, 0), loss_kl_z)
        loss_kl_z = torch.where(torch.isinf(loss_kl_z),
                                torch.full_like(loss_kl_z, 1), loss_kl_z)

        loss_2 = loss_m2_mse + loss_m2_mse_mean + loss_m2_mse_std + loss_kl_z

        #loss3
        loss_m1_mse_img = loss_mse(imgs2, imgs3)

        imgs2_ = F.avg_pool2d(imgs2, 2, 2)
        imgs3_ = F.avg_pool2d(imgs3, 2, 2)

        loss_img_lpips = loss_lpips(imgs2_, imgs3_).mean()

        loss_3 = loss_img_lpips + loss_m1_mse_img * 3

        loss_all = loss_2 + loss_1 * 3 + loss_3 * 15  # z -> w -> x
        loss_all.backward()
        Gm_optimizer.step()

        print('i_' + str(epoch) + '--loss_all__:' + str(loss_all.item()) +
              '--loss_m1_mse:' + str(loss_m1_mse.item()) +
              '--loss_m1_mse_mean:' + str(loss_m1_mse_mean.item()) +
              '--loss_m1_mse_std:' + str(loss_m1_mse_std.item()) +
              '--loss_kl_w:' + str(loss_kl_w.item()))
        print('--loss_m2_mse:' + str(loss_m2_mse.item()) +
              '--loss_m2_mse_mean:' + str(loss_m2_mse_mean.item()) +
              '--loss_m2_mse_std:' + str(loss_m2_mse_std.item()) +
              '--loss_kl_z:' + str(loss_kl_z.item()))
        #print('--loss_m1_mse_img:'+str(loss_m1_mse_img.item())+'--loss_m2_mse_img:'+str(loss_m2_mse_img.item())+'--loss_m3_mse_img:'+str(loss_m3_mse_img.item()))
        print('loss_img_lpips' + str(loss_img_lpips.item()) +
              '--loss_m1_mse_img:' + str(loss_m1_mse_img.item()))
        print('-')

        if epoch % 100 == 0:
            with torch.no_grad():  #这里需要生成图片和变量
                test_img = torch.cat((imgs1[:3], imgs2[:3]))
                test_img = torch.cat((test_img, imgs3[:3]))
                test_img = test_img * 0.5 + 0.5
            torchvision.utils.save_image(test_img,
                                         resultPath1_1 + '/ep%d.jpg' % (epoch),
                                         nrow=3)  # nrow=3
            with open(resultPath + '/Loss.txt', 'a+') as f:
                print('i_' + str(epoch) + '--loss_all__:' +
                      str(loss_all.item()) + '--loss_m1_mse:' +
                      str(loss_m1_mse.item()) + '--loss_m1_mse_mean:' +
                      str(loss_m1_mse_mean.item()) + '--loss_m1_mse_std:' +
                      str(loss_m1_mse_std.item()) + '--loss_kl_w:' +
                      str(loss_kl_w.item()),
                      file=f)
                print('--loss_m2_mse:' + str(loss_m2_mse.item()) +
                      '--loss_m2_mse_mean:' + str(loss_m2_mse_mean.item()) +
                      '--loss_m2_mse_std:' + str(loss_m2_mse_std.item()) +
                      '--loss_kl_z:' + str(loss_kl_z.item()),
                      file=f)
                #print('--loss_m1_mse_img:'+str(loss_m1_mse_img.item())+'--loss_m2_mse_img:'+str(loss_m2_mse_img.item())+'--loss_m3_mse_img:'+str(loss_m3_mse_img.item()),file=f)
                print('loss_img_lpips' + str(loss_img_lpips.item()) +
                      '--loss_m1_mse_img:' + str(loss_m1_mse_img.item()),
                      file=f)
            if epoch % 5000 == 0:
                torch.save(Gm1.state_dict(),
                           resultPath1_2 + '/Gm1_model_ep%d.pth' % epoch)
                torch.save(Gm2.state_dict(),
                           resultPath1_2 + '/Gm2_model_ep%d.pth' % epoch)
コード例 #4
0
def train(cfg, logger, gpu_id=0):
    torch.cuda.set_device(gpu_id)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=3)
    model.cuda(gpu_id)
    model.train()

    if gpu_id == 0:
        model_s = Model(
            startf=cfg.MODEL.START_CHANNEL_COUNT,
            layer_count=cfg.MODEL.LAYER_COUNT,
            maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
            latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
            truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
            truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
            mapping_layers=cfg.MODEL.MAPPING_LAYERS,
            channels=3)
        del model_s.discriminator
        model_s.cuda(gpu_id)
        model_s.eval()
        model_s.requires_grad_(False)

    generator = model.generator
    discriminator = model.discriminator
    mapping = model.mapping
    dlatent_avg = model.dlatent_avg

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(generator)

    logger.info("Trainable parameters discriminator:")
    count_parameters(discriminator)

    generator_optimizer = LREQAdam([
        {'params': generator.parameters()},
        {'params': mapping.parameters()}], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0)

    discriminator_optimizer = LREQAdam([
        {'params': discriminator.parameters()},
    ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={'generator': generator_optimizer,'discriminator': discriminator_optimizer},
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, # []
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE, # 0.1
                                 reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) # 0.002

    model_dict = {
        'discriminator': discriminator,
        'generator': generator,
        'mapping': mapping,
        'dlatent_avg': dlatent_avg
    }

    if gpu_id == 0:
        model_dict['generator_s'] = model_s.generator
        model_dict['mapping_s'] = model_s.mapping

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {
                                    'generator_optimizer': generator_optimizer,
                                    'discriminator_optimizer': discriminator_optimizer,
                                    'scheduler': scheduler,
                                    'tracker': tracker
                                },
                                logger=logger,
                                save=gpu_id == 0)

    checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    layer_to_resolution = generator.layer_to_resolution #[4, 8, 16, 32, 64, 128]

    dataset = TFRecordsDataset(cfg, logger, buffer_size_mb=1024)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    sample = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg, logger, gpu_num=1, dataset_size=len(dataset)) #一个可以返回各类训练参数(param)的对象

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch, [generator_optimizer, discriminator_optimizer])

        logger.info("Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" % (
                                                                lod2batch.get_batch_size(),
                                                                lod2batch.get_per_GPU_batch_size(),
                                                                lod2batch.lod,
                                                                2 ** lod2batch.get_lod_power2(),
                                                                2 ** lod2batch.get_lod_power2(),
                                                                lod2batch.get_blend_factor(),
                                                                len(dataset)))

        dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size())
        batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), gpu_id) # 一个数据集分为多个batch,一个batch有n长图片

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod) #报错!

        need_permute = False

        with torch.autograd.profiler.profile(use_cuda=True, enabled=False) as prof:
            for x_orig in tqdm(batches): # x_orig:[-1,c,w,h]

                with torch.no_grad():
                    if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                        continue
                    if need_permute:
                        x_orig = x_orig.permute(0, 3, 1, 2)
                    x_orig = (x_orig / 127.5 - 1.)

                    blend_factor = lod2batch.get_blend_factor()

                    needed_resolution = layer_to_resolution[lod2batch.lod]
                    x = x_orig

                    if lod2batch.in_transition:
                        needed_resolution_prev = layer_to_resolution[lod2batch.lod - 1]
                        x_prev = F.avg_pool2d(x_orig, 2, 2)
                        x_prev_2x = F.interpolate(x_prev, needed_resolution)
                        x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)
                x.requires_grad = True

                discriminator_optimizer.zero_grad()
                loss_d = model(x, lod2batch.lod, blend_factor, d_train=True)
                tracker.update(dict(loss_d=loss_d))
                loss_d.backward()
                discriminator_optimizer.step()

                if gpu_id == 0:
                    betta = 0.5 ** (lod2batch.get_batch_size() / (10 * 1000.0))
                    model_s.lerp(model, betta)

                generator_optimizer.zero_grad()
                loss_g = model(x, lod2batch.lod, blend_factor, d_train=False)
                tracker.update(dict(loss_g=loss_g))
                loss_g.backward()
                generator_optimizer.step()

                lod2batch.step()
                if gpu_id == 0:
                    if lod2batch.is_time_to_save():
                        checkpointer.save("model_tmp_intermediate")
                    if lod2batch.is_time_to_report():
                        save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, discriminator_optimizer, generator_optimizer)
        scheduler.step()

        if gpu_id == 0:
            checkpointer.save("model_tmp")
            save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, discriminator_optimizer, generator_optimizer)

    logger.info("Training finish!... save training results")
    if gpu_id == 0:
        checkpointer.save("model_final").wait()
コード例 #5
0
ファイル: EAE_V2_may_18.py プロジェクト: disanda/EAE
def train(avg_tensor=None, coefs=0, tensor_writer=None):
    Gs = Generator(startf=64,
                   maxf=512,
                   layer_count=7,
                   latent_size=512,
                   channels=3)  # 32->512 layer_count=8 / 64->256 layer_count=7
    Gs.load_state_dict(torch.load('./pre-model/cat/cat256_Gs_dict.pth'))
    Gm = Mapping(num_layers=14,
                 mapping_layers=8,
                 latent_size=512,
                 dlatent_size=512,
                 mapping_fmaps=512)  #num_layers: 14->256 / 16->512 / 18->1024
    Gm.load_state_dict(torch.load('./pre-model/cat/cat256_Gm_dict.pth'))
    Gm.buffer1 = avg_tensor
    E = BE.BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3)
    E.load_state_dict(
        torch.load(
            '/_yucheng/myStyle/myStyle-v1/EAE-car-cat/result/EB_cat_cosine_v2/E_model_ep80000.pth'
        ))
    Gs.cuda()
    #Gm.cuda()
    E.cuda()
    const_ = Gs.const
    writer = tensor_writer

    E_optimizer = LREQAdam([
        {
            'params': E.parameters()
        },
    ],
                           lr=0.0015,
                           betas=(0.0, 0.99),
                           weight_decay=0)
    loss_mse = torch.nn.MSELoss()
    loss_lpips = lpips.LPIPS(net='vgg').to('cuda')
    loss_kl = torch.nn.KLDivLoss()
    ssim_loss = pytorch_ssim.SSIM()

    batch_size = 3
    const1 = const_.repeat(batch_size, 1, 1, 1)

    vgg16 = torchvision.models.vgg16(pretrained=True).cuda()
    final_layer = None
    for name, m in vgg16.named_modules():
        if isinstance(m, nn.Conv2d):
            final_layer = name
    grad_cam_plus_plus = GradCamPlusPlus(vgg16, final_layer)
    gbp = GuidedBackPropagation(vgg16)

    it_d = 0
    for epoch in range(0, 250001):
        set_seed(epoch % 30000)
        latents = torch.randn(batch_size, 512)  #[32, 512]
        with torch.no_grad():  #这里需要生成图片和变量
            w1 = Gm(latents, coefs_m=coefs).to('cuda')  #[batch_size,18,512]
            imgs1 = Gs.forward(w1, 6)  # 7->512 / 6->256

        const2, w2 = E(imgs1.cuda())

        imgs2 = Gs.forward(w2, 6)

        E_optimizer.zero_grad()

        #Image Space
        mask_1 = grad_cam_plus_plus(imgs1, None)  #[c,1,h,w]
        mask_2 = grad_cam_plus_plus(imgs2, None)
        #imgs1.retain_grad()
        #imgs2.retain_grad()
        imgs1_ = imgs1.detach().clone()
        imgs1_.requires_grad = True
        imgs2_ = imgs2.detach().clone()
        imgs2_.requires_grad = True
        grad1 = gbp(imgs1_)  # [n,c,h,w]
        grad2 = gbp(imgs2_)

        #Mask_Cam
        mask_1 = mask_1.cuda().float()
        mask_1.requires_grad = True
        mask_2 = mask_2.cuda().float()
        mask_2.requires_grad = True
        loss_mask_mse_1 = loss_mse(mask_1, mask_2)
        loss_mask_mse_2 = loss_mse(mask_1.mean(), mask_2.mean())
        loss_mask_mse_3 = loss_mse(mask_1.std(), mask_2.std())
        loss_mask_mse = loss_mask_mse_1 + loss_mask_mse_2 + loss_mask_mse_3

        ssim_value = pytorch_ssim.ssim(mask_1,
                                       mask_2)  # while ssim_value<0.999:
        loss_mask_ssim = 1 - ssim_loss(mask_1, mask_2)

        loss_mask_lpips = loss_lpips(mask_1, mask_2).mean()

        mask1_kl, mask2_kl = torch.nn.functional.softmax(
            mask_1), torch.nn.functional.softmax(mask_2)
        loss_kl_mask = loss_kl(torch.log(mask2_kl),
                               mask1_kl)  #D_kl(True=y1_imgs||Fake=y2_imgs)
        loss_kl_mask = torch.where(torch.isnan(loss_kl_w),
                                   torch.full_like(loss_kl_w, 0), loss_kl_w)
        loss_kl_mask = torch.where(torch.isinf(loss_kl_w),
                                   torch.full_like(loss_kl_w, 1), loss_kl_w)

        mask1_cos = mask1.view(-1)
        mask2_cos = mask2.view(-1)
        loss_cosine_w = 1 - mask1_cos.dot(mask2_cos) / (
            torch.sqrt(mask1_cos.dot(mask1_cos)) *
            torch.sqrt(mask2_cos.dot(mask2_cos)))  #[-1,1],-1:反向相反,1:方向相同

        loss_mask = loss_mask_mse + loss_mask_ssim + loss_mask_lpips + loss_kl_mask + loss_cosine_w
        E_optimizer.zero_grad()
        loss_mask.backward(retain_graph=True)
        E_optimizer.step()

        #Grad
        grad1 = grad1.cuda().float()
        grad1.requires_grad = True
        grad2 = grad2.cuda().float()
        grad2.requires_grad = True
        loss_grad_mse = loss_mse(grad1, grad2)
        E_optimizer.zero_grad()
        loss_grad_mse.backward(retain_graph=True)
        E_optimizer.step()

        ssim_value = pytorch_ssim.ssim(grad1, grad2)  # while ssim_value<0.999:
        loss_grad_ssim = 1 - ssim_loss(grad1, grad2)
        E_optimizer.zero_grad()
        loss_grad_ssim.backward(retain_graph=True)
        E_optimizer.step()

        loss_grad_lpips = loss_lpips(grad1, grad2).mean()
        E_optimizer.zero_grad()
        loss_grad_lpips.backward(retain_graph=True)
        E_optimizer.step()

        grad1 = grad1.cuda().float()
        grad1.requires_grad = True
        grad2 = grad2.cuda().float()
        grad2.requires_grad = True
        loss_grad_mse_1 = loss_mse(grad1, grad2)
        loss_grad_mse_2 = loss_mse(grad1.mean(), grad2.mean())
        loss_grad_mse_3 = loss_mse(mask_1.std(), mask_2.std())
        loss_mask_mse = loss_mask_mse_1 + loss_mask_mse_2 + loss_mask_mse_3

        ssim_value = pytorch_ssim.ssim(mask_1,
                                       mask_2)  # while ssim_value<0.999:
        loss_mask_ssim = 1 - ssim_loss(mask_1, mask_2)

        loss_mask_lpips = loss_lpips(mask_1, mask_2).mean()

        mask1_kl, mask2_kl = torch.nn.functional.softmax(
            mask_1), torch.nn.functional.softmax(mask_2)
        loss_kl_mask = loss_kl(torch.log(mask2_kl),
                               mask1_kl)  #D_kl(True=y1_imgs||Fake=y2_imgs)
        loss_kl_mask = torch.where(torch.isnan(loss_kl_w),
                                   torch.full_like(loss_kl_w, 0), loss_kl_w)
        loss_kl_mask = torch.where(torch.isinf(loss_kl_w),
                                   torch.full_like(loss_kl_w, 1), loss_kl_w)

        mask1_cos = mask1.view(-1)
        mask2_cos = mask2.view(-1)
        loss_cosine_w = 1 - mask1_cos.dot(mask2_cos) / (
            torch.sqrt(mask1_cos.dot(mask1_cos)) *
            torch.sqrt(mask2_cos.dot(mask2_cos)))  #[-1,1],-1:反向相反,1:方向相同

        loss_mask = loss_mask_mse + loss_mask_ssim + loss_mask_lpips + loss_kl_mask + loss_cosine_w
        E_optimizer.zero_grad()
        loss_mask.backward(retain_graph=True)
        E_optimizer.step()

        #Image
        loss_img_mse = loss_mse(imgs1, imgs2)
        E_optimizer.zero_grad()
        loss_img_mse.backward(retain_graph=True)
        E_optimizer.step()

        ssim_value = pytorch_ssim.ssim(imgs1, imgs2)  # while ssim_value<0.999:
        loss_img_ssim = 1 - ssim_loss(imgs1, imgs2)
        E_optimizer.zero_grad()
        loss_img_ssim.backward(retain_graph=True)
        E_optimizer.step()

        loss_img_lpips = loss_lpips(imgs1, imgs2).mean()
        E_optimizer.zero_grad()
        loss_img_lpips.backward(retain_graph=True)
        E_optimizer.step()

        #Latent Space
        # W
        loss_w = loss_mse(w1, w2)
        loss_w_m = loss_mse(w1.mean(), w2.mean())  #初期一会很大10,一会很小0.0001
        loss_w_s = loss_mse(w1.std(), w2.std())  #后期一会很大,一会很小

        w1_kl, w2_kl = torch.nn.functional.softmax(
            w1), torch.nn.functional.softmax(w2)
        loss_kl_w = loss_kl(torch.log(w2_kl),
                            w1_kl)  #D_kl(True=y1_imgs||Fake=y2_imgs)
        loss_kl_w = torch.where(torch.isnan(loss_kl_w),
                                torch.full_like(loss_kl_w, 0), loss_kl_w)
        loss_kl_w = torch.where(torch.isinf(loss_kl_w),
                                torch.full_like(loss_kl_w, 1), loss_kl_w)

        w1_cos = w1.view(-1)
        w2_cos = w2.view(-1)
        loss_cosine_w = 1 - w1_cos.dot(w2_cos) / (
            torch.sqrt(w1_cos.dot(w1_cos)) * torch.sqrt(w2_cos.dot(w2_cos))
        )  #[-1,1],-1:反向相反,1:方向相同
        # C
        loss_c = loss_mse(
            const1, const2)  #没有这个const,梯度起初没法快速下降,很可能无法收敛, 这个惩罚即乘0.1后,效果大幅提升!
        loss_c_m = loss_mse(const1.mean(), const2.mean())
        loss_c_s = loss_mse(const1.std(), const2.std())

        y1, y2 = torch.nn.functional.softmax(
            const1), torch.nn.functional.softmax(const2)
        loss_kl_c = loss_kl(torch.log(y2), y1)
        loss_kl_c = torch.where(torch.isnan(loss_kl_c),
                                torch.full_like(loss_kl_c, 0), loss_kl_c)
        loss_kl_c = torch.where(torch.isinf(loss_kl_c),
                                torch.full_like(loss_kl_c, 1), loss_kl_c)

        c_cos1 = const1.view(-1)
        c_cos2 = const2.view(-1)
        loss_cosine_c = 1 - c_cos1.dot(c_cos2) / (
            torch.sqrt(c_cos1.dot(c_cos1)) * torch.sqrt(c_cos2.dot(c_cos2)))


        loss_ls_all = loss_w+loss_w_m+loss_w_s+loss_kl_w+loss_cosine_w+\
                        loss_c+loss_c_m+loss_c_s+loss_kl_c+loss_cosine_c
        loss_ls_all.backward(retain_graph=True)
        E_optimizer.step()

        print('i_' + str(epoch))
        print('---------ImageSpace--------')
        print('loss_mask_mse:' + str(loss_mask_mse.item()) +
              '--loss_mask_ssim:' + str(loss_mask_ssim.item()) +
              '--loss_mask_lpips:' + str(loss_mask_lpips.item()))
        print('loss_grad_mse:' + str(loss_grad_mse.item()) +
              '--loss_grad_ssim:' + str(loss_grad_ssim.item()) +
              '--loss_grad_lpips:' + str(loss_grad_lpips.item()))
        print('loss_img_mse:' + str(loss_img_mse.item()) + '--loss_img_ssim:' +
              str(loss_img_ssim.item()) + '--loss_img_lpips:' +
              str(loss_img_lpips.item()))
        print('---------LatentSpace--------')
        print('loss_w:' + str(loss_w.item()) + '--loss_w_m:' +
              str(loss_w_m.item()) + '--loss_w_s:' + str(loss_w_s.item()))
        print('loss_kl_w:' + str(loss_kl_w.item()) + '--loss_cosine_w:' +
              str(loss_cosine_w.item()))
        print('loss_c:' + str(loss_c.item()) + '--loss_c_m:' +
              str(loss_c_m.item()) + '--loss_c_s:' + str(loss_c_s.item()))
        print('loss_kl_c:' + str(loss_kl_c.item()) + '--loss_cosine_c:' +
              str(loss_cosine_c.item()))

        it_d += 1
        writer.add_scalar('loss_mask_mse', loss_mask_mse, global_step=it_d)
        writer.add_scalar('loss_mask_ssim', loss_mask_ssim, global_step=it_d)
        writer.add_scalar('loss_mask_lpips', loss_mask_lpips, global_step=it_d)
        writer.add_scalar('loss_grad_mse', loss_grad_mse, global_step=it_d)
        writer.add_scalar('loss_grad_ssim', loss_grad_ssim, global_step=it_d)
        writer.add_scalar('loss_grad_lpips', loss_grad_lpips, global_step=it_d)
        writer.add_scalar('loss_img_mse', loss_img_mse, global_step=it_d)
        writer.add_scalar('loss_img_ssim', loss_img_ssim, global_step=it_d)
        writer.add_scalar('loss_img_lpips', loss_img_lpips, global_step=it_d)
        writer.add_scalar('loss_w', loss_w, global_step=it_d)
        writer.add_scalar('loss_w_m', loss_w_m, global_step=it_d)
        writer.add_scalar('loss_w_s', loss_w_s, global_step=it_d)
        writer.add_scalar('loss_kl_w', loss_kl_w, global_step=it_d)
        writer.add_scalar('loss_cosine_w', loss_cosine_w, global_step=it_d)
        writer.add_scalar('loss_c', loss_c, global_step=it_d)
        writer.add_scalar('loss_c_m', loss_c_m, global_step=it_d)
        writer.add_scalar('loss_c_s', loss_c_s, global_step=it_d)
        writer.add_scalar('loss_kl_c', loss_kl_c, global_step=it_d)
        writer.add_scalar('loss_cosine_c', loss_cosine_c, global_step=it_d)
        writer.add_scalars('Image_Space', {
            'loss_mask_mse': loss_mask_mse,
            'loss_grad_mse': loss_grad_mse,
            'loss_img_mse': loss_img_mse
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space', {
            'loss_mask_ssim': loss_mask_mse,
            'loss_grad_ssim': loss_grad_ssim,
            'loss_img_ssim': loss_img_ssim
        },
                           global_step=it_d)
        writer.add_scalars('Image_Space', {
            'loss_mask_lpips': loss_mask_lpips,
            'loss_grad_lpips': loss_grad_lpips,
            'loss_img_lpips': loss_img_lpips
        },
                           global_step=it_d)
        writer.add_scalars('Latent Space W', {
            'loss_w': loss_w,
            'loss_w_m': loss_w_m,
            'loss_w_s': loss_w_s,
            'loss_kl_w': loss_kl_w,
            'loss_cosine_w': loss_cosine_w
        },
                           global_step=it_d)
        writer.add_scalars('Latent Space C', {
            'loss_c': loss_c,
            'loss_c_m': loss_c_m,
            'loss_c_s': loss_c_s,
            'loss_kl_c': loss_kl_c,
            'loss_cosine_c': loss_cosine_c
        },
                           global_step=it_d)

        if epoch % 100 == 0:
            n_row = batch_size
            test_img = torch.cat((imgs1[:n_row], imgs2[:n_row])) * 0.5 + 0.5
            torchvision.utils.save_image(test_img,
                                         resultPath1_1 + '/ep%d.png' % (epoch),
                                         nrow=n_row)  # nrow=3
            heatmap1, cam1 = mask2cam(mask_1, imgs1)
            heatmap2, cam2 = mask2cam(mask_2, imgs2)
            heatmap = torch.cat((heatmap1, heatmap1))
            cam = torch.cat((cam1, cam2))
            grads = torch.cat((grad1, grad2))
            grads = grads.data.cpu().numpy()  # [n,c,h,w]
            grads -= np.max(np.min(grads), 0)
            grads /= np.max(grads)
            torchvision.utils.save_image(
                torch.tensor(heatmap),
                resultPath_grad_cam + '/heatmap_%d.png' % (epoch))
            torchvision.utils.save_image(
                torch.tensor(cam),
                resultPath_grad_cam + '/cam_%d.png' % (epoch))
            torchvision.utils.save_image(
                torch.tensor(grads),
                resultPath_grad_cam + '/gb_%d.png' % (epoch))
            with open(resultPath + '/Loss.txt', 'a+') as f:
                print('i_' + str(epoch), file=f)
                print('---------ImageSpace--------', file=f)
                print('loss_mask_mse:' + str(loss_mask_mse.item()) +
                      '--loss_mask_ssim:' + str(loss_mask_ssim.item()) +
                      '--loss_mask_lpips:' + str(loss_mask_lpips.item()),
                      file=f)
                print('loss_grad_mse:' + str(loss_grad_mse.item()) +
                      '--loss_grad_ssim:' + str(loss_grad_ssim.item()) +
                      '--loss_grad_lpips:' + str(loss_grad_lpips.item()),
                      file=f)
                print('loss_img_mse:' + str(loss_img_mse.item()) +
                      '--loss_img_ssim:' + str(loss_img_ssim.item()) +
                      '--loss_img_lpips:' + str(loss_img_lpips.item()),
                      file=f)
                print('---------LatentSpace--------', file=f)
                print('loss_w:' + str(loss_w.item()) + '--loss_w_m:' +
                      str(loss_w_m.item()) + '--loss_w_s:' +
                      str(loss_w_s.item()),
                      file=f)
                print('loss_kl_w:' + str(loss_kl_w.item()) +
                      '--loss_cosine_w:' + str(loss_cosine_w.item()),
                      file=f)
                print('loss_c:' + str(loss_c.item()) + '--loss_c_m:' +
                      str(loss_c_m.item()) + '--loss_c_s:' +
                      str(loss_c_s.item()),
                      file=f)
                print('loss_kl_c:' + str(loss_kl_c.item()) +
                      '--loss_cosine_c:' + str(loss_cosine_c.item()),
                      file=f)
            if epoch % 5000 == 0:
                torch.save(E.state_dict(),
                           resultPath1_2 + '/E_model_ep%d.pth' % epoch)
コード例 #6
0
ファイル: train_be.py プロジェクト: disanda/EAE
def train(avg_tensor = None, coefs=0):
	Gs = Generator(startf=16, maxf=512, layer_count=9, latent_size=512, channels=3)
	Gs.load_state_dict(torch.load('./pre-model/Gs_dict.pth'))
	Gm = Mapping(num_layers=18, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512)
	Gm.load_state_dict(torch.load('./pre-model/Gm_dict.pth')) 
	#Gm.requires_grad_(False)
	#Gs.requires_grad_(False)
	Gm.buffer1 = avg_tensor
	E = BE.BE()
	E.load_state_dict(torch.load('/_yucheng/myStyle/myStyle-v1/result/EB_V8_newE_noEw_Ebias_2/models/E_model_ep65000.pth'),strict=False)
	# model_dict = E.state_dict()
	# pretrained_dict = torch.load('/_yucheng/myStyle/myStyle-v1/result/EB_V6_3ImgLoss_Res0618_truncW_noUpgradeW/models/E_model_ep15000.pth')
	# for k,v in model_dict.items():
	# 	if 'decode_block.0.noise_weight_2' in k:
	# 		pretrained_dict.pop(k)
	# 	if 'decode_block.1.noise_weight_2' in k:
	# 		pretrained_dict.pop(k)
	# 	if 'decode_block.2.noise_weight_2' in k:
	# 		pretrained_dict.pop(k)
	# 	if 'decode_block.3.noise_weight_2' in k:
	# 		pretrained_dict.pop(k)
	# 	if 'decode_block.4.noise_weight_2' in k:
	# 		pretrained_dict.pop(k)
	# 	if 'decode_block.0.bias_2' in k:
	# 		pretrained_dict.pop(k)
	# 	if 'decode_block.1.bias_2' in k:
	# 		pretrained_dict.pop(k)
	# 	if 'decode_block.2.bias_2' in k:
	# 		pretrained_dict.pop(k)
	# 	if 'decode_block.3.bias_2' in k:
	# 		pretrained_dict.pop(k)
	# 	if 'decode_block.4.bias_2' in k:
	# 		pretrained_dict.pop(k)

	# model_dict.update(pretrained_dict)
	# E.load_state_dict(model_dict,strict=False)

	Gs.cuda()
	#Gm.cuda()
	E.cuda()
	const_ = Gs.const

	E_optimizer = LREQAdam([{'params': E.parameters()},], lr=0.0015, betas=(0.0, 0.99), weight_decay=0)

	loss_all=0
	loss_mse = torch.nn.MSELoss()
	loss_lpips = lpips.LPIPS(net='vgg').to('cuda')
	loss_kl = torch.nn.KLDivLoss()

	batch_size = 3
	const1 = const_.repeat(batch_size,1,1,1)
	for epoch in range(0,250001):
		set_seed(epoch%25000)
		latents = torch.randn(batch_size, 512) #[32, 512]
		w1 = Gm(latents,coefs_m=coefs).to('cuda') #[batch_size,18,512]
		with torch.no_grad(): #这里需要生成图片和变量
			imgs1 = Gs.forward(w1,8)

		const2,w2 = E(imgs1.cuda())

		imgs2=Gs.forward(w2,8)

		E_optimizer.zero_grad()
#loss1
		#loss_img_mse = loss_mse(imgs1,imgs2)
		loss_img_mse_c1 = loss_mse(imgs1[:,0],imgs2[:,0])
		loss_img_mse_c2 = loss_mse(imgs1[:,1],imgs2[:,1])
		loss_img_mse_c3 = loss_mse(imgs1[:,2],imgs2[:,2])
		loss_img_mse = max(loss_img_mse_c1,loss_img_mse_c2,loss_img_mse_c3)

		imgs1_ = F.avg_pool2d(imgs1,2,2)
		imgs2_ = F.avg_pool2d(imgs2,2,2)
		imgs1__ = F.avg_pool2d(imgs1_,2,2)
		imgs2__ = F.avg_pool2d(imgs2_,2,2)
		loss_img_lpips = loss_lpips(imgs1__,imgs2__).mean()

		y1_imgs, y2_imgs = torch.nn.functional.softmax(imgs1_),torch.nn.functional.softmax(imgs2_)
		loss_kl_img = loss_kl(torch.log(y2_imgs),y1_imgs) #D_kl(True=y1_imgs||Fake=y2_imgs)
		loss_kl_img = torch.where(torch.isnan(loss_kl_img),torch.full_like(loss_kl_img,0), loss_kl_img)
		loss_kl_img = torch.where(torch.isinf(loss_kl_img),torch.full_like(loss_kl_img,1), loss_kl_img)

		loss_1 = 13*loss_img_mse+ 5*loss_img_lpips + loss_kl_img
		loss_1.backward(retain_graph=True)
		E_optimizer.step()
#loss2
		imgs_column1 = imgs1[:,:,:,128:-128]
		imgs_column2 = imgs2[:,:,:,128:-128]
		loss_img_mse_column = loss_mse(imgs_column1,imgs_column2)

		imgs1_column_down = F.avg_pool2d(imgs_column1,2,2)
		imgs2_column_down = F.avg_pool2d(imgs_column2,2,2)
		loss_img_lpips_column = loss_lpips(imgs1_column_down,imgs2_column_down).mean()

		loss_2 = 19*loss_img_mse_column +7*loss_img_lpips_column
		loss_2.backward(retain_graph=True)
		E_optimizer.step()
#loss3
		imgs_center1 = imgs1[:,:,128:640,256:-256]
		imgs_center2 = imgs2[:,:,128:640,256:-256]
		loss_img_mse_center = loss_mse(imgs_center1,imgs_center2)

		imgs1_c_down = F.avg_pool2d(imgs_center1,2,2)
		imgs2_c_down = F.avg_pool2d(imgs_center2,2,2)
		loss_img_lpips_center = loss_lpips(imgs1_c_down,imgs2_c_down).mean()

		imgs_blob1 = imgs1[:,:,924:,924:]
		imgs_blob2 = imgs2[:,:,924:,924:]
		loss_img_mse_blob = loss_mse(imgs_blob1,imgs_blob2)

		loss_3 = 23*loss_img_mse_center +11*loss_img_lpips_center + loss_img_mse_blob
		loss_3.backward(retain_graph=True)
		#loss_x = loss_1+loss_2+loss_3
		#loss_x.backward(retain_graph=True)
		E_optimizer.step()
#loss4
		loss_c = loss_mse(const1,const2) #没有这个const,梯度起初没法快速下降,很可能无法收敛, 这个惩罚即乘0.1后,效果大幅提升!
		loss_c_m = loss_mse(const1.mean(),const2.mean())
		loss_c_s = loss_mse(const1.std(),const2.std())

		loss_w = loss_mse(w1,w2)
		loss_w_m = loss_mse(w1.mean(),w2.mean()) #初期一会很大10,一会很小0.0001
		loss_w_s = loss_mse(w1.std(),w2.std()) #后期一会很大,一会很小

		y1, y2 = torch.nn.functional.softmax(const1),torch.nn.functional.softmax(const2)
		loss_kl_c = loss_kl(torch.log(y2),y1)
		loss_kl_c = torch.where(torch.isnan(loss_kl_c),torch.full_like(loss_kl_c,0), loss_kl_c)
		loss_kl_c = torch.where(torch.isinf(loss_kl_c),torch.full_like(loss_kl_c,1), loss_kl_c)

		w1_kl, w2_kl = torch.nn.functional.softmax(w1),torch.nn.functional.softmax(w2)
		loss_kl_w = loss_kl(torch.log(w2_kl),w1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs)
		loss_kl_w = torch.where(torch.isnan(loss_kl_w),torch.full_like(loss_kl_w,0), loss_kl_w)
		loss_kl_w = torch.where(torch.isinf(loss_kl_w),torch.full_like(loss_kl_w,1), loss_kl_w)

		loss_4 = 0.02*loss_c+0.03*loss_c_m+0.03*loss_c_s+0.02*loss_w+0.03*loss_w_m+0.03*loss_w_s+ loss_kl_w  + loss_kl_c
		loss_4.backward(retain_graph=True)
		E_optimizer.step()

		loss_all =  loss_1  + loss_4 + loss_2 + loss_3
		print('i_'+str(epoch)+'--loss_all__:'+str(loss_all.item())+'--loss_mse:'+str(loss_img_mse.item())+'--loss_lpips:'+str(loss_img_lpips.item())+'--loss_kl_img:'+str(loss_kl_img.item()))
		print('loss_img_mse_column:'+str(loss_img_mse_column.item())+'loss_img_lpips_column:'+str(loss_img_lpips_column.item())\
			+'--loss_img_mse_center:'+str(loss_img_mse_center.item())+'--loss_lpips_center:'+str(loss_img_lpips_center.item()))
		print('loss_w:'+str(loss_w.item())+'--loss_w_m:'+str(loss_w_m.item())+'--loss_w_s:'+str(loss_w_s.item())+'--loss_kl_w:'+str(loss_kl_w.item())\
			+'--loss_c:'+str(loss_c.item())+'--loss_c_m:'+str(loss_c_m.item())+'--loss_c_s:'+str(loss_c_s.item())+'--loss_kl_c:'+str(loss_kl_c.item())+'--loss_img_blob_mse:'+str(loss_img_mse_blob.item()))
		print('-')

		if epoch % 100 == 0:
			test_img = torch.cat((imgs1[:3],imgs2[:3]))*0.5+0.5
			torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d.jpg'%(epoch),nrow=3) # nrow=3
			with open(resultPath+'/Loss.txt', 'a+') as f:
				print('i_'+str(epoch)+'--loss_all__:'+str(loss_all.item())+'--loss_mse:'+str(loss_img_mse.item())+'--loss_lpips:'+str(loss_img_lpips.item())+'--loss_kl_img:'+str(loss_kl_img.item()),file=f)
				print('loss_img_mse_column:'+str(loss_img_mse_column.item())+'loss_img_lpips_column:'+str(loss_img_lpips_column.item())\
				+'--loss_img_mse_center:'+str(loss_img_mse_center.item())+'--loss_lpips_center:'+str(loss_img_lpips_center.item()),file=f)
				print('loss_w:'+str(loss_w.item())+'--loss_w_m:'+str(loss_w_m.item())+'--loss_w_s:'+str(loss_w_s.item())+'--loss_kl_w:'+str(loss_kl_w.item())+'--loss_c:'+str(loss_c.item())\
				+'--loss_c_m:'+str(loss_c_m.item())+'--loss_c_s:'+str(loss_c_s.item())+'--loss_kl_c:'+str(loss_kl_c.item())+'--loss_img_blob_mse:'+str(loss_img_mse_blob.item()),file=f)
			if epoch % 5000 == 0:
				torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d.pth'%epoch)
コード例 #7
0
ファイル: train_car_cat_bedroom.py プロジェクト: disanda/EAE
def train(avg_tensor = None, coefs=0):
    Gs = Generator(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) # 32->512 layer_count=8 / 64->256 layer_count=7
    Gs.load_state_dict(torch.load('./pre-model/cat/cat256_Gs_dict.pth'))
    Gm = Mapping(num_layers=14, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024
    Gm.load_state_dict(torch.load('./pre-model/cat/cat256_Gm_dict.pth')) 
    Gm.buffer1 = avg_tensor
    E = BE.BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3)
    #E.load_state_dict(torch.load('/_yucheng/myStyle/EAE/result/EB_cars_v1/models/E_model_ep135000.pth'))
    #E.load_state_dict(torch.load('/_yucheng/myStyle/EAE/result/EB_cat_v1/models/E_model_ep165000.pth'))
    Gs.cuda()
    #Gm.cuda()
    E.cuda()
    const_ = Gs.const

    E_optimizer = LREQAdam([{'params': E.parameters()},], lr=0.0015, betas=(0.0, 0.99), weight_decay=0)

    loss_all=0
    loss_mse = torch.nn.MSELoss()
    loss_lpips = lpips.LPIPS(net='vgg').to('cuda')
    loss_kl = torch.nn.KLDivLoss()

    batch_size = 5
    const1 = const_.repeat(batch_size,1,1,1)
    for epoch in range(0,250001):
        set_seed(epoch%30000)
        latents = torch.randn(batch_size, 512) #[32, 512]
        with torch.no_grad(): #这里需要生成图片和变量
            w1 = Gm(latents,coefs_m=coefs).to('cuda') #[batch_size,18,512]
            imgs1 = Gs.forward(w1,6) # 7->512 / 6->256

        const2,w2 = E(imgs1.cuda())

        imgs2=Gs.forward(w2,6)

        E_optimizer.zero_grad()
#loss1 
        loss_img_mse = loss_mse(imgs1,imgs2)
        # loss_img_mse_c1 = loss_mse(imgs1[:,0],imgs2[:,0])
        # loss_img_mse_c2 = loss_mse(imgs1[:,1],imgs2[:,1])
        # loss_img_mse_c3 = loss_mse(imgs1[:,2],imgs2[:,2])
        # loss_img_mse = max(loss_img_mse_c1,loss_img_mse_c2,loss_img_mse_c3)

        loss_img_lpips = loss_lpips(imgs1,imgs2).mean()

        y1_imgs, y2_imgs = torch.nn.functional.softmax(imgs1),torch.nn.functional.softmax(imgs2)
        loss_kl_img = loss_kl(torch.log(y2_imgs),y1_imgs) #D_kl(True=y1_imgs||Fake=y2_imgs)
        loss_kl_img = torch.where(torch.isnan(loss_kl_img),torch.full_like(loss_kl_img,0), loss_kl_img)
        loss_kl_img = torch.where(torch.isinf(loss_kl_img),torch.full_like(loss_kl_img,1), loss_kl_img)

        loss_1 = 17*loss_img_mse + 5*loss_img_lpips + loss_kl_img
        loss_1.backward(retain_graph=True)
        E_optimizer.step()
#loss2 中等区域
        #imgs_column1 = imgs1[:,:,imgs1.shape[2]//20:-imgs1.shape[2]//20,imgs1.shape[3]//20:-imgs1.shape[3]//20] # w,h
        #imgs_column2 = imgs2[:,:,imgs2.shape[2]//20:-imgs2.shape[2]//20,imgs2.shape[3]//20:-imgs2.shape[3]//20]
        #loss_img_mse_column = loss_mse(imgs_column1,imgs_column2)
        #loss_img_lpips_column = loss_lpips(imgs_column1,imgs_column2).mean()

        # loss_2 = 5*loss_img_mse_column + 3*loss_img_lpips_column
        # loss_2.backward(retain_graph=True)
        # E_optimizer.step()
#loss3 最小区域
        #imgs_center1 = imgs1[:,:,imgs1.shape[2]//10:-imgs1.shape[2]//10,imgs1.shape[3]//10:-imgs1.shape[3]//10]
        #imgs_center2 = imgs2[:,:,imgs2.shape[2]//10:-imgs2.shape[2]//10,imgs2.shape[3]//10:-imgs2.shape[3]//10]
        #loss_img_mse_center = loss_mse(imgs_center1,imgs_center2)
        #loss_img_lpips_center = loss_lpips(imgs_center1,imgs_center2).mean()

        # imgs_blob1 = imgs1[:,:,924:,924:]
        # imgs_blob2 = imgs2[:,:,924:,924:]
        # loss_img_mse_blob = loss_mse(imgs_blob1,imgs_blob2)

        #loss_3 = 3*loss_img_mse_center + loss_img_lpips_center #+ loss_img_mse_blob
        #loss_3.backward(retain_graph=True)
        #loss_x = loss_1+loss_2+loss_3
        #loss_x.backward(retain_graph=True)
        #E_optimizer.step()

#loss3_v2, cosine相似性
        i1 = imgs1.view(-1)
        i2 = imgs2.view(-1)
        loss_cosine_i = i1.dot(i2)/(torch.sqrt(i1.dot(i1))*torch.sqrt(i2.dot(i2)))
        #loss_cosine_w = w1.dot(w2)/(torch.sqrt(w1.dot(w1))*torch.sqrt(w2.dot(w2)))

#loss4
        loss_c = loss_mse(const1,const2) #没有这个const,梯度起初没法快速下降,很可能无法收敛, 这个惩罚即乘0.1后,效果大幅提升!
        loss_c_m = loss_mse(const1.mean(),const2.mean())
        loss_c_s = loss_mse(const1.std(),const2.std())

        loss_w = loss_mse(w1,w2)
        loss_w_m = loss_mse(w1.mean(),w2.mean()) #初期一会很大10,一会很小0.0001
        loss_w_s = loss_mse(w1.std(),w2.std()) #后期一会很大,一会很小

        y1, y2 = torch.nn.functional.softmax(const1),torch.nn.functional.softmax(const2)
        loss_kl_c = loss_kl(torch.log(y2),y1)
        loss_kl_c = torch.where(torch.isnan(loss_kl_c),torch.full_like(loss_kl_c,0), loss_kl_c)
        loss_kl_c = torch.where(torch.isinf(loss_kl_c),torch.full_like(loss_kl_c,1), loss_kl_c)

        w1_kl, w2_kl = torch.nn.functional.softmax(w1),torch.nn.functional.softmax(w2)
        loss_kl_w = loss_kl(torch.log(w2_kl),w1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs)
        loss_kl_w = torch.where(torch.isnan(loss_kl_w),torch.full_like(loss_kl_w,0), loss_kl_w)
        loss_kl_w = torch.where(torch.isinf(loss_kl_w),torch.full_like(loss_kl_w,1), loss_kl_w)


        w1_cos = w1.view(-1)
        w2_cos = w2.view(-1)
        loss_cosine_w = w1_cos.dot(w2_cos)/(torch.sqrt(w1_cos.dot(w1_cos))*torch.sqrt(w1_cos.dot(w1_cos)))


        loss_4 = 0.02*loss_c+0.03*loss_c_m+0.03*loss_c_s+0.02*loss_w+0.03*loss_w_m+0.03*loss_w_s+ loss_kl_w  + loss_kl_c+loss_cosine_i
        loss_4.backward(retain_graph=True)
        E_optimizer.step()

        loss_all =  loss_1  + loss_4 + loss_cosine_i #loss_2 + loss_3
        print('i_'+str(epoch)+'--loss_all__:'+str(loss_all.item())+'--loss_mse:'+str(loss_img_mse.item())+'--loss_lpips:'+str(loss_img_lpips.item())+'--loss_kl_img:'+str(loss_kl_img.item())+'--loss_cosine_i:'+str(loss_cosine_i.item()))
        #print('loss_img_mse_column:'+str(loss_img_mse_column.item())+'loss_img_lpips_column:'+str(loss_img_lpips_column.item())+'--loss_img_mse_center:'+str(loss_img_mse_center.item())+'--loss_lpips_center:'+str(loss_img_lpips_center.item()))
        print('loss_w:'+str(loss_w.item())+'--loss_w_m:'+str(loss_w_m.item())+'--loss_w_s:'+str(loss_w_s.item())+'--loss_kl_w:'+str(loss_kl_w.item())+'--loss_c:'+str(loss_c.item())+'--loss_c_m:'+str(loss_c_m.item())+'--loss_c_s:'+str(loss_c_s.item())+'--loss_kl_c:'+str(loss_kl_c.item())+'--loss_cosine_w:'+str(loss_cosine_w.item()))
        print('-')

        if epoch % 100 == 0:
            n_row = batch_size
            test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5
            torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d.jpg'%(epoch),nrow=n_row) # nrow=3
            with open(resultPath+'/Loss.txt', 'a+') as f:
                print('i_'+str(epoch)+'--loss_all__:'+str(loss_all.item())+'--loss_mse:'+str(loss_img_mse.item())+'--loss_lpips:'+str(loss_img_lpips.item())+'--loss_kl_img:'+str(loss_kl_img.item())+'--loss_cosine_i:'+str(loss_cosine_i.item()),file=f)
                #print('loss_img_mse_column:'+str(loss_img_mse_column.item())+'loss_img_lpips_column:'+str(loss_img_lpips_column.item())+'--loss_img_mse_center:'+str(loss_img_mse_center.item())+'--loss_lpips_center:'+str(loss_img_lpips_center.item()),file=f)
                print('loss_w:'+str(loss_w.item())+'--loss_w_m:'+str(loss_w_m.item())+'--loss_w_s:'+str(loss_w_s.item())+'--loss_kl_w:'+str(loss_kl_w.item())+'--loss_c:'+str(loss_c.item())+'--loss_c_m:'+str(loss_c_m.item())+'--loss_c_s:'+str(loss_c_s.item())+'--loss_kl_c:'+str(loss_kl_c.item())+'--loss_cosine_w:'+str(loss_cosine_w.item()),file=f)
            if epoch % 5000 == 0:
                torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d.pth'%epoch)