Ejemplo n.º 1
0
    def finetune(self, ref_labels, ref_images):
        train_names = ['fc', 'conv_img', 'up']        
        params, _ = self.get_train_params(self.netG, train_names)         
        self.optimizer_G = self.get_optimizer(params, for_discriminator=False)        
        
        update_D = True
        if update_D:
            params = list(self.netD.parameters())
            if self.add_face_D: params += list(self.netDf.parameters())            
            self.optimizer_D = self.get_optimizer(params, for_discriminator=True)

        iterations = 100
        for it in range(1, iterations + 1):            
            idx = np.random.randint(ref_labels.size(1))
            tgt_label, tgt_image = random_roll([ref_labels[:,idx], ref_images[:,idx]])
            tgt_label, tgt_image = tgt_label.unsqueeze(1), tgt_image.unsqueeze(1)            

            g_losses, generated, prev = self.forward_generator(tgt_label, tgt_image, ref_labels, ref_images)
            g_losses = loss_backward(self.opt, g_losses, self.optimizer_G)

            d_losses = []
            if update_D:
                d_losses = self.forward_discriminator(tgt_label, tgt_image, ref_labels, ref_images)
                d_losses = loss_backward(self.opt, d_losses, self.optimizer_D)

            if (it % 10) == 0: 
                message = '(iters: %d) ' % it
                loss_dict = dict(zip(self.lossCollector.loss_names, g_losses + d_losses))
                for k, v in loss_dict.items():
                    if v != 0: message += '%s: %.3f ' % (k, v)
                print(message)
Ejemplo n.º 2
0
def train():
    opt = TrainOptions().parse()
    if opt.distributed:
        init_dist()
        opt.batchSize = opt.batchSize // len(opt.gpu_ids)

    ### setup dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    pose = 'pose' in opt.dataset_mode

    ### setup trainer
    trainer = Trainer(opt, data_loader)

    ### setup models
    model, flowNet = create_model(opt, trainer.start_epoch)
    flow_gt = conf_gt = [None] * 2

    for epoch in range(trainer.start_epoch, opt.niter + opt.niter_decay + 1):
        trainer.start_of_epoch(epoch, model, data_loader)
        n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu
        for idx, data in enumerate(dataset, start=trainer.epoch_iter):
            trainer.start_of_iter()

            if not opt.no_flow_gt:
                data_list = [
                    data['tgt_label'], data['ref_label']
                ] if pose else [data['tgt_image'], data['ref_image']]
                flow_gt, conf_gt = flowNet(data_list, epoch)
            data_list = [
                data['tgt_label'], data['tgt_image'], flow_gt, conf_gt
            ]
            data_ref_list = [data['ref_label'], data['ref_image']]
            data_prev = [None, None]

            ############## Forward Pass ######################
            for t in range(0, n_frames_total, n_frames_load):
                data_list_t = get_data_t(data_list, n_frames_load,
                                         t) + data_ref_list + data_prev

                g_losses, generated, data_prev = model(
                    data_list_t, save_images=trainer.save, mode='generator')
                g_losses = loss_backward(opt, g_losses,
                                         model.module.optimizer_G)

                d_losses = model(data_list_t, mode='discriminator')
                d_losses = loss_backward(opt, d_losses,
                                         model.module.optimizer_D)

            loss_dict = dict(
                zip(model.module.lossCollector.loss_names,
                    g_losses + d_losses))

            if trainer.end_of_iter(loss_dict,
                                   generated + data_list + data_ref_list,
                                   model):
                break
        trainer.end_of_epoch(model)
Ejemplo n.º 3
0
    def finetune(self, ref_labels, ref_images, warp_ref_lmark, warp_ref_img,
                 ani_img, ani_lmark):
        train_names = ['fc', 'conv_img', 'up']
        params, _ = self.get_train_params(self.netG, train_names)
        self.optimizer_G = self.get_optimizer(params, for_discriminator=False)

        update_D = True
        if update_D:
            params = list(self.netD.parameters())
            self.optimizer_D = self.get_optimizer(params,
                                                  for_discriminator=True)

        iterations = 70
        for it in range(1, iterations + 1):
            idx = np.random.randint(ref_labels.size(1))
            tgt_label, tgt_image = ref_labels[:, idx].unsqueeze(
                1), ref_images[:, idx].unsqueeze(1)

            g_losses, generated, prev, _ = self.forward_generator(tgt_label=tgt_label, tgt_image=tgt_image, \
                                                                  tgt_template=1, tgt_crop_image=None, \
                                                                  flow_gt=[None]*3, conf_gt=[None]*3, \
                                                                  ref_labels=ref_labels, ref_images=ref_images, \
                                                                  warp_ref_lmark=warp_ref_lmark.unsqueeze(1), warp_ref_img=warp_ref_img.unsqueeze(1))

            g_losses = loss_backward(self.opt, g_losses, self.optimizer_G)

            d_losses = []
            if update_D:
                d_losses, _ = self.forward_discriminator(
                    tgt_label,
                    tgt_image,
                    ref_labels,
                    ref_images,
                    warp_ref_lmark=warp_ref_lmark.unsqueeze(1),
                    warp_ref_img=warp_ref_img.unsqueeze(1))
                d_losses = loss_backward(self.opt, d_losses, self.optimizer_D)

            if (it % 10) == 0:
                message = '(iters: %d) ' % it
                loss_dict = dict(
                    zip(self.lossCollector.loss_names, g_losses + d_losses))
                for k, v in loss_dict.items():
                    if v != 0: message += '%s: %.3f ' % (k, v)
                print(message)
Ejemplo n.º 4
0
    def forward(self, opt, data, epoch, n_frames_total, n_frames_load,
                save_images):
        # ground truth flownet
        if not opt.no_flow_gt:
            data_list = [
                data['tgt_image'], data['cropped_images'], data['warping_ref'],
                data['ani_image']
            ]
            flow_gt, conf_gt = self.flowNet(data_list, epoch)
        else:
            flow_gt, conf_gt = None, None
        data_list = [
            data['tgt_label'], data['tgt_image'], data['cropped_images'],
            flow_gt, conf_gt
        ]
        data_ref_list = [data['ref_label'], data['ref_image']]
        data_prev = [None, None, None]
        data_ani = [
            data['warping_ref_lmark'], data['warping_ref'], data['ani_lmark'],
            data['ani_image']
        ]

        ############## Forward Pass ######################
        for t in range(0, n_frames_total, n_frames_load):

            data_list_t = self.get_data_t(data_list, n_frames_load, t) + data_ref_list + \
                            self.get_data_t(data_ani, n_frames_load, t) + data_prev

            g_losses, generated, data_prev, ref_idx = self.v2vModel(
                data_list_t,
                save_images=save_images,
                mode='generator',
                ref_idx_fix=None)
            g_losses = loss_backward(opt, g_losses,
                                     self.v2vModel.module.optimizer_G)

            d_losses, _ = self.v2vModel(data_list_t,
                                        mode='discriminator',
                                        ref_idx_fix=None)
            d_losses = loss_backward(opt, d_losses,
                                     self.v2vModel.module.optimizer_D)

        return generated, flow_gt, conf_gt
Ejemplo n.º 5
0
    def finetune_call_multi(self,
                            tgt_label_list,
                            tgt_image_list,
                            ref_label_list,
                            ref_image_list,
                            warp_ref_lmark_list,
                            warp_ref_img_list,
                            ani_lmark_list=None,
                            ani_img_list=None,
                            iterations=0):
        train_names = ['fc', 'conv_img', 'up']
        params, _ = self.get_train_params(self.netG, train_names)
        self.optimizer_G = self.get_optimizer(params, for_discriminator=False)

        update_D = True
        if update_D:
            params = list(self.netD.parameters())
            self.optimizer_D = self.get_optimizer(params,
                                                  for_discriminator=True)

        for iteration in tqdm(range(iterations)):
            for data_id in tqdm(range(len(tgt_label_list))):
                tgt_labels, tgt_images, ref_labels, ref_images, warp_ref_lmark, warp_ref_img, ani_lmark, ani_img = \
                    encode_input_finetune(self.opt, data_list=[tgt_label_list[data_id], tgt_image_list[data_id], \
                                                               ref_label_list[data_id], ref_image_list[data_id], \
                                                               warp_ref_lmark_list[data_id], warp_ref_img_list[data_id], \
                                                               ani_lmark_list[data_id] if ani_lmark_list is not None else None, \
                                                               ani_img_list[data_id] if ani_img_list is not None else None],
                                                               dummy_bs=0)

                idx = np.random.randint(tgt_labels.size(1))
                tgt_label, tgt_image = tgt_labels[:, idx].unsqueeze(
                    1), tgt_images[:, idx].unsqueeze(1)

                ref_labels_finetune, ref_images_finetune = ref_labels, ref_images
                warp_ref_lmark_finetune, warp_ref_img_finetune = warp_ref_lmark, warp_ref_img
                ani_lmark_finetune, ani_img_finetune = ani_lmark, ani_img

                if ani_img is not None:
                    assert ani_img.shape[1] == ani_lmark.shape[
                        1] == tgt_labels.shape[1]
                    ani_img_finetune = ani_img[:, idx].unsqueeze(1)
                    ani_lmark_finetune = ani_lmark[:, idx].unsqueeze(1)
                if self.opt.n_shot < ref_labels.shape[1]:
                    idxs = np.random.choice(ref_labels.shape[1],
                                            self.opt.n_shot)
                    ref_labels_finetune = ref_labels[:, idxs]
                    ref_images_finetune = ref_images[:, idxs]
                if warp_ref_lmark.shape[1] > 1:
                    if self.opt.n_shot >= ref_labels.shape[1]:
                        idxs = np.random.choice(ref_labels.shape[1],
                                                self.opt.n_shot)
                    warp_ref_lmark_finetune = warp_ref_lmark[:,
                                                             idxs[0]].unsqueeze(
                                                                 1)
                    warp_ref_img_finetune = warp_ref_img[:,
                                                         idxs[0]].unsqueeze(1)

                g_losses, generated, prev, _ = self.forward_generator(tgt_label=tgt_label, tgt_image=tgt_image, \
                                                                        tgt_template=1, tgt_crop_image=None, \
                                                                        flow_gt=[None]*3, conf_gt=[None]*3, \
                                                                        ref_labels=ref_labels_finetune, ref_images=ref_images_finetune, \
                                                                        warp_ref_lmark=warp_ref_lmark_finetune, warp_ref_img=warp_ref_img_finetune, \
                                                                        ani_lmark=ani_lmark_finetune, ani_img=ani_img_finetune)

                g_losses = loss_backward(self.opt, g_losses, self.optimizer_G)

                d_losses = []
                if update_D:
                    d_losses, _ = self.forward_discriminator(
                        tgt_label,
                        tgt_image,
                        ref_labels_finetune,
                        ref_images_finetune,
                        warp_ref_lmark=warp_ref_lmark_finetune,
                        warp_ref_img=warp_ref_img_finetune)
                    d_losses = loss_backward(self.opt, d_losses,
                                             self.optimizer_D)

            if (iteration % 10) == 0:
                message = '(iters: %d) ' % iteration
                loss_dict = dict(
                    zip(self.lossCollector.loss_names, g_losses + d_losses))
                for k, v in loss_dict.items():
                    if v != 0: message += '%s: %.3f ' % (k, v)
                    print(message)

        self.opt.finetune = False
def train():
    opt = TrainOptions().parse()    
    if opt.distributed:
        init_dist()
        opt.batchSize = opt.batchSize // len(opt.gpu_ids)    

    ### setup dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()

    ### setup trainer    
    trainer = Trainer(opt, data_loader) 

    ### setup models
    model, flowNet = create_model(opt, trainer.start_epoch)
    flow_gt = conf_gt = [None] * 3      
    
    ref_idx_fix = torch.zeros([opt.batchSize])
    for epoch in tqdm(range(trainer.start_epoch, opt.niter + opt.niter_decay + 1)):
        trainer.start_of_epoch(epoch, model, data_loader)
        n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu
        for idx, data in enumerate(tqdm(dataset), start=trainer.epoch_iter):
            trainer.start_of_iter()            

            if not opt.warp_ani:
                data.update({'ani_image':None, 'ani_lmark':None, 'cropped_images':None, 'cropped_lmarks':None })

            if not opt.no_flow_gt: 
                data_list = [data['tgt_mask_images'], data['cropped_images'], data['warping_ref'], data['ani_image']]
                flow_gt, conf_gt = flowNet(data_list, epoch)
            data_list = [data['tgt_label'], data['tgt_image'], data['tgt_template'], data['cropped_images'], flow_gt, conf_gt]
            data_ref_list = [data['ref_label'], data['ref_image']]
            data_prev = [None, None, None]
            data_ani = [data['warping_ref_lmark'], data['warping_ref'], data['ori_warping_refs'], data['ani_lmark'], data['ani_image']]

            ############## Forward Pass ######################
            prevs = {"raw_images":[], "synthesized_images":[], \
                    "prev_warp_images":[], "prev_weights":[], \
                    "ani_warp_images":[], "ani_weights":[], \
                    "ref_warp_images":[], "ref_weights":[], \
                    "ref_flows":[], "prev_flows":[], "ani_flows":[], \
                    "ani_syn":[]}
            for t in range(0, n_frames_total, n_frames_load):
                
                data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + \
                              get_data_t(data_ani, n_frames_load, t) + data_prev

                g_losses, generated, data_prev, ref_idx = model(data_list_t, save_images=trainer.save, mode='generator', ref_idx_fix=ref_idx_fix)
                g_losses = loss_backward(opt, g_losses, model.module.optimizer_G)

                d_losses, _ = model(data_list_t, mode='discriminator', ref_idx_fix=ref_idx_fix)
                d_losses = loss_backward(opt, d_losses, model.module.optimizer_D)

                # store previous
                store_prev(generated, prevs)
                        
            loss_dict = dict(zip(model.module.lossCollector.loss_names, g_losses + d_losses))     

            output_data_list = [prevs] + [data['ref_image']] + data_ani + data_list + [data['tgt_mask_images']]

            if trainer.end_of_iter(loss_dict, output_data_list, model):
                break        

            # pdb.set_trace()

        trainer.end_of_epoch(model)