Beispiel #1
0
class DGNet_Trainer(nn.Module):
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        lr_g = hyperparameters['lr_g']  #生成器学习率
        lr_d = hyperparameters['lr_d']  #判别器学习率
        ID_class = hyperparameters['ID_class']
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        # 构建Es编码+解码过程 gen_a.encode()可以进行编码,gen_b.encode()可以进行解码
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16 = False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b
        
        # ID_stride,外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2
        # 构建外观编码器
        if hyperparameters['ID_style']=='PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style']=='AB': #使用的AB编码器
            self.id_a = ft_netAB(ID_class, stride = hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) 
        else:
            self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now
        # 浅拷贝,两者等同
        self.id_b = self.id_a
        
        # 鉴别器,使用的是一个多尺寸的鉴别器,即对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
         
        self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16 = False)  # discriminator for domain a
        self.dis_b = self.dis_a # discriminator for domain b

        # load teachers 加载教师模型
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            # 构建教师模型
            teacher_names = teacher_name.split(',')
            teacher_model = nn.ModuleList()
            teacher_count = 0
      
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)
                if 'stride' in config_tmp:
                    stride = config_tmp['stride'] 
                else:
                    stride = 2
                # 网络搭建
                model_tmp = ft_net(ID_class, stride = stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                teacher_model_tmp.model.fc = nn.Sequential()  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count +=1
            self.teacher_model = teacher_model
            # 选择是否使用bn
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)
        # 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        if hyperparameters['single']=='edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']  # erasing_p表示随机擦除的概率
        # 随机擦除矩形区域的一些像素,数据增强
        self.single_re = RandomErasing(probability = self.erasing_p, mean=[0.0, 0.0, 0.0])

        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1
        # Setup the optimizers 设置优化器的参数
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) #+ list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) #+ list(self.gen_b.parameters())
        # 使用Adam优化器
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        # id params
        if hyperparameters['ID_style']=='PCB':
            ignored_params = (list(map(id, self.id_a.classifier0.parameters() ))
                            +list(map(id, self.id_a.classifier1.parameters() ))
                            +list(map(id, self.id_a.classifier2.parameters() ))
                            +list(map(id, self.id_a.classifier3.parameters() ))
                            )
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier0.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier3.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        elif hyperparameters['ID_style']=='AB':
            ignored_params = (list(map(id, self.id_a.classifier1.parameters()))
                            + list(map(id, self.id_a.classifier2.parameters())))
            # 获得基本的配置参数,如学习率
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters() ))
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        
        # 选择各个网络优化的策略
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(size_average=False)
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1")

    def to_re(self, x):
        out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3))
        out = out.cuda()
        for i in range(x.size(0)):
            out[i,:,:,:] = self.single_re(x[i,:,:,:])
        return out

    def recon_criterion(self, input, target):
        diff = input - target.detach()
        return torch.mean(torch.abs(diff[:]))

    def recon_criterion_sqrt(self, input, target):
        diff = input - target
        return torch.mean(torch.sqrt(torch.abs(diff[:])+1e-8))

    def recon_criterion2(self, input, target):
        diff = input - target
        return torch.mean(diff[:]**2)

    def recon_cos(self, input, target):
        cos = torch.nn.CosineSimilarity()
        cos_dis = 1 - cos(input, target)
        return torch.mean(cos_dis[:])
    
    # 送入x_a,x_b两张图片(来自训练集的不同ID)
    def forward(self, x_a, x_b, xp_a, xp_b):
        # 通过st 编码器,编码成两个st code:
        # s_a[batch,128,64,32]
        # s_b[batch,128,64,32]
        # single会根据参数设定判断是否转化为灰度图
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        
        # f代表的是经过ap编码器得到的ap code,
        # p表示对身份的预测
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))
        
        # 进行解码G操作,这里的x_a,与x_b进行ID交叉生成图片
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        
        # 进行解码G操作,这里的x_a x_b进行自身同构生成图片
        x_a_recon = self.gen_a.decode(s_a, f_a)
        x_b_recon = self.gen_b.decode(s_b, f_b)
      
        # xp_a与x_a同ID不同图片  xp_b与x_b同ID不同图片
        # 进行外观编码
        fp_a, pp_a = self.id_a(scale2(xp_a))
        fp_b, pp_b = self.id_b(scale2(xp_b))
        # 进行解码G操作,这里的xp_a,与xp_b进行ID同构生成图片
        x_a_recon_p = self.gen_a.decode(s_a, fp_a)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)

        # Random Erasing only effect the ID and PID loss.
        # 进行像素擦除,后ap code编码
        if self.erasing_p > 0:
            x_a_re = self.to_re(scale2(x_a.clone()))
            x_b_re = self.to_re(scale2(x_b.clone()))
            xp_a_re = self.to_re(scale2(xp_a.clone()))
            xp_b_re = self.to_re(scale2(xp_b.clone()))
            _, p_a = self.id_a(x_a_re)
            _, p_b = self.id_b(x_b_re)
            # encode the same ID different photo
            _, pp_a = self.id_a(xp_a_re) 
            _, pp_b = self.id_b(xp_b_re)
            
        # 混合合成图片:x_ab[images_a的st,images_b的ap]    混合合成图片x_ba[images_b的st,images_a的ap]
        # s_a[输入图片images_a经过Es编码得到的 st code]     s_b[输入图片images_b经过Es编码得到的 st code]
        # f_a[输入图片images_a经过Ea编码得到的 ap code]     f_b[输入图片images_b经过Ea编码得到的 ap code]
        # p_a[输入图片images_a经过Ea编码进行身份ID的预测]   p_b[输入图片images_b经过Ea编码进行身份ID的预测]
        # pp_a[输入图片pos_a经过Ea编码进行身份ID的预测]     pp_b[输入图片pos_b经过Ea编码进行身份ID的预测]
        # x_a_recon[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样]
        # x_b_recon[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样]
        # x_a_recon_p[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样]
        # x_b_recon_p[输入图片images_a(s_a)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样]

        return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p

    def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b, xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu):
        # ppa, ppb is the same person
        self.gen_opt.zero_grad()
        self.id_opt.zero_grad()
 
        # no gradient
        # 对合成x_ba与x_ab分别进行一份拷贝
        x_ba_copy = Variable(x_ba.data, requires_grad=False)
        x_ab_copy = Variable(x_ab.data, requires_grad=False)

        rand_num = random.uniform(0,1)
        #################################
        # encode structure
        if hyperparameters['use_encoder_again']>=rand_num:
            # encode again (encoder is tuned, input is fixed)
            s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy))
            s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy))
        else:
            # copy the encoder
            self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
            self.enc_content_copy = self.enc_content_copy.eval()
            # encode again (encoder is fixed, input is tuned)
            s_a_recon = self.enc_content_copy(self.single(x_ab))
            s_b_recon = self.enc_content_copy(self.single(x_ba))

        #################################
        # encode appearance
        self.id_a_copy = copy.deepcopy(self.id_a)
        self.id_a_copy = self.id_a_copy.eval()
        if hyperparameters['train_bn']:
            self.id_a_copy = self.id_a_copy.apply(train_bn)
        self.id_b_copy = self.id_a_copy
        # encode again (encoder is fixed, input is tuned)
        # 对混合生成的图片x_ba,x_ab记行Es编码操作,同时对身份进行鉴别
        # f_a_recon,f_b_recon表示的ap code,p_a_recon,p_b_recon表示对身份的鉴别 Lrecon^code2
        f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba))
        f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab))

        # teacher Loss
        #  Tune the ID model
        log_sm = nn.LogSoftmax(dim=1)
        if hyperparameters['teacher_w'] >0 and hyperparameters['teacher'] != "":
            if hyperparameters['ID_style'] == 'normal':
                _, p_a_student = self.id_a(scale2(x_ba_copy))
                p_a_student = log_sm(p_a_student)
                p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_b_student = self.id_b(scale2(x_ab_copy))
                p_b_student = log_sm(p_b_student)
                
                p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0)
            elif hyperparameters['ID_style'] == 'AB':
                # normal teacher-student loss
                # BA -> LabelA(smooth) + LabelB(batchB) Lprim Lrecon^code1
                _, p_ba_student = self.id_a(scale2(x_ba_copy))# f_a, s_b
                p_a_student = log_sm(p_ba_student[0])
                # 计算离散距离,可以理解为p_a_student与p_a_teacher每个元素的距离和,然后除以p_a_student.size(0)取平均值
	            # 就是说学生网络(Ea)的预测越与教师网络结果相同,则是最好的
                with torch.no_grad():
                    p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a
                p_b_student = log_sm(p_ab_student[0])
                with torch.no_grad():

                    p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0)

                # branch b loss
                # here we give different label Lfine
                # 求得学生网络(Ea)预测ID结果,与实际ID之间的损失,注意这里取的是p_ba_student[1],表示是细节身份特征
                loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion(p_ab_student[1], l_a)
                self.loss_teacher = hyperparameters['T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B
        else:
            self.loss_teacher = 0.0

        # auto-encoder image reconstruction Limg1^recon Limg2^recon
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a)
        self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b)

        # feature reconstruction
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_f_a = self.recon_criterion(f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0
        self.loss_gen_recon_f_b = self.recon_criterion(f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0

        x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # ID loss AND Tune the Generated image
        if hyperparameters['ID_style']=='PCB':
            self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
            self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
            self.loss_gen_recon_id = self.PCB_loss(p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b)
        elif hyperparameters['ID_style']=='AB':
            weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w']
            self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
                         + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )
            self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(pp_b[0], l_b) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) )
            self.loss_gen_recon_id = self.id_criterion(p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b)
        else:
            self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(p_b, l_b)
            self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(pp_b, l_b)
            self.loss_gen_recon_id = self.id_criterion(p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)

        #print(f_a_recon, f_a)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        if num_gpu>1:
            self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss(self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss(self.dis_b, x_ab)
        else:
            self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # 设置每个loss所占的权重
        if iteration > hyperparameters['warm_iter']:
            hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w'])
            hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w'])
            hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_x_cyc_w'] = min(hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])

        if iteration > hyperparameters['warm_teacher_iter']:
            hyperparameters['teacher_w'] += hyperparameters['warm_scale']
            hyperparameters['teacher_w'] = min(hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \       #GAN损失
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \   #GAN图像重构损失
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \   #编码损失
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['id_w'] * self.loss_id + \
                              hyperparameters['pid_w'] * self.loss_pid + \
                              hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \  #id损失
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['teacher_w'] * self.loss_teacher             #辨别模块损失
        if self.fp16:
            with amp.scale_loss(self.loss_gen_total, [self.gen_opt,self.id_opt]) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()
            self.id_opt.step()
        else:
            self.loss_gen_total.backward()
            self.gen_opt.step()
            self.id_opt.step()
        print("L_total: %.4f, L_gan: %.4f,  Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \
                                                        hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \
                                                        hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \
                                                        hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \
                                                        hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \
                                                        hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \
                                                        hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \
                                                        hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \
                                                        hyperparameters['id_w'] * self.loss_id,\
                                                        hyperparameters['pid_w'] * self.loss_pid,\
                                                        hyperparameters['teacher_w'] * self.loss_teacher )  )
Beispiel #2
0
class ERGAN_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(ERGAN_Trainer, self).__init__()
        lr_G = hyperparameters['lr_G']
        lr_D = hyperparameters['lr_D']
        print(lr_D, lr_G)
        self.fp16 = hyperparameters['fp16']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.gen_b.enc_content = self.gen_a.enc_content  # content share weight
        #self.gen_b.enc_style = self.gen_a.enc_style
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.a = hyperparameters['gen']['new_size'] / 224
        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_D,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_G,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        if self.fp16:
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.gen_b = self.gen_b.cuda()
            self.dis_b = self.dis_b.cuda()
            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):

        input = input.type_as(target)
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a, x_b)
        x_ab = self.gen_b.decode(c_a, s_b, x_a)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        #mask = torch.ones(x_a.shape).cuda()
        block_a = x_a.clone()
        block_b = x_b.clone()
        block_a[:, :,
                round(self.a * 92):round(self.a * 144),
                round(self.a * 48):round(self.a * 172)] = 0
        block_b[:, :,
                round(self.a * 92):round(self.a * 144),
                round(self.a * 48):round(self.a * 172)] = 0

        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime, x_a)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime, x_b)

        # decode (cross domain)
        # decode random
        # x_ba_randn = self.gen_a.decode(c_b, s_a, x_b)
        # x_ab_randn = self.gen_b.decode(c_a, s_b, x_a)
        # decode real
        x_ba_real = self.gen_a.decode(c_b, s_a_prime, x_b)
        x_ab_real = self.gen_b.decode(c_a, s_b_prime, x_a)

        block_ba_real = x_ba_real.clone()
        block_ab_real = x_ab_real.clone()
        block_ba_real[:, :,
                      round(self.a * 92):round(self.a * 144),
                      round(self.a * 48):round(self.a * 172)] = 0
        block_ab_real[:, :,
                      round(self.a * 92):round(self.a * 144),
                      round(self.a * 48):round(self.a * 172)] = 0
        # encode again
        # c_b_recon, s_a_recon = self.gen_a.encode(x_ba_randn)
        # c_a_recon, s_b_recon = self.gen_b.encode(x_ab_randn)

        c_b_real_recon, s_a_prime_recon = self.gen_a.encode(x_ba_real)
        c_a_real_recon, s_b_prime_recon = self.gen_b.encode(x_ab_real)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_real_recon, s_a_prime,
            x_a) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_real_recon, s_b_prime,
            x_b) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_res_a = self.recon_criterion(
            block_ab_real, block_a)
        self.loss_gen_recon_res_b = self.recon_criterion(
            block_ba_real, block_b)
        self.loss_gen_recon_x_a_re = self.recon_criterion(
            x_a_recon[:, :,
                      round(self.a * 92):round(self.a * 144),
                      round(self.a * 48):round(self.a * 172)],
            x_a[:, :,
                round(self.a * 92):round(self.a * 144),
                round(self.a * 48):round(self.a * 172)])
        self.loss_gen_recon_x_b_re = self.recon_criterion(
            x_b_recon[:, :,
                      round(self.a * 92):round(self.a * 144),
                      round(self.a * 48):round(self.a * 172)],
            x_b[:, :,
                round(self.a * 92):round(self.a * 144),
                round(self.a * 48):round(self.a * 172)]
        )  # both celebA and MeGlass: [92:144, 48:172]
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)

        self.loss_gen_recon_s_a_prime = self.recon_criterion(
            s_a_prime_recon, s_a_prime)
        self.loss_gen_recon_s_b_prime = self.recon_criterion(
            s_b_prime_recon, s_b_prime)

        self.loss_gen_recon_c_a_real = self.recon_criterion(
            c_a_real_recon, c_a)
        self.loss_gen_recon_c_b_real = self.recon_criterion(
            c_b_real_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a_real = self.dis_a.calc_gen_loss(x_ba_real)
        self.loss_gen_adv_b_real = self.dis_b.calc_gen_loss(x_ab_real)
        # domain-invariant perceptual loss
        # self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba_randn, x_b) if hyperparameters['vgg_w'] > 0 else 0
        # self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab_randn, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a_real + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_x_w_re'] * self.loss_gen_recon_x_b_re + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b_prime + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b_real + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b+\
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_x_w_re'] * self.loss_gen_recon_x_a_re + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a_prime + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a_real + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b_real +\
                              hyperparameters['recon_x_w_res'] * self.loss_gen_recon_res_b + \
                              hyperparameters['recon_x_w_res'] * self.loss_gen_recon_res_b

        if self.fp16:
            with amp.scale_loss(self.loss_gen_total,
                                self.gen_opt) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()

        else:
            self.loss_gen_total.backward()
            self.gen_opt.step()

        #loss_gan = hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b)
        loss_gan_real = hyperparameters['gan_w'] * (self.loss_gen_adv_a_real +
                                                    self.loss_gen_adv_b_real)
        loss_x = hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a +
                                                 self.loss_gen_recon_x_b)
        loss_x_re = hyperparameters['recon_x_w_re'] * (
            self.loss_gen_recon_x_a_re + self.loss_gen_recon_x_b_re)
        #loss_x_res = hyperparameters['recon_x_w_res'] * (self.loss_gen_recon_res_a + self.loss_gen_recon_res_b)
        #loss_s = hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b)
        #loss_c = hyperparameters['recon_c_w'] * (self.loss_gen_recon_c_a + self.loss_gen_recon_c_b)
        loss_s_prime = hyperparameters['recon_s_w'] * (
            self.loss_gen_recon_s_a_prime + self.loss_gen_recon_s_b_prime)
        loss_c_real = hyperparameters['recon_c_w'] * (
            self.loss_gen_recon_c_a_real + self.loss_gen_recon_c_b_real)
        loss_x_cyc = hyperparameters['recon_x_cyc_w'] * (
            self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b)

        #loss_vgg = hyperparameters['vgg_w'] * (self.loss_gen_vgg_a + self.loss_gen_vgg_b)
        print(
            '||total:%.2f||gan_real:%.2f||x:%.2f||x_re:%.2f||s_prime:%.4f||c_real:%.2f||x_cyc:%.4f||'
            % (self.loss_gen_total, loss_gan_real, loss_x, loss_x_re,
               loss_s_prime, loss_c_real, loss_x_cyc))

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        # s_a1 = Variable(self.s_a)
        # s_b1 = Variable(self.s_b)
        # s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        # s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        #x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        x_a_recon, x_b_recon, x_bab, x_ab, x_ba, x_aba = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))

            x_a_recon.append(
                self.gen_a.decode(c_a, s_a_fake, x_a[i].unsqueeze(0)))
            x_b_recon.append(
                self.gen_b.decode(c_b, s_b_fake, x_b[i].unsqueeze(0)))

            x_ba_tmp = self.gen_a.decode(c_b, s_a_fake, x_b[i].unsqueeze(0))
            x_ab_tmp = self.gen_b.decode(c_a, s_b_fake, x_a[i].unsqueeze(0))
            x_ba.append(x_ba_tmp)
            x_ab.append(x_ab_tmp)

            c_b_recon, _ = self.gen_a.encode(x_ba_tmp)
            c_a_recon, _ = self.gen_b.encode(x_ab_tmp)

            x_aba.append(
                self.gen_a.decode(c_a_recon, s_a_fake, x_a[i].unsqueeze(0)))
            x_bab.append(
                self.gen_b.decode(c_b_recon, s_b_fake, x_b[i].unsqueeze(0)))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)

        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
        x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba)
        self.train()

        return x_a, x_a_recon, x_ab, x_ba, x_b, x_aba, x_b, x_b_recon, x_ba, x_ab, x_a, x_bab

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        # s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (cross domain)

        x_ba_real = self.gen_a.decode(c_b, s_a_prime, x_b)
        x_ab_real = self.gen_b.decode(c_a, s_b_prime, x_a)
        # D loss
        # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba_randn.detach(), x_a)
        # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab_randn.detach(), x_b)

        self.loss_dis_a_real = self.dis_a.calc_dis_loss(
            x_ba_real.detach(), x_a)
        self.loss_dis_b_real = self.dis_b.calc_dis_loss(
            x_ab_real.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * (
            self.loss_dis_a_real + self.loss_dis_b_real)

        if self.fp16:
            with amp.scale_loss(self.loss_dis_total,
                                self.dis_opt) as scaled_loss:
                scaled_loss.backward()
            self.dis_opt.step()
        else:
            self.loss_dis_total.backward()
            self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations -
                                           375000)  #fine_tune -370000
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations - 375000)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #3
0
class DGNet_Trainer(nn.Module):
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        # 从配置文件获取生成模型和鉴别模型的学习率
        lr_g = hyperparameters['lr_g']
        lr_d = hyperparameters['lr_d']

        # ID 类别
        ID_class = hyperparameters['ID_class']

        # 是否设置使用fp16,
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'],
                              hyperparameters['gen'],
                              fp16=False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b
        '''
        ft_netAB :   Ea
        '''
        # ID_stride: 外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2
        # id_a : 外观编码器  ->  Ea
        if hyperparameters['ID_style'] == 'PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style'] == 'AB':
            self.id_a = ft_netAB(ID_class,
                                 stride=hyperparameters['ID_stride'],
                                 norm=hyperparameters['norm_id'],
                                 pool=hyperparameters['pool'])
        else:
            self.id_a = ft_net(ID_class,
                               norm=hyperparameters['norm_id'],
                               pool=hyperparameters['pool'])  # return 2048 now

        self.id_b = self.id_a  # 对图片b的操作与图片a的操作一致

        # 判别器,使用的是一个多尺寸的判别器,就是对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
        # 经过网络3个缩放,,分别为:[batch_size, 1, 64, 32],[batch_size, 1, 32, 16],[batch_size, 1, 16, 8]
        self.dis_a = MsImageDis(3, hyperparameters['dis'],
                                fp16=False)  # discriminator for domain a
        self.dis_b = self.dis_a  # discriminator for domain b

        # load teachers
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)

            # 加载多个老师模型
            teacher_names = teacher_name.split(',')
            # 构建老师模型
            teacher_model = nn.ModuleList()  # 初始化为空,接下来开始填充
            teacher_count = 0
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)

                # 池化层的stride
                if 'stride' in config_tmp:
                    stride = config_tmp['stride']
                else:
                    stride = 2

                # 开始搭建网络
                model_tmp = ft_net(ID_class, stride=stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                teacher_model_tmp.model.fc = nn.Sequential(
                )  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                # teacher_model_tmp,[3, 224, 224]

                # 使用fp16
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp,
                                                       opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval(
                ))  # 第一个填充为 teacher_model_tmp.cuda().eval()
                teacher_count += 1
            self.teacher_model = teacher_model

            # 是否使用batchnorm
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)
        # 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        # 因为Es 需要使用灰度图, 所以single 用来将图片转化为灰度图
        if hyperparameters['single'] == 'edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        # arasing_p 随机擦除的概率
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']
        # 对图片中的某一随机区域进行擦除,具体:将该区域的像素值设置为均值
        self.single_re = RandomErasing(probability=self.erasing_p,
                                       mean=[0.0, 0.0, 0.0])

        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(
            self.dis_a.parameters())  #+ list(self.dis_b.parameters())
        gen_params = list(
            self.gen_a.parameters())  #+ list(self.gen_b.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_d,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_g,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        # id params
        # 修改 id_a模型中分类器的学习率
        if hyperparameters['ID_style'] == 'PCB':
            ignored_params = (
                list(map(id, self.id_a.classifier0.parameters())) +
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())) +
                list(map(id, self.id_a.classifier3.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier0.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier3.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        elif hyperparameters['ID_style'] == 'AB':
            ignored_params = (
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        # 生成器和判别器中的优化策略(学习率的更新策略)
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(
            size_average=False)  # 生成主要特征: Lprim
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        # 保存当前的模型,是为了提高计算效率
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a,
                                                    self.id_opt,
                                                    opt_level="O1")

    def to_re(self, x):
        out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3))
        out = out.cuda()
        for i in range(x.size(0)):
            out[i, :, :, :] = self.single_re(x[i, :, :, :])  # 修改对应像素值
        return out

    def recon_criterion(self, input, target):  # 重构损失函数
        diff = input - target.detach()  # 对应像素之间相减
        return torch.mean(torch.abs(diff[:]))

    def recon_criterion_sqrt(self, input, target):  # 重构损失平方函数
        diff = input - target
        return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8))

    def recon_criterion2(self, input, target):  # 重构损失平方求均值
        diff = input - target
        return torch.mean(diff[:]**2)

    def recon_cos(self, input, target):  # 重构均值余弦相似度损失
        cos = torch.nn.CosineSimilarity()
        cos_dis = 1 - cos(input, target)
        return torch.mean(cos_dis[:])

    def forward(self, x_a, x_b, xp_a, xp_b):
        '''
        一共输入4张图片
        :param x_a: :param xp_a:    id 相同
        :param x_b: :param xp_b:    id 相同

        为什么要输入四张图片:
        因为一个完整的DG_Net输入需要三张图片:id1, id2, id1正例
        如果一次输入3张图片,那么训练两组数据就需要6张图片
        而如果一次输入四张图片如:id1,id1正例, id2,id2正例
        那么就可以组成两组数据:id1,id2,id1正例  和  id2,id1,d2正例
        这样就节省了两张图片。
        '''

        # self.gen_a.encode :-> Es
        # single : 转化为灰度图
        s_a = self.gen_a.encode(self.single(
            x_a))  # shape: [batch_size, 128, 64, 32]    -> a st code
        s_b = self.gen_b.encode(self.single(
            x_b))  # shape:  [batch_size, 128, 64, 32]    -> b st code
        # self.id_a : -> Ea
        f_a, p_a = self.id_a(
            scale2(x_a))  #                                      -> a ap code
        f_b, p_b = self.id_b(scale2(x_b))
        # f shape:[batch_size, 2024*4=8192]           #                                      -> b ap code
        # p[0] shape:[batch_size, class_num=751], p[1] shape:[batch_size, class_num=751]     -> probability distribution

        # self.gen_a.decode -> D
        x_ba = self.gen_a.decode(
            s_b, f_a)  # shape: [batch_size, 3, 256, 128]     -> a-ap + b-st
        x_ab = self.gen_b.decode(
            s_a, f_b)  # shape: [batch_size, 3, 256, 128]     -> a-st + b-ap

        x_a_recon = self.gen_a.decode(
            s_a, f_a)  # shape: [batch_size, 3, 256, 128]     -> a-ap + a-st
        x_b_recon = self.gen_b.decode(
            s_b, f_b)  # shape: [batch_size, 3, 256, 128]     -> b-ap + b-st
        fp_a, pp_a = self.id_a(
            scale2(xp_a)
        )  #                                      -> x_a ap code, pro-dis
        fp_b, pp_b = self.id_b(
            scale2(xp_b)
        )  #                                      -> x_b ap code, pro-dis
        # decode the same person
        x_a_recon_p = self.gen_a.decode(
            s_a, fp_a)  # shape: [batch_size, 3, 256, 128]     -> a-st + x_a-ap
        x_b_recon_p = self.gen_b.decode(
            s_b, fp_b)  # shape: [batch_size, 3, 256, 128]     -> b-st + x_b-ap

        # Random Erasing only effect the ID and PID loss.
        if self.erasing_p > 0:
            x_a_re = self.to_re(scale2(x_a.clone()))
            x_b_re = self.to_re(scale2(x_b.clone()))
            xp_a_re = self.to_re(scale2(xp_a.clone()))
            xp_b_re = self.to_re(scale2(xp_b.clone()))
            _, p_a = self.id_a(x_a_re)  # 经过随机擦除之后再预测概率分布
            _, p_b = self.id_b(x_b_re)
            # encode the same ID different photo
            _, pp_a = self.id_a(xp_a_re)
            _, pp_b = self.id_b(xp_b_re)

        return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p
        '''
        输入3张图片训练一次
        s_a = self.gen_a.encode(self.single(x_a))
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))
        fp_a, pp_a = self.id_a(scale2(xp_a))
        x_a_recon = self.gen_a.decode(s_a, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        x_a_recon_p = self.gen_a.decode(s_a, fp_a)
        
        
        输入3张图片训练一次
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))
        fp_b, pp_b = self.id_b(scale2(xp_b))
        x_ba = self.gen_a.decode(s_b, f_a)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)
        '''

    def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b,
                   x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b,
                   xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu):
        # ppa, ppb is the same person
        # pp_a: 输入图片a经过Ea编码进行身份预测  pp_b:输入图片b经过Ea编码进行身份预测
        self.gen_opt.zero_grad()
        self.id_opt.zero_grad()

        # no gradient
        x_ba_copy = Variable(x_ba.data, requires_grad=False)
        x_ab_copy = Variable(x_ab.data, requires_grad=False)

        rand_num = random.uniform(0, 1)
        #################################
        # encode structure
        if hyperparameters['use_encoder_again'] >= rand_num:
            # encode again (encoder is tuned, input is fixed)
            s_a_recon = self.gen_b.enc_content(
                self.single(x_ab_copy))  # 对x_ab经过Es进行编码 得到st code
            s_b_recon = self.gen_a.enc_content(
                self.single(x_ba_copy))  # 对x_ba经过Es进行编码 得到st code
        else:
            # copy the encoder
            # 这里是shencopy
            self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
            self.enc_content_copy = self.enc_content_copy.eval()
            # encode again (encoder is fixed, input is tuned)
            s_a_recon = self.enc_content_copy(self.single(x_ab))
            s_b_recon = self.enc_content_copy(self.single(x_ba))

        #################################
        # encode appearance
        self.id_a_copy = copy.deepcopy(self.id_a)
        self.id_a_copy = self.id_a_copy.eval()
        if hyperparameters['train_bn']:
            self.id_a_copy = self.id_a_copy.apply(train_bn)
        self.id_b_copy = self.id_a_copy
        # encode again (encoder is fixed, input is tuned)
        f_a_recon, p_a_recon = self.id_a_copy(
            scale2(x_ba))  # 对合成的图片 x_ba进行Ea编码和身份预测
        f_b_recon, p_b_recon = self.id_b_copy(
            scale2(x_ab))  # 对合成的图片 x_ab进行Ea编码和身份预测

        # teacher Loss
        #  Tune the ID model
        log_sm = nn.LogSoftmax(dim=1)
        if hyperparameters['teacher_w'] > 0 and hyperparameters[
                'teacher'] != "":
            if hyperparameters['ID_style'] == 'normal':
                _, p_a_student = self.id_a(scale2(x_ba_copy))
                p_a_student = log_sm(p_a_student)
                p_a_teacher = predict_label(
                    self.teacher_model,
                    scale2(x_ba_copy),
                    num_class=hyperparameters['ID_class'],
                    alabel=l_a,
                    slabel=l_b,
                    teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_b_student = self.id_b(scale2(x_ab_copy))
                p_b_student = log_sm(p_b_student)
                p_b_teacher = predict_label(
                    self.teacher_model,
                    scale2(x_ab_copy),
                    num_class=hyperparameters['ID_class'],
                    alabel=l_b,
                    slabel=l_a,
                    teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)
            elif hyperparameters['ID_style'] == 'AB':
                # normal teacher-student loss
                # BA -> LabelA(smooth) + LabelB(batchB)
                # 合成的图片经过身份鉴别器,得到每个ID可能的概率
                _, p_ba_student = self.id_a(scale2(x_ba_copy))  # f_a, s_b
                p_a_student = log_sm(p_ba_student[0])  # 两个身份预测的第一个预测值
                with torch.no_grad():
                    p_a_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ba_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_a,
                        slabel=l_b,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(
                        0)  # 在老师模型监督下,x_ba身份预测损失  # 公式(8)

                _, p_ab_student = self.id_b(scale2(x_ab_copy))  # f_b, s_a
                p_b_student = log_sm(p_ab_student[0])
                with torch.no_grad():
                    p_b_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ab_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_b,
                        slabel=l_a,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(
                        0)  # 在老师模型监督下,x_ab身份预测损失   # 公式 (8)

                # branch b loss
                # here we give different label
                # 用Ea的第二个身份预测值计算身份预测损失,
                # 这就相当于是Ea输出两个向量,一个用来计算与老师模型的身份预测损失,另一个用来计算自身身份预测损失
                loss_B = self.id_criterion(
                    p_ba_student[1], l_b) + self.id_criterion(
                        p_ab_student[1], l_a)  # l_b 是b的label    # 公式(9)
                self.loss_teacher = hyperparameters[
                    'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B
        else:
            self.loss_teacher = 0.0

        # auto-encoder image reconstruction
        self.loss_gen_recon_x_a = self.recon_criterion(
            x_a_recon, x_a)  # x_a_recon,  a 的 ap 和 a 的 st     # 公式 (1)
        self.loss_gen_recon_x_b = self.recon_criterion(
            x_b_recon, x_b)  # x_b_recon,  b 的 ap 和 b 的 st     # 公式 (1)
        self.loss_gen_recon_xp_a = self.recon_criterion(
            x_a_recon_p, x_a)  # x_a_recon_p, a 的 st 和 pos_a 的 ap   # 公式 (2)
        self.loss_gen_recon_xp_b = self.recon_criterion(
            x_b_recon_p, x_b)  # x_b_recon_p, b 的 st 和 pos_b 的 ap   # 公式 (2)

        # feature reconstruction
        self.loss_gen_recon_s_a = self.recon_criterion(
            s_a_recon, s_a) if hyperparameters[
                'recon_s_w'] > 0 else 0  # s_a_recon, 合成图片x_ab 的st   # 公式 (5)
        self.loss_gen_recon_s_b = self.recon_criterion(
            s_b_recon, s_b) if hyperparameters[
                'recon_s_w'] > 0 else 0  # s_b_recon, 合成图片x_ba 的st   # 公式 (5)
        self.loss_gen_recon_f_a = self.recon_criterion(
            f_a_recon, f_a) if hyperparameters[
                'recon_f_w'] > 0 else 0  # f_a_recon, 合成图片x_ba 的ap   # 公式 (4)
        self.loss_gen_recon_f_b = self.recon_criterion(
            f_b_recon, f_b) if hyperparameters[
                'recon_f_w'] > 0 else 0  # f_b_recon, 合成图片x_ab 的ap   # 公式 (4)

        x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters[
            'recon_x_cyc_w'] > 0 else None  # x_aba,ab 的 st 与 ba 的 ap
        x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters[
            'recon_x_cyc_w'] > 0 else None  # x_bab,ba 的 st 与 ab 的 ap

        # ID loss AND Tune the Generated image
        if hyperparameters['ID_style'] == 'PCB':
            self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
            self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
            self.loss_gen_recon_id = self.PCB_loss(
                p_a_recon, l_a) + self.PCB_loss(
                    p_b_recon, l_b)  # x_ba 与l_a, x_ab 与l_b 的身份预测损失
        elif hyperparameters['ID_style'] == 'AB':
            weight_B = hyperparameters['teacher_w'] * hyperparameters[
                'B_w']  # teather_w = 1.0, B_w = 0.2
            self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
                         + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )  # a和b的身份预测损失             # 公式(3)
            self.loss_pid = self.id_criterion(
                pp_a[0], l_a) + self.id_criterion(
                    pp_b[0], l_b)  # pos_a 和 pos_b 的身份预测损失  # 公式(3)
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon[0], l_a) + self.id_criterion(
                    p_b_recon[0], l_b)  # 不太懂为什么用了b的st   却要判定为a的label 公式(7)
        else:
            self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(
                p_b, l_b)
            self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(
                pp_b, l_b)
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)

        #print(f_a_recon, f_a)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters[
                'recon_x_cyc_w'] > 0 else 0  # x_aba,ab 的 st 与 ba 的 ap
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters[
                'recon_x_cyc_w'] > 0 else 0  # x_bab,ba 的 st 与 ab 的 ap
        # GAN loss
        if num_gpu > 1:
            self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss(
                self.dis_a, x_ba)  # 公式(6)
            self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss(
                self.dis_b, x_ab)  # 公式(6)
        else:
            self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab)

        # domain-invariant perceptual loss
        # 使用vgg,对合成图片和真实图片进行特征提取,然后计算两个特征loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        # 每个loss所占的权重
        if iteration > hyperparameters['warm_iter']:
            hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_x_cyc_w'] = min(
                hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])

        if iteration > hyperparameters['warm_teacher_iter']:
            hyperparameters['teacher_w'] += hyperparameters['warm_scale']
            hyperparameters['teacher_w'] = min(
                hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['id_w'] * self.loss_id + \
                              hyperparameters['pid_w'] * self.loss_pid + \
                              hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['teacher_w'] * self.loss_teacher
        # 增大计算效率
        if self.fp16:
            with amp.scale_loss(self.loss_gen_total,
                                [self.gen_opt, self.id_opt]) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()
            self.id_opt.step()
        else:
            self.loss_gen_total.backward()  # 后向传播
            self.gen_opt.step()
            self.id_opt.step()
        print("L_total: %.4f, L_gan: %.4f,  Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \
                                                        hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \
                                                        hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \
                                                        hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \
                                                        hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \
                                                        hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \
                                                        hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \
                                                        hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \
                                                        hyperparameters['id_w'] * self.loss_id,\
                                                        hyperparameters['pid_w'] * self.loss_pid,\
hyperparameters['teacher_w'] * self.loss_teacher )  )

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def PCB_loss(self, inputs, labels):
        loss = 0.0
        for part in inputs:
            loss += self.id_criterion(part, labels)
        return loss / len(inputs)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0)))
            s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0)))
            f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0)))
            f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0)))
            x_a_recon.append(self.gen_a.decode(s_a, f_a))
            x_b_recon.append(self.gen_b.decode(s_b, f_b))
            x_ba = self.gen_a.decode(s_b, f_a)
            x_ab = self.gen_b.decode(s_a, f_b)
            x_ba1.append(x_ba)
            x_ab1.append(x_ab)
            #cycle
            s_b_recon = self.gen_a.enc_content(self.single(x_ba))
            s_a_recon = self.gen_b.enc_content(self.single(x_ab))
            f_a_recon, _ = self.id_a(scale2(x_ba))
            f_b_recon, _ = self.id_b(scale2(x_ab))
            x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon))
            x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
        x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1)
        self.train()

        return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1

    def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters,
                   num_gpu):  # 对判别器进行更新
        self.dis_opt.zero_grad()
        # D loss
        if num_gpu > 1:
            self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss(
                self.dis_a, x_ba.detach(), x_a)  # lsgan 损失
            self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss(
                self.dis_b, x_ab.detach(), x_b)
        else:
            self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(
                self.dis_a, x_ba.detach(), x_a)
            self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(
                self.dis_b, x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        print("DLoss: %.4f" % self.loss_dis_total,
              "Reg: %.4f" % (reg_a + reg_b))
        if self.fp16:
            with amp.scale_loss(self.loss_dis_total,
                                self.dis_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):  # 调整学习率
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.id_scheduler is not None:
            self.id_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):  # load 网络
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b = self.gen_a
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b = self.dis_a
        # Load ID dis
        last_model_name = get_model_list(checkpoint_dir, "id")
        state_dict = torch.load(last_model_name)
        self.id_a.load_state_dict(state_dict['a'])
        self.id_b = self.id_a
        # Load optimizers
        try:
            state_dict = torch.load(
                os.path.join(checkpoint_dir, 'optimizer.pt'))
            self.dis_opt.load_state_dict(state_dict['dis'])
            self.gen_opt.load_state_dict(state_dict['gen'])
            self.id_opt.load_state_dict(state_dict['id'])
        except:
            pass
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations, num_gpu=1):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict()}, gen_name)
        if num_gpu > 1:
            torch.save({'a': self.dis_a.module.state_dict()}, dis_name)
        else:
            torch.save({'a': self.dis_a.state_dict()}, dis_name)
        torch.save({'a': self.id_a.state_dict()}, id_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'id': self.id_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #4
0
class DGNet_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(DGNet_Trainer, self).__init__()
        lr_g = hyperparameters['lr_g']
        lr_d = hyperparameters['lr_d']
        ID_class = hyperparameters['ID_class']
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'],
                              hyperparameters['gen'],
                              fp16=False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b

        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2

        if hyperparameters['ID_style'] == 'PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style'] == 'AB':
            self.id_a = ft_netAB(ID_class,
                                 stride=hyperparameters['ID_stride'],
                                 norm=hyperparameters['norm_id'],
                                 pool=hyperparameters['pool'])
        else:
            self.id_a = ft_net(ID_class,
                               norm=hyperparameters['norm_id'],
                               pool=hyperparameters['pool'])  # return 2048 now

        self.id_b = self.id_a
        self.dis_a = MsImageDis(3, hyperparameters['dis'],
                                fp16=False)  # discriminator for domain a
        self.dis_b = self.dis_a  # discriminator for domain b

        # load teachers
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            teacher_names = teacher_name.split(',')
            teacher_model = nn.ModuleList()
            teacher_count = 0
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)
                if 'stride' in config_tmp:
                    stride = config_tmp['stride']
                else:
                    stride = 2
                model_tmp = ft_net(ID_class, stride=stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                teacher_model_tmp.model.fc = nn.Sequential(
                )  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp,
                                                       opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count += 1
            self.teacher_model = teacher_model
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        display_size = int(hyperparameters['display_size'])

        # RGB to one channel
        if hyperparameters['single'] == 'edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        if not 'erasing_p' in hyperparameters.keys():
            hyperparameters['erasing_p'] = 0
        self.single_re = RandomErasing(
            probability=hyperparameters['erasing_p'], mean=[0.0, 0.0, 0.0])

        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(
            self.dis_a.parameters())  #+ list(self.dis_b.parameters())
        gen_params = list(
            self.gen_a.parameters())  #+ list(self.gen_b.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_d,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_g,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        # id params
        if hyperparameters['ID_style'] == 'PCB':
            ignored_params = (
                list(map(id, self.id_a.classifier0.parameters())) +
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())) +
                list(map(id, self.id_a.classifier3.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier0.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier3.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        elif hyperparameters['ID_style'] == 'AB':
            ignored_params = (
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(size_average=False)
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a,
                                                    self.id_opt,
                                                    opt_level="O1")

    def to_re(self, x):
        out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3))
        out = out.cuda()
        for i in range(x.size(0)):
            out[i, :, :, :] = self.single_re(x[i, :, :, :])
        return out

    def recon_criterion(self, input, target):
        diff = input - target.detach()
        return torch.mean(torch.abs(diff[:]))

    def recon_criterion_sqrt(self, input, target):
        diff = input - target
        return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8))

    def recon_criterion2(self, input, target):
        diff = input - target
        return torch.mean(diff[:]**2)

    def recon_cos(self, input, target):
        cos = torch.nn.CosineSimilarity()
        cos_dis = 1 - cos(input, target)
        return torch.mean(cos_dis[:])

    def forward(self, x_a, x_b):
        self.eval()
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, _ = self.id_a(scale2(x_a))
        f_b, _ = self.id_b(scale2(x_b))
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, l_a, xp_a, x_b, l_b, xp_b, hyperparameters,
                   iteration):
        # ppa, ppb is the same person
        self.gen_opt.zero_grad()
        self.id_opt.zero_grad()
        # encode
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))
        # autodecode
        x_a_recon = self.gen_a.decode(s_a, f_a)
        x_b_recon = self.gen_b.decode(s_b, f_b)

        # encode the same ID different photo
        fp_a, pp_a = self.id_a(scale2(xp_a))
        fp_b, pp_b = self.id_b(scale2(xp_b))

        # decode the same person
        x_a_recon_p = self.gen_a.decode(s_a, fp_a)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)

        # has gradient
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        # no gradient
        x_ba_copy = Variable(x_ba.data, requires_grad=False)
        x_ab_copy = Variable(x_ab.data, requires_grad=False)

        rand_num = random.uniform(0, 1)
        #################################
        # encode structure
        if hyperparameters['use_encoder_again'] >= rand_num:
            # encode again (encoder is tuned, input is fixed)
            s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy))
            s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy))
        else:
            # copy the encoder
            self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
            self.enc_content_copy = self.enc_content_copy.eval()
            # encode again (encoder is fixed, input is tuned)
            s_a_recon = self.enc_content_copy(self.single(x_ab))
            s_b_recon = self.enc_content_copy(self.single(x_ba))

        #################################
        # encode appearance
        self.id_a_copy = copy.deepcopy(self.id_a)
        self.id_a_copy = self.id_a_copy.eval()
        if hyperparameters['train_bn']:
            self.id_a_copy = self.id_a_copy.apply(train_bn)
        self.id_b_copy = self.id_a_copy
        # encode again (encoder is fixed, input is tuned)
        f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba))
        f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab))

        # teacher Loss
        #  Tune the ID model
        log_sm = nn.LogSoftmax(dim=1)
        if hyperparameters['teacher_w'] > 0 and hyperparameters[
                'teacher'] != "":
            if hyperparameters['ID_style'] == 'normal':
                _, p_a_student = self.id_a(scale2(x_ba_copy))
                p_a_student = log_sm(p_a_student)
                p_a_teacher = predict_label(self.teacher_model,
                                            scale2(x_ba_copy))
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_b_student = self.id_b(scale2(x_ab_copy))
                p_b_student = log_sm(p_b_student)
                p_b_teacher = predict_label(self.teacher_model,
                                            scale2(x_ab_copy))
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)
            elif hyperparameters['ID_style'] == 'AB':
                # normal teacher-student loss
                # BA -> LabelA(smooth) + LabelB(batchB)
                _, p_ba_student = self.id_a(scale2(x_ba_copy))  # f_a, s_b
                p_a_student = log_sm(p_ba_student[0])
                with torch.no_grad():
                    p_a_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ba_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_a,
                        slabel=l_b,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_ab_student = self.id_b(scale2(x_ab_copy))  # f_b, s_a
                p_b_student = log_sm(p_ab_student[0])
                with torch.no_grad():
                    p_b_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ab_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_b,
                        slabel=l_a,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)

                # branch b loss
                # here we give different label
                loss_B = self.id_criterion(p_ba_student[1],
                                           l_b) + self.id_criterion(
                                               p_ab_student[1], l_a)
                self.loss_teacher = hyperparameters[
                    'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B
        else:
            self.loss_teacher = 0.0

        # decode again (if needed)
        if hyperparameters['use_decoder_again']:
            x_aba = self.gen_a.decode(
                s_a_recon,
                f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
            x_bab = self.gen_b.decode(
                s_b_recon,
                f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        else:
            self.mlp_w_copy = copy.deepcopy(self.gen_a.mlp_w)
            self.mlp_b_copy = copy.deepcopy(self.gen_a.mlp_b)
            self.dec_copy = copy.deepcopy(self.gen_a.dec)  # Error
            ID = f_a_recon
            ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1)
            adain_params_w = self.mlp_w_copy(ID_Style)
            adain_params_b = self.mlp_b_copy(ID_Style)
            self.gen_a.assign_adain_params(adain_params_w, adain_params_b,
                                           self.dec_copy)
            x_aba = self.dec_copy(
                s_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

            ID = f_b_recon
            ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1)
            adain_params_w = self.mlp_w_copy(ID_Style)
            adain_params_b = self.mlp_b_copy(ID_Style)
            self.gen_a.assign_adain_params(adain_params_w, adain_params_b,
                                           self.dec_copy)
            x_bab = self.dec_copy(
                s_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # auto-encoder image reconstruction
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a)
        self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b)

        # feature reconstruction
        self.loss_gen_recon_s_a = self.recon_criterion(
            s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_s_b = self.recon_criterion(
            s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_f_a = self.recon_criterion(
            f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0
        self.loss_gen_recon_f_b = self.recon_criterion(
            f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0

        # Random Erasing only effect the ID and PID loss.
        if hyperparameters['erasing_p'] > 0:
            x_a_re = self.to_re(scale2(x_a.clone()))
            x_b_re = self.to_re(scale2(x_b.clone()))
            xp_a_re = self.to_re(scale2(xp_a.clone()))
            xp_b_re = self.to_re(scale2(xp_b.clone()))
            _, p_a = self.id_a(x_a_re)
            _, p_b = self.id_b(x_b_re)
            # encode the same ID different photo
            _, pp_a = self.id_a(xp_a_re)
            _, pp_b = self.id_b(xp_b_re)

        # ID loss AND Tune the Generated image
        if hyperparameters['ID_style'] == 'PCB':
            self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
            self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
            self.loss_gen_recon_id = self.PCB_loss(
                p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b)
        elif hyperparameters['ID_style'] == 'AB':
            weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w']
            self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
                         + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )
            self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(
                pp_b[0], l_b
            )  #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) )
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b)
        else:
            self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(
                p_b, l_b)
            self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(
                pp_b, l_b)
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)

        #print(f_a_recon, f_a)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        if iteration > hyperparameters['warm_iter']:
            hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_x_cyc_w'] = min(
                hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])

        if iteration > hyperparameters['warm_teacher_iter']:
            hyperparameters['teacher_w'] += hyperparameters['warm_scale']
            hyperparameters['teacher_w'] = min(
                hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['id_w'] * self.loss_id + \
                              hyperparameters['pid_w'] * self.loss_pid + \
                              hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['teacher_w'] * self.loss_teacher
        if self.fp16:
            with amp.scale_loss(self.loss_gen_total,
                                [self.gen_opt, self.id_opt]) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()
            self.id_opt.step()
        else:
            self.loss_gen_total.backward()
            self.gen_opt.step()
            self.id_opt.step()
        print("L_total: %.4f, L_gan: %.4f,  Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \
                                                        hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \
                                                        hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \
                                                        hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \
                                                        hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \
                                                        hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \
                                                        hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \
                                                        hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \
                                                        hyperparameters['id_w'] * self.loss_id,\
                                                        hyperparameters['pid_w'] * self.loss_pid,\
hyperparameters['teacher_w'] * self.loss_teacher )  )

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def PCB_loss(self, inputs, labels):
        loss = 0.0
        for part in inputs:
            loss += self.id_criterion(part, labels)
        return loss / len(inputs)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0)))
            s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0)))
            f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0)))
            f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0)))
            x_a_recon.append(self.gen_a.decode(s_a, f_a))
            x_b_recon.append(self.gen_b.decode(s_b, f_b))
            x_ba = self.gen_a.decode(s_b, f_a)
            x_ab = self.gen_b.decode(s_a, f_b)
            x_ba1.append(x_ba)
            x_ab1.append(x_ab)
            #cycle
            s_b_recon = self.gen_a.enc_content(self.single(x_ba))
            s_a_recon = self.gen_b.enc_content(self.single(x_ab))
            f_a_recon, _ = self.id_a(scale2(x_ba))
            f_b_recon, _ = self.id_b(scale2(x_ab))
            x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon))
            x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
        x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1)
        self.train()

        return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # encode
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        f_a, _ = self.id_a(scale2(x_a))
        f_b, _ = self.id_b(scale2(x_b))
        # decode (cross domain)
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        # D loss
        self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        print("DLoss: %.4f" % self.loss_dis_total,
              "Reg: %.4f" % (reg_a + reg_b))
        if self.fp16:
            with amp.scale_loss(self.loss_dis_total,
                                self.dis_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.id_scheduler is not None:
            self.id_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b = self.gen_a
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b = self.dis_a
        # Load ID dis
        last_model_name = get_model_list(checkpoint_dir, "id")
        state_dict = torch.load(last_model_name)
        self.id_a.load_state_dict(state_dict['a'])
        self.id_b = self.id_a
        # Load optimizers
        try:
            state_dict = torch.load(
                os.path.join(checkpoint_dir, 'optimizer.pt'))
            self.dis_opt.load_state_dict(state_dict['dis'])
            self.gen_opt.load_state_dict(state_dict['gen'])
            self.id_opt.load_state_dict(state_dict['id'])
        except:
            pass
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict()}, dis_name)
        torch.save({'a': self.id_a.state_dict()}, id_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'id': self.id_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
class DGNet_Trainer(nn.Module):
    #初始化函数
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        # 从配置文件获取生成模型的和鉴别模型的学习率
        lr_g = hyperparameters['lr_g']
        lr_d = hyperparameters['lr_d']

        # # ID的类别,这里要注意,不同的数据集都是不一样的,应该是训练数据集的ID数目,非测试集
        ID_class = hyperparameters['ID_class']

        # 看是否设置使用float16,估计float16可以增加精确度
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']

        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        ################################################################################################################
        ##这里是定义Es和G
        # 注意这里包含了两个步骤,Es编码+解码过程,既然解码(论文Figure 2的黄色梯形G)包含到这里了,下面Ea应该不会包含解码过程了
        # 因为这里是一个类,如后续gen_a.encode()可以进行编码,gen_b.encode()可以进行解码
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'],
                              hyperparameters['gen'],
                              fp16=False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b
        ############################################################################################################################################

        ############################################################################################################################################
        ##这里是定义Ea
        # ID_stride,外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2

        # hyperparameters['ID_style']默认为'AB',论文中的Ea编码器
        #这里是设置Ea,有三种模型可以选择
        #PCB模型,ft_netAB为改造后的resnet50,ft_net为resnet50
        if hyperparameters['ID_style'] == 'PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style'] == 'AB':
            # 这是我们执行的模型,注意的是,id_a返回两个x(表示身份),获得f,具体介绍看函数内部
            # 我们使用的是ft_netAB,是代码中Ea编码的过程,也就得到 ap code的过程,除了ap code还会得到两个分类结果
            # 现在怀疑,该分类结果,可能就是行人重识别的结果
            #ID_class表示有ID_class个不同ID的行人
            self.id_a = ft_netAB(ID_class,
                                 stride=hyperparameters['ID_stride'],
                                 norm=hyperparameters['norm_id'],
                                 pool=hyperparameters['pool'])
        else:
            self.id_a = ft_net(ID_class,
                               norm=hyperparameters['norm_id'],
                               pool=hyperparameters['pool'])  # return 2048 now

        # 这里进行的是浅拷贝,所以我认为他们的权重是一起的,可以理解为一个
        self.id_b = self.id_a
        ############################################################################################################################################################

        ############################################################################################################################################################
        ##这里是定义D
        # 鉴别器,行人重识别,这里使用的是一个多尺寸的鉴别器,大概就是说,对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
        # 经过网络3个元素,分别大小为[batch_size,1,64,32], [batch_size,1,32,16], [batch_size,1,16,8]
        self.dis_a = MsImageDis(3, hyperparameters['dis'],
                                fp16=False)  # discriminator for domain a
        self.dis_b = self.dis_a  # discriminator for domain b
        ############################################################################################################################################################

        ############################################################################################################################################################
        # load teachers
        # 加载老师模型
        # teacher:老师模型名称。对于DukeMTMC,您可以设置“best - duke”
        if hyperparameters['teacher'] != "":
            #teacher_name=best
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            #有这个操作,我怀疑是可以加载多个教师模型
            teacher_names = teacher_name.split(',')
            #构建老师模型
            teacher_model = nn.ModuleList()
            teacher_count = 0

            # 默认只有一个teacher_name='teacher_name',所以其加载的模型配置文件为项目根目录models/best/opts.yaml模型
            for teacher_name in teacher_names:
                # 加载配置文件models/best/opts.yaml
                config_tmp = load_config(teacher_name)
                if 'stride' in config_tmp:
                    #stride=1
                    stride = config_tmp['stride']
                else:
                    stride = 2

                #  老师模型加载,老师模型为ft_net为resnet50
                model_tmp = ft_net(ID_class, stride=stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                # 移除原本的全连接层
                teacher_model_tmp.model.fc = nn.Sequential(
                )  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                # summary(teacher_model_tmp, (3, 224, 224))
                #使用浮点型
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp,
                                                       opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count += 1
            self.teacher_model = teacher_model
            # 选择是否使用bn
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)
############################################################################################################################################################

# 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        # 默认设置signal=gray,Es的输入为灰度图
        if hyperparameters['single'] == 'edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        #earsing_p表示随机擦除的概率
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']
        #随机擦除矩形区域的一些像素,应该类似于数据增强
        self.single_re = RandomErasing(probability=self.erasing_p,
                                       mean=[0.0, 0.0, 0.0])
        # 设置T_w为1,T_w为primary feature learning loss的权重系数
        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1

        ################################################################################################
        # Setup the optimizers
        # 设置优化器参数
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(
            self.dis_a.parameters())  #+ list(self.dis_b.parameters())
        gen_params = list(
            self.gen_a.parameters())  #+ list(self.gen_b.parameters())
        #使用Adams优化器,用Adams训练Es,G,D
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_d,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_g,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        # id params
        # 因为ID_style默认为AB,所以这里不执行
        if hyperparameters['ID_style'] == 'PCB':
            ignored_params = (
                list(map(id, self.id_a.classifier0.parameters())) +
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())) +
                list(map(id, self.id_a.classifier3.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            #Ea 的优化器
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier0.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier3.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)

        #     这里是我们执行的代码
        elif hyperparameters['ID_style'] == 'AB':
            # 忽略的参数,应该是适用于'PCB'或者其他的,但是不适用于'AB'的
            ignored_params = (
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())))
            # 获得基本的配置参数,如学习率
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']

            #对Ea使用SGD
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)

        # 选择各个网络的优化
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        #交叉熵损失函数
        self.id_criterion = nn.CrossEntropyLoss()
        # KL散度
        self.criterion_teacher = nn.KLDivLoss(size_average=False)

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a,
                                                    self.id_opt,
                                                    opt_level="O1")

    def to_re(self, x):
        out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3))
        out = out.cuda()
        for i in range(x.size(0)):
            out[i, :, :, :] = self.single_re(x[i, :, :, :])
        return out

    # L1 loss,(差的绝对值)
    def recon_criterion(self, input, target):
        diff = input - target.detach()
        return torch.mean(torch.abs(diff[:]))

    #L1 loss 开根号((差的绝对值后开根号))
    def recon_criterion_sqrt(self, input, target):
        diff = input - target
        return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8))

    # L2 loss
    def recon_criterion2(self, input, target):
        diff = input - target
        return torch.mean(diff[:]**2)

    # cos loss
    def recon_cos(self, input, target):
        cos = torch.nn.CosineSimilarity()
        cos_dis = 1 - cos(input, target)
        return torch.mean(cos_dis[:])

    # x_a,x_b, xp_a, xp_b[4, 3, 256, 128],
    # 第一个参数表示bitch size,第二个参数表示输入通道数,第三个参数表示输入图片的高度,第四个参数表示输入图片的宽度
    def forward(self, x_a, x_b, xp_a, xp_b):
        #送入x_a,x_b两张图片(来自训练集不同ID)
        #通过st编码器,编码成两个stcode,structure code
        # s_a[batch,128,64,32]
        # s_b[batch,128,64,32]
        # single会根据参数设定判断是否转化为灰度图
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))

        # 先把图片进行下采样,图示我们可以看到ap code的体积比st code是要小的,这样会出现一个情况,那么他们是没有办法直接融合的,所以后面有个全链接成把他们统一
        # f_a[batch_size,2024*4=8192],     p_a[0]=[batch_size, class_num=751], p_a[1]=[batch_size, class_num=751]
        # f_b[batch_size,2024*4=8192],     p_b[0]=[batch_size, class_num=751], p_b[1]=[batch_size, class_num=751]
        # f代表的是经过ap编码器得到的ap code,
        # p表示对身份的预测(有两个身份预测,也就是p_a了两个元素,这里不好解释),
        # 前面提到过,ap编码器,不仅负责编码,还要负责身份的预测(行人重识别),也是我们落实项目的关键所在
        # 这里是第一个重难点,在论文的翻译中提到过,后续详细讲解
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))

        # 进行解码操作,就是Figure 2中的黄色梯形G操作,这里的x_a,与x_b进行衣服互换,不同ID
        # s_b[batch,128,64,32] f_a[batch_size,2028,4,1] -->  x_ba[batch_size,3,256,128]
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)

        #同一张图片进行重构,相当于autoencoder
        x_a_recon = self.gen_a.decode(s_a, f_a)
        x_b_recon = self.gen_b.decode(s_b, f_b)

        fp_a, pp_a = self.id_a(scale2(xp_a))
        fp_b, pp_b = self.id_b(scale2(xp_b))

        # decode the same person
        #x_a,xp_a表示同ID的不同图片,以下即表示同ID不同图片的重构
        x_a_recon_p = self.gen_a.decode(s_a, fp_a)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)

        # Random Erasing only effect the ID and PID loss.
        #把图片擦除一些像素,然后进行ap code编码
        if self.erasing_p > 0:
            #先把每一张图片都擦除一些像素
            x_a_re = self.to_re(scale2(x_a.clone()))
            x_b_re = self.to_re(scale2(x_b.clone()))
            xp_a_re = self.to_re(scale2(xp_a.clone()))
            xp_b_re = self.to_re(scale2(xp_b.clone()))

            # 然后经过编码成ap code,暂时不知道作用,感觉应该是数据增强
            # 类似于,擦除了图片的一些像素,但是已经能够识别出来这些图片是谁
            _, p_a = self.id_a(x_a_re)
            _, p_b = self.id_b(x_b_re)
            # encode the same ID different photo
            _, pp_a = self.id_a(xp_a_re)
            _, pp_b = self.id_b(xp_b_re)

        # 混合合成图片:x_ab[images_a的st,images_b的ap]    混合合成图片x_ba[images_b的st,images_a的ap]
        # s_a[输入图片images_a经过Es编码得到的 st code]     s_b[输入图片images_b经过Es编码得到的 st code]
        # f_a[输入图片images_a经过Ea编码得到的 ap code]     f_b[输入图片images_b经过Ea编码得到的 ap code]
        # p_a[输入图片images_a经过Ea编码进行身份ID的预测]    p_b[输入图片images_b经过Ea编码进行身份ID的预测]
        # pp_a[输入图片pos_a经过Ea编码进行身份ID的预测]      pp_b[输入图片pos_b经过Ea编码进行身份ID的预测]
        # x_a_recon[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样]
        # x_b_recon[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样]
        # x_a_recon_p[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样]
        # x_b_recon_p[输入图片images_a(s_a)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样]

        return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p

    def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b,
                   x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b,
                   xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu):
        """

        :param x_ab:[images_a的st,images_b的ap]
        :param x_ba:[images_b的st,images_a的ap]
        :param s_a:[输入图片images_a经过Es编码得到的 st code]
        :param s_b:[输入图片images_b经过Es编码得到的 st code]
        :param f_a:[输入图片images_a经过Ea编码得到的 ap code]
        :param f_b:[输入图片images_b经过Ea编码得到的 ap code]
        :param p_a:[输入图片images_a经过Ea编码进行身份ID的预测]
        :param p_b:[输入图片images_b经过Ea编码进行身份ID的预测]
        :param pp_a:[输入图片pos_a经过Ea编码进行身份ID的预测]
        :param pp_b:[输入图片pos_b经过Ea编码进行身份ID的预测]
        :param x_a_recon:[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样]
        :param x_b_recon:[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样]
        :param x_a_recon_p:[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样]
        :param x_b_recon_p:[输入图片images_b(s_b)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样]
        :param x_a:images_a
        :param x_b:images_b
        :param xp_a:pos_a
        :param xp_b:pos_b
        :param l_a:labels_a
        :param l_b:labels_b
        :param hyperparameters:
        :param iteration:
        :param num_gpu:
        :return:
        """
        # ppa, ppb is the same person?
        self.gen_opt.zero_grad()  #梯度清零
        self.id_opt.zero_grad()

        # no gradient
        # 对合成x_ba和x_ab分别进行一份拷贝
        x_ba_copy = Variable(x_ba.data, requires_grad=False)
        x_ab_copy = Variable(x_ab.data, requires_grad=False)

        rand_num = random.uniform(0, 1)
        #################################
        # encode structure
        # enc_content是类ContentEncoder对象
        if hyperparameters['use_encoder_again'] >= rand_num:
            # encode again (encoder is tuned, input is fixed)
            # Es编码得到s_a_recon与s_b_recon即st code
            # 如果是理想模型,s_a_recon=s_a, s_b_recon=s_b
            s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy))
            s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy))
        else:
            # copy the encoder
            # 这里的是深拷贝
            #enc_content_copy=gen_a.enc_content
            self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
            self.enc_content_copy = self.enc_content_copy.eval()
            # encode again (encoder is fixed, input is tuned)
            s_a_recon = self.enc_content_copy(self.single(x_ab))
            s_b_recon = self.enc_content_copy(self.single(x_ba))

        #################################
        # encode appearance
        #id_a_copy=id_a=Ea
        self.id_a_copy = copy.deepcopy(self.id_a)
        self.id_a_copy = self.id_a_copy.eval()
        if hyperparameters['train_bn']:
            self.id_a_copy = self.id_a_copy.apply(train_bn)
        self.id_b_copy = self.id_a_copy
        # encode again (encoder is fixed, input is tuned)
        # 对混合生成的图片x_ba,x_ab进行Es编码操作,同时对身份进行鉴别#
        # f_a_recon,f_b_recon表示的ap code,p_a_recon,p_b_recon表示对身份的鉴别
        f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba))
        f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab))

        # teacher Loss
        #  Tune the ID model
        log_sm = nn.LogSoftmax(dim=1)
        #如果使用了教师网络
        #默认ID_style为AB
        if hyperparameters['teacher_w'] > 0 and hyperparameters[
                'teacher'] != "":
            if hyperparameters['ID_style'] == 'normal':
                #p_a_student表示x_ba_copy的身份编码,使用的是Ea进行身份编码,也就是使用学生模型进行身份编码
                _, p_a_student = self.id_a(scale2(x_ba_copy))
                #对p_a_student使用logsoftmax,输出结果为x_ba_copy像某张图片的概率(就是一个分布)
                p_a_student = log_sm(p_a_student)
                #使用教师模型对生成图像x_ba_copy进行分类,输出结果为x_ba_copy像某张图片的概率(就是一个分布)
                p_a_teacher = predict_label(
                    self.teacher_model,
                    scale2(x_ba_copy),
                    num_class=hyperparameters['ID_class'],
                    alabel=l_a,
                    slabel=l_b,
                    teacher_style=hyperparameters['teacher_style'])
                #通过最小化KL散度损失函数,目的是让分布p_a_student与p_a_teacher尽可能的一致
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                #对x_ab_copy进行同样的操作
                _, p_b_student = self.id_b(scale2(x_ab_copy))
                p_b_student = log_sm(p_b_student)
                p_b_teacher = predict_label(
                    self.teacher_model,
                    scale2(x_ab_copy),
                    num_class=hyperparameters['ID_class'],
                    alabel=l_b,
                    slabel=l_a,
                    teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)

            #######################################################################################################################################################################################################
            # primary feature learning loss
            #######################################################################################################################################################################################################
            #  ID_style为AB
            elif hyperparameters['ID_style'] == 'AB':
                # normal teacher-student loss
                # BA -> LabelA(smooth) + LabelB(batchB)
                # 合成的图片经过身份鉴别器,得到每个ID可能性的概率,注意这里去的是p_ba_student[0],我们知有两个身份预测结果,这里只取了一个
                # 并且赋值给了p_a_student,用于和教师模型结合的,共同计算损失
                #p_a_student分为两个部分,p_a_student[0]表示L_prim,p_a_student[1]表示L_fine。
                _, p_ba_student = self.id_a(scale2(x_ba_copy))  # f_a, s_b
                p_a_student = log_sm(p_ba_student[0])

                with torch.no_grad():
                    ##使用教师模型对生成图像x_ba_copy进行分类,输出结果为x_ba_copy像某张图片(x_a/x_b)的概率(就是一个分布)
                    p_a_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ba_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_a,
                        slabel=l_b,
                        teacher_style=hyperparameters['teacher_style'])

                # criterion_teacher = nn.KLDivLoss(size_average=False)
                # 计算离散距离,可以理解为p_a_student与p_a_teacher每个元素的距离之和,然后除以p_a_student.size(0)取平均值
                # 就是说学生网络(Ea)的预测越与教师网络结果相同,则是最好的
                self.loss_teacher = self.criterion_teacher(
                    p_a_student, p_a_teacher) / p_a_student.size(0)

                # 对另一张合成图片进行同样的操作
                _, p_ab_student = self.id_b(scale2(x_ab_copy))  # f_b, s_a
                p_b_student = log_sm(p_ab_student[0])
                with torch.no_grad():
                    p_b_teacher = predict_label(
                        self.teacher_model,
                        scale2(x_ab_copy),
                        num_class=hyperparameters['ID_class'],
                        alabel=l_b,
                        slabel=l_a,
                        teacher_style=hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(
                    p_b_student, p_b_teacher) / p_b_student.size(0)
                ########################################################################################################################################################################################################

                ########################################################################################################################################################################################################
                #fine—grained feature mining loss
                ########################################################################################################################################################################################################
                # branch b loss
                # here we give different label
                # p_ba_student[1]表示的是f_fine特征,l_b表示的是images_b,即为生成图像提供st code 的图片
                loss_B = self.id_criterion(p_ba_student[1],
                                           l_b) + self.id_criterion(
                                               p_ab_student[1], l_a)
                #######################################################################################################################################################################################################

                # 对两部分损失进行权重调整
                self.loss_teacher = hyperparameters[
                    'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B
        else:
            self.loss_teacher = 0.0

        ## 剩下的就是重构图像之间的损失了
        # 前面提到,重构和合成是不一样的,重构是构建出来和原来图片一样的图片
        # 所以也就是可以把重构的图片和原来的图像直接计算像素直接的插值
        # 但是合成的图片是没有办法的,因为训练数据集是没有合成图片的,所以,没有办法计算像素之间的损失
        # #######################################################################################################################################################################################################
        # auto-encoder image reconstruction
        # 同ID图像进行重构时的损失函数
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a)
        self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b)
        # #######################################################################################################################################################################################################

        #######################################################################################################################################################################################################
        # feature reconstruction
        # 不同ID图像进行图像合成时,为了保证合成图像的st code和ap code与为合成图像提供st code 和 ap code保持一致所使用的损失函数
        self.loss_gen_recon_s_a = self.recon_criterion(
            s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_s_b = self.recon_criterion(
            s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_f_a = self.recon_criterion(
            f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0
        self.loss_gen_recon_f_b = self.recon_criterion(
            f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0
        # #######################################################################################################################################################################################################

        # 又一次进行图像合成
        x_aba = self.gen_a.decode(
            s_a_recon,
            f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            s_b_recon,
            f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # ID loss AND Tune the Generated image
        if hyperparameters['ID_style'] == 'PCB':
            self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
            self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
            self.loss_gen_recon_id = self.PCB_loss(
                p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b)

        ########################################################################################################################################################################################################
        #   使用的是  ['ID_style']=='AB'
        elif hyperparameters['ID_style'] == 'AB':
            weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w']
            #计算的是L^s_id
            self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
                         + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )

            #对同ID不同图片计算L^s_id
            self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(
                pp_b[0], l_b
            )  #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) )

            # 对生成图像计算L^C_id
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b)
        ########################################################################################################################################################################################################

        else:
            self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(
                p_b, l_b)
            self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(
                pp_b, l_b)
            self.loss_gen_recon_id = self.id_criterion(
                p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)

        #print(f_a_recon, f_a)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        ########################################################################################################################################################################################################
        # GAN loss
        #计算生成器G的对抗损失函数
        ########################################################################################################################################################################################################
        if num_gpu > 1:
            self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss(
                self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss(
                self.dis_b, x_ab)
        else:
            self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab)
        ########################################################################################################################################################################################################

        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        if iteration > hyperparameters['warm_iter']:
            hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'],
                                               hyperparameters['max_w'])
            hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_x_cyc_w'] = min(
                hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])

        if iteration > hyperparameters['warm_teacher_iter']:
            hyperparameters['teacher_w'] += hyperparameters['warm_scale']
            hyperparameters['teacher_w'] = min(
                hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
        # total loss,计算总的loss
        #1个teacher loss+4个同ID图片重构loss+4个不同ID图片合成loss++3个ID loss+2个生成器loss、
        #teacher loss包括了primary feature learning loss和fine_grain mining loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['id_w'] * self.loss_id + \
                              hyperparameters['pid_w'] * self.loss_pid + \
                              hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['teacher_w'] * self.loss_teacher

        if self.fp16:
            with amp.scale_loss(self.loss_gen_total,
                                [self.gen_opt, self.id_opt]) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()
            self.id_opt.step()
        else:
            self.loss_gen_total.backward()  #计算梯度
            self.gen_opt.step()  #梯度更新
            self.id_opt.step()  #梯度更新
        print("L_total: %.4f, L_gan: %.4f,  Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \
                                                        hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \
                                                        hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \
                                                        hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \
                                                        hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \
                                                        hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \
                                                        hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \
                                                        hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \
                                                        hyperparameters['id_w'] * self.loss_id,\
                                                        hyperparameters['pid_w'] * self.loss_pid,\
hyperparameters['teacher_w'] * self.loss_teacher )  )

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def PCB_loss(self, inputs, labels):
        loss = 0.0
        for part in inputs:
            loss += self.id_criterion(part, labels)
        return loss / len(inputs)

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0)))
            s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0)))
            f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0)))
            f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0)))
            x_a_recon.append(self.gen_a.decode(s_a, f_a))
            x_b_recon.append(self.gen_b.decode(s_b, f_b))
            x_ba = self.gen_a.decode(s_b, f_a)
            x_ab = self.gen_b.decode(s_a, f_b)
            x_ba1.append(x_ba)
            x_ab1.append(x_ab)
            #cycle
            s_b_recon = self.gen_a.enc_content(self.single(x_ba))
            s_a_recon = self.gen_b.enc_content(self.single(x_ab))
            f_a_recon, _ = self.id_a(scale2(x_ba))
            f_b_recon, _ = self.id_b(scale2(x_ab))
            x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon))
            x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
        x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1)
        self.train()

        return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1

    def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters, num_gpu):
        self.dis_opt.zero_grad()  #梯度清零
        # D loss
        #计算判别器的损失函数,然后计算梯度,进行梯度更新
        #输入为(x_ba,x_a),(x_ab,x_b)两对图片,损失为两对图片的总和
        if num_gpu > 1:
            self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss(
                self.dis_a, x_ba.detach(), x_a)
            self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss(
                self.dis_b, x_ab.detach(), x_b)
        else:
            # 计算判别器的损失函数
            self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(
                self.dis_a, x_ba.detach(), x_a)
            self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(
                self.dis_b, x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        print("DLoss: %.4f" % self.loss_dis_total,
              "Reg: %.4f" % (reg_a + reg_b))
        if self.fp16:
            with amp.scale_loss(self.loss_dis_total,
                                self.dis_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            self.loss_dis_total.backward()  #计算梯度
        self.dis_opt.step()  #梯度更新

    def update_learning_rate(self):
        #对学习率的更新
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.id_scheduler is not None:
            self.id_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b = self.gen_a
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b = self.dis_a
        # Load ID dis
        last_model_name = get_model_list(checkpoint_dir, "id")
        state_dict = torch.load(last_model_name)
        self.id_a.load_state_dict(state_dict['a'])
        self.id_b = self.id_a
        # Load optimizers
        try:
            state_dict = torch.load(
                os.path.join(checkpoint_dir, 'optimizer.pt'))
            self.dis_opt.load_state_dict(state_dict['dis'])
            self.gen_opt.load_state_dict(state_dict['gen'])
            self.id_opt.load_state_dict(state_dict['id'])
        except:
            pass
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations, num_gpu=1):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict()}, gen_name)
        if num_gpu > 1:
            torch.save({'a': self.dis_a.module.state_dict()}, dis_name)
        else:
            torch.save({'a': self.dis_a.state_dict()}, dis_name)
        torch.save({'a': self.id_a.state_dict()}, id_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'id': self.id_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)