예제 #1
0
def test():
    # -----------------------------------  Model Build -------------------------
    from tqdm import tqdm
    model = DeformNet()
    args = train_configs.args
    model = torch.nn.DataParallel(model.cuda()) 
    start_epoch = 1
    if True:
        print("Loading Pretrained model~") 
        # "/home1/quanquan/code/Film-Recovery/output/train/std120201220-014112-xA836H/cmap_aug_500.pkl"
        # "/home1/quanquan/code/Film-Recovery/output/train/extrabg20201223-025124-sJKxHA/model/extrabg_190.pkl"
        # "/home1/quanquan/code/Film-Recovery/output/train/extrabg_ab20201224-230459-AvKPR7/model/extrabg_ab_455.pkl"
        # "/home1/quanquan/code/Film-Recovery/output/train/iter20201229-221658-CxBn85/model/iter_175.pkl"
        pretrained_dict = torch.load("/home1/quanquan/code/Film-Recovery/output/train/iter-7-0.0120201230-032707-wyKn9z/model/iter-7-0.01_390.pkl", map_location=None)
        model.load_state_dict(pretrained_dict['model_state'])
    # ------------------------------------  Load Dataset  -------------------------
    kwargs = {'num_workers': 8, 'pin_memory': True} 
    # "/home1/quanquan/datasets/real_films/pad_img"
    # "/home1/quanquan/datasets/generate/mesh_film_hypo_alpha2/img"
    # "imgshow_test2"
    dataset_test = RealDataset("imgshow_test2", load_mod="new_ab", reg_start="pad_gaus_40", reg_end="jpg") # reg_str="pad_gaus_40"
    dataset_loader = DataLoader(dataset_test, batch_size=1, shuffle=False, **kwargs)
    model.eval()
    
    p(output_dir_test)
    for batch_idx, data in tqdm(enumerate(dataset_loader)):
        ori_gt = data[0].cuda()
        ori_gt_large = data[1].cuda()
        
        deform, _ = model(ori_gt)
        
        bg_template, pad_bg = construct_plain_bg(ori_gt.size(0),img_size=256)
        cmap_template, pad_cmap = construct_plain_cmap(ori_gt.size(0), img_size=256)
        dewarp_bg_t   = iter_mapping(bg_template  , deform)
        dewarp_cmap_t = iter_mapping(cmap_template, deform)
        
        print_img_with_reprocess(dewarp_bg_t[0,:,:,:]  , "bg"  , fname=tfilename(output_dir_test,"imgshow/epoch_{}".format(batch_idx), "bg.jpg"))
        print_img_with_reprocess(dewarp_cmap_t[0,:,:,:], "cmap", fname=tfilename(output_dir_test,"imgshow/epoch_{}".format(batch_idx), "cmap.jpg"))
        print_img_with_reprocess(ori_gt[0,:,:,:]       , "ori" ,  fname=tfilename(output_dir_test,"imgshow/epoch_{}".format(batch_idx), "ori_gt.jpg")) 
        
        if batch_idx >25:
            break
예제 #2
0
def write_imgs_2(img_tuple,
                 epoch,
                 type_tuple=None,
                 name_tuple=None,
                 training=True):
    print("Writing Images to ", output_dir)
    if training:
        cmap, uv, ab, bg, nor, dep, bg2, ori_gt,\
            cmap_gt, uv_gt, ab_gt, bg_gt, nor_gt, dep_gt = img_tuple
    else:
        cmap, uv, ab, bg, nor, dep, bg2, ori_gt = img_tuple

    print_img_with_reprocess(uv,
                             "uv",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "uv.jpg"))
    print_img_with_reprocess(ab,
                             "ab",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "ab.jpg"))
    print_img_with_reprocess(bg,
                             "bg",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "bg.jpg"))
    print_img_with_reprocess(bg2,
                             "bg",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "bg2.jpg"))
    #reprocess_np_auto(cmap, "")
    print_img_with_reprocess(cmap,
                             "exr",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "cmap.jpg"))  #
    print_img_with_reprocess(nor,
                             "exr",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "nor.jpg"))  #
    print_img_with_reprocess(dep,
                             "exr",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "dep.jpg"))  #
    print_img_with_reprocess(ori_gt,
                             "ori",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "ori_gt.jpg"))

    if training:
        print_img_with_reprocess(
            uv_gt,
            "uv",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "uv_gt.jpg"))
        print_img_with_reprocess(
            ab_gt,
            "ab",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "ab_gt.jpg"))
        print_img_with_reprocess(
            bg_gt,
            "bg",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "bg_gt.jpg"))

        print_img_with_reprocess(
            cmap_gt,
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "cmap_gt.jpg"))  #
        print_img_with_reprocess(
            nor_gt,
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "nor_gt.jpg"))  #
        print_img_with_reprocess(
            dep_gt,
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "dep_gt.jpg"))  #
        print_img_with_reprocess(
            gt_clip(cmap_gt),
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "cmap_gt2.jpg"))  #
        print_img_with_reprocess(
            gt_clip(nor_gt),
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "nor_gt2.jpg"))  #
        print_img_with_reprocess(
            gt_clip(dep_gt),
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "dep_gt2.jpg"))  #

    uv = reprocess_auto(uv, "uv")
    bg2 = reprocess_auto(bg2, "bg")
    ori_gt = reprocess_auto(ori_gt, "ori")
    bw = uv2backward_trans_3(uv, bg2)
    dewarp = bw_mapping_single_3(ori_gt, bw)

    if training:
        uv_gt = reprocess_auto(uv_gt, "uv")
        bg_gt = reprocess_auto(bg_gt, "bg")
        bw_gt = uv2backward_trans_3(uv_gt, bg_gt)
        bw2 = uv2backward_trans_3(uv, bg_gt)
        dewarp_gt = bw_mapping_single_3(ori_gt, bw_gt)
        dewarp2 = bw_mapping_single_3(ori_gt, bw2)

    print_img_auto(bw,
                   "bw",
                   fname=tfilename(output_dir,
                                   "imgshow/epoch_{}".format(epoch), "bw.jpg"))
    print_img_auto(dewarp,
                   "ori",
                   fname=tfilename(output_dir,
                                   "imgshow/epoch_{}".format(epoch),
                                   "dewarp.jpg"))

    if training:
        print_img_auto(bw_gt,
                       "bw",
                       fname=tfilename(output_dir,
                                       "imgshow/epoch_{}".format(epoch),
                                       "bw_gt.jpg"))
        print_img_auto(bw2,
                       "bw",
                       fname=tfilename(output_dir,
                                       "imgshow/epoch_{}".format(epoch),
                                       "bw2.jpg"))
        print_img_auto(dewarp_gt,
                       "ori",
                       fname=tfilename(output_dir,
                                       "imgshow/epoch_{}".format(epoch),
                                       "dewarp_gt.jpg"))
        print_img_auto(dewarp2,
                       "ori",
                       fname=tfilename(output_dir,
                                       "imgshow/epoch_{}".format(epoch),
                                       "dewarp2.jpg"))
예제 #3
0
def train(args,
          model,
          device,
          train_loader,
          optimizer,
          criterion,
          epoch,
          writer,
          output_dir,
          isWriteImage,
          isVal=False,
          test_loader=None):
    model.train()
    correct = 0
    for batch_idx, data in enumerate(train_loader):
        threeD_map_gt = data[0]
        uv_map_gt = data[1]
        bw_map_gt = data[2]
        mask_map_gt = data[3]
        uv_map_gt, threeD_map_gt, bw_map_gt, mask_map_gt = uv_map_gt.to(
            device), threeD_map_gt.to(device), bw_map_gt.to(
                device), mask_map_gt.to(device)

        optimizer.zero_grad()
        uv_map, bw_map = model(threeD_map_gt)
        # TODO: 这里需不需要改成这样
        uv_map = torch.where(mask_map_gt > 0, uv_map, mask_map_gt)
        loss_uv = criterion(uv_map, uv_map_gt).float()
        loss_bw = criterion(bw_map, bw_map_gt).float()
        loss_uv.backward()
        loss_bw.backward()
        optimizer.step()
        lr = get_lr(optimizer)
        if batch_idx % args.log_intervals == 0:
            print(
                '\r Epoch:{}  batch index:{}/{}||lr:{:.8f}||loss_uv:{:.6f}||loss_bw:{:.6f}'
                .format(epoch, batch_idx + 1,
                        len(train_loader.dataset) // args.batch_size, lr,
                        loss_uv.item(), loss_bw.item()),
                end=" ")
            if args.write_summary:
                writer.add_scalar('summary/train_uv_loss',
                                  loss_uv.item(),
                                  global_step=epoch * len(train_loader) +
                                  batch_idx + 1)
                writer.add_scalar('summary/backward_loss',
                                  loss_bw.item(),
                                  global_step=epoch * len(train_loader) +
                                  batch_idx + 1)
                writer.add_scalar('summary/lrate',
                                  lr,
                                  global_step=epoch * len(train_loader) +
                                  batch_idx + 1)
        if isWriteImage:
            if batch_idx == (len(train_loader.dataset) // args.batch_size) - 1:
                print('writing image')
                # if not os.path.exists(output_dir + '/train/epoch_{}'.format(epoch)):
                #     os.makedirs(output_dir + '/train/epoch_{}'.format(epoch))
                for k in range(2):
                    uv_pred = uv_map[k, :, :, :]
                    uv_gt = uv_map_gt[k, :, :, :]
                    mask_gt = mask_map_gt[k, :, :, :]
                    # bw_gt = metrics.uv2bmap(uv_gt, mask_gt)
                    # bw_from_uv = uv2bmap_in_tensor(uv_pred, mask_gt)  # 不需要reprocess
                    bw_gt = bw_map_gt[k, :, :, :]
                    bw_pred = bw_map[k, :, :, :]
                    cmap_gt = threeD_map_gt[k, :, :, :]

                    output_dir1 = tdir(
                        output_dir + '/uvbw_train/',
                        'epoch_{}_batch_{}/'.format(epoch, batch_idx))
                    """pred"""
                    print_img_with_reprocess(
                        uv_pred,
                        img_type="uv",
                        fname=tfilename(
                            output_dir1 +
                            'train/epoch_{}/pred_uv_ind_{}'.format(epoch, k) +
                            '.jpg'))
                    # print_img_with_reprocess(bw_from_uv, img_type="bw", fname=tfilename(output_dir1 + 'train/epoch_{}/bw_f_uv_ind_{}'.format(epoch, k) + '.jpg'))
                    print_img_with_reprocess(
                        bw_pred,
                        img_type="bw",
                        fname=tfilename(
                            output_dir1 +
                            'train/epoch_{}/pred_bw_ind_{}'.format(epoch, k) +
                            '.jpg'))
                    """gt"""
                    print_img_with_reprocess(
                        cmap_gt,
                        img_type="cmap",
                        fname=tfilename(
                            output_dir1 +
                            'train/epoch_{}/gt_3D_ind_{}'.format(epoch, k) +
                            '.jpg'))  # Problem
                    print_img_with_reprocess(
                        uv_gt,
                        img_type="uv",
                        fname=tfilename(
                            output_dir1 +
                            'train/epoch_{}/gt_uv_ind_{}'.format(epoch, k) +
                            '.jpg'))
                    print_img_with_reprocess(
                        bw_gt,
                        img_type="bw",
                        fname=tfilename(
                            output_dir1 +
                            'train/epoch_{}/gt_bw_ind_{}'.format(epoch, k) +
                            '.jpg'))  # Problem
                    print_img_with_reprocess(
                        mask_gt,
                        img_type="background",
                        fname=tfilename(
                            output_dir1 +
                            'train/epoch_{}/gt_back_ind_{}'.format(epoch, k) +
                            '.jpg'))

            # if isVal and (batch_idx + 1) % 500 == 0:
            # sstep = test.count + 1
            # test2(args, model, device, test_loader, criterion, epoch, writer, output_dir, args.write_image_val, sstep)

    return lr
예제 #4
0
def write_imgs_2(img_tuple, epoch, type_tuple=None, name_tuple=None):
    cmap, uv, ab, \
        ori_gt, cmap_gt, uv_gt, ab_gt = img_tuple

    #print_img_auto(ori,  "ori",  fname=tfilename(output_dir,"imgshow", "ori.jpg"))
    print_img_with_reprocess(cmap,
                             "cmap",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "cmap.jpg"))
    print_img_with_reprocess(uv,
                             "uv",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "uv.jpg"))
    # print_img_auto(bg ,  "bg"  ,  fname=tfilename(output_dir,"imgshow", "bg.jpg"))
    print_img_with_reprocess(ab,
                             "ab",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "ab.jpg"))

    print_img_with_reprocess(ori_gt,
                             "ori",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "ori_gt.jpg"))
    print_img_with_reprocess(cmap_gt,
                             "cmap",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "cmap_gt.jpg"))
    print_img_with_reprocess(uv_gt,
                             "uv",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "uv_gt.jpg"))
    # print_img_auto(bg_gt ,  "bg"  ,  fname=tfilename(output_dir,"imgshow", "bg.jpg"))
    print_img_with_reprocess(ab_gt,
                             "ab",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "ab_gt.jpg"))
예제 #5
0
def main():
    # Model Build
    model = models3.UnwarpNet()
    model2 = models3.Conf_Discriminator()

    args = train_configs.args
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda:
        print(" [*] Set cuda: True")
        model = torch.nn.DataParallel(model.cuda())
        model2 = torch.nn.DataParallel(model2.cuda())
    else:
        print(" [*] Set cuda: False")
    # Load Dataset
    kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}
    dataset_test = filmDataset_old(npy_dir=args.test_path,
                                   load_mod='test_uvbw_mapping')
    dataset_test_loader = DataLoader(dataset_test,
                                     batch_size=1,
                                     shuffle=True,
                                     **kwargs)
    if UVBW_TRAIN:
        dataset_train = filmDataset_old(npy_dir=args.train_path,
                                        load_mod='uvbw')
        dataset_train_loader = DataLoader(dataset_train,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          **kwargs)
    start_epoch = 1
    learning_rate = args.lr
    # Load Parameters
    #if args.pretrained:
    if DEFORM_TEST:
        # pre_model = "/home1/quanquan/film_code/test_output2/20201018-070501z0alvFmodels/uvbw/tv_constrain_35.pkl"
        pre_model = "/home1/quanquan/film_code/test_output2/20201021-094607K3qzNUmodels/uvbw/tv_constrain_69.pkl"
        pretrained_dict = torch.load(pre_model, map_location=None)
        model.load_state_dict(pretrained_dict['model_state'])
        start_lr = pretrained_dict['lr']
        start_epoch = pretrained_dict['epoch']
        print("Start_lr: {} ,  Start_epoch {}".format(start_lr, start_epoch))
    # Add Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    # Output dir setting
    output_dir = tdir(args.output_dir, generate_name())
    print("Saving Dir: ", output_dir)

    if args.use_mse:
        criterion = torch.nn.MSELoss()
    else:
        criterion = torch.nn.L1Loss()
    if args.write_summary:
        writer_dir = tdir(
            output_dir, 'summary/' + args.model_name +
            '_start_epoch{}'.format(start_epoch))
        print("Using TensorboardX !")
        writer = SummaryWriter(logdir=writer_dir)
        # print(args.model_name)
    else:
        writer = 0
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    #start_lr = args.lr
    for epoch in range(start_epoch, args.epochs + 1):
        if UVBW_TRAIN:

            model.train()
            correct = 0
            for batch_idx, data in enumerate(dataset_train_loader):
                ori_map_gt = data[0].to(device)
                ab_map_gt = data[1].to(device)
                depth_map_gt = data[2].to(device)
                normal_map_gt = data[3].to(device)
                cmap_gt = data[4].to(device)
                uv_map_gt = data[5].to(device)
                df_map_gt = data[6].to(device)
                bg_map_gt = data[7].to(device)

                optimizer.zero_grad()
                uv_pred, cmap_pred, nor_pred, alb_pred, dep_pred, mask_map, \
                    nor_from_threeD, dep_from_threeD, nor_from_dep, dep_from_nor, df_map = model(ori_map_gt)

                # TODO: 这里需不需要改成这样
                uv = torch.where(mask_map > 0.5, uv_pred, 0)
                loss_uv = criterion(uv_pred, uv_map_gt).float()
                loss_uv.backward()
                optimizer.step()
                lr = get_lr(optimizer)

                # Start using Confident
                if epoch > 30:
                    loss_conf, loss_total = model2(dewarp_ori_pred, ori_map_gt)

                if batch_idx % args.log_intervals == 0:
                    print(
                        '\r Epoch:{}  batch index:{}/{}||lr:{:.8f}||loss_uv:{:.6f}||loss_bw:{:.6f}'
                        .format(
                            epoch, batch_idx + 1,
                            len(dataset_train_loader.dataset) //
                            args.batch_size, lr, loss_uv.item(),
                            loss_bw.item()),
                        end=" ")
                    if args.write_summary:
                        writer.add_scalar(
                            'summary/train_uv_loss',
                            loss_uv.item(),
                            global_step=epoch * len(dataset_train_loader) +
                            batch_idx + 1)
                        # writer.add_scalar('summary/backward_loss', loss_bw.item(),
                        #                   global_step=epoch * len(train_loader) + batch_idx + 1)
                        writer.add_scalar(
                            'summary/lrate',
                            lr,
                            global_step=epoch * len(dataset_train_loader) +
                            batch_idx + 1)
                if True:  # Draw Image
                    if batch_idx == (len(dataset_train_loader.dataset) //
                                     args.batch_size) - 1:
                        print('writing image')
                        for k in range(2):
                            uv_pred = uv_map[k, :, :, :]
                            uv_gt = uv_map_gt[k, :, :, :]
                            mask_gt = mask_map_gt[k, :, :, :]
                            bw_gt = df_map_gt[k, :, :, :]
                            bw_pred = bw_map[k, :, :, :]
                            cmap_gt = threeD_map_gt[k, :, :, :]

                            output_dir1 = tdir(
                                output_dir + '/uvbw_train/',
                                'epoch_{}_batch_{}/'.format(epoch, batch_idx))
                            """pred"""
                            print_img_with_reprocess(
                                uv_pred,
                                img_type="uv",
                                fname=tfilename(output_dir1 +
                                                'train/epoch_{}/pred_uv_ind_{}'
                                                .format(epoch, k) + '.jpg'))
                            # print_img_with_reprocess(bw_from_uv, img_type="bw", fname=tfilename(output_dir1 + 'train/epoch_{}/bw_f_uv_ind_{}'.format(epoch, k) + '.jpg'))
                            print_img_with_reprocess(
                                bw_pred,
                                img_type="bw",
                                fname=tfilename(output_dir1 +
                                                'train/epoch_{}/pred_bw_ind_{}'
                                                .format(epoch, k) + '.jpg'))
                            """gt"""
                            print_img_with_reprocess(
                                cmap_gt,
                                img_type="cmap",
                                fname=tfilename(output_dir1 +
                                                'train/epoch_{}/gt_3D_ind_{}'.
                                                format(epoch, k) +
                                                '.jpg'))  # Problem
                            print_img_with_reprocess(
                                uv_gt,
                                img_type="uv",
                                fname=tfilename(output_dir1 +
                                                'train/epoch_{}/gt_uv_ind_{}'.
                                                format(epoch, k) + '.jpg'))
                            print_img_with_reprocess(
                                bw_gt,
                                img_type="bw",
                                fname=tfilename(output_dir1 +
                                                'train/epoch_{}/gt_bw_ind_{}'.
                                                format(epoch, k) +
                                                '.jpg'))  # Problem
                            print_img_with_reprocess(
                                mask_gt,
                                img_type="background",
                                fname=tfilename(output_dir1 +
                                                'train/epoch_{}/gt_back_ind_{}'
                                                .format(epoch, k) + '.jpg'))

        else:
            test(args, model, dataset_test_loader, optimizer, criterion, epoch, \
                writer, output_dir, args.write_image_test)
            print("#" * 22)
            break

        scheduler.step()
        if UVBW_TRAIN and args.save_model:
            state = {
                'epoch': epoch + 1,
                'lr': lr,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict()
            }
            torch.save(
                state,
                tfilename(
                    output_dir,
                    "models/uvbw/{}_{}.pkl".format(args.model_name, epoch)))
예제 #6
0
def main():
    # -----------------------------------  Model Build -------------------------
    # model  = UnwarpNet(combine_num=1)
    model2 = DeformNet()
    args = train_configs.args
    isTrain = True
    # model  = torch.nn.DataParallel(model.cuda()) 
    model2 = torch.nn.DataParallel(model2.cuda()) 
    start_epoch = 1
    # Load Parameters
    # if args.pretrained:
    if True:
        print("Loading Pretrained model~")
        # "/home1/quanquan/code/Film-Recovery/output/train/std120201220-014112-xA836H/cmap_aug_500.pkl"
        # "/home1/quanquan/code/Film-Recovery/output/train/extrabg20201223-025124-sJKxHA/model/extrabg_310.pkl"
        pretrained_dict = torch.load("/home1/quanquan/code/Film-Recovery/output/train/iter20201229-084255-1htiye/model/iter_5.pkl", map_location=None)
        start_lr = pretrained_dict['lr']
        start_epoch = pretrained_dict['epoch'] if pretrained_dict['epoch'] < 100 else 100
        # -----------------------  Load partial model  ---------------------
        model_dict=model2.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict['model_state'].items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # -------------------------------------------------------------------
        # model.load_state_dict(pretrained_dict['model_state'])
        model2.load_state_dict(model_dict)
    # ------------------------------------  Load Dataset  -------------------------
    kwargs = {'num_workers': 8, 'pin_memory': True} 
    # dataset_test = filmDataset_3(npy_dir="/home1/quanquan/datasets/generate/mesh_film_small/")
    # dataset_test_loader = DataLoader(dataset_test,batch_size=args.test_batch_size, shuffle=False, **kwargs)
    dataset_eval = RealDataset("imgshow_test2", load_mod="new_ab", reg_start="pad_gaus_40", reg_end="jpg") # reg_str="pad_gaus_40"
    dataset_eval_loader = DataLoader(dataset_eval, batch_size=1, shuffle=False, **kwargs)
    dataset_train = filmDataset_3("/home1/quanquan/datasets/generate/mesh_film_hypo_alpha2/", load_mod="extra_bg")
    dataset_train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, **kwargs)
    
    # ------------------------------------  Optimizer  -------------------------
    optimizer = optim.Adam(model2.parameters(), lr=0.01)
    scheduler = StepLR(optimizer, step_size=2, gamma=args.gamma)
    #criterion = torch.nn.MSELoss()  
    criterion = torch.nn.L1Loss()
    bc_critic = nn.BCELoss() 
    # tv_loss = tv_loss
    
    if args.visualize_para:
        for name, parameters in model.named_parameters():
            print(name, ':', parameters.size())
    start_lr = args.lr
    global_step = 0
    
    # -----------------------------------  Training  ---------------------------
    for epoch in range(start_epoch, max_epoch + 1):
        model2.train()
        loss_value, loss_cmap_value, loss_ab_value, loss_uv_value, loss_bg_value = 0,0,0,0,0
        loss_nor_value, loss_dep_value = 0,0
        loss_bg_t_value, loss_cmap_t_value, loss_deform_value,loss_tv_value = 0,0,0,0
        datalen = len(dataset_train)
        print("Output dir:", output_dir)
        for batch_idx, data in enumerate(dataset_train_loader):
            
            ori_gt = data[0].cuda()
            ab_gt  = data[1].cuda()
            dep_gt = data[2].cuda()
            nor_gt = data[3].cuda()
            cmap_gt= data[4].cuda()
            uv_gt  = data[5].cuda()
            bg_gt  = data[6].cuda()
            
            optimizer.zero_grad()
            
            deform, _ = model2(ori_gt)      
            loss_tv = tv_loss(deform, 0.01)
            bg_template, pad_bg = construct_plain_bg(ori_gt.size(0),img_size=256)
            cmap_template, pad_cmap = construct_plain_cmap(ori_gt.size(0), img_size=256)
            dewarp_bg_t   = iter_mapping(bg_template  , deform)
            dewarp_cmap_t = iter_mapping(cmap_template, deform)
            loss_bg_t   = criterion(dewarp_bg_t, bg_gt)
            loss_cmap_t = criterion(dewarp_cmap_t, cmap_gt)
            loss_deform = loss_bg_t + loss_cmap_t + loss_tv
            loss_deform.backward()
            loss_deform_value += loss_deform
            loss_bg_t_value   += loss_bg_t
            loss_cmap_t_value += loss_cmap_t_value
            loss_tv_value += loss_tv
            optimizer.step()
            # global_step += 1
            lr = get_lr(optimizer)
            writer.add_scalar('summary/lrate_batch', lr, global_step=global_step)
            print("Epoch[\t{}/{}] \t batch:\t{}/{} \t lr:{} \t loss: {}".format(epoch, max_epoch, batch_idx,datalen,lr, loss_deform_value/(batch_idx+1)), end=" ") 
            print(f"loss_t: {loss_tv_value/(batch_idx+1)}, ")
            
            # w("check code")
            # break
        
        # ------ Scheduler Step -------
        # scheduler.step()
        writer.add_scalar('summary/loss_bg_t'  , loss_bg_t_value/(batch_idx+1)  , global_step=epoch)
        writer.add_scalar('summary/loss_cmap_t', loss_cmap_t_value/(batch_idx+1), global_step=epoch)     
        writer.add_scalar('summary/loss_tv', loss_tv_value/(batch_idx+1), global_step=epoch)
        print_img_with_reprocess(dewarp_bg_t[0,:,:,:]  , "bg"  , fname=tfilename(output_dir,"imgshow/epoch_{}".format(epoch), "bg.jpg"))
        print_img_with_reprocess(dewarp_cmap_t[0,:,:,:], "cmap", fname=tfilename(output_dir,"imgshow/epoch_{}".format(epoch), "cmap.jpg"))
        print_img_with_reprocess(ori_gt[0,:,:,:]       , "ori" ,  fname=tfilename(output_dir,"imgshow/epoch_{}".format(epoch), "ori_gt.jpg")) 
        print_img_with_reprocess(bg_gt[0,:,:,:]        ,  "bg" ,  fname=tfilename(output_dir,"imgshow/epoch_{}".format(epoch), "bg_gt.jpg"))
        print_img_with_reprocess(cmap_gt[0,:,:,:]      ,  "exr",  fname=tfilename(output_dir,"imgshow/epoch_{}".format(epoch), "cmap_gt.jpg")) #

        if isTrain and args.save_model and epoch %5 == 0:
            state = {'epoch': epoch + 1,
                     'lr': lr,
                     'model_state': model2.state_dict(),
                     'optimizer_state': optimizer.state_dict()
                     }
            torch.save(state, tfilename(output_dir, "model", "{}_{}.pkl".format(modelname, epoch)))
        
        # -----------  Evaluation  -------------
        if True: 
            model2.eval()
            p(output_dir_eval)
            for batch_idx, data in tqdm(enumerate(dataset_eval_loader)):
                ori_gt = data[0].cuda()
                ori_gt_large = data[1].cuda()
                
                deform, _ = model2(ori_gt)               
                bg_template, pad_bg = construct_plain_bg(ori_gt.size(0),img_size=256)
                cmap_template, pad_cmap = construct_plain_cmap(ori_gt.size(0), img_size=256)
                dewarp_bg_t   = iter_mapping(bg_template  , deform)
                dewarp_cmap_t = iter_mapping(cmap_template, deform)
                
                print_img_with_reprocess(dewarp_bg_t[0,:,:,:]  , "bg"  , fname=tfilename(output_dir_eval,"imgshow/epoch_{}".format(batch_idx), "bg.jpg"))
                print_img_with_reprocess(dewarp_cmap_t[0,:,:,:], "cmap", fname=tfilename(output_dir_eval,"imgshow/epoch_{}".format(batch_idx), "cmap.jpg"))
                print_img_with_reprocess(ori_gt[0,:,:,:]       , "ori" ,  fname=tfilename(output_dir_eval,"imgshow/epoch_{}".format(batch_idx), "ori_gt.jpg")) 
                
                if batch_idx >25:
                    break