コード例 #1
0
ファイル: model.py プロジェクト: zqh0253/FD-GAN
    def _init_models(self):
        self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, self.opt.noise_feature_size,
                                dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers)
        e_base_model = create(self.opt.arch, cut_at_pooling=True)
        e_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=2)
        self.net_E = SiameseNet(e_base_model, e_embed_model)

        di_base_model = create(self.opt.arch, cut_at_pooling=True)
        di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1)
        self.net_Di = SiameseNet(di_base_model, di_embed_model)
        self.net_Dp = NLayerDiscriminator(3+18, norm_layer=self.norm_layer)

        if self.opt.stage==1:
            init_weights(self.net_G)
            init_weights(self.net_Dp)
            state_dict = remove_module_key(torch.load(self.opt.netE_pretrain))
            self.net_E.load_state_dict(state_dict)
            state_dict['embed_model.classifier.weight'] = state_dict['embed_model.classifier.weight'][1]
            state_dict['embed_model.classifier.bias'] = torch.FloatTensor([state_dict['embed_model.classifier.bias'][1]])
            self.net_Di.load_state_dict(state_dict)
        elif self.opt.stage==2:
            self._load_state_dict(self.net_E, self.opt.netE_pretrain)
            self._load_state_dict(self.net_G, self.opt.netG_pretrain)
            self._load_state_dict(self.net_Di, self.opt.netDi_pretrain)
            self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain)
        else:
            assert('unknown training stage')

        self.net_E = torch.nn.DataParallel(self.net_E).cuda()
        self.net_G = torch.nn.DataParallel(self.net_G).cuda()
        self.net_Di = torch.nn.DataParallel(self.net_Di).cuda()
        self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()
コード例 #2
0
 def load_model(self, model_arch='resnet50'):
     # Create model
     # # 利用resnet50建立base_model
     base_model = models.create(model_arch, cut_at_pooling=True)
     # # 建立嵌入模型
     embed_model = EltwiseSubEmbed(use_batch_norm=True,
                                   use_classifier=True,
                                   num_features=2048,
                                   num_classes=2)
     model = SiameseNet(base_model, embed_model)
     model = nn.DataParallel(model).cuda()
     checkpoint = load_checkpoint(self.model_path)
     if 'state_dict' in checkpoint.keys():
         checkpoint = checkpoint['state_dict']
     model.load_state_dict(checkpoint)
     return torch.nn.DataParallel(base_model).cuda()
コード例 #3
0
    def _init_models(self):
        #self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, self.opt.noise_feature_size,
        self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, 0, pose_nc=self.pose_size, dropout=self.opt.drop,
                                         norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode,
                                         connect_layers=self.opt.connect_layers)

        if (self.opt.emb_type == 'Single'):
            self.net_E = SingleNet(self.opt.arch, self.emb_size, pretraind=True, use_bn=True, test_bn=False, last_stride=self.opt.last_stride)
        elif  (self.opt.emb_type == 'Siamese'):  
            self.net_E = SiameseNet(self.opt.arch, self.emb_size, pretraind=True, use_bn=True, test_bn=False, last_stride=self.opt.last_stride)
        else:
            raise ValueError('unrecognized model')

        self.net_Di = SingleNet('resnet18', 1, pretraind=True, use_bn=True, test_bn=False, last_stride=2)

        self.net_Dp = NLayerDiscriminator(3+self.pose_size, norm_layer=self.norm_layer)

        if self.opt.stage==0: # This is for training end-to-end
            init_weights(self.net_G)
            init_weights(self.net_Dp)
        elif self.opt.stage==1: # This is for training fixing a baseline model
            init_weights(self.net_G)
            init_weights(self.net_Dp)
            checkpoint = load_checkpoint(self.opt.netE_pretrain)
            
            if 'state_dict' in checkpoint.keys():
                checkpoint = checkpoint['state_dict']
            state_dict = remove_module_key(checkpoint)

            self.net_E.load_state_dict(state_dict)
            #state_dict['classifier.weight'] = state_dict['classifier.weight'][1:2]
            #state_dict['classifier.bias'] = torch.FloatTensor([state_dict['classifier.bias'][1]])
            #self.net_Di.load_state_dict(state_dict)
        elif self.opt.stage==2: # This is for training with a provided model
            self._load_state_dict(self.net_E, self.opt.netE_pretrain)
            self._load_state_dict(self.net_G, self.opt.netG_pretrain)
            self._load_state_dict(self.net_Di, self.opt.netDi_pretrain)
            self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain)
        else:
            raise ValueError('unrecognized mode')

        self.net_E = torch.nn.DataParallel(self.net_E).cuda()
        self.net_G = torch.nn.DataParallel(self.net_G).cuda()
        self.net_Di = torch.nn.DataParallel(self.net_Di).cuda()
        self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()
コード例 #4
0
ファイル: baseline.py プロジェクト: zhudi512/PDA-Net
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.benchmark = True

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    else:
        log_dir = osp.dirname(args.resume)
        sys.stdout = Logger(osp.join(log_dir, 'log_test.txt'))
    # print("==========\nArgs:{}\n==========".format(args))

    # Create data loaders
    if args.height is None or args.width is None:
        args.height, args.width = (256, 128)
    dataset, train_loader, val_loader, test_loader = \
        get_data(args.dataset, args.split, args.data_dir, args.height,
                 args.width, args.batch_size, args.workers,
                 args.combine_trainval, args.np_ratio)

    # Create model
    base_model = models.create(args.arch, cut_at_pooling=True)
    embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True,
                                      num_features=2048, num_classes=2)
    model = SiameseNet(base_model, embed_model)
    model = nn.DataParallel(model).cuda()

    # Evaluator
    evaluator = CascadeEvaluator(
        torch.nn.DataParallel(base_model).cuda(),
        embed_model,
        embed_dist_fn=lambda x: F.softmax(Variable(x), dim=1).data[:, 0])

    # Load from checkpoint
    best_mAP = 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        if 'state_dict' in checkpoint.keys():
            checkpoint = checkpoint['state_dict']
        model.load_state_dict(checkpoint)

        print("Test the loaded model:")
        top1, mAP = evaluator.evaluate(test_loader, dataset.query, dataset.gallery, rerank_topk=100, dataset=args.dataset)
        best_mAP = mAP

    if args.evaluate:
        return

    # Criterion
    criterion = nn.CrossEntropyLoss().cuda()
    # Optimizer
    param_groups = [
        {'params': model.module.base_model.parameters(), 'lr_mult': 1.0},
        {'params': model.module.embed_model.parameters(), 'lr_mult': 1.0}]
    optimizer = torch.optim.SGD(param_groups, args.lr, momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # Trainer
    trainer = SiameseTrainer(model, criterion)

    # Schedule learning rate
    def adjust_lr(epoch):
        lr = args.lr * (0.1 ** (epoch // args.step_size))
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Start training
    for epoch in range(0, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, train_loader, optimizer, base_lr=args.lr)

        if epoch % args.eval_step==0:
            mAP = evaluator.evaluate(val_loader, dataset.val, dataset.val, top1=False)
            is_best = mAP > best_mAP
            best_mAP = max(mAP, best_mAP)
            save_checkpoint({
                'state_dict': model.state_dict()
            }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

            print('\n * Finished epoch {:3d}  mAP: {:5.1%}  best: {:5.1%}{}\n'.
                  format(epoch, mAP, best_mAP, ' *' if is_best else ''))

    # Final test
    print('Test with best model:')
    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])
    evaluator.evaluate(test_loader, dataset.query, dataset.gallery, dataset=args.dataset)
コード例 #5
0
class ST_ReIDNet(object):

    def __init__(self, opt, emb_size):
        self.opt = opt
        self.save_dir = os.path.join(opt.checkpoints, opt.name)
        self.norm_layer = get_norm_layer(norm_type=opt.norm)
        self.emb_size = emb_size

        if (self.opt.dataset == 'DanceReID'):
            self.pose_size = 17
        else:
            self.pose_size = 18
        
        self._init_models()
        self._init_losses()
        self._init_optimizers()

        print('---------- Networks initialized -------------')
        #print_network(self.net_E)
        #print_network(self.net_G)
        #print_network(self.net_Di)
        #print_network(self.net_Dp)
        #print('-----------------------------------------------')

    def _init_models(self):
        #self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, self.opt.noise_feature_size,
        self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, 0, pose_nc=self.pose_size, dropout=self.opt.drop,
                                         norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode,
                                         connect_layers=self.opt.connect_layers)

        if (self.opt.emb_type == 'Single'):
            self.net_E = SingleNet(self.opt.arch, self.emb_size, pretraind=True, use_bn=True, test_bn=False, last_stride=self.opt.last_stride)
        elif  (self.opt.emb_type == 'Siamese'):  
            self.net_E = SiameseNet(self.opt.arch, self.emb_size, pretraind=True, use_bn=True, test_bn=False, last_stride=self.opt.last_stride)
        else:
            raise ValueError('unrecognized model')

        self.net_Di = SingleNet('resnet18', 1, pretraind=True, use_bn=True, test_bn=False, last_stride=2)

        self.net_Dp = NLayerDiscriminator(3+self.pose_size, norm_layer=self.norm_layer)

        if self.opt.stage==0: # This is for training end-to-end
            init_weights(self.net_G)
            init_weights(self.net_Dp)
        elif self.opt.stage==1: # This is for training fixing a baseline model
            init_weights(self.net_G)
            init_weights(self.net_Dp)
            checkpoint = load_checkpoint(self.opt.netE_pretrain)
            
            if 'state_dict' in checkpoint.keys():
                checkpoint = checkpoint['state_dict']
            state_dict = remove_module_key(checkpoint)

            self.net_E.load_state_dict(state_dict)
            #state_dict['classifier.weight'] = state_dict['classifier.weight'][1:2]
            #state_dict['classifier.bias'] = torch.FloatTensor([state_dict['classifier.bias'][1]])
            #self.net_Di.load_state_dict(state_dict)
        elif self.opt.stage==2: # This is for training with a provided model
            self._load_state_dict(self.net_E, self.opt.netE_pretrain)
            self._load_state_dict(self.net_G, self.opt.netG_pretrain)
            self._load_state_dict(self.net_Di, self.opt.netDi_pretrain)
            self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain)
        else:
            raise ValueError('unrecognized mode')

        self.net_E = torch.nn.DataParallel(self.net_E).cuda()
        self.net_G = torch.nn.DataParallel(self.net_G).cuda()
        self.net_Di = torch.nn.DataParallel(self.net_Di).cuda()
        self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()

    def reset_model_status(self):
        if self.opt.stage==0:
            self.net_E.train()
            self.net_G.train()
            self.net_Di.train()
            self.net_Dp.train()
            self.net_Di.apply(set_bn_fix)
        elif self.opt.stage==1:
            self.net_G.train()
            self.net_Dp.train()
            self.net_E.eval()
            self.net_Di.train()
            self.net_Di.apply(set_bn_fix)
        elif self.opt.stage==2:
            self.net_E.train()
            self.net_G.train()
            self.net_Di.train()
            self.net_Dp.train()
            self.net_E.apply(set_bn_fix)
            self.net_Di.apply(set_bn_fix)

    def _load_state_dict(self, net, path):

        state_dict = remove_module_key(torch.load(path))
        net.load_state_dict(state_dict)

    def _init_losses(self):
        if self.opt.smooth_label:
            self.criterionGAN_D = GANLoss(smooth=True).cuda()
            self.rand_list = [True] * 1 + [False] * 10000
        else:
            self.criterionGAN_D = GANLoss(smooth=False).cuda()
            self.rand_list = [False]
        self.criterionGAN_G = GANLoss(smooth=False).cuda()

        if self.opt.soft_margin:
            self.tri_criterion = TripletLoss(margin='soft', batch_hard=True, distractor=True).cuda()
        else:
            self.tri_criterion = TripletLoss(margin=self.opt.margin, batch_hard=True, distractor=True).cuda

        if (self.opt.emb_type == 'Single'):
            if (self.opt.emb_smooth):
                self.class_criterion = CrossEntropyLabelSmooth(self.emb_size, epsilon=0.1).cuda()
            else:
                self.class_criterion = torch.nn.CrossEntropyLoss().cuda()
        elif  (self.opt.emb_type == 'Siamese'): 
            self.class_criterion = torch.nn.CrossEntropyLoss().cuda()

        if self.opt.mask:
            self.reco_criterion = MaskedL1loss(use_mask=True).cuda()
        else:
            self.reco_criterion = MaskedL1loss(use_mask=False).cuda()



    def _init_optimizers(self):
        if self.opt.stage==0:
            param_groups = [{'params': self.net_E.module.base_model.parameters(), 'lr_mult': 0.1},
                            {'params': self.net_E.module.classifier.parameters(), 'lr_mult': 1.0}]
            self.optimizer_E = torch.optim.SGD(param_groups, lr=self.opt.lr, momentum=0.9, weight_decay=5e-4)
            #self.optimizer_E = torch.optim.Adam(param_groups, lr=self.opt.lr, betas=(0.9, 0.999), weight_decay=5e-4)
            self.optimizer_G = torch.optim.Adam(self.net_G.parameters(),
                                                lr=1e-5, betas=(0.5, 0.999))
            self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(),
                                                lr=4e-5, momentum=0.9, weight_decay=1e-4)
            self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(),
                                                lr=4e-5, momentum=0.9, weight_decay=1e-4)
        elif self.opt.stage==1:
            param_groups = [{'params': self.net_E.module.base_model.parameters(), 'lr_mult': 0.1},
                            {'params': self.net_E.module.classifier.parameters(), 'lr_mult': 0.1}]

            self.optimizer_E = torch.optim.SGD(param_groups, lr=self.opt.lr, momentum=0.9, weight_decay=5e-4)

            self.optimizer_G = torch.optim.Adam(self.net_G.parameters(),
                                                lr=self.opt.lr*0.1, betas=(0.5, 0.999))
            self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(),
                                                lr=self.opt.lr, momentum=0.9, weight_decay=1e-4)
            self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(),
                                                lr=self.opt.lr, momentum=0.9, weight_decay=1e-4)
        elif self.opt.stage==2:
            param_groups = [{'params': self.net_E.module.base_model.parameters(), 'lr_mult': 0.01},
                            {'params': self.net_E.module.classifier.parameters(), 'lr_mult': 0.1}]

            self.optimizer_E = torch.optim.SGD(param_groups, lr=self.opt.lr, momentum=0.9, weight_decay=5e-4)

            self.optimizer_G = torch.optim.Adam(self.net_G.parameters(),
                                                lr=1e-6, betas=(0.5, 0.999))
            self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(),
                                                lr=1e-5, momentum=0.9, weight_decay=1e-4)
            self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(),
                                                lr=1e-5, momentum=0.9, weight_decay=1e-4)


        self.schedulers = []
        self.optimizers = []
        self.optimizers.append(self.optimizer_E)
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_Di)
        self.optimizers.append(self.optimizer_Dp)
        for optimizer in self.optimizers:
            self.schedulers.append(get_scheduler(optimizer, self.opt))

    def set_input(self, input):
        input_ach, input_pos, input_neg = input
        ids = torch.cat([input_ach['pid'], input_pos['pid'], input_neg['pid']]).long()

        ones = torch.ones_like(input_ach['pid']).fill_(1.0)
        zeros = torch.ones_like(input_ach['pid']).fill_(0.0)

        labels = torch.cat([ones, zeros]).long()

        noise = torch.randn(labels.size(0), self.opt.noise_feature_size)
        noise = torch.cat((noise, noise, noise))



        self.ach_gt = input_ach['origin'].cuda()
        self.pos_gt = input_pos['origin'].cuda()
        self.neg_gt = input_neg['origin'].cuda()

        self.ach_img = input_ach['input'].cuda()
        self.pos_img = input_pos['input'].cuda()
        self.neg_img = input_neg['input'].cuda()

        self.ach_pose = input_ach['posemap'].cuda()
        self.pos_pose = input_pos['posemap'].cuda()
        self.neg_pose = input_neg['posemap'].cuda()

        self.ach_mask = input_ach['mask'].cuda()
        self.pos_mask = input_pos['mask'].cuda()
        self.neg_mask = input_neg['mask'].cuda()


        self.gt = torch.cat( [self.ach_gt, self.pos_gt, self.neg_gt])
        self.mask = torch.cat( [self.ach_mask, self.pos_mask, self.neg_mask])

        self.posemap = torch.cat( [self.ach_pose, self.pos_pose, self.neg_pose])

        self.two_gt = torch.cat([self.ach_gt, self.pos_gt])
        self.two_mask = torch.cat( [self.ach_mask, self.pos_mask])

        self.pos_posemap = torch.cat([self.ach_pose, self.pos_pose])
        self.swap_posemap = torch.cat([self.pos_pose, self.ach_pose])

        self.ids = ids.cuda()
        self.labels = labels.cuda()
        self.noise = noise.cuda()

    def forward(self):
        z = Variable(self.noise)


        if (self.opt.emb_type == 'Single'):
            if (self.opt.stage == 1):
                self.A_ach = self.net_E(Variable(self.ach_img))
                self.A_pos = self.net_E(Variable(self.pos_img))
                self.A_neg = self.net_E(Variable(self.neg_img))
            else:
                self.A_ach, pred_ach = self.net_E(Variable(self.ach_img))
                self.A_pos, pred_pos = self.net_E(Variable(self.pos_img))
                self.A_neg, pred_neg = self.net_E(Variable(self.neg_img))
                self.id_pred = torch.cat( [pred_ach, pred_pos, pred_neg])

        elif  (self.opt.emb_type == 'Siamese'):
            self.A_ach, self.A_pos, pos_pred = self.net_E(torch.cat([Variable(self.ach_img), Variable(self.pos_img)]))
            _, self.A_neg, neg_pred = self.net_E(torch.cat([Variable(self.ach_img), Variable(self.neg_img)]))
            self.id_pred =  torch.cat( [pos_pred, neg_pred])

        self.fake_ach = self.net_G(Variable(self.ach_pose),
                 self.A_ach.view(self.A_ach.size(0), self.A_ach.size(1), 1, 1), None)
        self.fake_pos = self.net_G(Variable(self.pos_pose),
                 self.A_pos.view(self.A_pos.size(0), self.A_pos.size(1), 1, 1), None)
        self.fake_neg = self.net_G(Variable(self.neg_pose),
                 self.A_neg.view(self.A_neg.size(0), self.A_neg.size(1), 1, 1), None)

        self.swap_ach = self.net_G(Variable(self.ach_pose),
                 self.A_pos.view(self.A_pos.size(0), self.A_pos.size(1), 1, 1), None)
        self.swap_pos = self.net_G(Variable(self.pos_pose),
                 self.A_ach.view(self.A_ach.size(0), self.A_ach.size(1), 1, 1), None)

        self.fake = torch.cat((self.fake_ach, self.fake_pos, self.fake_neg))

        self.swap_fake = torch.cat((self.swap_ach, self.swap_pos))


    def backward_Dp(self):
        real_pose =  torch.cat((Variable(self.posemap), Variable(self.gt)), dim=1)
        fake_pose1 = torch.cat((Variable(self.posemap), self.fake.detach()), dim=1)
        fake_pose2 = torch.cat((Variable(self.pos_posemap), self.swap_fake.detach()), dim=1)
        fake_pose3 = torch.cat((Variable(self.swap_posemap), Variable(self.two_gt)), dim=1)

        
        pred_real = self.net_Dp(real_pose)
        pred_fake = self.net_Dp(torch.cat([fake_pose1, fake_pose2, fake_pose3]))
        #print(pred_real.size())
        #print(pred_fake.size())
        if random.choice(self.rand_list):
            loss_D_real = self.criterionGAN_D(pred_fake, True)
            loss_D_fake = self.criterionGAN_D(pred_real, False)
        else:
            loss_D_real = self.criterionGAN_D(pred_real, True)
            loss_D_fake = self.criterionGAN_D(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        #self.loss_Dp = loss_D.data[0]
        self.loss_Dp = loss_D.item()

    def backward_Di(self):
        _, pred_real = self.net_Di(Variable(self.gt))
        _, pred_fake1 = self.net_Di(self.fake.detach())
        _, pred_fake2 = self.net_Di(self.swap_fake.detach())

        pred_fake = torch.cat ([pred_fake1, pred_fake2])

        if random.choice(self.rand_list):
            loss_D_real = self.criterionGAN_D(pred_fake, True)
            loss_D_fake = self.criterionGAN_D(pred_real, False)
        else:
            loss_D_real = self.criterionGAN_D(pred_real, True)
            loss_D_fake = self.criterionGAN_D(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        #self.loss_Di = loss_D.data[0]
        self.loss_Di = loss_D.item()


    def backward_G(self):
        
        loss_r1 = self.reco_criterion(self.fake, Variable(self.mask), Variable(self.gt))
        loss_r2 = self.reco_criterion(self.swap_fake, Variable(self.two_mask), Variable(self.two_gt))

        loss_r = 0.7 * loss_r2 + 0.3 * loss_r1

        _, pred_fake_Di = self.net_Di(torch.cat([self.fake, self.swap_fake]))

        fake_pose1 = torch.cat((Variable(self.posemap), self.fake), dim=1)
        fake_pose2 = torch.cat((Variable(self.pos_posemap), self.swap_fake), dim=1)
        fake_pose3 = torch.cat((Variable(self.swap_posemap), Variable(self.two_gt)), dim=1)

        pred_fake_Dp = self.net_Dp(torch.cat([fake_pose1, fake_pose2, fake_pose3]))
        
        loss_G_GAN_Di = self.criterionGAN_G(pred_fake_Di, True)
        loss_G_GAN_Dp = self.criterionGAN_G(pred_fake_Dp, True)


        loss_G = loss_G_GAN_Di * self.opt.lambda_d + \
                 loss_G_GAN_Dp * self.opt.lambda_dp  + \
                 loss_r * self.opt.lambda_recon

        loss_G.backward(retain_graph=True)

        # Compute triplet loss
        feat = torch.cat([self.A_ach, self.A_pos, self.A_neg])
        loss_t, _ , _ = self.tri_criterion(feat, torch.squeeze(self.ids, 1))
        # Classification loss
        if (self.opt.stage==1):
            loss_c = torch.tensor(0.0)
        else:
            if (self.opt.emb_type == 'Single'):
                loss_c = self.class_criterion(self.id_pred, torch.squeeze(self.ids, 1))
            elif (self.opt.emb_type == 'Siamese'):
                loss_c = self.class_criterion(self.id_pred, torch.squeeze(self.labels, 1))


        loss_E = loss_G + \
                loss_t * self.opt.lambda_tri+ \
                loss_c * self.opt.lambda_class
                
        loss_E.backward()

        self.loss_E = loss_E.item()
        self.loss_t = loss_t.item()
        self.loss_c = loss_c.item()

        self.loss_G = loss_G.item()
        self.loss_r = loss_r.item()
        self.loss_G_GAN_Di = loss_G_GAN_Di.item()
        self.loss_G_GAN_Dp = loss_G_GAN_Dp.item()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_Di.zero_grad()
        self.backward_Di()
        self.optimizer_Di.step()

        self.optimizer_Dp.zero_grad()
        self.backward_Dp()
        self.optimizer_Dp.step()

        self.optimizer_G.zero_grad()
        self.optimizer_E.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        if self.opt.stage!=1:
            self.optimizer_E.step()


    def get_current_errors(self):
        return OrderedDict([('E_t', self.loss_t),
                            ('E_c', self.loss_c),
                            ('E', self.loss_E),
                            ('G_r', self.loss_r),
                            ('G_gan_Di', self.loss_G_GAN_Di),
                            ('G_gan_Dp', self.loss_G_GAN_Dp),
                            ('G', self.loss_G),
                            ('D_i', self.loss_Di),
                            ('D_p', self.loss_Dp)
                            ])

    def get_current_visuals(self):
        ach = util.tensor2im(self.ach_img)

        ach_map = self.ach_pose.sum(1)
        ach_map[ach_map>1]=1
        ach_pose = util.tensor2im(torch.unsqueeze(ach_map,1))
        ach_mask = util.tensor2im(torch.unsqueeze(self.ach_mask,1))
        ach_fake1 = util.tensor2im(self.fake_ach.data)
        ach_fake2 = util.tensor2im(self.swap_ach.data)
        ###########
        pos = util.tensor2im(self.pos_img)

        pos_map = self.pos_pose.sum(1)
        pos_map[pos_map>1]=1
        pos_pose = util.tensor2im(torch.unsqueeze(pos_map,1))
        pos_mask = util.tensor2im(torch.unsqueeze(self.pos_mask,1))

        pos_fake1 = util.tensor2im(self.fake_pos.data)
        pos_fake2 = util.tensor2im(self.swap_pos.data)

        return OrderedDict([('ach', ach), ('ach_pose', ach_pose), ('ach_mask', ach_mask), ('ach_fake1', ach_fake1), ('ach_fake2', ach_fake2),
                            ('pos', pos), ('pos_pose', pos_pose), ('pos_mask', pos_mask), ('pos_fake1', pos_fake1), ('pos_fake2', pos_fake2)])

    def save(self, epoch):
        self.save_network(self.net_E, 'E', epoch)
        self.save_network(self.net_G, 'G', epoch)
        self.save_network(self.net_Di, 'Di', epoch)
        self.save_network(self.net_Dp, 'Dp', epoch)

    def save_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.state_dict(), save_path)

    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
コード例 #6
0
ファイル: model.py プロジェクト: zqh0253/FD-GAN
class FDGANModel(object):

    def __init__(self, opt):
        self.opt = opt
        self.save_dir = os.path.join(opt.checkpoints, opt.name)
        self.norm_layer = get_norm_layer(norm_type=opt.norm)

        self._init_models()
        self._init_losses()
        self._init_optimizers()

        print('---------- Networks initialized -------------')
        print_network(self.net_E)
        print_network(self.net_G)
        print_network(self.net_Di)
        print_network(self.net_Dp)
        print('-----------------------------------------------')

    def _init_models(self):
        self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, self.opt.noise_feature_size,
                                dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers)
        e_base_model = create(self.opt.arch, cut_at_pooling=True)
        e_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=2)
        self.net_E = SiameseNet(e_base_model, e_embed_model)

        di_base_model = create(self.opt.arch, cut_at_pooling=True)
        di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1)
        self.net_Di = SiameseNet(di_base_model, di_embed_model)
        self.net_Dp = NLayerDiscriminator(3+18, norm_layer=self.norm_layer)

        if self.opt.stage==1:
            init_weights(self.net_G)
            init_weights(self.net_Dp)
            state_dict = remove_module_key(torch.load(self.opt.netE_pretrain))
            self.net_E.load_state_dict(state_dict)
            state_dict['embed_model.classifier.weight'] = state_dict['embed_model.classifier.weight'][1]
            state_dict['embed_model.classifier.bias'] = torch.FloatTensor([state_dict['embed_model.classifier.bias'][1]])
            self.net_Di.load_state_dict(state_dict)
        elif self.opt.stage==2:
            self._load_state_dict(self.net_E, self.opt.netE_pretrain)
            self._load_state_dict(self.net_G, self.opt.netG_pretrain)
            self._load_state_dict(self.net_Di, self.opt.netDi_pretrain)
            self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain)
        else:
            assert('unknown training stage')

        self.net_E = torch.nn.DataParallel(self.net_E).cuda()
        self.net_G = torch.nn.DataParallel(self.net_G).cuda()
        self.net_Di = torch.nn.DataParallel(self.net_Di).cuda()
        self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()

    def reset_model_status(self):
        if self.opt.stage==1:
            self.net_G.train()
            self.net_Dp.train()
            self.net_E.eval()
            self.net_Di.train()
            self.net_Di.apply(set_bn_fix)
        elif self.opt.stage==2:
            self.net_E.train()
            self.net_G.train()
            self.net_Di.train()
            self.net_Dp.train()
            self.net_E.apply(set_bn_fix)
            self.net_Di.apply(set_bn_fix)

    def _load_state_dict(self, net, path):
        state_dict = remove_module_key(torch.load(path))
        net.load_state_dict(state_dict)

    def _init_losses(self):
        if self.opt.smooth_label:
            self.criterionGAN_D = GANLoss(smooth=True).cuda()
            self.rand_list = [True] * 1 + [False] * 10000
        else:
            self.criterionGAN_D = GANLoss(smooth=False).cuda()
            self.rand_list = [False]
        self.criterionGAN_G = GANLoss(smooth=False).cuda()

    def _init_optimizers(self):
        if self.opt.stage==1:
            self.optimizer_G = torch.optim.Adam(self.net_G.parameters(),
                                                lr=self.opt.lr*0.1, betas=(0.5, 0.999))
            self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(),
                                                lr=self.opt.lr*0.01, momentum=0.9, weight_decay=1e-4)
            self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(),
                                                lr=self.opt.lr, momentum=0.9, weight_decay=1e-4)
        elif self.opt.stage==2:
            param_groups = [{'params': self.net_E.module.base_model.parameters(), 'lr_mult': 0.1},
                            {'params': self.net_E.module.embed_model.parameters(), 'lr_mult': 1.0},
                            {'params': self.net_G.parameters(), 'lr_mult': 0.1}]
            self.optimizer_G = torch.optim.Adam(param_groups,
                                                lr=self.opt.lr*0.1, betas=(0.5, 0.999))
            self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(),
                                                lr=self.opt.lr, momentum=0.9, weight_decay=1e-4)
            self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(),
                                                lr=self.opt.lr, momentum=0.9, weight_decay=1e-4)

        self.schedulers = []
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_Di)
        self.optimizers.append(self.optimizer_Dp)
        for optimizer in self.optimizers:
            self.schedulers.append(get_scheduler(optimizer, self.opt))

    def set_input(self, input):
        input1, input2 = input
        labels = (input1['pid']==input2['pid']).long()
        noise = torch.randn(labels.size(0), self.opt.noise_feature_size)

        # keep the same pose map for persons with the same identity
        mask = labels.view(-1,1,1,1).expand_as(input1['posemap'])
        input2['posemap'] = input1['posemap']*mask.float() + input2['posemap']*(1-mask.float())
        mask = labels.view(-1,1,1,1).expand_as(input1['target'])
        input2['target'] = input1['target']*mask.float() + input2['target']*(1-mask.float())

        origin = torch.cat([input1['origin'], input2['origin']])
        target = torch.cat([input1['target'], input2['target']])
        posemap = torch.cat([input1['posemap'], input2['posemap']])
        noise = torch.cat((noise, noise))

        self.origin = origin.cuda()
        self.target = target.cuda()
        self.posemap = posemap.cuda()
        self.labels = labels.cuda()
        self.noise = noise.cuda()

    def forward(self):
        A = Variable(self.origin)
        B_map = Variable(self.posemap)
        z = Variable(self.noise)
        bs = A.size(0)

        A_id1, A_id2, self.id_score = self.net_E(A[:bs//2], A[bs//2:])
        A_id = torch.cat((A_id1, A_id2))
        self.fake = self.net_G(B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1), z.view(z.size(0), z.size(1), 1, 1))

    def backward_Dp(self):
        real_pose = torch.cat((Variable(self.posemap), Variable(self.target)),dim=1)
        fake_pose = torch.cat((Variable(self.posemap), self.fake.detach()),dim=1)
        pred_real = self.net_Dp(real_pose)
        pred_fake = self.net_Dp(fake_pose)

        if random.choice(self.rand_list):
            loss_D_real = self.criterionGAN_D(pred_fake, True)
            loss_D_fake = self.criterionGAN_D(pred_real, False)
        else:
            loss_D_real = self.criterionGAN_D(pred_real, True)
            loss_D_fake = self.criterionGAN_D(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        self.loss_Dp = loss_D.data[0]

    def backward_Di(self):
        _, _, pred_real = self.net_Di(Variable(self.origin), Variable(self.target))
        _, _, pred_fake = self.net_Di(Variable(self.origin), self.fake.detach())
        if random.choice(self.rand_list):
            loss_D_real = self.criterionGAN_D(pred_fake, True)
            loss_D_fake = self.criterionGAN_D(pred_real, False)
        else:
            loss_D_real = self.criterionGAN_D(pred_real, True)
            loss_D_fake = self.criterionGAN_D(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        self.loss_Di = loss_D.data[0]

    def backward_G(self):
        loss_v = F.cross_entropy(self.id_score, Variable(self.labels).view(-1))
        loss_r = F.l1_loss(self.fake, Variable(self.target))
        fake_1 = self.fake[:self.fake.size(0)//2]
        fake_2 = self.fake[self.fake.size(0)//2:]
        loss_sp = F.l1_loss(fake_1[self.labels.view(self.labels.size(0),1,1,1).expand_as(fake_1)==1], 
                            fake_2[self.labels.view(self.labels.size(0),1,1,1).expand_as(fake_1)==1])

        _, _, pred_fake_Di = self.net_Di(Variable(self.origin), self.fake)
        pred_fake_Dp = self.net_Dp(torch.cat((Variable(self.posemap),self.fake),dim=1))
        loss_G_GAN_Di = self.criterionGAN_G(pred_fake_Di, True)
        loss_G_GAN_Dp = self.criterionGAN_G(pred_fake_Dp, True)

        loss_G = loss_G_GAN_Di + loss_G_GAN_Dp + \
                loss_r * self.opt.lambda_recon + \
                loss_v * self.opt.lambda_veri + \
                loss_sp * self.opt.lambda_sp
        loss_G.backward()

        del self.id_score
        self.loss_G = loss_G.data[0]
        self.loss_v = loss_v.data[0]
        self.loss_sp = loss_sp.data[0]
        self.loss_r = loss_r.data[0]
        self.loss_G_GAN_Di = loss_G_GAN_Di.data[0]
        self.loss_G_GAN_Dp = loss_G_GAN_Dp.data[0]
        self.fake = self.fake.data

    def optimize_parameters(self):
        self.forward()

        self.optimizer_Di.zero_grad()
        self.backward_Di()
        self.optimizer_Di.step()

        self.optimizer_Dp.zero_grad()
        self.backward_Dp()
        self.optimizer_Dp.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_v', self.loss_v),
                            ('G_r', self.loss_r),
                            ('G_sp', self.loss_sp),
                            ('G_gan_Di', self.loss_G_GAN_Di),
                            ('G_gan_Dp', self.loss_G_GAN_Dp),
                            ('D_i', self.loss_Di),
                            ('D_p', self.loss_Dp)
                            ])

    def get_current_visuals(self):
        input = util.tensor2im(self.origin)
        target = util.tensor2im(self.target)
        fake = util.tensor2im(self.fake)
        map = self.posemap.sum(1)
        map[map>1]=1
        map = util.tensor2im(torch.unsqueeze(map,1))
        return OrderedDict([('input', input), ('posemap', map), ('fake', fake), ('target', target)])

    def save(self, epoch):
        self.save_network(self.net_E, 'E', epoch)
        self.save_network(self.net_G, 'G', epoch)
        self.save_network(self.net_Di, 'Di', epoch)
        self.save_network(self.net_Dp, 'Dp', epoch)

    def save_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.state_dict(), save_path)

    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
コード例 #7
0
ファイル: baseline.py プロジェクト: azuxmioy/ST-ReIDNet
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.benchmark = True

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    else:
        log_dir = osp.dirname(args.resume)
        sys.stdout = Logger(osp.join(log_dir, 'log_test.txt'))
    # print("==========\nArgs:{}\n==========".format(args))

    # Create data loaders
    if args.height is None or args.width is None:
        args.height, args.width = (256, 128)
    dataset, train_loader, val_loader, test_loader = \
        get_data(args.dataset, args.split, args.data_dir, args.height,
                 args.width, args.batch_size, args.workers,
                 args.combine_trainval, args.np_ratio,
                args.emb_type, args.inst_mode, args.eraser)

    if args.combine_trainval:
        emb_size = dataset.num_trainval_ids
    else:
        emb_size = dataset.num_train_ids

    # Create model
    if (args.emb_type == 'Single'):
        model = SingleNet(args.arch,
                          emb_size,
                          pretraind=True,
                          use_bn=args.use_bn,
                          test_bn=args.test_bn,
                          last_stride=args.last_stride)
    elif (args.emb_type == 'Siamese'):
        model = SiameseNet(args.arch,
                           emb_size,
                           pretraind=True,
                           use_bn=args.use_bn,
                           test_bn=args.test_bn,
                           last_stride=args.last_stride)
    else:
        raise ValueError('unrecognized model')
    model = nn.DataParallel(model).cuda()

    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        if 'state_dict' in checkpoint.keys():
            checkpoint = checkpoint['state_dict']
        model.load_state_dict(checkpoint)

    # Evaluator

    evaluator = CascadeEvaluator(torch.nn.DataParallel(model).cuda(),
                                 emb_size=emb_size)

    # Load from checkpoint
    best_mAP = 0
    if args.resume:
        print("Test the loaded model:")
        top1, mAP = evaluator.evaluate(test_loader,
                                       dataset.query,
                                       dataset.gallery,
                                       dataset=args.dataset)
        best_mAP = mAP
    if args.evaluate:
        return

    # Criterion
    if args.soft_margin:
        tri_criterion = TripletLoss(margin='soft').cuda()
    else:
        tri_criterion = TripletLoss(margin=args.margin).cuda()

    if (args.emb_type == 'Single'):
        if args.label_smoothing:
            cla_criterion = CrossEntropyLabelSmooth(emb_size,
                                                    epsilon=0.1).cuda()
        else:
            cla_criterion = torch.nn.CrossEntropyLoss().cuda()
    elif (args.emb_type == 'Siamese'):
        cla_criterion = torch.nn.CrossEntropyLoss().cuda()

    # Optimizer
    param_groups = [{
        'params': model.module.base_model.parameters(),
        'lr_mult': 0.1
    }, {
        'params': model.module.classifier.parameters(),
        'lr_mult': 1.0
    }]

    if (args.opt_name == 'SGD'):
        optimizer = getattr(torch.optim,
                            args.opt_name)(param_groups,
                                           lr=args.lr,
                                           weight_decay=args.weight_decay,
                                           momentum=args.momentum)
    else:
        optimizer = getattr(torch.optim,
                            args.opt_name)(param_groups,
                                           lr=args.lr,
                                           weight_decay=args.weight_decay)

    # Trainer

    if (args.emb_type == 'Single'):
        trainer = TripletTrainer(model, tri_criterion, cla_criterion,
                                 args.lambda_tri, args.lambda_cla)
    elif (args.emb_type == 'Siamese'):
        trainer = SiameseTrainer(model, tri_criterion, cla_criterion,
                                 args.lambda_tri, args.lambda_cla)

    #TODO:Warmup lr
    # Schedule learning rate
    def adjust_lr(epoch):

        lr = args.lr * (0.1**(epoch // args.step_size))
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Start training
    for epoch in range(0, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, train_loader, optimizer, base_lr=args.lr)

        if epoch % args.eval_step == 0:
            #mAP = evaluator.evaluate(val_loader, dataset.val, dataset.val, top1=False, dataset=args.dataset)
            mAP = evaluator.evaluate(test_loader,
                                     dataset.query,
                                     dataset.gallery,
                                     top1=False,
                                     dataset=args.dataset)
            is_best = mAP > best_mAP
            best_mAP = max(mAP, best_mAP)
            save_checkpoint({'state_dict': model.state_dict()},
                            is_best,
                            fpath=osp.join(args.logs_dir,
                                           'checkpoint.pth.tar'))

            print('\n * Finished epoch {:3d}  mAP: {:5.1%}  best: {:5.1%}{}\n'.
                  format(epoch, mAP, best_mAP, ' *' if is_best else ''))

    # Final test
    print('Test with best model:')
    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])
    evaluator.evaluate(test_loader,
                       dataset.query,
                       dataset.gallery,
                       dataset=args.dataset)
コード例 #8
0
    def _init_models(self):
        #G For source
        self.net_G = CustomPoseGenerator(
            self.opt.pose_feature_size,
            2048,
            self.opt.noise_feature_size,
            dropout=self.opt.drop,
            norm_layer=self.norm_layer,
            fuse_mode=self.opt.fuse_mode,
            connect_layers=self.opt.connect_layers)
        #G For target
        self.tar_net_G = CustomPoseGenerator(
            self.opt.pose_feature_size,
            2048,
            self.opt.noise_feature_size,
            dropout=self.opt.drop,
            norm_layer=self.norm_layer,
            fuse_mode=self.opt.fuse_mode,
            connect_layers=self.opt.connect_layers)

        #We share same E for cross-dataset
        e_base_model = create(self.opt.arch, cut_at_pooling=True)
        e_embed_model = EltwiseSubEmbed(use_batch_norm=True,
                                        use_classifier=True,
                                        num_features=2048,
                                        num_classes=2)
        self.net_E = SiameseNet(e_base_model, e_embed_model)

        #Di For source
        di_base_model = create(self.opt.arch, cut_at_pooling=True)
        di_embed_model = EltwiseSubEmbed(use_batch_norm=True,
                                         use_classifier=True,
                                         num_features=2048,
                                         num_classes=1)
        self.net_Di = SiameseNet(di_base_model, di_embed_model)

        #Di For target
        di_base_model = create(self.opt.arch, cut_at_pooling=True)
        di_embed_model = EltwiseSubEmbed(use_batch_norm=True,
                                         use_classifier=True,
                                         num_features=2048,
                                         num_classes=1)
        self.tar_net_Di = SiameseNet(di_base_model, di_embed_model)

        #We share same Dp for cross-dataset
        self.net_Dp = NLayerDiscriminator(3 + 18, norm_layer=self.norm_layer)

        #Load model
        if self.opt.stage == 1:
            self._load_state_dict(self.net_E, self.opt.netE_pretrain)
            self._load_state_dict(self.net_G, self.opt.netG_pretrain)
            self._load_state_dict(self.net_Di, self.opt.netDi_pretrain)
            self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain)

            self._load_state_dict(self.tar_net_Di, self.opt.tar_netDi_pretrain)
            self._load_state_dict(self.tar_net_G, self.opt.tar_netG_pretrain)

        elif self.opt.stage == 2:
            self._load_state_dict(self.net_E, self.opt.netE_pretrain)
            self._load_state_dict(self.net_G, self.opt.netG_pretrain)
            self._load_state_dict(self.net_Di, self.opt.netDi_pretrain)
            self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain)

            self._load_state_dict(self.tar_net_Di, self.opt.tar_netDi_pretrain)
            self._load_state_dict(self.tar_net_G, self.opt.tar_netG_pretrain)
        else:
            assert ('unknown training stage')

        self.net_E = torch.nn.DataParallel(self.net_E).cuda()
        self.net_G = torch.nn.DataParallel(self.net_G).cuda()
        self.net_Di = torch.nn.DataParallel(self.net_Di).cuda()
        self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()

        self.tar_net_G = torch.nn.DataParallel(self.tar_net_G).cuda()
        self.tar_net_Di = torch.nn.DataParallel(self.tar_net_Di).cuda()
コード例 #9
0
class PDANetModel(object):
    def __init__(self, opt):
        self.opt = opt
        self.save_dir = os.path.join(opt.checkpoints, opt.name)
        self.norm_layer = get_norm_layer(norm_type=opt.norm)

        self._init_models()
        self._init_losses()
        self._init_cross_optimizers()

    def _init_models(self):
        #G For source
        self.net_G = CustomPoseGenerator(
            self.opt.pose_feature_size,
            2048,
            self.opt.noise_feature_size,
            dropout=self.opt.drop,
            norm_layer=self.norm_layer,
            fuse_mode=self.opt.fuse_mode,
            connect_layers=self.opt.connect_layers)
        #G For target
        self.tar_net_G = CustomPoseGenerator(
            self.opt.pose_feature_size,
            2048,
            self.opt.noise_feature_size,
            dropout=self.opt.drop,
            norm_layer=self.norm_layer,
            fuse_mode=self.opt.fuse_mode,
            connect_layers=self.opt.connect_layers)

        #We share same E for cross-dataset
        e_base_model = create(self.opt.arch, cut_at_pooling=True)
        e_embed_model = EltwiseSubEmbed(use_batch_norm=True,
                                        use_classifier=True,
                                        num_features=2048,
                                        num_classes=2)
        self.net_E = SiameseNet(e_base_model, e_embed_model)

        #Di For source
        di_base_model = create(self.opt.arch, cut_at_pooling=True)
        di_embed_model = EltwiseSubEmbed(use_batch_norm=True,
                                         use_classifier=True,
                                         num_features=2048,
                                         num_classes=1)
        self.net_Di = SiameseNet(di_base_model, di_embed_model)

        #Di For target
        di_base_model = create(self.opt.arch, cut_at_pooling=True)
        di_embed_model = EltwiseSubEmbed(use_batch_norm=True,
                                         use_classifier=True,
                                         num_features=2048,
                                         num_classes=1)
        self.tar_net_Di = SiameseNet(di_base_model, di_embed_model)

        #We share same Dp for cross-dataset
        self.net_Dp = NLayerDiscriminator(3 + 18, norm_layer=self.norm_layer)

        #Load model
        if self.opt.stage == 1:
            self._load_state_dict(self.net_E, self.opt.netE_pretrain)
            self._load_state_dict(self.net_G, self.opt.netG_pretrain)
            self._load_state_dict(self.net_Di, self.opt.netDi_pretrain)
            self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain)

            self._load_state_dict(self.tar_net_Di, self.opt.tar_netDi_pretrain)
            self._load_state_dict(self.tar_net_G, self.opt.tar_netG_pretrain)

        elif self.opt.stage == 2:
            self._load_state_dict(self.net_E, self.opt.netE_pretrain)
            self._load_state_dict(self.net_G, self.opt.netG_pretrain)
            self._load_state_dict(self.net_Di, self.opt.netDi_pretrain)
            self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain)

            self._load_state_dict(self.tar_net_Di, self.opt.tar_netDi_pretrain)
            self._load_state_dict(self.tar_net_G, self.opt.tar_netG_pretrain)
        else:
            assert ('unknown training stage')

        self.net_E = torch.nn.DataParallel(self.net_E).cuda()
        self.net_G = torch.nn.DataParallel(self.net_G).cuda()
        self.net_Di = torch.nn.DataParallel(self.net_Di).cuda()
        self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()

        self.tar_net_G = torch.nn.DataParallel(self.tar_net_G).cuda()
        self.tar_net_Di = torch.nn.DataParallel(self.tar_net_Di).cuda()

    def reset_model_status(self):
        if self.opt.stage == 1:
            self.net_G.train()
            self.tar_net_G.train()
            self.net_Dp.train()
            self.net_E.eval()
            self.net_Di.train()
            self.net_Di.apply(set_bn_fix)
            self.tar_net_Di.train()
            self.tar_net_Di.apply(set_bn_fix)

        elif self.opt.stage == 2:
            self.net_E.train()
            self.net_G.train()
            self.tar_net_G.train()
            self.net_Di.train()
            self.tar_net_Di.train()
            self.net_Dp.train()
            self.net_E.apply(set_bn_fix)
            self.net_Di.apply(set_bn_fix)
            self.tar_net_Di.apply(set_bn_fix)

    def _load_state_dict(self, net, path):
        state_dict = remove_module_key(torch.load(path))
        net.load_state_dict(state_dict)

    def _init_losses(self):
        self.criterion_Triplet = TripletLoss(margin=self.opt.tri_margin)
        self.criterion_MMD = MMDLoss(sigma_list=[1, 2, 10])
        #Smooth label is a mechanism
        if self.opt.smooth_label:
            self.criterionGAN_D = GANLoss(smooth=True).cuda()
            self.rand_list = [True] * 1 + [False] * 10000
        else:
            self.criterionGAN_D = GANLoss(smooth=False).cuda()
            self.rand_list = [False]
        self.criterionGAN_G = GANLoss(smooth=False).cuda()

    #For cross dataset optimizers
    def _init_cross_optimizers(self):

        if self.opt.stage == 1:
            self.optimizer_G = torch.optim.Adam(self.net_G.parameters(),
                                                lr=self.opt.lr * 0.1,
                                                betas=(0.5, 0.999))
            self.optimizer_tar_G = torch.optim.Adam(
                self.tar_net_G.parameters(),
                lr=self.opt.lr * 0.1,
                betas=(0.5, 0.999))
            self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(),
                                                lr=self.opt.lr * 0.01,
                                                momentum=0.9,
                                                weight_decay=1e-4)
            self.optimizer_tar_Di = torch.optim.SGD(
                self.tar_net_Di.parameters(),
                lr=self.opt.lr * 0.01,
                momentum=0.9,
                weight_decay=1e-4)
            self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(),
                                                lr=self.opt.lr,
                                                momentum=0.9,
                                                weight_decay=1e-4)
        elif self.opt.stage == 2:
            param_groups = [{
                'params': self.net_E.module.base_model.parameters(),
                'lr_mult': 0.1
            }, {
                'params':
                self.net_E.module.embed_model.parameters(),
                'lr_mult':
                1.0
            }, {
                'params': self.net_G.parameters(),
                'lr_mult': 0.1
            }]

            param_tar_groups = [{
                'params':
                self.net_E.module.base_model.parameters(),
                'lr_mult':
                0.1
            }, {
                'params':
                self.net_E.module.embed_model.parameters(),
                'lr_mult':
                1.0
            }, {
                'params': self.tar_net_G.parameters(),
                'lr_mult': 0.1
            }]

            self.optimizer_G = torch.optim.Adam(param_groups,
                                                lr=self.opt.lr * 0.1,
                                                betas=(0.5, 0.999))
            self.optimizer_tar_G = torch.optim.Adam(param_tar_groups,
                                                    lr=self.opt.lr * 0.1,
                                                    betas=(0.5, 0.999))
            self.optimizer_Di = torch.optim.SGD(self.net_Di.parameters(),
                                                lr=self.opt.lr,
                                                momentum=0.9,
                                                weight_decay=1e-4)
            self.optimizer_tar_Di = torch.optim.SGD(
                self.tar_net_Di.parameters(),
                lr=self.opt.lr,
                momentum=0.9,
                weight_decay=1e-4)
            self.optimizer_Dp = torch.optim.SGD(self.net_Dp.parameters(),
                                                lr=self.opt.lr,
                                                momentum=0.9,
                                                weight_decay=1e-4)
            self.optimizer_E = torch.optim.Adam(
                self.net_E.module.base_model.parameters(),
                lr=self.opt.lr * 0.1,
                betas=(0.5, 0.999))

        self.schedulers = []
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_Di)
        self.optimizers.append(self.optimizer_Dp)

        self.optimizers.append(self.optimizer_tar_G)
        self.optimizers.append(self.optimizer_tar_Di)

        for optimizer in self.optimizers:
            self.schedulers.append(get_scheduler(optimizer, self.opt))

    def set_input(self, input):
        input1, input2 = input
        labels = (input1['pid'] == input2['pid']).long()
        noise = torch.randn(labels.size(0), self.opt.noise_feature_size)

        # keep the same pose map for persons with the same identity
        mask = labels.view(-1, 1, 1, 1).expand_as(input1['posemap'])
        input2['posemap'] = input1['posemap'] * mask.float(
        ) + input2['posemap'] * (1 - mask.float())
        mask = labels.view(-1, 1, 1, 1).expand_as(input1['target'])
        input2['target'] = input1['target'] * mask.float(
        ) + input2['target'] * (1 - mask.float())

        origin = torch.cat([input1['origin'], input2['origin']])
        target = torch.cat([input1['target'], input2['target']])
        posemap = torch.cat([input1['posemap'], input2['posemap']])
        noise = torch.cat((noise, noise))

        self.origin = origin.cuda()
        self.target = target.cuda()
        self.posemap = posemap.cuda()
        self.labels = labels.cuda()
        self.noise = noise.cuda()

    #For cross-dataset
    def set_inputs(self, target_data, source_data):
        #=============source data===============================
        input1, input2 = source_data
        labels = (input1['pid'] == input2['pid']).long()
        noise = torch.randn(labels.size(0), self.opt.noise_feature_size)

        # keep the same pose map for persons with the same identity
        mask = labels.view(-1, 1, 1, 1).expand_as(input1['posemap'])
        input2['posemap'] = input1['posemap'] * mask.float(
        ) + input2['posemap'] * (1 - mask.float())
        mask = labels.view(-1, 1, 1, 1).expand_as(input1['target'])
        input2['target'] = input1['target'] * mask.float(
        ) + input2['target'] * (1 - mask.float())

        origin = torch.cat([input1['origin'], input2['origin']])
        target = torch.cat([input1['target'], input2['target']])
        posemap = torch.cat([input1['posemap'], input2['posemap']])
        noise = torch.cat((noise, noise))
        plabels = torch.cat((input1['pid'].long(),
                             input2['pid'].long()))  # Used for triplet loss

        self.s_origin = origin.cuda()
        self.s_target = target.cuda()
        self.s_posemap = posemap.cuda()
        self.s_labels = labels.cuda()
        self.s_noise = noise.cuda()
        self.s_plabels = plabels.cuda()  # Used for triplet loss

        #=============target data===============================
        input1, input2 = target_data
        noise = torch.randn(input1['origin'].size(0),
                            self.opt.noise_feature_size)

        origin = torch.cat([input1['origin'], input2['origin']])
        target = torch.cat([input1['target'], input2['target']])
        posemap = torch.cat([input1['posemap'], input2['posemap']])
        noise = torch.cat((noise, noise))

        self.t_origin = origin.cuda()
        self.t_target = target.cuda()
        self.t_posemap = posemap.cuda()
        self.t_noise = noise.cuda()

    #forward cross
    def forward_cross(self):
        #source
        A = Variable(self.s_origin)
        B_map = Variable(self.s_posemap)
        z = Variable(self.s_noise)
        bs = A.size(0)

        A_id1, A_id2, self.s_id_score = self.net_E(A[:bs // 2], A[bs // 2:])
        A_id = torch.cat((A_id1, A_id2))
        self.s_A_id = A_id
        self.s_fake = self.net_G(B_map,
                                 A_id.view(A_id.size(0), A_id.size(1), 1, 1),
                                 z.view(z.size(0), z.size(1), 1, 1))

        #source to target
        A_id1_st = A_id1[:bs // 2]
        z_st = z[:bs // 2]
        self.st_fake = self.tar_net_G(
            B_map[:bs // 2],
            A_id1_st.view(A_id1_st.size(0), A_id1_st.size(1), 1, 1),
            z_st.view(z_st.size(0), z_st.size(1), 1, 1))
        #         self.st_fake = self.tar_net_G(B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1), z.view(z.size(0), z.size(1), 1, 1))

        #target
        A = Variable(self.t_origin)
        B_map = Variable(self.t_posemap)
        z = Variable(self.t_noise)
        bs = A.size(0)

        A_id1, A_id2, self.t_id_score = self.net_E(A[:bs // 2], A[bs // 2:])
        A_id = torch.cat((A_id1, A_id2))
        self.t_A_id = A_id
        self.t_fake = self.tar_net_G(
            B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1),
            z.view(z.size(0), z.size(1), 1, 1))

    def backward_Dp(self):
        real_pose = torch.cat((Variable(self.posemap), Variable(self.target)),
                              dim=1)
        fake_pose = torch.cat((Variable(self.posemap), self.fake.detach()),
                              dim=1)
        pred_real = self.net_Dp(real_pose)
        pred_fake = self.net_Dp(fake_pose)

        if random.choice(self.rand_list):
            loss_D_real = self.criterionGAN_D(pred_fake, True)
            loss_D_fake = self.criterionGAN_D(pred_real, False)
        else:
            loss_D_real = self.criterionGAN_D(pred_real, True)
            loss_D_fake = self.criterionGAN_D(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        self.loss_Dp = loss_D.data[0]

    def backward_Di(self):
        _, _, pred_real = self.net_Di(Variable(self.origin),
                                      Variable(self.target))
        _, _, pred_fake = self.net_Di(Variable(self.origin),
                                      self.fake.detach())
        if random.choice(self.rand_list):
            loss_D_real = self.criterionGAN_D(pred_fake, True)
            loss_D_fake = self.criterionGAN_D(pred_real, False)
        else:
            loss_D_real = self.criterionGAN_D(pred_real, True)
            loss_D_fake = self.criterionGAN_D(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        self.loss_Di = loss_D.data[0]

    def backward_E_MMD(self):

        loss_mmd = self.criterion_MMD(self.s_A_id, self.t_A_id)
        self.loss_mmd = loss_mmd.item()
        loss_mmd = loss_mmd * self.opt.lambda_mmd
        loss_mmd.backward(retain_graph=True)

    def backward_s_G(self):

        #         self.s_plabels
        #         self.criterion_Triplet
        #         self.s_A_id
        #         self.criterion_Triplet = TripletLoss(margin=opt.tri_margin)
        #         self.criterion_MMD = MMDLoss(sigma_list=[1, 2, 10])

        loss_tri, prec = self.criterion_Triplet(self.s_A_id, self.s_plabels)

        loss_v = F.cross_entropy(self.s_id_score,
                                 Variable(self.s_labels).view(-1))
        loss_r = F.l1_loss(self.s_fake, Variable(self.s_target))
        fake_1 = self.s_fake[:self.s_fake.size(0) // 2]
        fake_2 = self.s_fake[self.s_fake.size(0) // 2:]

        _, _, pred_fake_Di = self.net_Di(Variable(self.s_origin), self.s_fake)
        pred_fake_Dp = self.net_Dp(
            torch.cat((Variable(self.s_posemap), self.s_fake), dim=1))
        loss_G_GAN_Di = self.criterionGAN_G(pred_fake_Di, True)
        loss_G_GAN_Dp = self.criterionGAN_G(pred_fake_Dp, True)

        loss_G = loss_G_GAN_Di + loss_G_GAN_Dp + \
                loss_r * self.opt.lambda_recon + \
                loss_v * self.opt.lambda_veri  + \
                loss_tri * self.opt.lambda_tri

        loss_G.backward()

        del self.s_id_score
        self.loss_s_G = loss_G.item()
        self.loss_s_v = loss_v.item()

        self.loss_s_r = loss_r.item()
        self.loss_s_G_GAN_Di = loss_G_GAN_Di.item()
        self.loss_s_G_GAN_Dp = loss_G_GAN_Dp.item()
        self.loss_s_tri = loss_tri.item()

    def backward_t_G(self):

        loss_r = F.l1_loss(self.t_fake, Variable(self.t_target))
        fake_1 = self.t_fake[:self.t_fake.size(0) // 2]
        fake_2 = self.t_fake[self.t_fake.size(0) // 2:]

        _, _, pred_fake_Di = self.tar_net_Di(Variable(self.t_origin),
                                             self.t_fake)
        pred_fake_Dp = self.net_Dp(
            torch.cat((Variable(self.t_posemap), self.t_fake), dim=1))
        loss_G_GAN_Di = self.criterionGAN_G(pred_fake_Di, True)
        loss_G_GAN_Dp = self.criterionGAN_G(pred_fake_Dp, True)

        loss_G = loss_G_GAN_Di + loss_G_GAN_Dp + \
                loss_r * self.opt.lambda_recon

        loss_G.backward()

        del self.t_id_score
        self.loss_t_G = loss_G.item()

        self.loss_t_r = loss_r.item()
        self.loss_t_G_GAN_Di = loss_G_GAN_Di.item()
        self.loss_t_G_GAN_Dp = loss_G_GAN_Dp.item()

    def backward_s_Di(self):
        _, _, pred_real = self.net_Di(Variable(self.s_origin),
                                      Variable(self.s_target))
        _, _, pred_fake = self.net_Di(Variable(self.s_origin),
                                      self.s_fake.detach())
        if random.choice(self.rand_list):
            loss_D_real = self.criterionGAN_D(pred_fake, True)
            loss_D_fake = self.criterionGAN_D(pred_real, False)
        else:
            loss_D_real = self.criterionGAN_D(pred_real, True)
            loss_D_fake = self.criterionGAN_D(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        #         self.loss_s_Di = loss_D.data[0]
        self.loss_s_Di = loss_D.item()

    def backward_t_Di(self):
        _, _, pred_real = self.tar_net_Di(Variable(self.t_origin),
                                          Variable(self.t_target))
        _, _, pred_fake = self.tar_net_Di(Variable(self.t_origin),
                                          self.t_fake.detach())
        if random.choice(self.rand_list):
            loss_D_real = self.criterionGAN_D(pred_fake, True)
            loss_D_fake = self.criterionGAN_D(pred_real, False)
        else:
            loss_D_real = self.criterionGAN_D(pred_real, True)
            loss_D_fake = self.criterionGAN_D(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        #         self.loss_t_Di = loss_D.data[0]
        self.loss_t_Di = loss_D.item()

    def backward_cross_Dp(self):
        real_pose = torch.cat(
            (Variable(self.s_posemap), Variable(self.s_target)), dim=1)
        fake_pose = torch.cat((Variable(self.s_posemap), self.s_fake.detach()),
                              dim=1)
        pred_real = self.net_Dp(real_pose)
        pred_fake = self.net_Dp(fake_pose)

        if random.choice(self.rand_list):
            loss_D_real = self.criterionGAN_D(pred_fake, True)
            loss_D_fake = self.criterionGAN_D(pred_real, False)
        else:
            loss_D_real = self.criterionGAN_D(pred_real, True)
            loss_D_fake = self.criterionGAN_D(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        #         self.loss_Dp = loss_D.data[0]
        self.loss_Dp = loss_D.item()

        real_pose = torch.cat(
            (Variable(self.t_posemap), Variable(self.t_target)), dim=1)
        fake_pose = torch.cat((Variable(self.t_posemap), self.t_fake.detach()),
                              dim=1)
        pred_real = self.net_Dp(real_pose)
        pred_fake = self.net_Dp(fake_pose)

        if random.choice(self.rand_list):
            loss_D_real = self.criterionGAN_D(pred_fake, True)
            loss_D_fake = self.criterionGAN_D(pred_real, False)
        else:
            loss_D_real = self.criterionGAN_D(pred_real, True)
            loss_D_fake = self.criterionGAN_D(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        #         self.loss_Dp += loss_D.data[0]
        self.loss_Dp += loss_D.item()

    def optimize_cross_parameters(self):
        self.forward_cross()

        self.optimizer_E.zero_grad()
        self.backward_E_MMD()
        self.optimizer_E.step()

        self.optimizer_Di.zero_grad()
        self.backward_s_Di()
        self.optimizer_Di.step()

        self.optimizer_Dp.zero_grad()
        self.backward_cross_Dp()
        self.optimizer_Dp.step()

        self.optimizer_G.zero_grad()
        self.backward_s_G()
        self.optimizer_G.step()

        self.optimizer_tar_Di.zero_grad()
        self.backward_t_Di()
        self.optimizer_tar_Di.step()

        self.optimizer_tar_G.zero_grad()
        self.backward_t_G()
        self.optimizer_tar_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_v', self.loss_v), ('G_r', self.loss_r),
                            ('G_sp', self.loss_sp),
                            ('G_gan_Di', self.loss_G_GAN_Di),
                            ('G_gan_Dp', self.loss_G_GAN_Dp),
                            ('D_i', self.loss_Di), ('D_p', self.loss_Dp)])

    def get_current_cross_errors(self):
        return OrderedDict([('G_tar_rec', self.loss_t_r),
                            ('G_tar_gan_Di', self.loss_t_G_GAN_Di),
                            ('G_tar_gan_Dp', self.loss_t_G_GAN_Dp),
                            ('tar_D_i', self.loss_t_Di), ('D_p', self.loss_Dp),
                            ('G_src_v', self.loss_s_v),
                            ('G_src_rec', self.loss_s_r),
                            ('G_src_gan_Di', self.loss_s_G_GAN_Di),
                            ('G_src_gan_Dp', self.loss_s_G_GAN_Dp),
                            ('src_D_i', self.loss_s_Di),
                            ('MMD', self.loss_mmd),
                            ('src_tri', self.loss_s_tri)])

    def get_current_visuals(self):
        input = util.tensor2im(self.origin)
        target = util.tensor2im(self.target)
        fake = util.tensor2im(self.fake)
        map = self.posemap.sum(1)
        map[map > 1] = 1
        map = util.tensor2im(torch.unsqueeze(map, 1))
        return OrderedDict([('input', input), ('posemap', map), ('fake', fake),
                            ('target', target)])

    def get_tf_visuals(self):
        input = util.tensor2ims(self.origin)
        target = util.tensor2ims(self.target)
        fake = util.tensor2ims(self.fake)
        map = self.posemap.sum(1)
        map[map > 1] = 1
        map = util.tensor2ims(torch.unsqueeze(map, 1))
        return OrderedDict([('input', input), ('posemap', map), ('fake', fake),
                            ('target', target)])

    #     self.t_origin = origin.cuda()
    #     self.t_target = target.cuda()
    #     self.t_posemap = posemap.cuda()
    #     self.t_labels = labels.cuda()
    #     self.t_noise = noise.cuda()
    def get_tf_cross_visuals(self):
        src_input = util.tensor2ims(self.s_origin)
        src_target = util.tensor2ims(self.s_target)
        src_fake = util.tensor2ims(self.s_fake.data)
        src_map = self.s_posemap.sum(1)
        src_map[src_map > 1] = 1
        src_map = util.tensor2ims(torch.unsqueeze(src_map, 1))

        tar_input = util.tensor2ims(self.t_origin)
        tar_target = util.tensor2ims(self.t_target)
        tar_fake = util.tensor2ims(self.t_fake.data)
        tar_map = self.t_posemap.sum(1)
        tar_map[tar_map > 1] = 1
        tar_map = util.tensor2ims(torch.unsqueeze(tar_map, 1))

        src2tgt_fake = util.tensor2ims(self.st_fake.data)

        return OrderedDict([('src_input', src_input), ('src_posemap', src_map),
                            ('src_fake', src_fake), ('src_target', src_target),
                            ('tar_input', tar_input), ('tar_posemap', tar_map),
                            ('tar_fake', tar_fake), ('tar_target', tar_target),
                            ('src2tgt_fake', src2tgt_fake)])

    def save(self, epoch):
        self.save_network(self.net_E, 'E', epoch)
        self.save_network(self.net_G, 'G', epoch)
        self.save_network(self.net_Di, 'Di', epoch)
        self.save_network(self.net_Dp, 'Dp', epoch)

        self.save_network(self.tar_net_G, 'tar_G', epoch)
        self.save_network(self.tar_net_Di, 'tar_Di', epoch)

    def save_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.state_dict(), save_path)

    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']

    def forward_test_cross(self):
        #source
        A = Variable(self.s_origin)
        B_map = Variable(self.s_posemap)
        z = Variable(self.s_noise)
        bs = A.size(0)

        A_id1, A_id2, self.s_id_score = self.net_E(A[:bs // 2], A[bs // 2:])
        A_id = torch.cat((A_id1, A_id2))
        self.s_A_id = A_id
        self.s_fake = self.net_G(B_map,
                                 A_id.view(A_id.size(0), A_id.size(1), 1, 1),
                                 z.view(z.size(0), z.size(1), 1, 1))

        #source to target
        #         A_id1_st = A_id1[:bs//]
        #         z_st = z[:bs//]
        #         self.st_fake = self.tar_net_G(B_map[:bs//], A_id1_st.view(A_id1_st.size(0), A_id1_st.size(1), 1, 1), z_st.view(z_st.size(0), z_st.size(1), 1, 1))
        self.st_fake = self.tar_net_G(
            B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1),
            z.view(z.size(0), z.size(1), 1, 1))

        #target
        A = Variable(self.t_origin)
        B_map = Variable(self.s_posemap)
        z = Variable(self.t_noise)
        bs = A.size(0)

        A_id1, A_id2, self.t_id_score = self.net_E(A[:bs // 2], A[bs // 2:])
        A_id = torch.cat((A_id1, A_id2))
        self.t_A_id = A_id
        self.t_fake = self.tar_net_G(
            B_map, A_id.view(A_id.size(0), A_id.size(1), 1, 1),
            z.view(z.size(0), z.size(1), 1, 1))

        self.ts_fake = self.net_G(B_map,
                                  A_id.view(A_id.size(0), A_id.size(1), 1, 1),
                                  z.view(z.size(0), z.size(1), 1, 1))

    def get_test_cross_visuals(self):
        src_input = util.tensor2ims(self.s_origin)
        src_target = util.tensor2ims(self.s_target)
        src_fake = util.tensor2ims(self.s_fake.data)
        src_map = self.s_posemap.sum(1)
        src_map[src_map > 1] = 1
        src_map = util.tensor2ims(torch.unsqueeze(src_map, 1))

        tar_input = util.tensor2ims(self.t_origin)
        tar_target = util.tensor2ims(self.t_target)
        tar_fake = util.tensor2ims(self.t_fake.data)
        tar_map = self.t_posemap.sum(1)
        tar_map[tar_map > 1] = 1
        tar_map = util.tensor2ims(torch.unsqueeze(tar_map, 1))

        src2tgt_fake = util.tensor2ims(self.st_fake.data)
        tgt2src_fake = util.tensor2ims(self.ts_fake.data)

        return OrderedDict([
            ('src_input', src_input),
            ('src_posemap', src_map),
            ('src_fake', src_fake),
            ('src_target', src_target),
            ('tar_input', tar_input),
            ('tar_fake', tar_fake),
            ('src2tgt_fake', src2tgt_fake),
            ('tgt2src_fake', tgt2src_fake),
        ])
コード例 #10
0
    def _init_models(self):
        self.net_G = CustomPoseGenerator(
            self.opt.pose_feature_size,
            2048,
            self.opt.noise_feature_size,
            dropout=self.opt.drop,
            norm_layer=self.norm_layer,
            fuse_mode=self.opt.fuse_mode,
            connect_layers=self.opt.connect_layers)
        e_base_model = create(self.opt.arch, cut_at_pooling=True)
        self.e_embed_model = Sub_model(num_features=256,
                                       in_features=2048,
                                       num_classes=751,
                                       FCN=True,
                                       dropout=0.5)

        self.net_E = ENet(e_base_model, self.e_embed_model)

        di_base_model = create(self.opt.arch, cut_at_pooling=True)
        di_embed_model = EltwiseSubEmbed(use_batch_norm=True,
                                         use_classifier=True,
                                         num_features=2048,
                                         num_classes=1)

        self.net_Di = SiameseNet(di_base_model, di_embed_model)
        self.net_Dp = NLayerDiscriminator(3 + 18, norm_layer=self.norm_layer)

        if self.opt.stage == 1:
            init_weights(self.net_G)
            init_weights(self.net_Dp)

            checkpoint = torch.load(self.opt.netE_pretrain)
            model_dict = self.net_E.state_dict()
            checkpoint_load = {
                k: v
                for k, v in (checkpoint['state_dict']).items()
                if k in model_dict
            }
            model_dict.update(checkpoint_load)
            self.net_E.load_state_dict(model_dict)  #state_dict

            model2_dict = self.net_Di.state_dict()
            checkpoint_load = {
                k: v
                for k, v in (checkpoint['state_dict']).items()
                if k in model2_dict
            }
            model2_dict.update(checkpoint_load)
            self.net_Di.load_state_dict(model2_dict)
        elif self.opt.stage == 2:
            self._load_state_dict(self.net_E, self.opt.netE_pretrain)
            self._load_state_dict(self.net_G, self.opt.netG_pretrain)
            self._load_state_dict(self.net_Di, self.opt.netDi_pretrain)
            self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain)
        else:
            assert ('unknown training stage')

        self.net_E = torch.nn.DataParallel(self.net_E).cuda()
        self.net_G = torch.nn.DataParallel(self.net_G).cuda()
        self.net_Di = torch.nn.DataParallel(self.net_Di).cuda()
        self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()