Beispiel #1
0
def TSS(root,
        source_image_transform=None,
        target_image_transform=None,
        flow_transform=None,
        co_transform=None,
        split=None):
    train_list, test_list = make_dataset(root, split)
    train_dataset = ListDataset(root,
                                train_list,
                                source_image_transform=source_image_transform,
                                target_image_transform=target_image_transform,
                                flow_transform=flow_transform,
                                co_transform=co_transform,
                                loader=TSS_flow_loader,
                                mask=True,
                                size=True)
    test_dataset = ListDataset(root,
                               test_list,
                               source_image_transform=source_image_transform,
                               target_image_transform=target_image_transform,
                               flow_transform=flow_transform,
                               co_transform=co_transform,
                               loader=TSS_flow_loader,
                               mask=True,
                               size=True)
    return train_dataset, test_dataset
def KITTI_noc(root,
              source_image_transform=None,
              target_image_transform=None,
              flow_transform=None,
              co_transform=None,
              split=None):
    train_list, test_list = make_dataset(root, split, False)
    train_dataset = ListDataset(root,
                                train_list,
                                source_image_transform=source_image_transform,
                                target_image_transform=target_image_transform,
                                flow_transform=flow_transform,
                                co_transform=co_transform,
                                loader=KITTI_flow_loader,
                                mask=True)
    test_dataset = ListDataset(root,
                               test_list,
                               source_image_transform=source_image_transform,
                               target_image_transform=target_image_transform,
                               flow_transform=flow_transform,
                               co_transform=co_transform,
                               loader=KITTI_flow_loader,
                               mask=True)

    return train_dataset, test_dataset
def kitti_occ_both(root,
                   source_image_transform=None,
                   target_image_transform=None,
                   flow_transform=None,
                   co_transform=None,
                   test_image_transform=None,
                   split=None):

    train_list1, test_list1 = make_dataset(os.path.join(
        root, 'KITTI_2012/training/'),
                                           dataset_name='KITTI_2012/training/',
                                           split=split,
                                           occ=True)
    train_list2, test_list2 = make_dataset(os.path.join(
        root, 'KITTI_2015/training/'),
                                           dataset_name='KITTI_2015/training/',
                                           split=split,
                                           occ=True)

    train_dataset = ListDataset(root,
                                train_list1 + train_list2,
                                source_image_transform=source_image_transform,
                                target_image_transform=target_image_transform,
                                flow_transform=flow_transform,
                                co_transform=co_transform,
                                loader=KITTI_flow_loader,
                                mask=True)
    if test_image_transform is None:
        test_dataset = ListDataset(
            root,
            test_list1 + test_list2,
            source_image_transform=source_image_transform,
            target_image_transform=target_image_transform,
            flow_transform=flow_transform,
            co_transform=co_flow_and_images_transforms.CenterCrop((368, 1224)),
            loader=KITTI_flow_loader,
            mask=True)
    else:
        test_dataset = ListDataset(
            root,
            test_list1 + test_list2,
            source_image_transform=test_image_transform,
            target_image_transform=test_image_transform,
            flow_transform=flow_transform,
            co_transform=co_flow_and_images_transforms.CenterCrop((368, 1224)),
            loader=KITTI_flow_loader,
            mask=True)
    return train_dataset, test_dataset
def fusion_data(root, transform=None, target_transform=None, co_transform=None, split=None):
    train_list, test_list = make_dataset(root, split)
    train_dataset = ListDataset(root, train_list, transform, target_transform, co_transform)
    test_dataset = ListDataset(root, test_list, transform, target_transform, None)

    return train_dataset, test_dataset