def CreateDataset(opt):
    if opt.dataset_mode == 'temporal':
        from data.temporal_dataset import TemporalDataset
        dataset = TemporalDataset()
    elif opt.dataset_mode == 'face':
        from data.face_dataset import FaceDataset
        dataset = FaceDataset()
    elif opt.dataset_mode == 'pose':
        from data.pose_dataset import PoseDataset
        dataset = PoseDataset()
    elif opt.dataset_mode == 'test':
        from data.test_dataset import TestDataset
        dataset = TestDataset()
    else:
        raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)

    print("dataset [%s] was created" % (dataset.name()))
    dataset.initialize(opt)
    return dataset
Beispiel #2
0
def train_net():
    # set result directory
    if not os.path.exists(opt.result_dir):
        os.makedirs(opt.result_dir)
    tb_writer = tf.summary.FileWriter(opt.result_dir)
    logger = setup_logger('train_log', os.path.join(opt.result_dir, 'log.txt'))
    for key, value in vars(opt).items():
        logger.info(key + ': ' + str(value))
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
    # model & loss
    estimator = DeformNet(opt.n_cat, opt.nv_prior)
    estimator.cuda()
    criterion = Loss(opt.corr_wt, opt.cd_wt, opt.entropy_wt, opt.deform_wt)
    if opt.resume_model != '':
        estimator.load_state_dict(torch.load(opt.resume_model))
    # dataset
    train_dataset = PoseDataset(opt.dataset, 'train', opt.data_dir, opt.n_pts,
                                opt.img_size)
    val_dataset = PoseDataset(opt.dataset, 'test', opt.data_dir, opt.n_pts,
                              opt.img_size)
    # start training
    st_time = time.time()
    train_steps = 1500
    global_step = train_steps * (opt.start_epoch - 1)
    n_decays = len(opt.decay_epoch)
    assert len(opt.decay_rate) == n_decays
    for i in range(n_decays):
        if opt.start_epoch > opt.decay_epoch[i]:
            decay_count = i
    train_size = train_steps * opt.batch_size
    indices = []
    page_start = -train_size
    for epoch in range(opt.start_epoch, opt.max_epoch + 1):
        # train one epoch
        logger.info('Time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + \
                    ', ' + 'Epoch %02d' % epoch + ', ' + 'Training started'))
        # create optimizer and adjust learning rate if needed
        if decay_count < len(opt.decay_rate):
            if epoch > opt.decay_epoch[decay_count]:
                current_lr = opt.lr * opt.decay_rate[decay_count]
                optimizer = torch.optim.Adam(estimator.parameters(),
                                             lr=current_lr)
                decay_count += 1
        # sample train subset
        page_start += train_size
        len_last = len(indices) - page_start
        if len_last < train_size:
            indices = indices[page_start:]
            if opt.dataset == 'CAMERA+Real':
                # CAMERA : Real = 3 : 1
                camera_len = train_dataset.subset_len[0]
                real_len = train_dataset.subset_len[1]
                real_indices = list(range(camera_len, camera_len + real_len))
                camera_indices = list(range(camera_len))
                n_repeat = (train_size - len_last) // (4 * real_len) + 1
                data_list = random.sample(camera_indices, 3 * n_repeat *
                                          real_len) + real_indices * n_repeat
                random.shuffle(data_list)
                indices += data_list
            else:
                data_list = list(range(train_dataset.length))
                for i in range((train_size - len_last) //
                               train_dataset.length + 1):
                    random.shuffle(data_list)
                    indices += data_list
            page_start = 0
        train_idx = indices[page_start:(page_start + train_size)]
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=opt.batch_size,
            sampler=train_sampler,
            num_workers=opt.num_workers,
            pin_memory=True)
        estimator.train()
        for i, data in enumerate(train_dataloader, 1):
            points, rgb, choose, cat_id, model, prior, sRT, nocs = data
            points = points.cuda()
            rgb = rgb.cuda()
            choose = choose.cuda()
            cat_id = cat_id.cuda()
            model = model.cuda()
            prior = prior.cuda()
            sRT = sRT.cuda()
            nocs = nocs.cuda()
            assign_mat, deltas = estimator(points, rgb, choose, cat_id, prior)
            loss, corr_loss, cd_loss, entropy_loss, deform_loss = criterion(
                assign_mat, deltas, prior, nocs, model)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1
            # write results to tensorboard
            summary = tf.Summary(value=[
                tf.Summary.Value(tag='learning_rate', simple_value=current_lr),
                tf.Summary.Value(tag='train_loss', simple_value=loss),
                tf.Summary.Value(tag='corr_loss', simple_value=corr_loss),
                tf.Summary.Value(tag='cd_loss', simple_value=cd_loss),
                tf.Summary.Value(tag='entropy_loss',
                                 simple_value=entropy_loss),
                tf.Summary.Value(tag='deform_loss', simple_value=deform_loss)
            ])
            tb_writer.add_summary(summary, global_step)
            if i % 10 == 0:
                logger.info(
                    'Batch {0} Loss:{1:f}, corr_loss:{2:f}, cd_loss:{3:f}, entropy_loss:{4:f}, deform_loss:{5:f}'
                    .format(i, loss.item(), corr_loss.item(), cd_loss.item(),
                            entropy_loss.item(), deform_loss.item()))

        logger.info(
            '>>>>>>>>----------Epoch {:02d} train finish---------<<<<<<<<'.
            format(epoch))

        # evaluate one epoch
        logger.info('Time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Epoch %02d' % epoch + ', ' + 'Testing started'))
        val_loss = 0.0
        total_count = np.zeros((opt.n_cat, ), dtype=int)
        strict_success = np.zeros((opt.n_cat, ),
                                  dtype=int)  # 5 degree and 5 cm
        easy_success = np.zeros((opt.n_cat, ), dtype=int)  # 10 degree and 5 cm
        iou_success = np.zeros((opt.n_cat, ),
                               dtype=int)  # relative scale error < 0.1
        # sample validation subset
        val_size = 1500
        val_idx = random.sample(list(range(val_dataset.length)), val_size)
        val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_idx)
        val_dataloader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=1,
            sampler=val_sampler,
            num_workers=opt.num_workers,
            pin_memory=True)
        estimator.eval()
        for i, data in enumerate(val_dataloader, 1):
            points, rgb, choose, cat_id, model, prior, sRT, nocs = data
            points = points.cuda()
            rgb = rgb.cuda()
            choose = choose.cuda()
            cat_id = cat_id.cuda()
            model = model.cuda()
            prior = prior.cuda()
            sRT = sRT.cuda()
            nocs = nocs.cuda()
            assign_mat, deltas = estimator(points, rgb, choose, cat_id, prior)
            loss, _, _, _, _ = criterion(assign_mat, deltas, prior, nocs,
                                         model)
            # estimate pose and scale
            inst_shape = prior + deltas
            assign_mat = F.softmax(assign_mat, dim=2)
            nocs_coords = torch.bmm(assign_mat, inst_shape)
            nocs_coords = nocs_coords.detach().cpu().numpy()[0]
            points = points.cpu().numpy()[0]
            # use choose to remove repeated points
            choose = choose.cpu().numpy()[0]
            _, choose = np.unique(choose, return_index=True)
            nocs_coords = nocs_coords[choose, :]
            points = points[choose, :]
            _, _, _, pred_sRT = estimateSimilarityTransform(
                nocs_coords, points)
            # evaluate pose
            cat_id = cat_id.item()
            if pred_sRT is not None:
                sRT = sRT.detach().cpu().numpy()[0]
                R_error, T_error, IoU = compute_sRT_errors(pred_sRT, sRT)
                if R_error < 5 and T_error < 0.05:
                    strict_success[cat_id] += 1
                if R_error < 10 and T_error < 0.05:
                    easy_success[cat_id] += 1
                if IoU < 0.1:
                    iou_success[cat_id] += 1
            total_count[cat_id] += 1
            val_loss += loss.item()
            if i % 100 == 0:
                logger.info('Batch {0} Loss:{1:f}'.format(i, loss.item()))
        # compute accuracy
        strict_acc = 100 * (strict_success / total_count)
        easy_acc = 100 * (easy_success / total_count)
        iou_acc = 100 * (iou_success / total_count)
        for i in range(opt.n_cat):
            logger.info('{} accuracies:'.format(val_dataset.cat_names[i]))
            logger.info('5^o 5cm: {:4f}'.format(strict_acc[i]))
            logger.info('10^o 5cm: {:4f}'.format(easy_acc[i]))
            logger.info('IoU < 0.1: {:4f}'.format(iou_acc[i]))
        strict_acc = np.mean(strict_acc)
        easy_acc = np.mean(easy_acc)
        iou_acc = np.mean(iou_acc)
        val_loss = val_loss / val_size
        summary = tf.Summary(value=[
            tf.Summary.Value(tag='val_loss', simple_value=val_loss),
            tf.Summary.Value(tag='5^o5cm_acc', simple_value=strict_acc),
            tf.Summary.Value(tag='10^o5cm_acc', simple_value=easy_acc),
            tf.Summary.Value(tag='iou_acc', simple_value=iou_acc)
        ])
        tb_writer.add_summary(summary, global_step)
        logger.info('Epoch {0:02d} test average loss: {1:06f}'.format(
            epoch, val_loss))
        logger.info('Overall accuracies:')
        logger.info('5^o 5cm: {:4f} 10^o 5cm: {:4f} IoU: {:4f}'.format(
            strict_acc, easy_acc, iou_acc))
        logger.info(
            '>>>>>>>>----------Epoch {:02d} test finish---------<<<<<<<<'.
            format(epoch))
        # save model after each epoch
        torch.save(estimator.state_dict(),
                   '{0}/model_{1:02d}.pth'.format(opt.result_dir, epoch))
Beispiel #3
0
def CreateDataset(opt):
    dataset = None
    if opt.dataset_mode == 'temporal':
        from data.temporal_dataset import TemporalDataset
        dataset = TemporalDataset()
    elif opt.dataset_mode == 'face':
        from data.face_dataset import FaceDataset
        dataset = FaceDataset()
    elif opt.dataset_mode == 'pose':
        from data.pose_dataset import PoseDataset
        dataset = PoseDataset()
    elif opt.dataset_mode == 'test':
        from data.test_dataset import TestDataset
        dataset = TestDataset()
    elif opt.dataset_mode == 'stanford':
        from patagona_common.data.datasets import StanfordDatasetTemporal
        dataset = StanfordDatasetTemporal(rootdir='/home/ubuntu/datasets')

    elif opt.dataset_mode == 'stanford_custom':
        from patagona_common.data.datasets import StanfordDatasetTemporal
        regions = [
            dict(scene='deathCircle',
                 video=0,
                 size=(1400, 1904),
                 sequences=None,
                 anchors=None)
        ]
        dataset = StanfordDatasetTemporal(rootdir=opt.dataroot,
                                          regions=regions,
                                          crop_mode='center',
                                          output_dim=(opt.loadSize,
                                                      opt.loadSize),
                                          n_sequential_frames=opt.n_frames_G,
                                          inference=True)

    elif opt.dataset_mode == 'stanford_test':
        from patagona_common.data.datasets import StanfordDatasetTemporal
        regions = [
            dict(scene=opt.test_video_scene,
                 video=opt.test_video_id,
                 size=(1400, 1904),
                 sequences=None,
                 anchors=None)
        ]
        dataset = StanfordDatasetTemporal(rootdir=opt.dataroot,
                                          regions=regions,
                                          crop_mode='center',
                                          output_dim=(opt.loadSize,
                                                      opt.loadSize),
                                          n_sequential_frames=opt.n_frames_G,
                                          inference=True)

    elif opt.dataset_mode == 'stanford_synthetic':
        from patagona_common.data.datasets import StanfordDatasetTemporal
        regions = [
            dict(scene='syntheticSequence',
                 video=1,
                 size=(1400, 1904),
                 sequences=None,
                 anchors=None),
            dict(scene='syntheticSequence',
                 video=2,
                 size=(1400, 1904),
                 sequences=None,
                 anchors=None),
            dict(scene='syntheticSequence',
                 video=3,
                 size=(1400, 1904),
                 sequences=None,
                 anchors=None),
            dict(scene='syntheticSequence',
                 video=4,
                 size=(1400, 1904),
                 sequences=None,
                 anchors=None)
        ]
        dataset = StanfordDatasetTemporal(rootdir=opt.dataroot,
                                          regions=regions,
                                          crop_mode='center',
                                          output_dim=(opt.loadSize,
                                                      opt.loadSize),
                                          n_sequential_frames=opt.n_frames_G,
                                          inference=True)

    elif opt.dataset_mode == 'kaist':
        from patagona_common.data.datasets import KAISTTemporalDataset
        dataset = KAISTTemporalDataset(
            root_dir='/home/ubuntu/datasets',
            video_sets=['set00', 'set01', 'set02', 'set03', 'set04', 'set05'])
    elif opt.dataset_mode == 'kaist_test':
        from patagona_common.data.datasets import KAISTTemporalDataset
        dataset = KAISTTemporalDataset(
            root_dir='/home/ubuntu/datasets',
            video_sets=['set06', 'set07', 'set08', 'set09', 'set10', 'set11'],
            random_crop=False,
            n_seq_frames=opt.n_frames_G,
            output_dim=(640, 512),
            start_frame=opt.start_frame)
    elif opt.dataset_mode == 'kaist_test_single':
        from patagona_common.data.datasets import KAISTTemporalDataset
        dataset = KAISTTemporalDataset(
            root_dir='/home/ubuntu/datasets',
            video_sets=['set06', 'set07', 'set08', 'set09', 'set10', 'set11'],
            random_crop=False,
            n_seq_frames=opt.n_frames_G,
            output_dim=(640, 512),
            all_first=True,
            start_frame=opt.start_frame)

    else:
        raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)

    print("dataset [%s] was created" % (dataset.name()))
    dataset.initialize(opt)
    return dataset
def train_net():
    # set result directory
    if not os.path.exists(opt.result_dir):
        os.makedirs(opt.result_dir)
    tb_writer = tf.summary.FileWriter(opt.result_dir)
    logger = setup_logger('train_log', os.path.join(opt.result_dir, 'log.txt'))
    for key, value in vars(opt).items():
        logger.info(key + ': ' + str(value))
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
    # model & loss
    estimator = DeformNet(opt.n_cat, opt.nv_prior)
    estimator.cuda()
    # pdb.set_trace()
    criterion = Loss(opt.corr_wt, opt.cd_wt, opt.entropy_wt, opt.deform_wt)
    chamferD = ChamferLoss()
    if opt.resume_model != '':
        estimator.load_state_dict(torch.load(opt.resume_model))
    # dataset
    # 253445 images found. = [249127, 4318]
    # 1101 models loaded.
    train_dataset = PoseDataset(opt.dataset, 'train', opt.data_dir, opt.n_pts,
                                opt.img_size)
    # 2754 images found.
    # 18 models loaded.
    val_dataset = PoseDataset(opt.dataset, 'test', opt.data_dir, opt.n_pts,
                              opt.img_size)
    # start training
    st_time = time.time()
    train_steps = 1500
    global_step = train_steps * (opt.start_epoch - 1)
    n_decays = len(opt.decay_epoch)
    assert len(opt.decay_rate) == n_decays
    for i in range(n_decays):
        if opt.start_epoch > opt.decay_epoch[i]:
            decay_count = i
    # pdb.set_trace()
    train_size = train_steps * opt.batch_size
    indices = []
    page_start = -train_size
    for epoch in range(opt.start_epoch, opt.max_epoch + 1):
        # # train one epoch
        # logger.info('Time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + \
        #             ', ' + 'Epoch %02d' % epoch + ', ' + 'Training started'))
        # # create optimizer and adjust learning rate if needed
        # if decay_count < len(opt.decay_rate):
        #     if epoch > opt.decay_epoch[decay_count]:
        #         current_lr = opt.lr * opt.decay_rate[decay_count]
        #         optimizer = torch.optim.Adam(estimator.parameters(), lr=current_lr)
        #         decay_count += 1
        # # sample train subset
        # page_start += train_size
        # len_last = len(indices) - page_start
        # if len_last < train_size:
        #     indices = indices[page_start:]
        #     if opt.dataset == 'CAMERA+Real':
        #         # CAMERA : Real = 3 : 1
        #         camera_len = train_dataset.subset_len[0]
        #         real_len = train_dataset.subset_len[1]
        #         real_indices = list(range(camera_len, camera_len+real_len))
        #         camera_indices = list(range(camera_len))
        #         n_repeat = (train_size - len_last) // (4 * real_len) + 1
        #         data_list = random.sample(camera_indices, 3*n_repeat*real_len) + real_indices*n_repeat
        #         random.shuffle(data_list)
        #         indices += data_list
        #     else:
        #         data_list = list(range(train_dataset.length))
        #         for i in range((train_size - len_last) // train_dataset.length + 1):
        #             random.shuffle(data_list)
        #             indices += data_list
        #     page_start = 0
        # train_idx = indices[page_start:(page_start+train_size)]
        # train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx)
        # train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, sampler=train_sampler,
        #                                                num_workers=opt.num_workers, pin_memory=True)
        # estimator.train()
        # for i, data in enumerate(train_dataloader, 1):
        #     points, rgb, choose, cat_id, model, prior, sRT, nocs = data
        #     points = points.cuda()
        #     rgb = rgb.cuda()
        #     choose = choose.cuda()
        #     cat_id = cat_id.cuda()
        #     model = model.cuda()
        #     prior = prior.cuda()
        #     sRT = sRT.cuda()
        #     nocs = nocs.cuda()
        #     assign_mat, deltas = estimator(points, rgb, choose, cat_id, prior)
        #     loss, corr_loss, cd_loss, entropy_loss, deform_loss = criterion(assign_mat, deltas, prior, nocs, model)
        #     optimizer.zero_grad()
        #     loss.backward()
        #     optimizer.step()
        #     global_step += 1
        #     # write results to tensorboard
        #     summary = tf.Summary(value=[tf.Summary.Value(tag='learning_rate', simple_value=current_lr),
        #                                 tf.Summary.Value(tag='train_loss', simple_value=loss),
        #                                 tf.Summary.Value(tag='corr_loss', simple_value=corr_loss),
        #                                 tf.Summary.Value(tag='cd_loss', simple_value=cd_loss),
        #                                 tf.Summary.Value(tag='entropy_loss', simple_value=entropy_loss),
        #                                 tf.Summary.Value(tag='deform_loss', simple_value=deform_loss)])
        #     tb_writer.add_summary(summary, global_step)
        #     if i % 10 == 0:
        #         logger.info('Batch {0} Loss:{1:f}, corr_loss:{2:f}, cd_loss:{3:f}, entropy_loss:{4:f}, deform_loss:{5:f}'.format(
        #             i, loss.item(), corr_loss.item(), cd_loss.item(), entropy_loss.item(), deform_loss.item()))
        #
        # logger.info('>>>>>>>>----------Epoch {:02d} train finish---------<<<<<<<<'.format(epoch))

        # evaluate one epoch
        logger.info('Time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Epoch %02d' % epoch + ', ' + 'Testing started'))
        val_loss = 0.0
        total_count = np.zeros((opt.n_cat, ), dtype=int)
        strict_success = np.zeros((opt.n_cat, ),
                                  dtype=int)  # 5 degree and 5 cm
        easy_success = np.zeros((opt.n_cat, ), dtype=int)  # 10 degree and 5 cm
        iou_success = np.zeros((opt.n_cat, ),
                               dtype=int)  # relative scale error < 0.1
        # sample validation subset
        # val_size = 1500
        # val_idx = random.sample(list(range(val_dataset.length)), val_size)
        # val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_idx)
        # val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, sampler=val_sampler,
        #                                              num_workers=opt.num_workers, pin_memory=True)
        val_dataloader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=1,
            num_workers=opt.num_workers,
            pin_memory=True)
        estimator.eval()
        cd_num = torch.zeros(6)
        prior_cd = torch.zeros(6)
        deform_cd = torch.zeros(6)
        # pdb.set_trace()
        for i, data in enumerate(val_dataloader, 1):
            points, rgb, choose, cat_id, model, prior, sRT, nocs = data
            points = points.cuda()
            rgb = rgb.cuda()
            choose = choose.cuda()
            cat_id = cat_id.cuda()
            model = model.cuda()
            prior = prior.cuda()
            sRT = sRT.cuda()
            nocs = nocs.cuda()
            assign_mat, deltas = estimator(points, rgb, choose, cat_id, prior)
            loss, _, _, _, _ = criterion(assign_mat, deltas, prior, nocs,
                                         model)
            # pdb.set_trace()
            prior_loss, _, _ = chamferD(prior, model)
            deform_loss, _, _ = chamferD(prior + deltas, model)

            idx = cat_id.item()
            cd_num[idx] += 1
            prior_cd[idx] += prior_loss.item()
            deform_cd[idx] += deform_loss.item()

            # estimate pose and scale
            inst_shape = prior + deltas
            assign_mat = F.softmax(assign_mat, dim=2)
            nocs_coords = torch.bmm(assign_mat, inst_shape)
            nocs_coords = nocs_coords.detach().cpu().numpy()[0]
            points = points.cpu().numpy()[0]
            # use choose to remove repeated points
            choose = choose.cpu().numpy()[0]
            _, choose = np.unique(choose, return_index=True)
            nocs_coords = nocs_coords[choose, :]
            points = points[choose, :]
            _, _, _, pred_sRT = estimateSimilarityTransform(
                nocs_coords, points)
            # evaluate pose
            cat_id = cat_id.item()
            if pred_sRT is not None:
                sRT = sRT.detach().cpu().numpy()[0]
                R_error, T_error, IoU = compute_sRT_errors(pred_sRT, sRT)
                if R_error < 5 and T_error < 0.05:
                    strict_success[cat_id] += 1
                if R_error < 10 and T_error < 0.05:
                    easy_success[cat_id] += 1
                if IoU < 0.1:
                    iou_success[cat_id] += 1
            total_count[cat_id] += 1
            val_loss += loss.item()
            if i % 100 == 0:
                logger.info('Batch {0} Loss:{1:f}'.format(i, loss.item()))
        # pdb.set_trace()
        deform_cd_metric = (deform_cd / cd_num) * 1000
        print(
            "recon: {:.2f} : {:.2f} : {:.2f} : {:.2f} : {:.2f} : {:.2f} : {:.2f}"
            .format(deform_cd_metric[0], deform_cd_metric[1],
                    deform_cd_metric[2], deform_cd_metric[3],
                    deform_cd_metric[4], deform_cd_metric[5],
                    torch.mean(deform_cd_metric)))
        prior_cd_metric = (prior_cd / cd_num) * 1000
        print(
            "prior: {:.2f} : {:.2f} : {:.2f} : {:.2f} : {:.2f} : {:.2f} : {:.2f}"
            .format(prior_cd_metric[0], prior_cd_metric[1], prior_cd_metric[2],
                    prior_cd_metric[3], prior_cd_metric[4], prior_cd_metric[5],
                    torch.mean(prior_cd_metric)))
Beispiel #5
0
model = CNNGeometric(use_cuda=use_cuda,geometric_model=args.geometric_model,arch=args.arch,featext_weights=args.fw)

if args.use_mse_loss:
    print('Using MSE loss...')
    loss = nn.MSELoss()
else:
    print('Using grid loss...')
    loss = TransformedGridLoss(use_cuda=use_cuda,geometric_model=args.geometric_model)


if args.geometric_model == 'pose':
    dataset = PoseDataset(
        geometric_model=args.geometric_model,
        csv_file=os.path.join(args.training_tnf_csv, 'train.csv'),
        training_image_path=args.training_image_path,
        output_size=(240, 240),
        transform=NormalizeImageDict(['source_image','target_image'])
    )
    dataset_test = PoseDataset(geometric_model=args.geometric_model,
                            csv_file=os.path.join(args.training_tnf_csv,'test.csv'),
                            training_image_path=args.training_image_path,
                            output_size=(240, 240),
                            transform=NormalizeImageDict(['source_image','target_image']))
# Dataset and dataloader
else:
    dataset = SynthDataset(geometric_model=args.geometric_model,
                       csv_file=os.path.join(args.training_tnf_csv,'train.csv'),
                       training_image_path=args.training_image_path,
                       transform=NormalizeImageDict(['image']))