class DGNet_Trainer(nn.Module): def __init__(self, hyperparameters, gpu_ids=[0]): super(DGNet_Trainer, self).__init__() lr_g = hyperparameters['lr_g'] #生成器学习率 lr_d = hyperparameters['lr_d'] #判别器学习率 ID_class = hyperparameters['ID_class'] if not 'apex' in hyperparameters.keys(): hyperparameters['apex'] = False self.fp16 = hyperparameters['apex'] # Initiate the networks # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False. # 构建Es编码+解码过程 gen_a.encode()可以进行编码,gen_b.encode()可以进行解码 self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16 = False) # auto-encoder for domain a self.gen_b = self.gen_a # auto-encoder for domain b # ID_stride,外观编码器池化层的stride if not 'ID_stride' in hyperparameters.keys(): hyperparameters['ID_stride'] = 2 # 构建外观编码器 if hyperparameters['ID_style']=='PCB': self.id_a = PCB(ID_class) elif hyperparameters['ID_style']=='AB': #使用的AB编码器 self.id_a = ft_netAB(ID_class, stride = hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) else: self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now # 浅拷贝,两者等同 self.id_b = self.id_a # 鉴别器,使用的是一个多尺寸的鉴别器,即对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失 self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16 = False) # discriminator for domain a self.dis_b = self.dis_a # discriminator for domain b # load teachers 加载教师模型 if hyperparameters['teacher'] != "": teacher_name = hyperparameters['teacher'] print(teacher_name) # 构建教师模型 teacher_names = teacher_name.split(',') teacher_model = nn.ModuleList() teacher_count = 0 for teacher_name in teacher_names: config_tmp = load_config(teacher_name) if 'stride' in config_tmp: stride = config_tmp['stride'] else: stride = 2 # 网络搭建 model_tmp = ft_net(ID_class, stride = stride) teacher_model_tmp = load_network(model_tmp, teacher_name) teacher_model_tmp.model.fc = nn.Sequential() # remove the original fc layer in ImageNet teacher_model_tmp = teacher_model_tmp.cuda() if self.fp16: teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1") teacher_model.append(teacher_model_tmp.cuda().eval()) teacher_count +=1 self.teacher_model = teacher_model # 选择是否使用bn if hyperparameters['train_bn']: self.teacher_model = self.teacher_model.apply(train_bn) # 实例正则化 self.instancenorm = nn.InstanceNorm2d(512, affine=False) # RGB to one channel if hyperparameters['single']=='edge': self.single = to_edge else: self.single = to_gray(False) # Random Erasing when training if not 'erasing_p' in hyperparameters.keys(): self.erasing_p = 0 else: self.erasing_p = hyperparameters['erasing_p'] # erasing_p表示随机擦除的概率 # 随机擦除矩形区域的一些像素,数据增强 self.single_re = RandomErasing(probability = self.erasing_p, mean=[0.0, 0.0, 0.0]) if not 'T_w' in hyperparameters.keys(): hyperparameters['T_w'] = 1 # Setup the optimizers 设置优化器的参数 beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) #+ list(self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) #+ list(self.gen_b.parameters()) # 使用Adam优化器 self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # id params if hyperparameters['ID_style']=='PCB': ignored_params = (list(map(id, self.id_a.classifier0.parameters() )) +list(map(id, self.id_a.classifier1.parameters() )) +list(map(id, self.id_a.classifier2.parameters() )) +list(map(id, self.id_a.classifier3.parameters() )) ) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD([ {'params': base_params, 'lr': lr2}, {'params': self.id_a.classifier0.parameters(), 'lr': lr2*10}, {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10}, {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10}, {'params': self.id_a.classifier3.parameters(), 'lr': lr2*10} ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) elif hyperparameters['ID_style']=='AB': ignored_params = (list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters()))) # 获得基本的配置参数,如学习率 base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD([ {'params': base_params, 'lr': lr2}, {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10}, {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10} ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) else: ignored_params = list(map(id, self.id_a.classifier.parameters() )) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD([ {'params': base_params, 'lr': lr2}, {'params': self.id_a.classifier.parameters(), 'lr': lr2*10} ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) # 选择各个网络优化的策略 self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.id_scheduler = get_scheduler(self.id_opt, hyperparameters) self.id_scheduler.gamma = hyperparameters['gamma2'] #ID Loss self.id_criterion = nn.CrossEntropyLoss() self.criterion_teacher = nn.KLDivLoss(size_average=False) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # save memory if self.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.id_a = self.id_a.cuda() self.gen_b = self.gen_a self.dis_b = self.dis_a self.id_b = self.id_a self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1") def to_re(self, x): out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3)) out = out.cuda() for i in range(x.size(0)): out[i,:,:,:] = self.single_re(x[i,:,:,:]) return out def recon_criterion(self, input, target): diff = input - target.detach() return torch.mean(torch.abs(diff[:])) def recon_criterion_sqrt(self, input, target): diff = input - target return torch.mean(torch.sqrt(torch.abs(diff[:])+1e-8)) def recon_criterion2(self, input, target): diff = input - target return torch.mean(diff[:]**2) def recon_cos(self, input, target): cos = torch.nn.CosineSimilarity() cos_dis = 1 - cos(input, target) return torch.mean(cos_dis[:]) # 送入x_a,x_b两张图片(来自训练集的不同ID) def forward(self, x_a, x_b, xp_a, xp_b): # 通过st 编码器,编码成两个st code: # s_a[batch,128,64,32] # s_b[batch,128,64,32] # single会根据参数设定判断是否转化为灰度图 s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) # f代表的是经过ap编码器得到的ap code, # p表示对身份的预测 f_a, p_a = self.id_a(scale2(x_a)) f_b, p_b = self.id_b(scale2(x_b)) # 进行解码G操作,这里的x_a,与x_b进行ID交叉生成图片 x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) # 进行解码G操作,这里的x_a x_b进行自身同构生成图片 x_a_recon = self.gen_a.decode(s_a, f_a) x_b_recon = self.gen_b.decode(s_b, f_b) # xp_a与x_a同ID不同图片 xp_b与x_b同ID不同图片 # 进行外观编码 fp_a, pp_a = self.id_a(scale2(xp_a)) fp_b, pp_b = self.id_b(scale2(xp_b)) # 进行解码G操作,这里的xp_a,与xp_b进行ID同构生成图片 x_a_recon_p = self.gen_a.decode(s_a, fp_a) x_b_recon_p = self.gen_b.decode(s_b, fp_b) # Random Erasing only effect the ID and PID loss. # 进行像素擦除,后ap code编码 if self.erasing_p > 0: x_a_re = self.to_re(scale2(x_a.clone())) x_b_re = self.to_re(scale2(x_b.clone())) xp_a_re = self.to_re(scale2(xp_a.clone())) xp_b_re = self.to_re(scale2(xp_b.clone())) _, p_a = self.id_a(x_a_re) _, p_b = self.id_b(x_b_re) # encode the same ID different photo _, pp_a = self.id_a(xp_a_re) _, pp_b = self.id_b(xp_b_re) # 混合合成图片:x_ab[images_a的st,images_b的ap] 混合合成图片x_ba[images_b的st,images_a的ap] # s_a[输入图片images_a经过Es编码得到的 st code] s_b[输入图片images_b经过Es编码得到的 st code] # f_a[输入图片images_a经过Ea编码得到的 ap code] f_b[输入图片images_b经过Ea编码得到的 ap code] # p_a[输入图片images_a经过Ea编码进行身份ID的预测] p_b[输入图片images_b经过Ea编码进行身份ID的预测] # pp_a[输入图片pos_a经过Ea编码进行身份ID的预测] pp_b[输入图片pos_b经过Ea编码进行身份ID的预测] # x_a_recon[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样] # x_b_recon[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样] # x_a_recon_p[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样] # x_b_recon_p[输入图片images_a(s_a)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样] return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b, xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu): # ppa, ppb is the same person self.gen_opt.zero_grad() self.id_opt.zero_grad() # no gradient # 对合成x_ba与x_ab分别进行一份拷贝 x_ba_copy = Variable(x_ba.data, requires_grad=False) x_ab_copy = Variable(x_ab.data, requires_grad=False) rand_num = random.uniform(0,1) ################################# # encode structure if hyperparameters['use_encoder_again']>=rand_num: # encode again (encoder is tuned, input is fixed) s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy)) s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy)) else: # copy the encoder self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content) self.enc_content_copy = self.enc_content_copy.eval() # encode again (encoder is fixed, input is tuned) s_a_recon = self.enc_content_copy(self.single(x_ab)) s_b_recon = self.enc_content_copy(self.single(x_ba)) ################################# # encode appearance self.id_a_copy = copy.deepcopy(self.id_a) self.id_a_copy = self.id_a_copy.eval() if hyperparameters['train_bn']: self.id_a_copy = self.id_a_copy.apply(train_bn) self.id_b_copy = self.id_a_copy # encode again (encoder is fixed, input is tuned) # 对混合生成的图片x_ba,x_ab记行Es编码操作,同时对身份进行鉴别 # f_a_recon,f_b_recon表示的ap code,p_a_recon,p_b_recon表示对身份的鉴别 Lrecon^code2 f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba)) f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab)) # teacher Loss # Tune the ID model log_sm = nn.LogSoftmax(dim=1) if hyperparameters['teacher_w'] >0 and hyperparameters['teacher'] != "": if hyperparameters['ID_style'] == 'normal': _, p_a_student = self.id_a(scale2(x_ba_copy)) p_a_student = log_sm(p_a_student) p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style']) self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0) _, p_b_student = self.id_b(scale2(x_ab_copy)) p_b_student = log_sm(p_b_student) p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0) elif hyperparameters['ID_style'] == 'AB': # normal teacher-student loss # BA -> LabelA(smooth) + LabelB(batchB) Lprim Lrecon^code1 _, p_ba_student = self.id_a(scale2(x_ba_copy))# f_a, s_b p_a_student = log_sm(p_ba_student[0]) # 计算离散距离,可以理解为p_a_student与p_a_teacher每个元素的距离和,然后除以p_a_student.size(0)取平均值 # 就是说学生网络(Ea)的预测越与教师网络结果相同,则是最好的 with torch.no_grad(): p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style']) self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0) _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a p_b_student = log_sm(p_ab_student[0]) with torch.no_grad(): p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0) # branch b loss # here we give different label Lfine # 求得学生网络(Ea)预测ID结果,与实际ID之间的损失,注意这里取的是p_ba_student[1],表示是细节身份特征 loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion(p_ab_student[1], l_a) self.loss_teacher = hyperparameters['T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B else: self.loss_teacher = 0.0 # auto-encoder image reconstruction Limg1^recon Limg2^recon self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a) self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b) # feature reconstruction self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_f_a = self.recon_criterion(f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0 self.loss_gen_recon_f_b = self.recon_criterion(f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0 x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # ID loss AND Tune the Generated image if hyperparameters['ID_style']=='PCB': self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b) self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b) self.loss_gen_recon_id = self.PCB_loss(p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b) elif hyperparameters['ID_style']=='AB': weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w'] self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) ) self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(pp_b[0], l_b) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) ) self.loss_gen_recon_id = self.id_criterion(p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b) else: self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(p_b, l_b) self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(pp_b, l_b) self.loss_gen_recon_id = self.id_criterion(p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b) #print(f_a_recon, f_a) self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss if num_gpu>1: self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss(self.dis_a, x_ba) self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss(self.dis_b, x_ab) else: self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # 设置每个loss所占的权重 if iteration > hyperparameters['warm_iter']: hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] hyperparameters['recon_x_cyc_w'] = min(hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) if iteration > hyperparameters['warm_teacher_iter']: hyperparameters['teacher_w'] += hyperparameters['warm_scale'] hyperparameters['teacher_w'] = min(hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ #GAN损失 hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ #GAN图像重构损失 hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ #编码损失 hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['id_w'] * self.loss_id + \ hyperparameters['pid_w'] * self.loss_pid + \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ #id损失 hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['teacher_w'] * self.loss_teacher #辨别模块损失 if self.fp16: with amp.scale_loss(self.loss_gen_total, [self.gen_opt,self.id_opt]) as scaled_loss: scaled_loss.backward() self.gen_opt.step() self.id_opt.step() else: self.loss_gen_total.backward() self.gen_opt.step() self.id_opt.step() print("L_total: %.4f, L_gan: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \ hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ hyperparameters['id_w'] * self.loss_id,\ hyperparameters['pid_w'] * self.loss_pid,\ hyperparameters['teacher_w'] * self.loss_teacher ) )
class ERGAN_Trainer(nn.Module): def __init__(self, hyperparameters): super(ERGAN_Trainer, self).__init__() lr_G = hyperparameters['lr_G'] lr_D = hyperparameters['lr_D'] print(lr_D, lr_G) self.fp16 = hyperparameters['fp16'] # Initiate the networks self.gen_a = AdaINGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = AdaINGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.gen_b.enc_content = self.gen_a.enc_content # content share weight #self.gen_b.enc_style = self.gen_a.enc_style self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.style_dim = hyperparameters['gen']['style_dim'] self.a = hyperparameters['gen']['new_size'] / 224 # fix the noise used in sampling display_size = int(hyperparameters['display_size']) self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr_D, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr_G, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) if self.fp16: self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.gen_b = self.gen_b.cuda() self.dis_b = self.dis_b.cuda() self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): input = input.type_as(target) return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() s_a = Variable(self.s_a) s_b = Variable(self.s_b) c_a, s_a_fake = self.gen_a.encode(x_a) c_b, s_b_fake = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(c_b, s_a, x_b) x_ab = self.gen_b.decode(c_a, s_b, x_a) self.train() return x_ab, x_ba def gen_update(self, x_a, x_b, hyperparameters): self.gen_opt.zero_grad() #mask = torch.ones(x_a.shape).cuda() block_a = x_a.clone() block_b = x_b.clone() block_a[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 block_b[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(c_a, s_a_prime, x_a) x_b_recon = self.gen_b.decode(c_b, s_b_prime, x_b) # decode (cross domain) # decode random # x_ba_randn = self.gen_a.decode(c_b, s_a, x_b) # x_ab_randn = self.gen_b.decode(c_a, s_b, x_a) # decode real x_ba_real = self.gen_a.decode(c_b, s_a_prime, x_b) x_ab_real = self.gen_b.decode(c_a, s_b_prime, x_a) block_ba_real = x_ba_real.clone() block_ab_real = x_ab_real.clone() block_ba_real[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 block_ab_real[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] = 0 # encode again # c_b_recon, s_a_recon = self.gen_a.encode(x_ba_randn) # c_a_recon, s_b_recon = self.gen_b.encode(x_ab_randn) c_b_real_recon, s_a_prime_recon = self.gen_a.encode(x_ba_real) c_a_real_recon, s_b_prime_recon = self.gen_b.encode(x_ab_real) # decode again (if needed) x_aba = self.gen_a.decode( c_a_real_recon, s_a_prime, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( c_b_real_recon, s_b_prime, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss self.loss_gen_recon_res_a = self.recon_criterion( block_ab_real, block_a) self.loss_gen_recon_res_b = self.recon_criterion( block_ba_real, block_b) self.loss_gen_recon_x_a_re = self.recon_criterion( x_a_recon[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)], x_a[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)]) self.loss_gen_recon_x_b_re = self.recon_criterion( x_b_recon[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)], x_b[:, :, round(self.a * 92):round(self.a * 144), round(self.a * 48):round(self.a * 172)] ) # both celebA and MeGlass: [92:144, 48:172] self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_s_a_prime = self.recon_criterion( s_a_prime_recon, s_a_prime) self.loss_gen_recon_s_b_prime = self.recon_criterion( s_b_prime_recon, s_b_prime) self.loss_gen_recon_c_a_real = self.recon_criterion( c_a_real_recon, c_a) self.loss_gen_recon_c_b_real = self.recon_criterion( c_b_real_recon, c_b) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a_real = self.dis_a.calc_gen_loss(x_ba_real) self.loss_gen_adv_b_real = self.dis_b.calc_gen_loss(x_ab_real) # domain-invariant perceptual loss # self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba_randn, x_b) if hyperparameters['vgg_w'] > 0 else 0 # self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab_randn, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a_real + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_x_w_re'] * self.loss_gen_recon_x_b_re + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b_prime + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b_real + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b+\ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_x_w_re'] * self.loss_gen_recon_x_a_re + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a_prime + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a_real + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b_real +\ hyperparameters['recon_x_w_res'] * self.loss_gen_recon_res_b + \ hyperparameters['recon_x_w_res'] * self.loss_gen_recon_res_b if self.fp16: with amp.scale_loss(self.loss_gen_total, self.gen_opt) as scaled_loss: scaled_loss.backward() self.gen_opt.step() else: self.loss_gen_total.backward() self.gen_opt.step() #loss_gan = hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b) loss_gan_real = hyperparameters['gan_w'] * (self.loss_gen_adv_a_real + self.loss_gen_adv_b_real) loss_x = hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b) loss_x_re = hyperparameters['recon_x_w_re'] * ( self.loss_gen_recon_x_a_re + self.loss_gen_recon_x_b_re) #loss_x_res = hyperparameters['recon_x_w_res'] * (self.loss_gen_recon_res_a + self.loss_gen_recon_res_b) #loss_s = hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b) #loss_c = hyperparameters['recon_c_w'] * (self.loss_gen_recon_c_a + self.loss_gen_recon_c_b) loss_s_prime = hyperparameters['recon_s_w'] * ( self.loss_gen_recon_s_a_prime + self.loss_gen_recon_s_b_prime) loss_c_real = hyperparameters['recon_c_w'] * ( self.loss_gen_recon_c_a_real + self.loss_gen_recon_c_b_real) loss_x_cyc = hyperparameters['recon_x_cyc_w'] * ( self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b) #loss_vgg = hyperparameters['vgg_w'] * (self.loss_gen_vgg_a + self.loss_gen_vgg_b) print( '||total:%.2f||gan_real:%.2f||x:%.2f||x_re:%.2f||s_prime:%.4f||c_real:%.2f||x_cyc:%.4f||' % (self.loss_gen_total, loss_gan_real, loss_x, loss_x_re, loss_s_prime, loss_c_real, loss_x_cyc)) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() # s_a1 = Variable(self.s_a) # s_b1 = Variable(self.s_b) # s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) #x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] x_a_recon, x_b_recon, x_bab, x_ab, x_ba, x_aba = [], [], [], [], [], [] for i in range(x_a.size(0)): c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append( self.gen_a.decode(c_a, s_a_fake, x_a[i].unsqueeze(0))) x_b_recon.append( self.gen_b.decode(c_b, s_b_fake, x_b[i].unsqueeze(0))) x_ba_tmp = self.gen_a.decode(c_b, s_a_fake, x_b[i].unsqueeze(0)) x_ab_tmp = self.gen_b.decode(c_a, s_b_fake, x_a[i].unsqueeze(0)) x_ba.append(x_ba_tmp) x_ab.append(x_ab_tmp) c_b_recon, _ = self.gen_a.encode(x_ba_tmp) c_a_recon, _ = self.gen_b.encode(x_ab_tmp) x_aba.append( self.gen_a.decode(c_a_recon, s_a_fake, x_a[i].unsqueeze(0))) x_bab.append( self.gen_b.decode(c_b_recon, s_b_fake, x_b[i].unsqueeze(0))) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ab, x_ba = torch.cat(x_ab), torch.cat(x_ba) self.train() return x_a, x_a_recon, x_ab, x_ba, x_b, x_aba, x_b, x_b_recon, x_ba, x_ab, x_a, x_bab def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) # s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # # encode c_a, s_a_prime = self.gen_a.encode(x_a) c_b, s_b_prime = self.gen_b.encode(x_b) # decode (cross domain) x_ba_real = self.gen_a.decode(c_b, s_a_prime, x_b) x_ab_real = self.gen_b.decode(c_a, s_b_prime, x_a) # D loss # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba_randn.detach(), x_a) # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab_randn.detach(), x_b) self.loss_dis_a_real = self.dis_a.calc_dis_loss( x_ba_real.detach(), x_a) self.loss_dis_b_real = self.dis_b.calc_dis_loss( x_ab_real.detach(), x_b) self.loss_dis_total = hyperparameters['gan_w'] * ( self.loss_dis_a_real + self.loss_dis_b_real) if self.fp16: with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: scaled_loss.backward() self.dis_opt.step() else: self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations - 375000) #fine_tune -370000 self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations - 375000) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class DGNet_Trainer(nn.Module): def __init__(self, hyperparameters, gpu_ids=[0]): super(DGNet_Trainer, self).__init__() # 从配置文件获取生成模型和鉴别模型的学习率 lr_g = hyperparameters['lr_g'] lr_d = hyperparameters['lr_d'] # ID 类别 ID_class = hyperparameters['ID_class'] # 是否设置使用fp16, if not 'apex' in hyperparameters.keys(): hyperparameters['apex'] = False self.fp16 = hyperparameters['apex'] # Initiate the networks # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False. self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16=False) # auto-encoder for domain a self.gen_b = self.gen_a # auto-encoder for domain b ''' ft_netAB : Ea ''' # ID_stride: 外观编码器池化层的stride if not 'ID_stride' in hyperparameters.keys(): hyperparameters['ID_stride'] = 2 # id_a : 外观编码器 -> Ea if hyperparameters['ID_style'] == 'PCB': self.id_a = PCB(ID_class) elif hyperparameters['ID_style'] == 'AB': self.id_a = ft_netAB(ID_class, stride=hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) else: self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now self.id_b = self.id_a # 对图片b的操作与图片a的操作一致 # 判别器,使用的是一个多尺寸的判别器,就是对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失 # 经过网络3个缩放,,分别为:[batch_size, 1, 64, 32],[batch_size, 1, 32, 16],[batch_size, 1, 16, 8] self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16=False) # discriminator for domain a self.dis_b = self.dis_a # discriminator for domain b # load teachers if hyperparameters['teacher'] != "": teacher_name = hyperparameters['teacher'] print(teacher_name) # 加载多个老师模型 teacher_names = teacher_name.split(',') # 构建老师模型 teacher_model = nn.ModuleList() # 初始化为空,接下来开始填充 teacher_count = 0 for teacher_name in teacher_names: config_tmp = load_config(teacher_name) # 池化层的stride if 'stride' in config_tmp: stride = config_tmp['stride'] else: stride = 2 # 开始搭建网络 model_tmp = ft_net(ID_class, stride=stride) teacher_model_tmp = load_network(model_tmp, teacher_name) teacher_model_tmp.model.fc = nn.Sequential( ) # remove the original fc layer in ImageNet teacher_model_tmp = teacher_model_tmp.cuda() # teacher_model_tmp,[3, 224, 224] # 使用fp16 if self.fp16: teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1") teacher_model.append(teacher_model_tmp.cuda().eval( )) # 第一个填充为 teacher_model_tmp.cuda().eval() teacher_count += 1 self.teacher_model = teacher_model # 是否使用batchnorm if hyperparameters['train_bn']: self.teacher_model = self.teacher_model.apply(train_bn) # 实例正则化 self.instancenorm = nn.InstanceNorm2d(512, affine=False) # RGB to one channel # 因为Es 需要使用灰度图, 所以single 用来将图片转化为灰度图 if hyperparameters['single'] == 'edge': self.single = to_edge else: self.single = to_gray(False) # Random Erasing when training # arasing_p 随机擦除的概率 if not 'erasing_p' in hyperparameters.keys(): self.erasing_p = 0 else: self.erasing_p = hyperparameters['erasing_p'] # 对图片中的某一随机区域进行擦除,具体:将该区域的像素值设置为均值 self.single_re = RandomErasing(probability=self.erasing_p, mean=[0.0, 0.0, 0.0]) if not 'T_w' in hyperparameters.keys(): hyperparameters['T_w'] = 1 # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list( self.dis_a.parameters()) #+ list(self.dis_b.parameters()) gen_params = list( self.gen_a.parameters()) #+ list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # id params # 修改 id_a模型中分类器的学习率 if hyperparameters['ID_style'] == 'PCB': ignored_params = ( list(map(id, self.id_a.classifier0.parameters())) + list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters())) + list(map(id, self.id_a.classifier3.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier0.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier3.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) elif hyperparameters['ID_style'] == 'AB': ignored_params = ( list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) else: ignored_params = list(map(id, self.id_a.classifier.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) # 生成器和判别器中的优化策略(学习率的更新策略) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.id_scheduler = get_scheduler(self.id_opt, hyperparameters) self.id_scheduler.gamma = hyperparameters['gamma2'] #ID Loss self.id_criterion = nn.CrossEntropyLoss() self.criterion_teacher = nn.KLDivLoss( size_average=False) # 生成主要特征: Lprim # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # save memory # 保存当前的模型,是为了提高计算效率 if self.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.id_a = self.id_a.cuda() self.gen_b = self.gen_a self.dis_b = self.dis_a self.id_b = self.id_a self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1") def to_re(self, x): out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3)) out = out.cuda() for i in range(x.size(0)): out[i, :, :, :] = self.single_re(x[i, :, :, :]) # 修改对应像素值 return out def recon_criterion(self, input, target): # 重构损失函数 diff = input - target.detach() # 对应像素之间相减 return torch.mean(torch.abs(diff[:])) def recon_criterion_sqrt(self, input, target): # 重构损失平方函数 diff = input - target return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8)) def recon_criterion2(self, input, target): # 重构损失平方求均值 diff = input - target return torch.mean(diff[:]**2) def recon_cos(self, input, target): # 重构均值余弦相似度损失 cos = torch.nn.CosineSimilarity() cos_dis = 1 - cos(input, target) return torch.mean(cos_dis[:]) def forward(self, x_a, x_b, xp_a, xp_b): ''' 一共输入4张图片 :param x_a: :param xp_a: id 相同 :param x_b: :param xp_b: id 相同 为什么要输入四张图片: 因为一个完整的DG_Net输入需要三张图片:id1, id2, id1正例 如果一次输入3张图片,那么训练两组数据就需要6张图片 而如果一次输入四张图片如:id1,id1正例, id2,id2正例 那么就可以组成两组数据:id1,id2,id1正例 和 id2,id1,d2正例 这样就节省了两张图片。 ''' # self.gen_a.encode :-> Es # single : 转化为灰度图 s_a = self.gen_a.encode(self.single( x_a)) # shape: [batch_size, 128, 64, 32] -> a st code s_b = self.gen_b.encode(self.single( x_b)) # shape: [batch_size, 128, 64, 32] -> b st code # self.id_a : -> Ea f_a, p_a = self.id_a( scale2(x_a)) # -> a ap code f_b, p_b = self.id_b(scale2(x_b)) # f shape:[batch_size, 2024*4=8192] # -> b ap code # p[0] shape:[batch_size, class_num=751], p[1] shape:[batch_size, class_num=751] -> probability distribution # self.gen_a.decode -> D x_ba = self.gen_a.decode( s_b, f_a) # shape: [batch_size, 3, 256, 128] -> a-ap + b-st x_ab = self.gen_b.decode( s_a, f_b) # shape: [batch_size, 3, 256, 128] -> a-st + b-ap x_a_recon = self.gen_a.decode( s_a, f_a) # shape: [batch_size, 3, 256, 128] -> a-ap + a-st x_b_recon = self.gen_b.decode( s_b, f_b) # shape: [batch_size, 3, 256, 128] -> b-ap + b-st fp_a, pp_a = self.id_a( scale2(xp_a) ) # -> x_a ap code, pro-dis fp_b, pp_b = self.id_b( scale2(xp_b) ) # -> x_b ap code, pro-dis # decode the same person x_a_recon_p = self.gen_a.decode( s_a, fp_a) # shape: [batch_size, 3, 256, 128] -> a-st + x_a-ap x_b_recon_p = self.gen_b.decode( s_b, fp_b) # shape: [batch_size, 3, 256, 128] -> b-st + x_b-ap # Random Erasing only effect the ID and PID loss. if self.erasing_p > 0: x_a_re = self.to_re(scale2(x_a.clone())) x_b_re = self.to_re(scale2(x_b.clone())) xp_a_re = self.to_re(scale2(xp_a.clone())) xp_b_re = self.to_re(scale2(xp_b.clone())) _, p_a = self.id_a(x_a_re) # 经过随机擦除之后再预测概率分布 _, p_b = self.id_b(x_b_re) # encode the same ID different photo _, pp_a = self.id_a(xp_a_re) _, pp_b = self.id_b(xp_b_re) return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p ''' 输入3张图片训练一次 s_a = self.gen_a.encode(self.single(x_a)) f_a, p_a = self.id_a(scale2(x_a)) f_b, p_b = self.id_b(scale2(x_b)) fp_a, pp_a = self.id_a(scale2(xp_a)) x_a_recon = self.gen_a.decode(s_a, f_a) x_ab = self.gen_b.decode(s_a, f_b) x_a_recon_p = self.gen_a.decode(s_a, fp_a) 输入3张图片训练一次 s_b = self.gen_b.encode(self.single(x_b)) f_a, p_a = self.id_a(scale2(x_a)) f_b, p_b = self.id_b(scale2(x_b)) fp_b, pp_b = self.id_b(scale2(xp_b)) x_ba = self.gen_a.decode(s_b, f_a) x_b_recon_p = self.gen_b.decode(s_b, fp_b) x_b_recon_p = self.gen_b.decode(s_b, fp_b) ''' def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b, xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu): # ppa, ppb is the same person # pp_a: 输入图片a经过Ea编码进行身份预测 pp_b:输入图片b经过Ea编码进行身份预测 self.gen_opt.zero_grad() self.id_opt.zero_grad() # no gradient x_ba_copy = Variable(x_ba.data, requires_grad=False) x_ab_copy = Variable(x_ab.data, requires_grad=False) rand_num = random.uniform(0, 1) ################################# # encode structure if hyperparameters['use_encoder_again'] >= rand_num: # encode again (encoder is tuned, input is fixed) s_a_recon = self.gen_b.enc_content( self.single(x_ab_copy)) # 对x_ab经过Es进行编码 得到st code s_b_recon = self.gen_a.enc_content( self.single(x_ba_copy)) # 对x_ba经过Es进行编码 得到st code else: # copy the encoder # 这里是shencopy self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content) self.enc_content_copy = self.enc_content_copy.eval() # encode again (encoder is fixed, input is tuned) s_a_recon = self.enc_content_copy(self.single(x_ab)) s_b_recon = self.enc_content_copy(self.single(x_ba)) ################################# # encode appearance self.id_a_copy = copy.deepcopy(self.id_a) self.id_a_copy = self.id_a_copy.eval() if hyperparameters['train_bn']: self.id_a_copy = self.id_a_copy.apply(train_bn) self.id_b_copy = self.id_a_copy # encode again (encoder is fixed, input is tuned) f_a_recon, p_a_recon = self.id_a_copy( scale2(x_ba)) # 对合成的图片 x_ba进行Ea编码和身份预测 f_b_recon, p_b_recon = self.id_b_copy( scale2(x_ab)) # 对合成的图片 x_ab进行Ea编码和身份预测 # teacher Loss # Tune the ID model log_sm = nn.LogSoftmax(dim=1) if hyperparameters['teacher_w'] > 0 and hyperparameters[ 'teacher'] != "": if hyperparameters['ID_style'] == 'normal': _, p_a_student = self.id_a(scale2(x_ba_copy)) p_a_student = log_sm(p_a_student) p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) _, p_b_student = self.id_b(scale2(x_ab_copy)) p_b_student = log_sm(p_b_student) p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) elif hyperparameters['ID_style'] == 'AB': # normal teacher-student loss # BA -> LabelA(smooth) + LabelB(batchB) # 合成的图片经过身份鉴别器,得到每个ID可能的概率 _, p_ba_student = self.id_a(scale2(x_ba_copy)) # f_a, s_b p_a_student = log_sm(p_ba_student[0]) # 两个身份预测的第一个预测值 with torch.no_grad(): p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size( 0) # 在老师模型监督下,x_ba身份预测损失 # 公式(8) _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a p_b_student = log_sm(p_ab_student[0]) with torch.no_grad(): p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size( 0) # 在老师模型监督下,x_ab身份预测损失 # 公式 (8) # branch b loss # here we give different label # 用Ea的第二个身份预测值计算身份预测损失, # 这就相当于是Ea输出两个向量,一个用来计算与老师模型的身份预测损失,另一个用来计算自身身份预测损失 loss_B = self.id_criterion( p_ba_student[1], l_b) + self.id_criterion( p_ab_student[1], l_a) # l_b 是b的label # 公式(9) self.loss_teacher = hyperparameters[ 'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B else: self.loss_teacher = 0.0 # auto-encoder image reconstruction self.loss_gen_recon_x_a = self.recon_criterion( x_a_recon, x_a) # x_a_recon, a 的 ap 和 a 的 st # 公式 (1) self.loss_gen_recon_x_b = self.recon_criterion( x_b_recon, x_b) # x_b_recon, b 的 ap 和 b 的 st # 公式 (1) self.loss_gen_recon_xp_a = self.recon_criterion( x_a_recon_p, x_a) # x_a_recon_p, a 的 st 和 pos_a 的 ap # 公式 (2) self.loss_gen_recon_xp_b = self.recon_criterion( x_b_recon_p, x_b) # x_b_recon_p, b 的 st 和 pos_b 的 ap # 公式 (2) # feature reconstruction self.loss_gen_recon_s_a = self.recon_criterion( s_a_recon, s_a) if hyperparameters[ 'recon_s_w'] > 0 else 0 # s_a_recon, 合成图片x_ab 的st # 公式 (5) self.loss_gen_recon_s_b = self.recon_criterion( s_b_recon, s_b) if hyperparameters[ 'recon_s_w'] > 0 else 0 # s_b_recon, 合成图片x_ba 的st # 公式 (5) self.loss_gen_recon_f_a = self.recon_criterion( f_a_recon, f_a) if hyperparameters[ 'recon_f_w'] > 0 else 0 # f_a_recon, 合成图片x_ba 的ap # 公式 (4) self.loss_gen_recon_f_b = self.recon_criterion( f_b_recon, f_b) if hyperparameters[ 'recon_f_w'] > 0 else 0 # f_b_recon, 合成图片x_ab 的ap # 公式 (4) x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters[ 'recon_x_cyc_w'] > 0 else None # x_aba,ab 的 st 与 ba 的 ap x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters[ 'recon_x_cyc_w'] > 0 else None # x_bab,ba 的 st 与 ab 的 ap # ID loss AND Tune the Generated image if hyperparameters['ID_style'] == 'PCB': self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b) self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b) self.loss_gen_recon_id = self.PCB_loss( p_a_recon, l_a) + self.PCB_loss( p_b_recon, l_b) # x_ba 与l_a, x_ab 与l_b 的身份预测损失 elif hyperparameters['ID_style'] == 'AB': weight_B = hyperparameters['teacher_w'] * hyperparameters[ 'B_w'] # teather_w = 1.0, B_w = 0.2 self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) ) # a和b的身份预测损失 # 公式(3) self.loss_pid = self.id_criterion( pp_a[0], l_a) + self.id_criterion( pp_b[0], l_b) # pos_a 和 pos_b 的身份预测损失 # 公式(3) self.loss_gen_recon_id = self.id_criterion( p_a_recon[0], l_a) + self.id_criterion( p_b_recon[0], l_b) # 不太懂为什么用了b的st 却要判定为a的label 公式(7) else: self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion( p_b, l_b) self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion( pp_b, l_b) self.loss_gen_recon_id = self.id_criterion( p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b) #print(f_a_recon, f_a) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters[ 'recon_x_cyc_w'] > 0 else 0 # x_aba,ab 的 st 与 ba 的 ap self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters[ 'recon_x_cyc_w'] > 0 else 0 # x_bab,ba 的 st 与 ab 的 ap # GAN loss if num_gpu > 1: self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss( self.dis_a, x_ba) # 公式(6) self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss( self.dis_b, x_ab) # 公式(6) else: self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab) # domain-invariant perceptual loss # 使用vgg,对合成图片和真实图片进行特征提取,然后计算两个特征loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # 每个loss所占的权重 if iteration > hyperparameters['warm_iter']: hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] hyperparameters['recon_x_cyc_w'] = min( hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) if iteration > hyperparameters['warm_teacher_iter']: hyperparameters['teacher_w'] += hyperparameters['warm_scale'] hyperparameters['teacher_w'] = min( hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['id_w'] * self.loss_id + \ hyperparameters['pid_w'] * self.loss_pid + \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['teacher_w'] * self.loss_teacher # 增大计算效率 if self.fp16: with amp.scale_loss(self.loss_gen_total, [self.gen_opt, self.id_opt]) as scaled_loss: scaled_loss.backward() self.gen_opt.step() self.id_opt.step() else: self.loss_gen_total.backward() # 后向传播 self.gen_opt.step() self.id_opt.step() print("L_total: %.4f, L_gan: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \ hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ hyperparameters['id_w'] * self.loss_id,\ hyperparameters['pid_w'] * self.loss_pid,\ hyperparameters['teacher_w'] * self.loss_teacher ) ) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def PCB_loss(self, inputs, labels): loss = 0.0 for part in inputs: loss += self.id_criterion(part, labels) return loss / len(inputs) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], [] for i in range(x_a.size(0)): s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0))) s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0))) f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0))) f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0))) x_a_recon.append(self.gen_a.decode(s_a, f_a)) x_b_recon.append(self.gen_b.decode(s_b, f_b)) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) x_ba1.append(x_ba) x_ab1.append(x_ab) #cycle s_b_recon = self.gen_a.enc_content(self.single(x_ba)) s_a_recon = self.gen_b.enc_content(self.single(x_ab)) f_a_recon, _ = self.id_a(scale2(x_ba)) f_b_recon, _ = self.id_b(scale2(x_ab)) x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon)) x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1) self.train() return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1 def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters, num_gpu): # 对判别器进行更新 self.dis_opt.zero_grad() # D loss if num_gpu > 1: self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss( self.dis_a, x_ba.detach(), x_a) # lsgan 损失 self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss( self.dis_b, x_ab.detach(), x_b) else: self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss( self.dis_a, x_ba.detach(), x_a) self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss( self.dis_b, x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b print("DLoss: %.4f" % self.loss_dis_total, "Reg: %.4f" % (reg_a + reg_b)) if self.fp16: with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: scaled_loss.backward() else: self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): # 调整学习率 if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.id_scheduler is not None: self.id_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # load 网络 # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b = self.gen_a iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b = self.dis_a # Load ID dis last_model_name = get_model_list(checkpoint_dir, "id") state_dict = torch.load(last_model_name) self.id_a.load_state_dict(state_dict['a']) self.id_b = self.id_a # Load optimizers try: state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.id_opt.load_state_dict(state_dict['id']) except: pass # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations, num_gpu=1): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict()}, gen_name) if num_gpu > 1: torch.save({'a': self.dis_a.module.state_dict()}, dis_name) else: torch.save({'a': self.dis_a.state_dict()}, dis_name) torch.save({'a': self.id_a.state_dict()}, id_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'id': self.id_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class DGNet_Trainer(nn.Module): def __init__(self, hyperparameters): super(DGNet_Trainer, self).__init__() lr_g = hyperparameters['lr_g'] lr_d = hyperparameters['lr_d'] ID_class = hyperparameters['ID_class'] if not 'apex' in hyperparameters.keys(): hyperparameters['apex'] = False self.fp16 = hyperparameters['apex'] # Initiate the networks # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False. self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16=False) # auto-encoder for domain a self.gen_b = self.gen_a # auto-encoder for domain b if not 'ID_stride' in hyperparameters.keys(): hyperparameters['ID_stride'] = 2 if hyperparameters['ID_style'] == 'PCB': self.id_a = PCB(ID_class) elif hyperparameters['ID_style'] == 'AB': self.id_a = ft_netAB(ID_class, stride=hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) else: self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now self.id_b = self.id_a self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16=False) # discriminator for domain a self.dis_b = self.dis_a # discriminator for domain b # load teachers if hyperparameters['teacher'] != "": teacher_name = hyperparameters['teacher'] print(teacher_name) teacher_names = teacher_name.split(',') teacher_model = nn.ModuleList() teacher_count = 0 for teacher_name in teacher_names: config_tmp = load_config(teacher_name) if 'stride' in config_tmp: stride = config_tmp['stride'] else: stride = 2 model_tmp = ft_net(ID_class, stride=stride) teacher_model_tmp = load_network(model_tmp, teacher_name) teacher_model_tmp.model.fc = nn.Sequential( ) # remove the original fc layer in ImageNet teacher_model_tmp = teacher_model_tmp.cuda() if self.fp16: teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1") teacher_model.append(teacher_model_tmp.cuda().eval()) teacher_count += 1 self.teacher_model = teacher_model if hyperparameters['train_bn']: self.teacher_model = self.teacher_model.apply(train_bn) self.instancenorm = nn.InstanceNorm2d(512, affine=False) display_size = int(hyperparameters['display_size']) # RGB to one channel if hyperparameters['single'] == 'edge': self.single = to_edge else: self.single = to_gray(False) # Random Erasing when training if not 'erasing_p' in hyperparameters.keys(): hyperparameters['erasing_p'] = 0 self.single_re = RandomErasing( probability=hyperparameters['erasing_p'], mean=[0.0, 0.0, 0.0]) if not 'T_w' in hyperparameters.keys(): hyperparameters['T_w'] = 1 # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list( self.dis_a.parameters()) #+ list(self.dis_b.parameters()) gen_params = list( self.gen_a.parameters()) #+ list(self.gen_b.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # id params if hyperparameters['ID_style'] == 'PCB': ignored_params = ( list(map(id, self.id_a.classifier0.parameters())) + list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters())) + list(map(id, self.id_a.classifier3.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier0.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier3.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) elif hyperparameters['ID_style'] == 'AB': ignored_params = ( list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) else: ignored_params = list(map(id, self.id_a.classifier.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.id_scheduler = get_scheduler(self.id_opt, hyperparameters) self.id_scheduler.gamma = hyperparameters['gamma2'] #ID Loss self.id_criterion = nn.CrossEntropyLoss() self.criterion_teacher = nn.KLDivLoss(size_average=False) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # save memory if self.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.id_a = self.id_a.cuda() self.gen_b = self.gen_a self.dis_b = self.dis_a self.id_b = self.id_a self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1") def to_re(self, x): out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3)) out = out.cuda() for i in range(x.size(0)): out[i, :, :, :] = self.single_re(x[i, :, :, :]) return out def recon_criterion(self, input, target): diff = input - target.detach() return torch.mean(torch.abs(diff[:])) def recon_criterion_sqrt(self, input, target): diff = input - target return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8)) def recon_criterion2(self, input, target): diff = input - target return torch.mean(diff[:]**2) def recon_cos(self, input, target): cos = torch.nn.CosineSimilarity() cos_dis = 1 - cos(input, target) return torch.mean(cos_dis[:]) def forward(self, x_a, x_b): self.eval() s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) f_a, _ = self.id_a(scale2(x_a)) f_b, _ = self.id_b(scale2(x_b)) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) self.train() return x_ab, x_ba def gen_update(self, x_a, l_a, xp_a, x_b, l_b, xp_b, hyperparameters, iteration): # ppa, ppb is the same person self.gen_opt.zero_grad() self.id_opt.zero_grad() # encode s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) f_a, p_a = self.id_a(scale2(x_a)) f_b, p_b = self.id_b(scale2(x_b)) # autodecode x_a_recon = self.gen_a.decode(s_a, f_a) x_b_recon = self.gen_b.decode(s_b, f_b) # encode the same ID different photo fp_a, pp_a = self.id_a(scale2(xp_a)) fp_b, pp_b = self.id_b(scale2(xp_b)) # decode the same person x_a_recon_p = self.gen_a.decode(s_a, fp_a) x_b_recon_p = self.gen_b.decode(s_b, fp_b) # has gradient x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) # no gradient x_ba_copy = Variable(x_ba.data, requires_grad=False) x_ab_copy = Variable(x_ab.data, requires_grad=False) rand_num = random.uniform(0, 1) ################################# # encode structure if hyperparameters['use_encoder_again'] >= rand_num: # encode again (encoder is tuned, input is fixed) s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy)) s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy)) else: # copy the encoder self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content) self.enc_content_copy = self.enc_content_copy.eval() # encode again (encoder is fixed, input is tuned) s_a_recon = self.enc_content_copy(self.single(x_ab)) s_b_recon = self.enc_content_copy(self.single(x_ba)) ################################# # encode appearance self.id_a_copy = copy.deepcopy(self.id_a) self.id_a_copy = self.id_a_copy.eval() if hyperparameters['train_bn']: self.id_a_copy = self.id_a_copy.apply(train_bn) self.id_b_copy = self.id_a_copy # encode again (encoder is fixed, input is tuned) f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba)) f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab)) # teacher Loss # Tune the ID model log_sm = nn.LogSoftmax(dim=1) if hyperparameters['teacher_w'] > 0 and hyperparameters[ 'teacher'] != "": if hyperparameters['ID_style'] == 'normal': _, p_a_student = self.id_a(scale2(x_ba_copy)) p_a_student = log_sm(p_a_student) p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy)) self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) _, p_b_student = self.id_b(scale2(x_ab_copy)) p_b_student = log_sm(p_b_student) p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy)) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) elif hyperparameters['ID_style'] == 'AB': # normal teacher-student loss # BA -> LabelA(smooth) + LabelB(batchB) _, p_ba_student = self.id_a(scale2(x_ba_copy)) # f_a, s_b p_a_student = log_sm(p_ba_student[0]) with torch.no_grad(): p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a p_b_student = log_sm(p_ab_student[0]) with torch.no_grad(): p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) # branch b loss # here we give different label loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion( p_ab_student[1], l_a) self.loss_teacher = hyperparameters[ 'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B else: self.loss_teacher = 0.0 # decode again (if needed) if hyperparameters['use_decoder_again']: x_aba = self.gen_a.decode( s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None else: self.mlp_w_copy = copy.deepcopy(self.gen_a.mlp_w) self.mlp_b_copy = copy.deepcopy(self.gen_a.mlp_b) self.dec_copy = copy.deepcopy(self.gen_a.dec) # Error ID = f_a_recon ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) adain_params_w = self.mlp_w_copy(ID_Style) adain_params_b = self.mlp_b_copy(ID_Style) self.gen_a.assign_adain_params(adain_params_w, adain_params_b, self.dec_copy) x_aba = self.dec_copy( s_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None ID = f_b_recon ID_Style = ID.view(ID.shape[0], ID.shape[1], 1, 1) adain_params_w = self.mlp_w_copy(ID_Style) adain_params_b = self.mlp_b_copy(ID_Style) self.gen_a.assign_adain_params(adain_params_w, adain_params_b, self.dec_copy) x_bab = self.dec_copy( s_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # auto-encoder image reconstruction self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a) self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b) # feature reconstruction self.loss_gen_recon_s_a = self.recon_criterion( s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_s_b = self.recon_criterion( s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_f_a = self.recon_criterion( f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0 self.loss_gen_recon_f_b = self.recon_criterion( f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0 # Random Erasing only effect the ID and PID loss. if hyperparameters['erasing_p'] > 0: x_a_re = self.to_re(scale2(x_a.clone())) x_b_re = self.to_re(scale2(x_b.clone())) xp_a_re = self.to_re(scale2(xp_a.clone())) xp_b_re = self.to_re(scale2(xp_b.clone())) _, p_a = self.id_a(x_a_re) _, p_b = self.id_b(x_b_re) # encode the same ID different photo _, pp_a = self.id_a(xp_a_re) _, pp_b = self.id_b(xp_b_re) # ID loss AND Tune the Generated image if hyperparameters['ID_style'] == 'PCB': self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b) self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b) self.loss_gen_recon_id = self.PCB_loss( p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b) elif hyperparameters['ID_style'] == 'AB': weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w'] self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) ) self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion( pp_b[0], l_b ) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) ) self.loss_gen_recon_id = self.id_criterion( p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b) else: self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion( p_b, l_b) self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion( pp_b, l_b) self.loss_gen_recon_id = self.id_criterion( p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b) #print(f_a_recon, f_a) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 if iteration > hyperparameters['warm_iter']: hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] hyperparameters['recon_x_cyc_w'] = min( hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) if iteration > hyperparameters['warm_teacher_iter']: hyperparameters['teacher_w'] += hyperparameters['warm_scale'] hyperparameters['teacher_w'] = min( hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['id_w'] * self.loss_id + \ hyperparameters['pid_w'] * self.loss_pid + \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['teacher_w'] * self.loss_teacher if self.fp16: with amp.scale_loss(self.loss_gen_total, [self.gen_opt, self.id_opt]) as scaled_loss: scaled_loss.backward() self.gen_opt.step() self.id_opt.step() else: self.loss_gen_total.backward() self.gen_opt.step() self.id_opt.step() print("L_total: %.4f, L_gan: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \ hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ hyperparameters['id_w'] * self.loss_id,\ hyperparameters['pid_w'] * self.loss_pid,\ hyperparameters['teacher_w'] * self.loss_teacher ) ) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def PCB_loss(self, inputs, labels): loss = 0.0 for part in inputs: loss += self.id_criterion(part, labels) return loss / len(inputs) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], [] for i in range(x_a.size(0)): s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0))) s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0))) f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0))) f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0))) x_a_recon.append(self.gen_a.decode(s_a, f_a)) x_b_recon.append(self.gen_b.decode(s_b, f_b)) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) x_ba1.append(x_ba) x_ab1.append(x_ab) #cycle s_b_recon = self.gen_a.enc_content(self.single(x_ba)) s_a_recon = self.gen_b.enc_content(self.single(x_ab)) f_a_recon, _ = self.id_a(scale2(x_ba)) f_b_recon, _ = self.id_b(scale2(x_ab)) x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon)) x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1) self.train() return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1 def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # encode s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) f_a, _ = self.id_a(scale2(x_a)) f_b, _ = self.id_b(scale2(x_b)) # decode (cross domain) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) # D loss self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b print("DLoss: %.4f" % self.loss_dis_total, "Reg: %.4f" % (reg_a + reg_b)) if self.fp16: with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: scaled_loss.backward() else: self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.id_scheduler is not None: self.id_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b = self.gen_a iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b = self.dis_a # Load ID dis last_model_name = get_model_list(checkpoint_dir, "id") state_dict = torch.load(last_model_name) self.id_a.load_state_dict(state_dict['a']) self.id_b = self.id_a # Load optimizers try: state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.id_opt.load_state_dict(state_dict['id']) except: pass # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict()}, dis_name) torch.save({'a': self.id_a.state_dict()}, id_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'id': self.id_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
class DGNet_Trainer(nn.Module): #初始化函数 def __init__(self, hyperparameters, gpu_ids=[0]): super(DGNet_Trainer, self).__init__() # 从配置文件获取生成模型的和鉴别模型的学习率 lr_g = hyperparameters['lr_g'] lr_d = hyperparameters['lr_d'] # # ID的类别,这里要注意,不同的数据集都是不一样的,应该是训练数据集的ID数目,非测试集 ID_class = hyperparameters['ID_class'] # 看是否设置使用float16,估计float16可以增加精确度 if not 'apex' in hyperparameters.keys(): hyperparameters['apex'] = False self.fp16 = hyperparameters['apex'] # Initiate the networks # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False. ################################################################################################################ ##这里是定义Es和G # 注意这里包含了两个步骤,Es编码+解码过程,既然解码(论文Figure 2的黄色梯形G)包含到这里了,下面Ea应该不会包含解码过程了 # 因为这里是一个类,如后续gen_a.encode()可以进行编码,gen_b.encode()可以进行解码 self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16=False) # auto-encoder for domain a self.gen_b = self.gen_a # auto-encoder for domain b ############################################################################################################################################ ############################################################################################################################################ ##这里是定义Ea # ID_stride,外观编码器池化层的stride if not 'ID_stride' in hyperparameters.keys(): hyperparameters['ID_stride'] = 2 # hyperparameters['ID_style']默认为'AB',论文中的Ea编码器 #这里是设置Ea,有三种模型可以选择 #PCB模型,ft_netAB为改造后的resnet50,ft_net为resnet50 if hyperparameters['ID_style'] == 'PCB': self.id_a = PCB(ID_class) elif hyperparameters['ID_style'] == 'AB': # 这是我们执行的模型,注意的是,id_a返回两个x(表示身份),获得f,具体介绍看函数内部 # 我们使用的是ft_netAB,是代码中Ea编码的过程,也就得到 ap code的过程,除了ap code还会得到两个分类结果 # 现在怀疑,该分类结果,可能就是行人重识别的结果 #ID_class表示有ID_class个不同ID的行人 self.id_a = ft_netAB(ID_class, stride=hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) else: self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now # 这里进行的是浅拷贝,所以我认为他们的权重是一起的,可以理解为一个 self.id_b = self.id_a ############################################################################################################################################################ ############################################################################################################################################################ ##这里是定义D # 鉴别器,行人重识别,这里使用的是一个多尺寸的鉴别器,大概就是说,对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失 # 经过网络3个元素,分别大小为[batch_size,1,64,32], [batch_size,1,32,16], [batch_size,1,16,8] self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16=False) # discriminator for domain a self.dis_b = self.dis_a # discriminator for domain b ############################################################################################################################################################ ############################################################################################################################################################ # load teachers # 加载老师模型 # teacher:老师模型名称。对于DukeMTMC,您可以设置“best - duke” if hyperparameters['teacher'] != "": #teacher_name=best teacher_name = hyperparameters['teacher'] print(teacher_name) #有这个操作,我怀疑是可以加载多个教师模型 teacher_names = teacher_name.split(',') #构建老师模型 teacher_model = nn.ModuleList() teacher_count = 0 # 默认只有一个teacher_name='teacher_name',所以其加载的模型配置文件为项目根目录models/best/opts.yaml模型 for teacher_name in teacher_names: # 加载配置文件models/best/opts.yaml config_tmp = load_config(teacher_name) if 'stride' in config_tmp: #stride=1 stride = config_tmp['stride'] else: stride = 2 # 老师模型加载,老师模型为ft_net为resnet50 model_tmp = ft_net(ID_class, stride=stride) teacher_model_tmp = load_network(model_tmp, teacher_name) # 移除原本的全连接层 teacher_model_tmp.model.fc = nn.Sequential( ) # remove the original fc layer in ImageNet teacher_model_tmp = teacher_model_tmp.cuda() # summary(teacher_model_tmp, (3, 224, 224)) #使用浮点型 if self.fp16: teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1") teacher_model.append(teacher_model_tmp.cuda().eval()) teacher_count += 1 self.teacher_model = teacher_model # 选择是否使用bn if hyperparameters['train_bn']: self.teacher_model = self.teacher_model.apply(train_bn) ############################################################################################################################################################ # 实例正则化 self.instancenorm = nn.InstanceNorm2d(512, affine=False) # RGB to one channel # 默认设置signal=gray,Es的输入为灰度图 if hyperparameters['single'] == 'edge': self.single = to_edge else: self.single = to_gray(False) # Random Erasing when training #earsing_p表示随机擦除的概率 if not 'erasing_p' in hyperparameters.keys(): self.erasing_p = 0 else: self.erasing_p = hyperparameters['erasing_p'] #随机擦除矩形区域的一些像素,应该类似于数据增强 self.single_re = RandomErasing(probability=self.erasing_p, mean=[0.0, 0.0, 0.0]) # 设置T_w为1,T_w为primary feature learning loss的权重系数 if not 'T_w' in hyperparameters.keys(): hyperparameters['T_w'] = 1 ################################################################################################ # Setup the optimizers # 设置优化器参数 beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list( self.dis_a.parameters()) #+ list(self.dis_b.parameters()) gen_params = list( self.gen_a.parameters()) #+ list(self.gen_b.parameters()) #使用Adams优化器,用Adams训练Es,G,D self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) # id params # 因为ID_style默认为AB,所以这里不执行 if hyperparameters['ID_style'] == 'PCB': ignored_params = ( list(map(id, self.id_a.classifier0.parameters())) + list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters())) + list(map(id, self.id_a.classifier3.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] #Ea 的优化器 self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier0.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier3.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) # 这里是我们执行的代码 elif hyperparameters['ID_style'] == 'AB': # 忽略的参数,应该是适用于'PCB'或者其他的,但是不适用于'AB'的 ignored_params = ( list(map(id, self.id_a.classifier1.parameters())) + list(map(id, self.id_a.classifier2.parameters()))) # 获得基本的配置参数,如学习率 base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] #对Ea使用SGD self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier1.parameters(), 'lr': lr2 * 10 }, { 'params': self.id_a.classifier2.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) else: ignored_params = list(map(id, self.id_a.classifier.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) lr2 = hyperparameters['lr2'] self.id_opt = torch.optim.SGD( [{ 'params': base_params, 'lr': lr2 }, { 'params': self.id_a.classifier.parameters(), 'lr': lr2 * 10 }], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) # 选择各个网络的优化 self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) self.id_scheduler = get_scheduler(self.id_opt, hyperparameters) self.id_scheduler.gamma = hyperparameters['gamma2'] #ID Loss #交叉熵损失函数 self.id_criterion = nn.CrossEntropyLoss() # KL散度 self.criterion_teacher = nn.KLDivLoss(size_average=False) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False # save memory if self.fp16: # Name the FP16_Optimizer instance to replace the existing optimizer assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." self.gen_a = self.gen_a.cuda() self.dis_a = self.dis_a.cuda() self.id_a = self.id_a.cuda() self.gen_b = self.gen_a self.dis_b = self.dis_a self.id_b = self.id_a self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1") def to_re(self, x): out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3)) out = out.cuda() for i in range(x.size(0)): out[i, :, :, :] = self.single_re(x[i, :, :, :]) return out # L1 loss,(差的绝对值) def recon_criterion(self, input, target): diff = input - target.detach() return torch.mean(torch.abs(diff[:])) #L1 loss 开根号((差的绝对值后开根号)) def recon_criterion_sqrt(self, input, target): diff = input - target return torch.mean(torch.sqrt(torch.abs(diff[:]) + 1e-8)) # L2 loss def recon_criterion2(self, input, target): diff = input - target return torch.mean(diff[:]**2) # cos loss def recon_cos(self, input, target): cos = torch.nn.CosineSimilarity() cos_dis = 1 - cos(input, target) return torch.mean(cos_dis[:]) # x_a,x_b, xp_a, xp_b[4, 3, 256, 128], # 第一个参数表示bitch size,第二个参数表示输入通道数,第三个参数表示输入图片的高度,第四个参数表示输入图片的宽度 def forward(self, x_a, x_b, xp_a, xp_b): #送入x_a,x_b两张图片(来自训练集不同ID) #通过st编码器,编码成两个stcode,structure code # s_a[batch,128,64,32] # s_b[batch,128,64,32] # single会根据参数设定判断是否转化为灰度图 s_a = self.gen_a.encode(self.single(x_a)) s_b = self.gen_b.encode(self.single(x_b)) # 先把图片进行下采样,图示我们可以看到ap code的体积比st code是要小的,这样会出现一个情况,那么他们是没有办法直接融合的,所以后面有个全链接成把他们统一 # f_a[batch_size,2024*4=8192], p_a[0]=[batch_size, class_num=751], p_a[1]=[batch_size, class_num=751] # f_b[batch_size,2024*4=8192], p_b[0]=[batch_size, class_num=751], p_b[1]=[batch_size, class_num=751] # f代表的是经过ap编码器得到的ap code, # p表示对身份的预测(有两个身份预测,也就是p_a了两个元素,这里不好解释), # 前面提到过,ap编码器,不仅负责编码,还要负责身份的预测(行人重识别),也是我们落实项目的关键所在 # 这里是第一个重难点,在论文的翻译中提到过,后续详细讲解 f_a, p_a = self.id_a(scale2(x_a)) f_b, p_b = self.id_b(scale2(x_b)) # 进行解码操作,就是Figure 2中的黄色梯形G操作,这里的x_a,与x_b进行衣服互换,不同ID # s_b[batch,128,64,32] f_a[batch_size,2028,4,1] --> x_ba[batch_size,3,256,128] x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) #同一张图片进行重构,相当于autoencoder x_a_recon = self.gen_a.decode(s_a, f_a) x_b_recon = self.gen_b.decode(s_b, f_b) fp_a, pp_a = self.id_a(scale2(xp_a)) fp_b, pp_b = self.id_b(scale2(xp_b)) # decode the same person #x_a,xp_a表示同ID的不同图片,以下即表示同ID不同图片的重构 x_a_recon_p = self.gen_a.decode(s_a, fp_a) x_b_recon_p = self.gen_b.decode(s_b, fp_b) # Random Erasing only effect the ID and PID loss. #把图片擦除一些像素,然后进行ap code编码 if self.erasing_p > 0: #先把每一张图片都擦除一些像素 x_a_re = self.to_re(scale2(x_a.clone())) x_b_re = self.to_re(scale2(x_b.clone())) xp_a_re = self.to_re(scale2(xp_a.clone())) xp_b_re = self.to_re(scale2(xp_b.clone())) # 然后经过编码成ap code,暂时不知道作用,感觉应该是数据增强 # 类似于,擦除了图片的一些像素,但是已经能够识别出来这些图片是谁 _, p_a = self.id_a(x_a_re) _, p_b = self.id_b(x_b_re) # encode the same ID different photo _, pp_a = self.id_a(xp_a_re) _, pp_b = self.id_b(xp_b_re) # 混合合成图片:x_ab[images_a的st,images_b的ap] 混合合成图片x_ba[images_b的st,images_a的ap] # s_a[输入图片images_a经过Es编码得到的 st code] s_b[输入图片images_b经过Es编码得到的 st code] # f_a[输入图片images_a经过Ea编码得到的 ap code] f_b[输入图片images_b经过Ea编码得到的 ap code] # p_a[输入图片images_a经过Ea编码进行身份ID的预测] p_b[输入图片images_b经过Ea编码进行身份ID的预测] # pp_a[输入图片pos_a经过Ea编码进行身份ID的预测] pp_b[输入图片pos_b经过Ea编码进行身份ID的预测] # x_a_recon[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样] # x_b_recon[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样] # x_a_recon_p[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样] # x_b_recon_p[输入图片images_a(s_a)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样] return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b, xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu): """ :param x_ab:[images_a的st,images_b的ap] :param x_ba:[images_b的st,images_a的ap] :param s_a:[输入图片images_a经过Es编码得到的 st code] :param s_b:[输入图片images_b经过Es编码得到的 st code] :param f_a:[输入图片images_a经过Ea编码得到的 ap code] :param f_b:[输入图片images_b经过Ea编码得到的 ap code] :param p_a:[输入图片images_a经过Ea编码进行身份ID的预测] :param p_b:[输入图片images_b经过Ea编码进行身份ID的预测] :param pp_a:[输入图片pos_a经过Ea编码进行身份ID的预测] :param pp_b:[输入图片pos_b经过Ea编码进行身份ID的预测] :param x_a_recon:[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样] :param x_b_recon:[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样] :param x_a_recon_p:[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样] :param x_b_recon_p:[输入图片images_b(s_b)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样] :param x_a:images_a :param x_b:images_b :param xp_a:pos_a :param xp_b:pos_b :param l_a:labels_a :param l_b:labels_b :param hyperparameters: :param iteration: :param num_gpu: :return: """ # ppa, ppb is the same person? self.gen_opt.zero_grad() #梯度清零 self.id_opt.zero_grad() # no gradient # 对合成x_ba和x_ab分别进行一份拷贝 x_ba_copy = Variable(x_ba.data, requires_grad=False) x_ab_copy = Variable(x_ab.data, requires_grad=False) rand_num = random.uniform(0, 1) ################################# # encode structure # enc_content是类ContentEncoder对象 if hyperparameters['use_encoder_again'] >= rand_num: # encode again (encoder is tuned, input is fixed) # Es编码得到s_a_recon与s_b_recon即st code # 如果是理想模型,s_a_recon=s_a, s_b_recon=s_b s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy)) s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy)) else: # copy the encoder # 这里的是深拷贝 #enc_content_copy=gen_a.enc_content self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content) self.enc_content_copy = self.enc_content_copy.eval() # encode again (encoder is fixed, input is tuned) s_a_recon = self.enc_content_copy(self.single(x_ab)) s_b_recon = self.enc_content_copy(self.single(x_ba)) ################################# # encode appearance #id_a_copy=id_a=Ea self.id_a_copy = copy.deepcopy(self.id_a) self.id_a_copy = self.id_a_copy.eval() if hyperparameters['train_bn']: self.id_a_copy = self.id_a_copy.apply(train_bn) self.id_b_copy = self.id_a_copy # encode again (encoder is fixed, input is tuned) # 对混合生成的图片x_ba,x_ab进行Es编码操作,同时对身份进行鉴别# # f_a_recon,f_b_recon表示的ap code,p_a_recon,p_b_recon表示对身份的鉴别 f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba)) f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab)) # teacher Loss # Tune the ID model log_sm = nn.LogSoftmax(dim=1) #如果使用了教师网络 #默认ID_style为AB if hyperparameters['teacher_w'] > 0 and hyperparameters[ 'teacher'] != "": if hyperparameters['ID_style'] == 'normal': #p_a_student表示x_ba_copy的身份编码,使用的是Ea进行身份编码,也就是使用学生模型进行身份编码 _, p_a_student = self.id_a(scale2(x_ba_copy)) #对p_a_student使用logsoftmax,输出结果为x_ba_copy像某张图片的概率(就是一个分布) p_a_student = log_sm(p_a_student) #使用教师模型对生成图像x_ba_copy进行分类,输出结果为x_ba_copy像某张图片的概率(就是一个分布) p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) #通过最小化KL散度损失函数,目的是让分布p_a_student与p_a_teacher尽可能的一致 self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) #对x_ab_copy进行同样的操作 _, p_b_student = self.id_b(scale2(x_ab_copy)) p_b_student = log_sm(p_b_student) p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) ####################################################################################################################################################################################################### # primary feature learning loss ####################################################################################################################################################################################################### # ID_style为AB elif hyperparameters['ID_style'] == 'AB': # normal teacher-student loss # BA -> LabelA(smooth) + LabelB(batchB) # 合成的图片经过身份鉴别器,得到每个ID可能性的概率,注意这里去的是p_ba_student[0],我们知有两个身份预测结果,这里只取了一个 # 并且赋值给了p_a_student,用于和教师模型结合的,共同计算损失 #p_a_student分为两个部分,p_a_student[0]表示L_prim,p_a_student[1]表示L_fine。 _, p_ba_student = self.id_a(scale2(x_ba_copy)) # f_a, s_b p_a_student = log_sm(p_ba_student[0]) with torch.no_grad(): ##使用教师模型对生成图像x_ba_copy进行分类,输出结果为x_ba_copy像某张图片(x_a/x_b)的概率(就是一个分布) p_a_teacher = predict_label( self.teacher_model, scale2(x_ba_copy), num_class=hyperparameters['ID_class'], alabel=l_a, slabel=l_b, teacher_style=hyperparameters['teacher_style']) # criterion_teacher = nn.KLDivLoss(size_average=False) # 计算离散距离,可以理解为p_a_student与p_a_teacher每个元素的距离之和,然后除以p_a_student.size(0)取平均值 # 就是说学生网络(Ea)的预测越与教师网络结果相同,则是最好的 self.loss_teacher = self.criterion_teacher( p_a_student, p_a_teacher) / p_a_student.size(0) # 对另一张合成图片进行同样的操作 _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a p_b_student = log_sm(p_ab_student[0]) with torch.no_grad(): p_b_teacher = predict_label( self.teacher_model, scale2(x_ab_copy), num_class=hyperparameters['ID_class'], alabel=l_b, slabel=l_a, teacher_style=hyperparameters['teacher_style']) self.loss_teacher += self.criterion_teacher( p_b_student, p_b_teacher) / p_b_student.size(0) ######################################################################################################################################################################################################## ######################################################################################################################################################################################################## #fine—grained feature mining loss ######################################################################################################################################################################################################## # branch b loss # here we give different label # p_ba_student[1]表示的是f_fine特征,l_b表示的是images_b,即为生成图像提供st code 的图片 loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion( p_ab_student[1], l_a) ####################################################################################################################################################################################################### # 对两部分损失进行权重调整 self.loss_teacher = hyperparameters[ 'T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B else: self.loss_teacher = 0.0 ## 剩下的就是重构图像之间的损失了 # 前面提到,重构和合成是不一样的,重构是构建出来和原来图片一样的图片 # 所以也就是可以把重构的图片和原来的图像直接计算像素直接的插值 # 但是合成的图片是没有办法的,因为训练数据集是没有合成图片的,所以,没有办法计算像素之间的损失 # ####################################################################################################################################################################################################### # auto-encoder image reconstruction # 同ID图像进行重构时的损失函数 self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a) self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b) # ####################################################################################################################################################################################################### ####################################################################################################################################################################################################### # feature reconstruction # 不同ID图像进行图像合成时,为了保证合成图像的st code和ap code与为合成图像提供st code 和 ap code保持一致所使用的损失函数 self.loss_gen_recon_s_a = self.recon_criterion( s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_s_b = self.recon_criterion( s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0 self.loss_gen_recon_f_a = self.recon_criterion( f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0 self.loss_gen_recon_f_b = self.recon_criterion( f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0 # ####################################################################################################################################################################################################### # 又一次进行图像合成 x_aba = self.gen_a.decode( s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # ID loss AND Tune the Generated image if hyperparameters['ID_style'] == 'PCB': self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b) self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b) self.loss_gen_recon_id = self.PCB_loss( p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b) ######################################################################################################################################################################################################## # 使用的是 ['ID_style']=='AB' elif hyperparameters['ID_style'] == 'AB': weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w'] #计算的是L^s_id self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) ) #对同ID不同图片计算L^s_id self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion( pp_b[0], l_b ) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) ) # 对生成图像计算L^C_id self.loss_gen_recon_id = self.id_criterion( p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b) ######################################################################################################################################################################################################## else: self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion( p_b, l_b) self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion( pp_b, l_b) self.loss_gen_recon_id = self.id_criterion( p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b) #print(f_a_recon, f_a) self.loss_gen_cycrecon_x_a = self.recon_criterion( x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion( x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 ######################################################################################################################################################################################################## # GAN loss #计算生成器G的对抗损失函数 ######################################################################################################################################################################################################## if num_gpu > 1: self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss( self.dis_a, x_ba) self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss( self.dis_b, x_ab) else: self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab) ######################################################################################################################################################################################################## # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 if iteration > hyperparameters['warm_iter']: hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] hyperparameters['recon_x_cyc_w'] = min( hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) if iteration > hyperparameters['warm_teacher_iter']: hyperparameters['teacher_w'] += hyperparameters['warm_scale'] hyperparameters['teacher_w'] = min( hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) # total loss,计算总的loss #1个teacher loss+4个同ID图片重构loss+4个不同ID图片合成loss++3个ID loss+2个生成器loss、 #teacher loss包括了primary feature learning loss和fine_grain mining loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \ hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ hyperparameters['id_w'] * self.loss_id + \ hyperparameters['pid_w'] * self.loss_pid + \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['teacher_w'] * self.loss_teacher if self.fp16: with amp.scale_loss(self.loss_gen_total, [self.gen_opt, self.id_opt]) as scaled_loss: scaled_loss.backward() self.gen_opt.step() self.id_opt.step() else: self.loss_gen_total.backward() #计算梯度 self.gen_opt.step() #梯度更新 self.id_opt.step() #梯度更新 print("L_total: %.4f, L_gan: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \ hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ hyperparameters['id_w'] * self.loss_id,\ hyperparameters['pid_w'] * self.loss_pid,\ hyperparameters['teacher_w'] * self.loss_teacher ) ) def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def PCB_loss(self, inputs, labels): loss = 0.0 for part in inputs: loss += self.id_criterion(part, labels) return loss / len(inputs) def sample(self, x_a, x_b): self.eval() x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], [] for i in range(x_a.size(0)): s_a = self.gen_a.encode(self.single(x_a[i].unsqueeze(0))) s_b = self.gen_b.encode(self.single(x_b[i].unsqueeze(0))) f_a, _ = self.id_a(scale2(x_a[i].unsqueeze(0))) f_b, _ = self.id_b(scale2(x_b[i].unsqueeze(0))) x_a_recon.append(self.gen_a.decode(s_a, f_a)) x_b_recon.append(self.gen_b.decode(s_b, f_b)) x_ba = self.gen_a.decode(s_b, f_a) x_ab = self.gen_b.decode(s_a, f_b) x_ba1.append(x_ba) x_ab1.append(x_ab) #cycle s_b_recon = self.gen_a.enc_content(self.single(x_ba)) s_a_recon = self.gen_b.enc_content(self.single(x_ab)) f_a_recon, _ = self.id_a(scale2(x_ba)) f_b_recon, _ = self.id_b(scale2(x_ab)) x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon)) x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1) self.train() return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1 def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters, num_gpu): self.dis_opt.zero_grad() #梯度清零 # D loss #计算判别器的损失函数,然后计算梯度,进行梯度更新 #输入为(x_ba,x_a),(x_ab,x_b)两对图片,损失为两对图片的总和 if num_gpu > 1: self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss( self.dis_a, x_ba.detach(), x_a) self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss( self.dis_b, x_ab.detach(), x_b) else: # 计算判别器的损失函数 self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss( self.dis_a, x_ba.detach(), x_a) self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss( self.dis_b, x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b print("DLoss: %.4f" % self.loss_dis_total, "Reg: %.4f" % (reg_a + reg_b)) if self.fp16: with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: scaled_loss.backward() else: self.loss_dis_total.backward() #计算梯度 self.dis_opt.step() #梯度更新 def update_learning_rate(self): #对学习率的更新 if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() if self.id_scheduler is not None: self.id_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b = self.gen_a iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b = self.dis_a # Load ID dis last_model_name = get_model_list(checkpoint_dir, "id") state_dict = torch.load(last_model_name) self.id_a.load_state_dict(state_dict['a']) self.id_b = self.id_a # Load optimizers try: state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.id_opt.load_state_dict(state_dict['id']) except: pass # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations, num_gpu=1): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict()}, gen_name) if num_gpu > 1: torch.save({'a': self.dis_a.module.state_dict()}, dis_name) else: torch.save({'a': self.dis_a.state_dict()}, dis_name) torch.save({'a': self.id_a.state_dict()}, id_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'id': self.id_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)