コード例 #1
0
ファイル: main.py プロジェクト: JennyGao00/FCRN
def create_loader(args):
    traindir = os.path.join(Path.db_root_dir(args.dataset), 'train')
    if os.path.exists(traindir):
        print('Train dataset "{}" is existed!'.format(traindir))
    else:
        print('Train dataset "{}" is not existed!'.format(traindir))
        exit(-1)

    valdir = os.path.join(Path.db_root_dir(args.dataset), 'val')
    if os.path.exists(traindir):
        print('Train dataset "{}" is existed!'.format(valdir))
    else:
        print('Train dataset "{}" is not existed!'.format(valdir))
        exit(-1)

    if args.dataset == 'kitti':
        train_set = kitti_dataloader.KITTIDataset(traindir, type='train')
        val_set = kitti_dataloader.KITTIDataset(valdir, type='val')

        # sample 3200 pictures for validation from val set
        weights = [1 for i in range(len(val_set))]
        print('weights:', len(weights))
        sampler = torch.utils.data.WeightedRandomSampler(weights,
                                                         num_samples=3200)
    elif args.dataset == 'nyu':
        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')
    else:
        print('no dataset named as ', args.dataset)
        exit(-1)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    if args.dataset == 'kitti':
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=args.batch_size,
                                                 sampler=sampler,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

    return train_loader, val_loader
コード例 #2
0
    def __init__(self,
                 base_dir=Path.db_root_dir('vocaug'),
                 split='train',
                 transform=None):
        super().__init__()
        self._base_dir = base_dir
        self._image_dir = os.path.join(self._base_dir, 'img')
        self._cat_dir = os.path.join(self._base_dir, 'gt')
        self._list_dir = os.path.join(self._base_dir, 'list')

        self.transform = transform

        # print(self._base_dir)

        if split == 'train':
            list_path = os.path.join(self._list_dir, 'train_aug.txt')
        elif split == 'val':
            list_path = os.path.join(self._list_dir, 'val.txt')
        else:
            print('error in split:', split)
            exit(-1)

        self.filenames = [i_id.strip() for i_id in open(list_path)]

        # Display stats
        print('Number of images in {}: {:d}'.format(split,
                                                    len(self.filenames)))
コード例 #3
0
ファイル: main.py プロジェクト: mmlab-cv/ICIP-2021-2346
def get_if_exists(type):
    dirpath = os.path.join(Path.db_root_dir(args.dataset), type)

    if os.path.exists(dirpath):
        print(f"{type} dataset exists")
        return dirpath
    else:
        print(f"{type} dataset does not exist")
        exit(-1)
コード例 #4
0
ファイル: __init__.py プロジェクト: yanyan-li/CSPN_monodepth
def create_loader(args, mode='train'):
    import os
    from dataloaders.path import Path
    root_dir = Path.db_root_dir(args.dataset)

    if mode.lower() == 'train':
        traindir = os.path.join(root_dir, 'train')

        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(traindir))
        else:
            print('Train dataset "{}" is not existed!'.format(traindir))
            exit(-1)

        from dataloaders.nyu_dataloader import nyu_dataloader
        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        import torch
        if torch.cuda.device_count() > 1:
            train_loader = torch.utils.data.DataLoader(
                train_set,
                batch_size=args.batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=args.workers,
                pin_memory=True)
        else:
            train_loader = torch.utils.data.DataLoader(
                train_set,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)

        return train_loader
    elif mode.lower() == 'val':

        valdir = os.path.join(root_dir, 'val')
        if os.path.exists(valdir):
            print('Val dataset "{}" is existed!'.format(valdir))
        else:
            print('Val dataset "{}" is not existed!'.format(valdir))
            exit(-1)

        from dataloaders.nyu_dataloader import nyu_dataloader
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')

        import torch
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        return val_loader
    else:
        raise NotImplementedError
コード例 #5
0
def create_loader(dataset='kitti'):
    root_dir = Path.db_root_dir(dataset)
    if dataset == 'kitti':
        train_set = KittiFolder(root_dir, mode='train', size=(385, 513))
        test_set = KittiFolder(root_dir, mode='test', size=(385, 513))
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=32,
                                                   shuffle=False,
                                                   num_workers=0,
                                                   pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=32,
                                                  shuffle=False,
                                                  num_workers=0,
                                                  pin_memory=True)
        return train_loader, test_loader
    else:
        traindir = os.path.join(root_dir, 'train')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(traindir))
        else:
            print('Train dataset "{}" is not existed!'.format(traindir))
            exit(-1)

        valdir = os.path.join(root_dir, 'val')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(valdir))
        else:
            print('Train dataset "{}" is not existed!'.format(valdir))
            exit(-1)

        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')

        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=4,
                                                   shuffle=False,
                                                   num_workers=0,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=4,
                                                 shuffle=False,
                                                 num_workers=0,
                                                 pin_memory=True)

        return train_loader, val_loader
コード例 #6
0
def create_loader(args):
    if args.dataset == 'kitti':
        kitti_root = Path.db_root_dir(args.dataset)
        if os.path.exists(kitti_root):
            print('kitti dataset "{}" exists!'.format(kitti_root))
        else:
            print('kitti dataset "{}" doesnt existed!'.format(kitti_root))
            exit(-1)

        train_set = kitti_dataloader.KITTIDataset(
            kitti_root, type='train', model=args.model)
        val_set = kitti_dataloader.KITTIDataset(
            kitti_root, type='test', model=args.model)

    elif args.dataset == 'nyu':
        traindir = os.path.join(Path.db_root_dir(args.dataset), 'train')
        if os.path.exists(traindir):
            print('Train dataset "{}" exits!'.format(traindir))
        else:
            print('Train dataset "{}" doesnt existed!'.format(traindir))
            exit(-1)

        valdir = os.path.join(Path.db_root_dir(args.dataset), 'val')
        if os.path.exists(valdir):
            print('Val dataset "{}" exists!'.format(valdir))
        else:
            print('Val dataset "{}" doesnt existed!'.format(valdir))
            exit(-1)

        train_set = nyu_dataloader.NYUDataset(
            traindir, type='train', model=args.model)
        val_set = nyu_dataloader.NYUDataset(
            valdir, type='val', model=args.model)

    elif args.dataset == 'saved_images':
        if not os.path.exists(args.save_image_dir):
            print('Val dataset "{}" doesnt existed!'.format(args.save_image_dir))
            exit(-1)

        train_set = folder_loader.FolderDataset(
            args.save_image_dir, model=args.model)
        val_set = folder_loader.FolderDataset(
            args.save_image_dir, model=args.model)
    else:
        print('no dataset named as ', args.dataset)
        exit(-1)

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)

    if args.dataset == 'kitti':
        if args.save_image_dir is not None:
            val_loader = torch.utils.data.DataLoader(
                val_set, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
        else:
            val_loader = torch.utils.data.DataLoader(
                val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
    else:
        val_loader = torch.utils.data.DataLoader(
            val_set, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)

    return train_loader, val_loader
コード例 #7
0
ファイル: __init__.py プロジェクト: rancheng/CSPN_monodepth
def create_loader(args, mode='train'):
    # Data loading code
    print('=> creating ', mode, ' loader ...')
    import os
    from dataloaders.path import Path
    root_dir = Path.db_root_dir(args.dataset)

    # sparsifier is a class for generating random sparse depth input from the ground truth
    import numpy as np
    sparsifier = None
    max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
    from dataloaders.nyu_dataloader.dense_to_sparse import UniformSampling
    from dataloaders.nyu_dataloader.dense_to_sparse import SimulatedStereo
    if args.sparsifier == UniformSampling.name:
        sparsifier = UniformSampling(num_samples=args.num_samples,
                                     max_depth=max_depth)
    elif args.sparsifier == SimulatedStereo.name:
        sparsifier = SimulatedStereo(num_samples=args.num_samples,
                                     max_depth=max_depth)

    from dataloaders.kitti_dataloader.kitti_dataloader import KITTIDataset

    import torch
    if mode.lower() == 'train':
        traindir = os.path.join(root_dir, 'train')

        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(traindir))
        else:
            print('Train dataset "{}" is not existed!'.format(traindir))
            exit(-1)
        train_dataset = KITTIDataset(traindir,
                                     type='train',
                                     modality=args.modality,
                                     sparsifier=sparsifier)
        # worker_init_fn ensures different sampling patterns for each data loading thread
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
            sampler=None,
            worker_init_fn=lambda work_id: np.random.seed(work_id))

        return train_loader

    elif mode.lower() == 'val':
        valdir = os.path.join(root_dir, 'val')
        if os.path.exists(valdir):
            print('Val dataset "{}" is existed!'.format(valdir))
        else:
            print('Val dataset "{}" is not existed!'.format(valdir))
            exit(-1)
        val_dataset = KITTIDataset(valdir,
                                   type='val',
                                   modality=args.modality,
                                   sparsifier=sparsifier)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

        return val_loader

    else:
        raise NotImplementedError
コード例 #8
0
ファイル: main.py プロジェクト: DaHaiHuha/DORN_pytorch
def create_loader(args):
    root_dir = Path.db_root_dir(args.dataset)  # --dataset hacker
    if args.dataset == 'kitti':
        train_set = KittiFolder(root_dir, mode='train', size=(385, 513))
        test_set = KittiFolder(root_dir, mode='test', size=(385, 513))
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)
        return train_loader, test_loader
    elif args.dataset == 'hacker':
        # data = 'test.txt' or 'val.txt', transform = None
        train_set = HackerDataloader(root_dir, type='train')
        val_set = HackerDataloader(root_dir, type='val')
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        return train_loader, val_loader
        # raise NotImplementedError
    else:
        traindir = os.path.join(root_dir, 'train')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(traindir))
        else:
            print('Train dataset "{}" is not existed!'.format(traindir))
            exit(-1)

        valdir = os.path.join(root_dir, 'val')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(valdir))
        else:
            print('Train dataset "{}" is not existed!'.format(valdir))
            exit(-1)

        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')

        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

        return train_loader, val_loader
コード例 #9
0
ファイル: main.py プロジェクト: kopetri/DORN_pytorch
def create_loader(args):
    root_dir = Path.db_root_dir(args.dataset)
    if args.dataset == 'kitti':
        train_set = KittiFolder(root_dir, mode='train', size=(385, 513))
        test_set = KittiFolder(root_dir, mode='test', size=(385, 513))
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)
        return train_loader, test_loader
    elif args.dataset == 'nyu':
        traindir = os.path.join(root_dir, 'train')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(traindir))
        else:
            print('Train dataset "{}" is not existed!'.format(traindir))
            exit(-1)

        valdir = os.path.join(root_dir, 'val')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(valdir))
        else:
            print('Train dataset "{}" is not existed!'.format(valdir))
            exit(-1)

        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')

        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

        return train_loader, val_loader
    elif args.dataset == 'floorplan3d':
        traindir = os.path.join(root_dir)
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(traindir))
        else:
            print('Train dataset "{}" is not existed!'.format(traindir))
            exit(-1)

        valdir = os.path.join(root_dir)
        if os.path.exists(traindir):
            print('Valid dataset "{}" is existed!'.format(valdir))
        else:
            print('Valid dataset "{}" is not existed!'.format(valdir))
            exit(-1)

        train_set = floorplan3d_dataloader.Floorplan3DDataset(
            traindir, dataset_type=args.dataset_type, split='train')
        val_set = floorplan3d_dataloader.Floorplan3DDataset(
            valdir, dataset_type=args.dataset_type, split='val')

        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        return train_loader, val_loader
    else:
        raise ValueError("unknown dataset")