Exemple #1
0
def train():
    opt = TrainOptions().parse()
    if opt.distributed:
        init_dist()
        opt.batchSize = opt.batchSize // len(opt.gpu_ids)

    ### setup dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    pose = 'pose' in opt.dataset_mode

    ### setup trainer
    trainer = Trainer(opt, data_loader)

    ### setup models
    model, flowNet = create_model(opt, trainer.start_epoch)
    flow_gt = conf_gt = [None] * 2

    for epoch in range(trainer.start_epoch, opt.niter + opt.niter_decay + 1):
        trainer.start_of_epoch(epoch, model, data_loader)
        n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu
        for idx, data in enumerate(dataset, start=trainer.epoch_iter):
            trainer.start_of_iter()

            if not opt.no_flow_gt:
                data_list = [
                    data['tgt_label'], data['ref_label']
                ] if pose else [data['tgt_image'], data['ref_image']]
                flow_gt, conf_gt = flowNet(data_list, epoch)
            data_list = [
                data['tgt_label'], data['tgt_image'], flow_gt, conf_gt
            ]
            data_ref_list = [data['ref_label'], data['ref_image']]
            data_prev = [None, None]

            ############## Forward Pass ######################
            for t in range(0, n_frames_total, n_frames_load):
                data_list_t = get_data_t(data_list, n_frames_load,
                                         t) + data_ref_list + data_prev

                g_losses, generated, data_prev = model(
                    data_list_t, save_images=trainer.save, mode='generator')
                g_losses = loss_backward(opt, g_losses,
                                         model.module.optimizer_G)

                d_losses = model(data_list_t, mode='discriminator')
                d_losses = loss_backward(opt, d_losses,
                                         model.module.optimizer_D)

            loss_dict = dict(
                zip(model.module.lossCollector.loss_names,
                    g_losses + d_losses))

            if trainer.end_of_iter(loss_dict,
                                   generated + data_list + data_ref_list,
                                   model):
                break
        trainer.end_of_epoch(model)
def train():
    opt = TrainOptions().parse()    
    if opt.distributed:
        init_dist()
        opt.batchSize = opt.batchSize // len(opt.gpu_ids)    

    ### setup dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()

    ### setup trainer    
    trainer = Trainer(opt, data_loader) 

    ### setup models
    model, flowNet = create_model(opt, trainer.start_epoch)
    flow_gt = conf_gt = [None] * 3      
    
    ref_idx_fix = torch.zeros([opt.batchSize])
    for epoch in tqdm(range(trainer.start_epoch, opt.niter + opt.niter_decay + 1)):
        trainer.start_of_epoch(epoch, model, data_loader)
        n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu
        for idx, data in enumerate(tqdm(dataset), start=trainer.epoch_iter):
            trainer.start_of_iter()            

            if not opt.warp_ani:
                data.update({'ani_image':None, 'ani_lmark':None, 'cropped_images':None, 'cropped_lmarks':None })

            if not opt.no_flow_gt: 
                data_list = [data['tgt_mask_images'], data['cropped_images'], data['warping_ref'], data['ani_image']]
                flow_gt, conf_gt = flowNet(data_list, epoch)
            data_list = [data['tgt_label'], data['tgt_image'], data['tgt_template'], data['cropped_images'], flow_gt, conf_gt]
            data_ref_list = [data['ref_label'], data['ref_image']]
            data_prev = [None, None, None]
            data_ani = [data['warping_ref_lmark'], data['warping_ref'], data['ori_warping_refs'], data['ani_lmark'], data['ani_image']]

            ############## Forward Pass ######################
            prevs = {"raw_images":[], "synthesized_images":[], \
                    "prev_warp_images":[], "prev_weights":[], \
                    "ani_warp_images":[], "ani_weights":[], \
                    "ref_warp_images":[], "ref_weights":[], \
                    "ref_flows":[], "prev_flows":[], "ani_flows":[], \
                    "ani_syn":[]}
            for t in range(0, n_frames_total, n_frames_load):
                
                data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + \
                              get_data_t(data_ani, n_frames_load, t) + data_prev

                g_losses, generated, data_prev, ref_idx = model(data_list_t, save_images=trainer.save, mode='generator', ref_idx_fix=ref_idx_fix)
                g_losses = loss_backward(opt, g_losses, model.module.optimizer_G)

                d_losses, _ = model(data_list_t, mode='discriminator', ref_idx_fix=ref_idx_fix)
                d_losses = loss_backward(opt, d_losses, model.module.optimizer_D)

                # store previous
                store_prev(generated, prevs)
                        
            loss_dict = dict(zip(model.module.lossCollector.loss_names, g_losses + d_losses))     

            output_data_list = [prevs] + [data['ref_image']] + data_ani + data_list + [data['tgt_mask_images']]

            if trainer.end_of_iter(loss_dict, output_data_list, model):
                break        

            # pdb.set_trace()

        trainer.end_of_epoch(model)