def _init_models(self): self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) e_base_model = create(self.opt.arch, cut_at_pooling=True) e_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=2) self.net_E = SiameseNet(e_base_model, e_embed_model) di_base_model = create(self.opt.arch, cut_at_pooling=True) di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1) self.net_Di = SiameseNet(di_base_model, di_embed_model) self.net_Dp = NLayerDiscriminator(3+18, norm_layer=self.norm_layer) if self.opt.stage==1: init_weights(self.net_G) init_weights(self.net_Dp) state_dict = remove_module_key(torch.load(self.opt.netE_pretrain)) self.net_E.load_state_dict(state_dict) state_dict['embed_model.classifier.weight'] = state_dict['embed_model.classifier.weight'][1] state_dict['embed_model.classifier.bias'] = torch.FloatTensor([state_dict['embed_model.classifier.bias'][1]]) self.net_Di.load_state_dict(state_dict) elif self.opt.stage==2: self._load_state_dict(self.net_E, self.opt.netE_pretrain) self._load_state_dict(self.net_G, self.opt.netG_pretrain) self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) else: assert('unknown training stage') self.net_E = torch.nn.DataParallel(self.net_E).cuda() self.net_G = torch.nn.DataParallel(self.net_G).cuda() self.net_Di = torch.nn.DataParallel(self.net_Di).cuda() self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()
def load_model(self, model_arch='resnet50'): # Create model # # 利用resnet50建立base_model base_model = models.create(model_arch, cut_at_pooling=True) # # 建立嵌入模型 embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=2) model = SiameseNet(base_model, embed_model) model = nn.DataParallel(model).cuda() checkpoint = load_checkpoint(self.model_path) if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] model.load_state_dict(checkpoint) return torch.nn.DataParallel(base_model).cuda()
def _init_models(self): #self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, 0, pose_nc=self.pose_size, dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) if (self.opt.emb_type == 'Single'): self.net_E = SingleNet(self.opt.arch, self.emb_size, pretraind=True, use_bn=True, test_bn=False, last_stride=self.opt.last_stride) elif (self.opt.emb_type == 'Siamese'): self.net_E = SiameseNet(self.opt.arch, self.emb_size, pretraind=True, use_bn=True, test_bn=False, last_stride=self.opt.last_stride) else: raise ValueError('unrecognized model') self.net_Di = SingleNet('resnet18', 1, pretraind=True, use_bn=True, test_bn=False, last_stride=2) self.net_Dp = NLayerDiscriminator(3+self.pose_size, norm_layer=self.norm_layer) if self.opt.stage==0: # This is for training end-to-end init_weights(self.net_G) init_weights(self.net_Dp) elif self.opt.stage==1: # This is for training fixing a baseline model init_weights(self.net_G) init_weights(self.net_Dp) checkpoint = load_checkpoint(self.opt.netE_pretrain) if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] state_dict = remove_module_key(checkpoint) self.net_E.load_state_dict(state_dict) #state_dict['classifier.weight'] = state_dict['classifier.weight'][1:2] #state_dict['classifier.bias'] = torch.FloatTensor([state_dict['classifier.bias'][1]]) #self.net_Di.load_state_dict(state_dict) elif self.opt.stage==2: # This is for training with a provided model self._load_state_dict(self.net_E, self.opt.netE_pretrain) self._load_state_dict(self.net_G, self.opt.netG_pretrain) self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) else: raise ValueError('unrecognized mode') self.net_E = torch.nn.DataParallel(self.net_E).cuda() self.net_G = torch.nn.DataParallel(self.net_G).cuda() self.net_Di = torch.nn.DataParallel(self.net_Di).cuda() self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()
def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) cudnn.benchmark = True # Redirect print to both console and log file if not args.evaluate: sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) else: log_dir = osp.dirname(args.resume) sys.stdout = Logger(osp.join(log_dir, 'log_test.txt')) # print("==========\nArgs:{}\n==========".format(args)) # Create data loaders if args.height is None or args.width is None: args.height, args.width = (256, 128) dataset, train_loader, val_loader, test_loader = \ get_data(args.dataset, args.split, args.data_dir, args.height, args.width, args.batch_size, args.workers, args.combine_trainval, args.np_ratio) # Create model base_model = models.create(args.arch, cut_at_pooling=True) embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=2) model = SiameseNet(base_model, embed_model) model = nn.DataParallel(model).cuda() # Evaluator evaluator = CascadeEvaluator( torch.nn.DataParallel(base_model).cuda(), embed_model, embed_dist_fn=lambda x: F.softmax(Variable(x), dim=1).data[:, 0]) # Load from checkpoint best_mAP = 0 if args.resume: checkpoint = load_checkpoint(args.resume) if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] model.load_state_dict(checkpoint) print("Test the loaded model:") top1, mAP = evaluator.evaluate(test_loader, dataset.query, dataset.gallery, rerank_topk=100, dataset=args.dataset) best_mAP = mAP if args.evaluate: return # Criterion criterion = nn.CrossEntropyLoss().cuda() # Optimizer param_groups = [ {'params': model.module.base_model.parameters(), 'lr_mult': 1.0}, {'params': model.module.embed_model.parameters(), 'lr_mult': 1.0}] optimizer = torch.optim.SGD(param_groups, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # Trainer trainer = SiameseTrainer(model, criterion) # Schedule learning rate def adjust_lr(epoch): lr = args.lr * (0.1 ** (epoch // args.step_size)) for g in optimizer.param_groups: g['lr'] = lr * g.get('lr_mult', 1) # Start training for epoch in range(0, args.epochs): adjust_lr(epoch) trainer.train(epoch, train_loader, optimizer, base_lr=args.lr) if epoch % args.eval_step==0: mAP = evaluator.evaluate(val_loader, dataset.val, dataset.val, top1=False) is_best = mAP > best_mAP best_mAP = max(mAP, best_mAP) save_checkpoint({ 'state_dict': model.state_dict() }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) print('\n * Finished epoch {:3d} mAP: {:5.1%} best: {:5.1%}{}\n'. format(epoch, mAP, best_mAP, ' *' if is_best else '')) # Final test print('Test with best model:') checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) model.load_state_dict(checkpoint['state_dict']) evaluator.evaluate(test_loader, dataset.query, dataset.gallery, dataset=args.dataset)
class ST_ReIDNet(object): def __init__(self, opt, emb_size): self.opt = opt self.save_dir = os.path.join(opt.checkpoints, opt.name) self.norm_layer = get_norm_layer(norm_type=opt.norm) self.emb_size = emb_size if (self.opt.dataset == 'DanceReID'): self.pose_size = 17 else: self.pose_size = 18 self._init_models() self._init_losses() self._init_optimizers() print('---------- Networks initialized -------------') #print_network(self.net_E) #print_network(self.net_G) #print_network(self.net_Di) #print_network(self.net_Dp) #print('-----------------------------------------------') def _init_models(self): #self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, 0, pose_nc=self.pose_size, dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) if (self.opt.emb_type == 'Single'): self.net_E = SingleNet(self.opt.arch, self.emb_size, pretraind=True, use_bn=True, test_bn=False, last_stride=self.opt.last_stride) elif (self.opt.emb_type == 'Siamese'): self.net_E = SiameseNet(self.opt.arch, self.emb_size, pretraind=True, use_bn=True, test_bn=False, last_stride=self.opt.last_stride) else: raise ValueError('unrecognized model') self.net_Di = SingleNet('resnet18', 1, pretraind=True, use_bn=True, test_bn=False, last_stride=2) self.net_Dp = NLayerDiscriminator(3+self.pose_size, norm_layer=self.norm_layer) if self.opt.stage==0: # This is for training end-to-end init_weights(self.net_G) init_weights(self.net_Dp) elif self.opt.stage==1: # This is for training fixing a baseline model init_weights(self.net_G) init_weights(self.net_Dp) checkpoint = load_checkpoint(self.opt.netE_pretrain) if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] state_dict = remove_module_key(checkpoint) self.net_E.load_state_dict(state_dict) #state_dict['classifier.weight'] = state_dict['classifier.weight'][1:2] #state_dict['classifier.bias'] = torch.FloatTensor([state_dict['classifier.bias'][1]]) #self.net_Di.load_state_dict(state_dict) elif self.opt.stage==2: # This is for training with a provided model self._load_state_dict(self.net_E, self.opt.netE_pretrain) self._load_state_dict(self.net_G, self.opt.netG_pretrain) self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) else: raise ValueError('unrecognized mode') self.net_E = torch.nn.DataParallel(self.net_E).cuda() self.net_G = torch.nn.DataParallel(self.net_G).cuda() self.net_Di = torch.nn.DataParallel(self.net_Di).cuda() self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda() def reset_model_status(self): if self.opt.stage==0: self.net_E.train() self.net_G.train() self.net_Di.train() self.net_Dp.train() self.net_Di.apply(set_bn_fix) elif self.opt.stage==1: self.net_G.train() self.net_Dp.train() self.net_E.eval() self.net_Di.train() self.net_Di.apply(set_bn_fix) elif self.opt.stage==2: self.net_E.train() self.net_G.train() self.net_Di.train() self.net_Dp.train() self.net_E.apply(set_bn_fix) self.net_Di.apply(set_bn_fix) def _load_state_dict(self, net, path): state_dict = remove_module_key(torch.load(path)) net.load_state_dict(state_dict) def _init_losses(self): if self.opt.smooth_label: self.criterionGAN_D = GANLoss(smooth=True).cuda() self.rand_list = [True] * 1 + [False] * 10000 else: self.criterionGAN_D = GANLoss(smooth=False).cuda() self.rand_list = [False] self.criterionGAN_G = GANLoss(smooth=False).cuda() if self.opt.soft_margin: self.tri_criterion = TripletLoss(margin='soft', batch_hard=True, distractor=True).cuda() else: self.tri_criterion = TripletLoss(margin=self.opt.margin, batch_hard=True, distractor=True).cuda if (self.opt.emb_type == 'Single'): if (self.opt.emb_smooth): self.class_criterion = CrossEntropyLabelSmooth(self.emb_size, epsilon=0.1).cuda() else: self.class_criterion = torch.nn.CrossEntropyLoss().cuda() elif (self.opt.emb_type == 'Siamese'): self.class_criterion = torch.nn.CrossEntropyLoss().cuda() if self.opt.mask: self.reco_criterion = MaskedL1loss(use_mask=True).cuda() else: self.reco_criterion = MaskedL1loss(use_mask=False).cuda() def _init_optimizers(self): if self.opt.stage==0: param_groups = [{'params': self.net_E.module.base_model.parameters(), 'lr_mult': 0.1}, {'params': self.net_E.module.classifier.parameters(), 'lr_mult': 1.0}] self.optimizer_E = torch.optim.SGD(param_groups, lr=self.opt.lr, momentum=0.9, weight_decay=5e-4) #self.optimizer_E = torch.optim.Adam(param_groups, lr=self.opt.lr, betas=(0.9, 0.999), weight_decay=5e-4) self.optimizer_G = torch.optim.Adam(self.net_G.parameters(), lr=1e-5, betas=(0.5, 0.999)) self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(), lr=4e-5, momentum=0.9, weight_decay=1e-4) self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(), lr=4e-5, momentum=0.9, weight_decay=1e-4) elif self.opt.stage==1: param_groups = [{'params': self.net_E.module.base_model.parameters(), 'lr_mult': 0.1}, {'params': self.net_E.module.classifier.parameters(), 'lr_mult': 0.1}] self.optimizer_E = torch.optim.SGD(param_groups, lr=self.opt.lr, momentum=0.9, weight_decay=5e-4) self.optimizer_G = torch.optim.Adam(self.net_G.parameters(), lr=self.opt.lr*0.1, betas=(0.5, 0.999)) self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(), lr=self.opt.lr, momentum=0.9, weight_decay=1e-4) self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(), lr=self.opt.lr, momentum=0.9, weight_decay=1e-4) elif self.opt.stage==2: param_groups = [{'params': self.net_E.module.base_model.parameters(), 'lr_mult': 0.01}, {'params': self.net_E.module.classifier.parameters(), 'lr_mult': 0.1}] self.optimizer_E = torch.optim.SGD(param_groups, lr=self.opt.lr, momentum=0.9, weight_decay=5e-4) self.optimizer_G = torch.optim.Adam(self.net_G.parameters(), lr=1e-6, betas=(0.5, 0.999)) self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(), lr=1e-5, momentum=0.9, weight_decay=1e-4) self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(), lr=1e-5, momentum=0.9, weight_decay=1e-4) self.schedulers = [] self.optimizers = [] self.optimizers.append(self.optimizer_E) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_Di) self.optimizers.append(self.optimizer_Dp) for optimizer in self.optimizers: self.schedulers.append(get_scheduler(optimizer, self.opt)) def set_input(self, input): input_ach, input_pos, input_neg = input ids = torch.cat([input_ach['pid'], input_pos['pid'], input_neg['pid']]).long() ones = torch.ones_like(input_ach['pid']).fill_(1.0) zeros = torch.ones_like(input_ach['pid']).fill_(0.0) labels = torch.cat([ones, zeros]).long() noise = torch.randn(labels.size(0), self.opt.noise_feature_size) noise = torch.cat((noise, noise, noise)) self.ach_gt = input_ach['origin'].cuda() self.pos_gt = input_pos['origin'].cuda() self.neg_gt = input_neg['origin'].cuda() self.ach_img = input_ach['input'].cuda() self.pos_img = input_pos['input'].cuda() self.neg_img = input_neg['input'].cuda() self.ach_pose = input_ach['posemap'].cuda() self.pos_pose = input_pos['posemap'].cuda() self.neg_pose = input_neg['posemap'].cuda() self.ach_mask = input_ach['mask'].cuda() self.pos_mask = input_pos['mask'].cuda() self.neg_mask = input_neg['mask'].cuda() self.gt = torch.cat( [self.ach_gt, self.pos_gt, self.neg_gt]) self.mask = torch.cat( [self.ach_mask, self.pos_mask, self.neg_mask]) self.posemap = torch.cat( [self.ach_pose, self.pos_pose, self.neg_pose]) self.two_gt = torch.cat([self.ach_gt, self.pos_gt]) self.two_mask = torch.cat( [self.ach_mask, self.pos_mask]) self.pos_posemap = torch.cat([self.ach_pose, self.pos_pose]) self.swap_posemap = torch.cat([self.pos_pose, self.ach_pose]) self.ids = ids.cuda() self.labels = labels.cuda() self.noise = noise.cuda() def forward(self): z = Variable(self.noise) if (self.opt.emb_type == 'Single'): if (self.opt.stage == 1): self.A_ach = self.net_E(Variable(self.ach_img)) self.A_pos = self.net_E(Variable(self.pos_img)) self.A_neg = self.net_E(Variable(self.neg_img)) else: self.A_ach, pred_ach = self.net_E(Variable(self.ach_img)) self.A_pos, pred_pos = self.net_E(Variable(self.pos_img)) self.A_neg, pred_neg = self.net_E(Variable(self.neg_img)) self.id_pred = torch.cat( [pred_ach, pred_pos, pred_neg]) elif (self.opt.emb_type == 'Siamese'): self.A_ach, self.A_pos, pos_pred = self.net_E(torch.cat([Variable(self.ach_img), Variable(self.pos_img)])) _, self.A_neg, neg_pred = self.net_E(torch.cat([Variable(self.ach_img), Variable(self.neg_img)])) self.id_pred = torch.cat( [pos_pred, neg_pred]) self.fake_ach = self.net_G(Variable(self.ach_pose), self.A_ach.view(self.A_ach.size(0), self.A_ach.size(1), 1, 1), None) self.fake_pos = self.net_G(Variable(self.pos_pose), self.A_pos.view(self.A_pos.size(0), self.A_pos.size(1), 1, 1), None) self.fake_neg = self.net_G(Variable(self.neg_pose), self.A_neg.view(self.A_neg.size(0), self.A_neg.size(1), 1, 1), None) self.swap_ach = self.net_G(Variable(self.ach_pose), self.A_pos.view(self.A_pos.size(0), self.A_pos.size(1), 1, 1), None) self.swap_pos = self.net_G(Variable(self.pos_pose), self.A_ach.view(self.A_ach.size(0), self.A_ach.size(1), 1, 1), None) self.fake = torch.cat((self.fake_ach, self.fake_pos, self.fake_neg)) self.swap_fake = torch.cat((self.swap_ach, self.swap_pos)) def backward_Dp(self): real_pose = torch.cat((Variable(self.posemap), Variable(self.gt)), dim=1) fake_pose1 = torch.cat((Variable(self.posemap), self.fake.detach()), dim=1) fake_pose2 = torch.cat((Variable(self.pos_posemap), self.swap_fake.detach()), dim=1) fake_pose3 = torch.cat((Variable(self.swap_posemap), Variable(self.two_gt)), dim=1) pred_real = self.net_Dp(real_pose) pred_fake = self.net_Dp(torch.cat([fake_pose1, fake_pose2, fake_pose3])) #print(pred_real.size()) #print(pred_fake.size()) if random.choice(self.rand_list): loss_D_real = self.criterionGAN_D(pred_fake, True) loss_D_fake = self.criterionGAN_D(pred_real, False) else: loss_D_real = self.criterionGAN_D(pred_real, True) loss_D_fake = self.criterionGAN_D(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() #self.loss_Dp = loss_D.data[0] self.loss_Dp = loss_D.item() def backward_Di(self): _, pred_real = self.net_Di(Variable(self.gt)) _, pred_fake1 = self.net_Di(self.fake.detach()) _, pred_fake2 = self.net_Di(self.swap_fake.detach()) pred_fake = torch.cat ([pred_fake1, pred_fake2]) if random.choice(self.rand_list): loss_D_real = self.criterionGAN_D(pred_fake, True) loss_D_fake = self.criterionGAN_D(pred_real, False) else: loss_D_real = self.criterionGAN_D(pred_real, True) loss_D_fake = self.criterionGAN_D(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() #self.loss_Di = loss_D.data[0] self.loss_Di = loss_D.item() def backward_G(self): loss_r1 = self.reco_criterion(self.fake, Variable(self.mask), Variable(self.gt)) loss_r2 = self.reco_criterion(self.swap_fake, Variable(self.two_mask), Variable(self.two_gt)) loss_r = 0.7 * loss_r2 + 0.3 * loss_r1 _, pred_fake_Di = self.net_Di(torch.cat([self.fake, self.swap_fake])) fake_pose1 = torch.cat((Variable(self.posemap), self.fake), dim=1) fake_pose2 = torch.cat((Variable(self.pos_posemap), self.swap_fake), dim=1) fake_pose3 = torch.cat((Variable(self.swap_posemap), Variable(self.two_gt)), dim=1) pred_fake_Dp = self.net_Dp(torch.cat([fake_pose1, fake_pose2, fake_pose3])) loss_G_GAN_Di = self.criterionGAN_G(pred_fake_Di, True) loss_G_GAN_Dp = self.criterionGAN_G(pred_fake_Dp, True) loss_G = loss_G_GAN_Di * self.opt.lambda_d + \ loss_G_GAN_Dp * self.opt.lambda_dp + \ loss_r * self.opt.lambda_recon loss_G.backward(retain_graph=True) # Compute triplet loss feat = torch.cat([self.A_ach, self.A_pos, self.A_neg]) loss_t, _ , _ = self.tri_criterion(feat, torch.squeeze(self.ids, 1)) # Classification loss if (self.opt.stage==1): loss_c = torch.tensor(0.0) else: if (self.opt.emb_type == 'Single'): loss_c = self.class_criterion(self.id_pred, torch.squeeze(self.ids, 1)) elif (self.opt.emb_type == 'Siamese'): loss_c = self.class_criterion(self.id_pred, torch.squeeze(self.labels, 1)) loss_E = loss_G + \ loss_t * self.opt.lambda_tri+ \ loss_c * self.opt.lambda_class loss_E.backward() self.loss_E = loss_E.item() self.loss_t = loss_t.item() self.loss_c = loss_c.item() self.loss_G = loss_G.item() self.loss_r = loss_r.item() self.loss_G_GAN_Di = loss_G_GAN_Di.item() self.loss_G_GAN_Dp = loss_G_GAN_Dp.item() def optimize_parameters(self): self.forward() self.optimizer_Di.zero_grad() self.backward_Di() self.optimizer_Di.step() self.optimizer_Dp.zero_grad() self.backward_Dp() self.optimizer_Dp.step() self.optimizer_G.zero_grad() self.optimizer_E.zero_grad() self.backward_G() self.optimizer_G.step() if self.opt.stage!=1: self.optimizer_E.step() def get_current_errors(self): return OrderedDict([('E_t', self.loss_t), ('E_c', self.loss_c), ('E', self.loss_E), ('G_r', self.loss_r), ('G_gan_Di', self.loss_G_GAN_Di), ('G_gan_Dp', self.loss_G_GAN_Dp), ('G', self.loss_G), ('D_i', self.loss_Di), ('D_p', self.loss_Dp) ]) def get_current_visuals(self): ach = util.tensor2im(self.ach_img) ach_map = self.ach_pose.sum(1) ach_map[ach_map>1]=1 ach_pose = util.tensor2im(torch.unsqueeze(ach_map,1)) ach_mask = util.tensor2im(torch.unsqueeze(self.ach_mask,1)) ach_fake1 = util.tensor2im(self.fake_ach.data) ach_fake2 = util.tensor2im(self.swap_ach.data) ########### pos = util.tensor2im(self.pos_img) pos_map = self.pos_pose.sum(1) pos_map[pos_map>1]=1 pos_pose = util.tensor2im(torch.unsqueeze(pos_map,1)) pos_mask = util.tensor2im(torch.unsqueeze(self.pos_mask,1)) pos_fake1 = util.tensor2im(self.fake_pos.data) pos_fake2 = util.tensor2im(self.swap_pos.data) return OrderedDict([('ach', ach), ('ach_pose', ach_pose), ('ach_mask', ach_mask), ('ach_fake1', ach_fake1), ('ach_fake2', ach_fake2), ('pos', pos), ('pos_pose', pos_pose), ('pos_mask', pos_mask), ('pos_fake1', pos_fake1), ('pos_fake2', pos_fake2)]) def save(self, epoch): self.save_network(self.net_E, 'E', epoch) self.save_network(self.net_G, 'G', epoch) self.save_network(self.net_Di, 'Di', epoch) self.save_network(self.net_Dp, 'Dp', epoch) def save_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(network.state_dict(), save_path) def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr']
class FDGANModel(object): def __init__(self, opt): self.opt = opt self.save_dir = os.path.join(opt.checkpoints, opt.name) self.norm_layer = get_norm_layer(norm_type=opt.norm) self._init_models() self._init_losses() self._init_optimizers() print('---------- Networks initialized -------------') print_network(self.net_E) print_network(self.net_G) print_network(self.net_Di) print_network(self.net_Dp) print('-----------------------------------------------') def _init_models(self): self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) e_base_model = create(self.opt.arch, cut_at_pooling=True) e_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=2) self.net_E = SiameseNet(e_base_model, e_embed_model) di_base_model = create(self.opt.arch, cut_at_pooling=True) di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1) self.net_Di = SiameseNet(di_base_model, di_embed_model) self.net_Dp = NLayerDiscriminator(3+18, norm_layer=self.norm_layer) if self.opt.stage==1: init_weights(self.net_G) init_weights(self.net_Dp) state_dict = remove_module_key(torch.load(self.opt.netE_pretrain)) self.net_E.load_state_dict(state_dict) state_dict['embed_model.classifier.weight'] = state_dict['embed_model.classifier.weight'][1] state_dict['embed_model.classifier.bias'] = torch.FloatTensor([state_dict['embed_model.classifier.bias'][1]]) self.net_Di.load_state_dict(state_dict) elif self.opt.stage==2: self._load_state_dict(self.net_E, self.opt.netE_pretrain) self._load_state_dict(self.net_G, self.opt.netG_pretrain) self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) else: assert('unknown training stage') self.net_E = torch.nn.DataParallel(self.net_E).cuda() self.net_G = torch.nn.DataParallel(self.net_G).cuda() self.net_Di = torch.nn.DataParallel(self.net_Di).cuda() self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda() def reset_model_status(self): if self.opt.stage==1: self.net_G.train() self.net_Dp.train() self.net_E.eval() self.net_Di.train() self.net_Di.apply(set_bn_fix) elif self.opt.stage==2: self.net_E.train() self.net_G.train() self.net_Di.train() self.net_Dp.train() self.net_E.apply(set_bn_fix) self.net_Di.apply(set_bn_fix) def _load_state_dict(self, net, path): state_dict = remove_module_key(torch.load(path)) net.load_state_dict(state_dict) def _init_losses(self): if self.opt.smooth_label: self.criterionGAN_D = GANLoss(smooth=True).cuda() self.rand_list = [True] * 1 + [False] * 10000 else: self.criterionGAN_D = GANLoss(smooth=False).cuda() self.rand_list = [False] self.criterionGAN_G = GANLoss(smooth=False).cuda() def _init_optimizers(self): if self.opt.stage==1: self.optimizer_G = torch.optim.Adam(self.net_G.parameters(), lr=self.opt.lr*0.1, betas=(0.5, 0.999)) self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(), lr=self.opt.lr*0.01, momentum=0.9, weight_decay=1e-4) self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(), lr=self.opt.lr, momentum=0.9, weight_decay=1e-4) elif self.opt.stage==2: param_groups = [{'params': self.net_E.module.base_model.parameters(), 'lr_mult': 0.1}, {'params': self.net_E.module.embed_model.parameters(), 'lr_mult': 1.0}, {'params': self.net_G.parameters(), 'lr_mult': 0.1}] self.optimizer_G = torch.optim.Adam(param_groups, lr=self.opt.lr*0.1, betas=(0.5, 0.999)) self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(), lr=self.opt.lr, momentum=0.9, weight_decay=1e-4) self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(), lr=self.opt.lr, momentum=0.9, weight_decay=1e-4) self.schedulers = [] self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_Di) self.optimizers.append(self.optimizer_Dp) for optimizer in self.optimizers: self.schedulers.append(get_scheduler(optimizer, self.opt)) def set_input(self, input): input1, input2 = input labels = (input1['pid']==input2['pid']).long() noise = torch.randn(labels.size(0), self.opt.noise_feature_size) # keep the same pose map for persons with the same identity mask = labels.view(-1,1,1,1).expand_as(input1['posemap']) input2['posemap'] = input1['posemap']*mask.float() + input2['posemap']*(1-mask.float()) mask = labels.view(-1,1,1,1).expand_as(input1['target']) input2['target'] = input1['target']*mask.float() + input2['target']*(1-mask.float()) origin = torch.cat([input1['origin'], input2['origin']]) target = torch.cat([input1['target'], input2['target']]) posemap = torch.cat([input1['posemap'], input2['posemap']]) noise = torch.cat((noise, noise)) self.origin = origin.cuda() self.target = target.cuda() self.posemap = posemap.cuda() self.labels = labels.cuda() self.noise = noise.cuda() def forward(self): A = Variable(self.origin) B_map = Variable(self.posemap) z = Variable(self.noise) bs = A.size(0) A_id1, A_id2, self.id_score = self.net_E(A[:bs//2], A[bs//2:]) A_id = torch.cat((A_id1, A_id2)) self.fake = self.net_G(B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1), z.view(z.size(0), z.size(1), 1, 1)) def backward_Dp(self): real_pose = torch.cat((Variable(self.posemap), Variable(self.target)),dim=1) fake_pose = torch.cat((Variable(self.posemap), self.fake.detach()),dim=1) pred_real = self.net_Dp(real_pose) pred_fake = self.net_Dp(fake_pose) if random.choice(self.rand_list): loss_D_real = self.criterionGAN_D(pred_fake, True) loss_D_fake = self.criterionGAN_D(pred_real, False) else: loss_D_real = self.criterionGAN_D(pred_real, True) loss_D_fake = self.criterionGAN_D(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() self.loss_Dp = loss_D.data[0] def backward_Di(self): _, _, pred_real = self.net_Di(Variable(self.origin), Variable(self.target)) _, _, pred_fake = self.net_Di(Variable(self.origin), self.fake.detach()) if random.choice(self.rand_list): loss_D_real = self.criterionGAN_D(pred_fake, True) loss_D_fake = self.criterionGAN_D(pred_real, False) else: loss_D_real = self.criterionGAN_D(pred_real, True) loss_D_fake = self.criterionGAN_D(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() self.loss_Di = loss_D.data[0] def backward_G(self): loss_v = F.cross_entropy(self.id_score, Variable(self.labels).view(-1)) loss_r = F.l1_loss(self.fake, Variable(self.target)) fake_1 = self.fake[:self.fake.size(0)//2] fake_2 = self.fake[self.fake.size(0)//2:] loss_sp = F.l1_loss(fake_1[self.labels.view(self.labels.size(0),1,1,1).expand_as(fake_1)==1], fake_2[self.labels.view(self.labels.size(0),1,1,1).expand_as(fake_1)==1]) _, _, pred_fake_Di = self.net_Di(Variable(self.origin), self.fake) pred_fake_Dp = self.net_Dp(torch.cat((Variable(self.posemap),self.fake),dim=1)) loss_G_GAN_Di = self.criterionGAN_G(pred_fake_Di, True) loss_G_GAN_Dp = self.criterionGAN_G(pred_fake_Dp, True) loss_G = loss_G_GAN_Di + loss_G_GAN_Dp + \ loss_r * self.opt.lambda_recon + \ loss_v * self.opt.lambda_veri + \ loss_sp * self.opt.lambda_sp loss_G.backward() del self.id_score self.loss_G = loss_G.data[0] self.loss_v = loss_v.data[0] self.loss_sp = loss_sp.data[0] self.loss_r = loss_r.data[0] self.loss_G_GAN_Di = loss_G_GAN_Di.data[0] self.loss_G_GAN_Dp = loss_G_GAN_Dp.data[0] self.fake = self.fake.data def optimize_parameters(self): self.forward() self.optimizer_Di.zero_grad() self.backward_Di() self.optimizer_Di.step() self.optimizer_Dp.zero_grad() self.backward_Dp() self.optimizer_Dp.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def get_current_errors(self): return OrderedDict([('G_v', self.loss_v), ('G_r', self.loss_r), ('G_sp', self.loss_sp), ('G_gan_Di', self.loss_G_GAN_Di), ('G_gan_Dp', self.loss_G_GAN_Dp), ('D_i', self.loss_Di), ('D_p', self.loss_Dp) ]) def get_current_visuals(self): input = util.tensor2im(self.origin) target = util.tensor2im(self.target) fake = util.tensor2im(self.fake) map = self.posemap.sum(1) map[map>1]=1 map = util.tensor2im(torch.unsqueeze(map,1)) return OrderedDict([('input', input), ('posemap', map), ('fake', fake), ('target', target)]) def save(self, epoch): self.save_network(self.net_E, 'E', epoch) self.save_network(self.net_G, 'G', epoch) self.save_network(self.net_Di, 'Di', epoch) self.save_network(self.net_Dp, 'Dp', epoch) def save_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(network.state_dict(), save_path) def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr']
def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) cudnn.benchmark = True # Redirect print to both console and log file if not args.evaluate: sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) else: log_dir = osp.dirname(args.resume) sys.stdout = Logger(osp.join(log_dir, 'log_test.txt')) # print("==========\nArgs:{}\n==========".format(args)) # Create data loaders if args.height is None or args.width is None: args.height, args.width = (256, 128) dataset, train_loader, val_loader, test_loader = \ get_data(args.dataset, args.split, args.data_dir, args.height, args.width, args.batch_size, args.workers, args.combine_trainval, args.np_ratio, args.emb_type, args.inst_mode, args.eraser) if args.combine_trainval: emb_size = dataset.num_trainval_ids else: emb_size = dataset.num_train_ids # Create model if (args.emb_type == 'Single'): model = SingleNet(args.arch, emb_size, pretraind=True, use_bn=args.use_bn, test_bn=args.test_bn, last_stride=args.last_stride) elif (args.emb_type == 'Siamese'): model = SiameseNet(args.arch, emb_size, pretraind=True, use_bn=args.use_bn, test_bn=args.test_bn, last_stride=args.last_stride) else: raise ValueError('unrecognized model') model = nn.DataParallel(model).cuda() if args.resume: checkpoint = load_checkpoint(args.resume) if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] model.load_state_dict(checkpoint) # Evaluator evaluator = CascadeEvaluator(torch.nn.DataParallel(model).cuda(), emb_size=emb_size) # Load from checkpoint best_mAP = 0 if args.resume: print("Test the loaded model:") top1, mAP = evaluator.evaluate(test_loader, dataset.query, dataset.gallery, dataset=args.dataset) best_mAP = mAP if args.evaluate: return # Criterion if args.soft_margin: tri_criterion = TripletLoss(margin='soft').cuda() else: tri_criterion = TripletLoss(margin=args.margin).cuda() if (args.emb_type == 'Single'): if args.label_smoothing: cla_criterion = CrossEntropyLabelSmooth(emb_size, epsilon=0.1).cuda() else: cla_criterion = torch.nn.CrossEntropyLoss().cuda() elif (args.emb_type == 'Siamese'): cla_criterion = torch.nn.CrossEntropyLoss().cuda() # Optimizer param_groups = [{ 'params': model.module.base_model.parameters(), 'lr_mult': 0.1 }, { 'params': model.module.classifier.parameters(), 'lr_mult': 1.0 }] if (args.opt_name == 'SGD'): optimizer = getattr(torch.optim, args.opt_name)(param_groups, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) else: optimizer = getattr(torch.optim, args.opt_name)(param_groups, lr=args.lr, weight_decay=args.weight_decay) # Trainer if (args.emb_type == 'Single'): trainer = TripletTrainer(model, tri_criterion, cla_criterion, args.lambda_tri, args.lambda_cla) elif (args.emb_type == 'Siamese'): trainer = SiameseTrainer(model, tri_criterion, cla_criterion, args.lambda_tri, args.lambda_cla) #TODO:Warmup lr # Schedule learning rate def adjust_lr(epoch): lr = args.lr * (0.1**(epoch // args.step_size)) for g in optimizer.param_groups: g['lr'] = lr * g.get('lr_mult', 1) # Start training for epoch in range(0, args.epochs): adjust_lr(epoch) trainer.train(epoch, train_loader, optimizer, base_lr=args.lr) if epoch % args.eval_step == 0: #mAP = evaluator.evaluate(val_loader, dataset.val, dataset.val, top1=False, dataset=args.dataset) mAP = evaluator.evaluate(test_loader, dataset.query, dataset.gallery, top1=False, dataset=args.dataset) is_best = mAP > best_mAP best_mAP = max(mAP, best_mAP) save_checkpoint({'state_dict': model.state_dict()}, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) print('\n * Finished epoch {:3d} mAP: {:5.1%} best: {:5.1%}{}\n'. format(epoch, mAP, best_mAP, ' *' if is_best else '')) # Final test print('Test with best model:') checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) model.load_state_dict(checkpoint['state_dict']) evaluator.evaluate(test_loader, dataset.query, dataset.gallery, dataset=args.dataset)
def _init_models(self): #G For source self.net_G = CustomPoseGenerator( self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) #G For target self.tar_net_G = CustomPoseGenerator( self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) #We share same E for cross-dataset e_base_model = create(self.opt.arch, cut_at_pooling=True) e_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=2) self.net_E = SiameseNet(e_base_model, e_embed_model) #Di For source di_base_model = create(self.opt.arch, cut_at_pooling=True) di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1) self.net_Di = SiameseNet(di_base_model, di_embed_model) #Di For target di_base_model = create(self.opt.arch, cut_at_pooling=True) di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1) self.tar_net_Di = SiameseNet(di_base_model, di_embed_model) #We share same Dp for cross-dataset self.net_Dp = NLayerDiscriminator(3 + 18, norm_layer=self.norm_layer) #Load model if self.opt.stage == 1: self._load_state_dict(self.net_E, self.opt.netE_pretrain) self._load_state_dict(self.net_G, self.opt.netG_pretrain) self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) self._load_state_dict(self.tar_net_Di, self.opt.tar_netDi_pretrain) self._load_state_dict(self.tar_net_G, self.opt.tar_netG_pretrain) elif self.opt.stage == 2: self._load_state_dict(self.net_E, self.opt.netE_pretrain) self._load_state_dict(self.net_G, self.opt.netG_pretrain) self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) self._load_state_dict(self.tar_net_Di, self.opt.tar_netDi_pretrain) self._load_state_dict(self.tar_net_G, self.opt.tar_netG_pretrain) else: assert ('unknown training stage') self.net_E = torch.nn.DataParallel(self.net_E).cuda() self.net_G = torch.nn.DataParallel(self.net_G).cuda() self.net_Di = torch.nn.DataParallel(self.net_Di).cuda() self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda() self.tar_net_G = torch.nn.DataParallel(self.tar_net_G).cuda() self.tar_net_Di = torch.nn.DataParallel(self.tar_net_Di).cuda()
class PDANetModel(object): def __init__(self, opt): self.opt = opt self.save_dir = os.path.join(opt.checkpoints, opt.name) self.norm_layer = get_norm_layer(norm_type=opt.norm) self._init_models() self._init_losses() self._init_cross_optimizers() def _init_models(self): #G For source self.net_G = CustomPoseGenerator( self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) #G For target self.tar_net_G = CustomPoseGenerator( self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) #We share same E for cross-dataset e_base_model = create(self.opt.arch, cut_at_pooling=True) e_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=2) self.net_E = SiameseNet(e_base_model, e_embed_model) #Di For source di_base_model = create(self.opt.arch, cut_at_pooling=True) di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1) self.net_Di = SiameseNet(di_base_model, di_embed_model) #Di For target di_base_model = create(self.opt.arch, cut_at_pooling=True) di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1) self.tar_net_Di = SiameseNet(di_base_model, di_embed_model) #We share same Dp for cross-dataset self.net_Dp = NLayerDiscriminator(3 + 18, norm_layer=self.norm_layer) #Load model if self.opt.stage == 1: self._load_state_dict(self.net_E, self.opt.netE_pretrain) self._load_state_dict(self.net_G, self.opt.netG_pretrain) self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) self._load_state_dict(self.tar_net_Di, self.opt.tar_netDi_pretrain) self._load_state_dict(self.tar_net_G, self.opt.tar_netG_pretrain) elif self.opt.stage == 2: self._load_state_dict(self.net_E, self.opt.netE_pretrain) self._load_state_dict(self.net_G, self.opt.netG_pretrain) self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) self._load_state_dict(self.tar_net_Di, self.opt.tar_netDi_pretrain) self._load_state_dict(self.tar_net_G, self.opt.tar_netG_pretrain) else: assert ('unknown training stage') self.net_E = torch.nn.DataParallel(self.net_E).cuda() self.net_G = torch.nn.DataParallel(self.net_G).cuda() self.net_Di = torch.nn.DataParallel(self.net_Di).cuda() self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda() self.tar_net_G = torch.nn.DataParallel(self.tar_net_G).cuda() self.tar_net_Di = torch.nn.DataParallel(self.tar_net_Di).cuda() def reset_model_status(self): if self.opt.stage == 1: self.net_G.train() self.tar_net_G.train() self.net_Dp.train() self.net_E.eval() self.net_Di.train() self.net_Di.apply(set_bn_fix) self.tar_net_Di.train() self.tar_net_Di.apply(set_bn_fix) elif self.opt.stage == 2: self.net_E.train() self.net_G.train() self.tar_net_G.train() self.net_Di.train() self.tar_net_Di.train() self.net_Dp.train() self.net_E.apply(set_bn_fix) self.net_Di.apply(set_bn_fix) self.tar_net_Di.apply(set_bn_fix) def _load_state_dict(self, net, path): state_dict = remove_module_key(torch.load(path)) net.load_state_dict(state_dict) def _init_losses(self): self.criterion_Triplet = TripletLoss(margin=self.opt.tri_margin) self.criterion_MMD = MMDLoss(sigma_list=[1, 2, 10]) #Smooth label is a mechanism if self.opt.smooth_label: self.criterionGAN_D = GANLoss(smooth=True).cuda() self.rand_list = [True] * 1 + [False] * 10000 else: self.criterionGAN_D = GANLoss(smooth=False).cuda() self.rand_list = [False] self.criterionGAN_G = GANLoss(smooth=False).cuda() #For cross dataset optimizers def _init_cross_optimizers(self): if self.opt.stage == 1: self.optimizer_G = torch.optim.Adam(self.net_G.parameters(), lr=self.opt.lr * 0.1, betas=(0.5, 0.999)) self.optimizer_tar_G = torch.optim.Adam( self.tar_net_G.parameters(), lr=self.opt.lr * 0.1, betas=(0.5, 0.999)) self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(), lr=self.opt.lr * 0.01, momentum=0.9, weight_decay=1e-4) self.optimizer_tar_Di = torch.optim.SGD( self.tar_net_Di.parameters(), lr=self.opt.lr * 0.01, momentum=0.9, weight_decay=1e-4) self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(), lr=self.opt.lr, momentum=0.9, weight_decay=1e-4) elif self.opt.stage == 2: param_groups = [{ 'params': self.net_E.module.base_model.parameters(), 'lr_mult': 0.1 }, { 'params': self.net_E.module.embed_model.parameters(), 'lr_mult': 1.0 }, { 'params': self.net_G.parameters(), 'lr_mult': 0.1 }] param_tar_groups = [{ 'params': self.net_E.module.base_model.parameters(), 'lr_mult': 0.1 }, { 'params': self.net_E.module.embed_model.parameters(), 'lr_mult': 1.0 }, { 'params': self.tar_net_G.parameters(), 'lr_mult': 0.1 }] self.optimizer_G = torch.optim.Adam(param_groups, lr=self.opt.lr * 0.1, betas=(0.5, 0.999)) self.optimizer_tar_G = torch.optim.Adam(param_tar_groups, lr=self.opt.lr * 0.1, betas=(0.5, 0.999)) self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(), lr=self.opt.lr, momentum=0.9, weight_decay=1e-4) self.optimizer_tar_Di = torch.optim.SGD( self.tar_net_Di.parameters(), lr=self.opt.lr, momentum=0.9, weight_decay=1e-4) self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(), lr=self.opt.lr, momentum=0.9, weight_decay=1e-4) self.optimizer_E = torch.optim.Adam( self.net_E.module.base_model.parameters(), lr=self.opt.lr * 0.1, betas=(0.5, 0.999)) self.schedulers = [] self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_Di) self.optimizers.append(self.optimizer_Dp) self.optimizers.append(self.optimizer_tar_G) self.optimizers.append(self.optimizer_tar_Di) for optimizer in self.optimizers: self.schedulers.append(get_scheduler(optimizer, self.opt)) def set_input(self, input): input1, input2 = input labels = (input1['pid'] == input2['pid']).long() noise = torch.randn(labels.size(0), self.opt.noise_feature_size) # keep the same pose map for persons with the same identity mask = labels.view(-1, 1, 1, 1).expand_as(input1['posemap']) input2['posemap'] = input1['posemap'] * mask.float( ) + input2['posemap'] * (1 - mask.float()) mask = labels.view(-1, 1, 1, 1).expand_as(input1['target']) input2['target'] = input1['target'] * mask.float( ) + input2['target'] * (1 - mask.float()) origin = torch.cat([input1['origin'], input2['origin']]) target = torch.cat([input1['target'], input2['target']]) posemap = torch.cat([input1['posemap'], input2['posemap']]) noise = torch.cat((noise, noise)) self.origin = origin.cuda() self.target = target.cuda() self.posemap = posemap.cuda() self.labels = labels.cuda() self.noise = noise.cuda() #For cross-dataset def set_inputs(self, target_data, source_data): #=============source data=============================== input1, input2 = source_data labels = (input1['pid'] == input2['pid']).long() noise = torch.randn(labels.size(0), self.opt.noise_feature_size) # keep the same pose map for persons with the same identity mask = labels.view(-1, 1, 1, 1).expand_as(input1['posemap']) input2['posemap'] = input1['posemap'] * mask.float( ) + input2['posemap'] * (1 - mask.float()) mask = labels.view(-1, 1, 1, 1).expand_as(input1['target']) input2['target'] = input1['target'] * mask.float( ) + input2['target'] * (1 - mask.float()) origin = torch.cat([input1['origin'], input2['origin']]) target = torch.cat([input1['target'], input2['target']]) posemap = torch.cat([input1['posemap'], input2['posemap']]) noise = torch.cat((noise, noise)) plabels = torch.cat((input1['pid'].long(), input2['pid'].long())) # Used for triplet loss self.s_origin = origin.cuda() self.s_target = target.cuda() self.s_posemap = posemap.cuda() self.s_labels = labels.cuda() self.s_noise = noise.cuda() self.s_plabels = plabels.cuda() # Used for triplet loss #=============target data=============================== input1, input2 = target_data noise = torch.randn(input1['origin'].size(0), self.opt.noise_feature_size) origin = torch.cat([input1['origin'], input2['origin']]) target = torch.cat([input1['target'], input2['target']]) posemap = torch.cat([input1['posemap'], input2['posemap']]) noise = torch.cat((noise, noise)) self.t_origin = origin.cuda() self.t_target = target.cuda() self.t_posemap = posemap.cuda() self.t_noise = noise.cuda() #forward cross def forward_cross(self): #source A = Variable(self.s_origin) B_map = Variable(self.s_posemap) z = Variable(self.s_noise) bs = A.size(0) A_id1, A_id2, self.s_id_score = self.net_E(A[:bs // 2], A[bs // 2:]) A_id = torch.cat((A_id1, A_id2)) self.s_A_id = A_id self.s_fake = self.net_G(B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1), z.view(z.size(0), z.size(1), 1, 1)) #source to target A_id1_st = A_id1[:bs // 2] z_st = z[:bs // 2] self.st_fake = self.tar_net_G( B_map[:bs // 2], A_id1_st.view(A_id1_st.size(0), A_id1_st.size(1), 1, 1), z_st.view(z_st.size(0), z_st.size(1), 1, 1)) # self.st_fake = self.tar_net_G(B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1), z.view(z.size(0), z.size(1), 1, 1)) #target A = Variable(self.t_origin) B_map = Variable(self.t_posemap) z = Variable(self.t_noise) bs = A.size(0) A_id1, A_id2, self.t_id_score = self.net_E(A[:bs // 2], A[bs // 2:]) A_id = torch.cat((A_id1, A_id2)) self.t_A_id = A_id self.t_fake = self.tar_net_G( B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1), z.view(z.size(0), z.size(1), 1, 1)) def backward_Dp(self): real_pose = torch.cat((Variable(self.posemap), Variable(self.target)), dim=1) fake_pose = torch.cat((Variable(self.posemap), self.fake.detach()), dim=1) pred_real = self.net_Dp(real_pose) pred_fake = self.net_Dp(fake_pose) if random.choice(self.rand_list): loss_D_real = self.criterionGAN_D(pred_fake, True) loss_D_fake = self.criterionGAN_D(pred_real, False) else: loss_D_real = self.criterionGAN_D(pred_real, True) loss_D_fake = self.criterionGAN_D(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() self.loss_Dp = loss_D.data[0] def backward_Di(self): _, _, pred_real = self.net_Di(Variable(self.origin), Variable(self.target)) _, _, pred_fake = self.net_Di(Variable(self.origin), self.fake.detach()) if random.choice(self.rand_list): loss_D_real = self.criterionGAN_D(pred_fake, True) loss_D_fake = self.criterionGAN_D(pred_real, False) else: loss_D_real = self.criterionGAN_D(pred_real, True) loss_D_fake = self.criterionGAN_D(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() self.loss_Di = loss_D.data[0] def backward_E_MMD(self): loss_mmd = self.criterion_MMD(self.s_A_id, self.t_A_id) self.loss_mmd = loss_mmd.item() loss_mmd = loss_mmd * self.opt.lambda_mmd loss_mmd.backward(retain_graph=True) def backward_s_G(self): # self.s_plabels # self.criterion_Triplet # self.s_A_id # self.criterion_Triplet = TripletLoss(margin=opt.tri_margin) # self.criterion_MMD = MMDLoss(sigma_list=[1, 2, 10]) loss_tri, prec = self.criterion_Triplet(self.s_A_id, self.s_plabels) loss_v = F.cross_entropy(self.s_id_score, Variable(self.s_labels).view(-1)) loss_r = F.l1_loss(self.s_fake, Variable(self.s_target)) fake_1 = self.s_fake[:self.s_fake.size(0) // 2] fake_2 = self.s_fake[self.s_fake.size(0) // 2:] _, _, pred_fake_Di = self.net_Di(Variable(self.s_origin), self.s_fake) pred_fake_Dp = self.net_Dp( torch.cat((Variable(self.s_posemap), self.s_fake), dim=1)) loss_G_GAN_Di = self.criterionGAN_G(pred_fake_Di, True) loss_G_GAN_Dp = self.criterionGAN_G(pred_fake_Dp, True) loss_G = loss_G_GAN_Di + loss_G_GAN_Dp + \ loss_r * self.opt.lambda_recon + \ loss_v * self.opt.lambda_veri + \ loss_tri * self.opt.lambda_tri loss_G.backward() del self.s_id_score self.loss_s_G = loss_G.item() self.loss_s_v = loss_v.item() self.loss_s_r = loss_r.item() self.loss_s_G_GAN_Di = loss_G_GAN_Di.item() self.loss_s_G_GAN_Dp = loss_G_GAN_Dp.item() self.loss_s_tri = loss_tri.item() def backward_t_G(self): loss_r = F.l1_loss(self.t_fake, Variable(self.t_target)) fake_1 = self.t_fake[:self.t_fake.size(0) // 2] fake_2 = self.t_fake[self.t_fake.size(0) // 2:] _, _, pred_fake_Di = self.tar_net_Di(Variable(self.t_origin), self.t_fake) pred_fake_Dp = self.net_Dp( torch.cat((Variable(self.t_posemap), self.t_fake), dim=1)) loss_G_GAN_Di = self.criterionGAN_G(pred_fake_Di, True) loss_G_GAN_Dp = self.criterionGAN_G(pred_fake_Dp, True) loss_G = loss_G_GAN_Di + loss_G_GAN_Dp + \ loss_r * self.opt.lambda_recon loss_G.backward() del self.t_id_score self.loss_t_G = loss_G.item() self.loss_t_r = loss_r.item() self.loss_t_G_GAN_Di = loss_G_GAN_Di.item() self.loss_t_G_GAN_Dp = loss_G_GAN_Dp.item() def backward_s_Di(self): _, _, pred_real = self.net_Di(Variable(self.s_origin), Variable(self.s_target)) _, _, pred_fake = self.net_Di(Variable(self.s_origin), self.s_fake.detach()) if random.choice(self.rand_list): loss_D_real = self.criterionGAN_D(pred_fake, True) loss_D_fake = self.criterionGAN_D(pred_real, False) else: loss_D_real = self.criterionGAN_D(pred_real, True) loss_D_fake = self.criterionGAN_D(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() # self.loss_s_Di = loss_D.data[0] self.loss_s_Di = loss_D.item() def backward_t_Di(self): _, _, pred_real = self.tar_net_Di(Variable(self.t_origin), Variable(self.t_target)) _, _, pred_fake = self.tar_net_Di(Variable(self.t_origin), self.t_fake.detach()) if random.choice(self.rand_list): loss_D_real = self.criterionGAN_D(pred_fake, True) loss_D_fake = self.criterionGAN_D(pred_real, False) else: loss_D_real = self.criterionGAN_D(pred_real, True) loss_D_fake = self.criterionGAN_D(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() # self.loss_t_Di = loss_D.data[0] self.loss_t_Di = loss_D.item() def backward_cross_Dp(self): real_pose = torch.cat( (Variable(self.s_posemap), Variable(self.s_target)), dim=1) fake_pose = torch.cat((Variable(self.s_posemap), self.s_fake.detach()), dim=1) pred_real = self.net_Dp(real_pose) pred_fake = self.net_Dp(fake_pose) if random.choice(self.rand_list): loss_D_real = self.criterionGAN_D(pred_fake, True) loss_D_fake = self.criterionGAN_D(pred_real, False) else: loss_D_real = self.criterionGAN_D(pred_real, True) loss_D_fake = self.criterionGAN_D(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() # self.loss_Dp = loss_D.data[0] self.loss_Dp = loss_D.item() real_pose = torch.cat( (Variable(self.t_posemap), Variable(self.t_target)), dim=1) fake_pose = torch.cat((Variable(self.t_posemap), self.t_fake.detach()), dim=1) pred_real = self.net_Dp(real_pose) pred_fake = self.net_Dp(fake_pose) if random.choice(self.rand_list): loss_D_real = self.criterionGAN_D(pred_fake, True) loss_D_fake = self.criterionGAN_D(pred_real, False) else: loss_D_real = self.criterionGAN_D(pred_real, True) loss_D_fake = self.criterionGAN_D(pred_fake, False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() # self.loss_Dp += loss_D.data[0] self.loss_Dp += loss_D.item() def optimize_cross_parameters(self): self.forward_cross() self.optimizer_E.zero_grad() self.backward_E_MMD() self.optimizer_E.step() self.optimizer_Di.zero_grad() self.backward_s_Di() self.optimizer_Di.step() self.optimizer_Dp.zero_grad() self.backward_cross_Dp() self.optimizer_Dp.step() self.optimizer_G.zero_grad() self.backward_s_G() self.optimizer_G.step() self.optimizer_tar_Di.zero_grad() self.backward_t_Di() self.optimizer_tar_Di.step() self.optimizer_tar_G.zero_grad() self.backward_t_G() self.optimizer_tar_G.step() def get_current_errors(self): return OrderedDict([('G_v', self.loss_v), ('G_r', self.loss_r), ('G_sp', self.loss_sp), ('G_gan_Di', self.loss_G_GAN_Di), ('G_gan_Dp', self.loss_G_GAN_Dp), ('D_i', self.loss_Di), ('D_p', self.loss_Dp)]) def get_current_cross_errors(self): return OrderedDict([('G_tar_rec', self.loss_t_r), ('G_tar_gan_Di', self.loss_t_G_GAN_Di), ('G_tar_gan_Dp', self.loss_t_G_GAN_Dp), ('tar_D_i', self.loss_t_Di), ('D_p', self.loss_Dp), ('G_src_v', self.loss_s_v), ('G_src_rec', self.loss_s_r), ('G_src_gan_Di', self.loss_s_G_GAN_Di), ('G_src_gan_Dp', self.loss_s_G_GAN_Dp), ('src_D_i', self.loss_s_Di), ('MMD', self.loss_mmd), ('src_tri', self.loss_s_tri)]) def get_current_visuals(self): input = util.tensor2im(self.origin) target = util.tensor2im(self.target) fake = util.tensor2im(self.fake) map = self.posemap.sum(1) map[map > 1] = 1 map = util.tensor2im(torch.unsqueeze(map, 1)) return OrderedDict([('input', input), ('posemap', map), ('fake', fake), ('target', target)]) def get_tf_visuals(self): input = util.tensor2ims(self.origin) target = util.tensor2ims(self.target) fake = util.tensor2ims(self.fake) map = self.posemap.sum(1) map[map > 1] = 1 map = util.tensor2ims(torch.unsqueeze(map, 1)) return OrderedDict([('input', input), ('posemap', map), ('fake', fake), ('target', target)]) # self.t_origin = origin.cuda() # self.t_target = target.cuda() # self.t_posemap = posemap.cuda() # self.t_labels = labels.cuda() # self.t_noise = noise.cuda() def get_tf_cross_visuals(self): src_input = util.tensor2ims(self.s_origin) src_target = util.tensor2ims(self.s_target) src_fake = util.tensor2ims(self.s_fake.data) src_map = self.s_posemap.sum(1) src_map[src_map > 1] = 1 src_map = util.tensor2ims(torch.unsqueeze(src_map, 1)) tar_input = util.tensor2ims(self.t_origin) tar_target = util.tensor2ims(self.t_target) tar_fake = util.tensor2ims(self.t_fake.data) tar_map = self.t_posemap.sum(1) tar_map[tar_map > 1] = 1 tar_map = util.tensor2ims(torch.unsqueeze(tar_map, 1)) src2tgt_fake = util.tensor2ims(self.st_fake.data) return OrderedDict([('src_input', src_input), ('src_posemap', src_map), ('src_fake', src_fake), ('src_target', src_target), ('tar_input', tar_input), ('tar_posemap', tar_map), ('tar_fake', tar_fake), ('tar_target', tar_target), ('src2tgt_fake', src2tgt_fake)]) def save(self, epoch): self.save_network(self.net_E, 'E', epoch) self.save_network(self.net_G, 'G', epoch) self.save_network(self.net_Di, 'Di', epoch) self.save_network(self.net_Dp, 'Dp', epoch) self.save_network(self.tar_net_G, 'tar_G', epoch) self.save_network(self.tar_net_Di, 'tar_Di', epoch) def save_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(network.state_dict(), save_path) def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] def forward_test_cross(self): #source A = Variable(self.s_origin) B_map = Variable(self.s_posemap) z = Variable(self.s_noise) bs = A.size(0) A_id1, A_id2, self.s_id_score = self.net_E(A[:bs // 2], A[bs // 2:]) A_id = torch.cat((A_id1, A_id2)) self.s_A_id = A_id self.s_fake = self.net_G(B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1), z.view(z.size(0), z.size(1), 1, 1)) #source to target # A_id1_st = A_id1[:bs//] # z_st = z[:bs//] # self.st_fake = self.tar_net_G(B_map[:bs//], A_id1_st.view(A_id1_st.size(0), A_id1_st.size(1), 1, 1), z_st.view(z_st.size(0), z_st.size(1), 1, 1)) self.st_fake = self.tar_net_G( B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1), z.view(z.size(0), z.size(1), 1, 1)) #target A = Variable(self.t_origin) B_map = Variable(self.s_posemap) z = Variable(self.t_noise) bs = A.size(0) A_id1, A_id2, self.t_id_score = self.net_E(A[:bs // 2], A[bs // 2:]) A_id = torch.cat((A_id1, A_id2)) self.t_A_id = A_id self.t_fake = self.tar_net_G( B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1), z.view(z.size(0), z.size(1), 1, 1)) self.ts_fake = self.net_G(B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1), z.view(z.size(0), z.size(1), 1, 1)) def get_test_cross_visuals(self): src_input = util.tensor2ims(self.s_origin) src_target = util.tensor2ims(self.s_target) src_fake = util.tensor2ims(self.s_fake.data) src_map = self.s_posemap.sum(1) src_map[src_map > 1] = 1 src_map = util.tensor2ims(torch.unsqueeze(src_map, 1)) tar_input = util.tensor2ims(self.t_origin) tar_target = util.tensor2ims(self.t_target) tar_fake = util.tensor2ims(self.t_fake.data) tar_map = self.t_posemap.sum(1) tar_map[tar_map > 1] = 1 tar_map = util.tensor2ims(torch.unsqueeze(tar_map, 1)) src2tgt_fake = util.tensor2ims(self.st_fake.data) tgt2src_fake = util.tensor2ims(self.ts_fake.data) return OrderedDict([ ('src_input', src_input), ('src_posemap', src_map), ('src_fake', src_fake), ('src_target', src_target), ('tar_input', tar_input), ('tar_fake', tar_fake), ('src2tgt_fake', src2tgt_fake), ('tgt2src_fake', tgt2src_fake), ])
def _init_models(self): self.net_G = CustomPoseGenerator( self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) e_base_model = create(self.opt.arch, cut_at_pooling=True) self.e_embed_model = Sub_model(num_features=256, in_features=2048, num_classes=751, FCN=True, dropout=0.5) self.net_E = ENet(e_base_model, self.e_embed_model) di_base_model = create(self.opt.arch, cut_at_pooling=True) di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1) self.net_Di = SiameseNet(di_base_model, di_embed_model) self.net_Dp = NLayerDiscriminator(3 + 18, norm_layer=self.norm_layer) if self.opt.stage == 1: init_weights(self.net_G) init_weights(self.net_Dp) checkpoint = torch.load(self.opt.netE_pretrain) model_dict = self.net_E.state_dict() checkpoint_load = { k: v for k, v in (checkpoint['state_dict']).items() if k in model_dict } model_dict.update(checkpoint_load) self.net_E.load_state_dict(model_dict) #state_dict model2_dict = self.net_Di.state_dict() checkpoint_load = { k: v for k, v in (checkpoint['state_dict']).items() if k in model2_dict } model2_dict.update(checkpoint_load) self.net_Di.load_state_dict(model2_dict) elif self.opt.stage == 2: self._load_state_dict(self.net_E, self.opt.netE_pretrain) self._load_state_dict(self.net_G, self.opt.netG_pretrain) self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) else: assert ('unknown training stage') self.net_E = torch.nn.DataParallel(self.net_E).cuda() self.net_G = torch.nn.DataParallel(self.net_G).cuda() self.net_Di = torch.nn.DataParallel(self.net_Di).cuda() self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()