예제 #1
0
def NYUDepth_loader(data_path, batch_size=32, isTrain=True):
    if isTrain:
        traindir = os.path.join(data_path, 'train')
        print('Train file path is ', traindir)

        if os.path.exists(traindir):
            print('Train dataset file path is existed!')
        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=lambda work_id: np.random.seed(work_id))
        return train_loader
    else:
        valdir = os.path.join(data_path, 'val')
        print('Test file path is ', valdir)

        if os.path.exists(valdir):
            print('Test dataset file path is existed!')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        return val_loader
예제 #2
0
def NYUDepth_loader(data_path, batch_size=32, isTrain=True):
    if isTrain:
        traindir = os.path.join(data_path, 'train')
        print(traindir)

        if os.path.exists(traindir):
            print('训练集目录存在')
        trainset = nyu_dataloader.NYUDataset(traindir, type='train')
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=batch_size,
                                                   shuffle=True)  # @wx 多线程读取失败
        return train_loader
    else:
        valdir = os.path.join(data_path, 'val')
        print(valdir)

        if os.path.exists(valdir):
            print('测试集目录存在')
        valset = nyu_dataloader.NYUDataset(valdir, type='val')
        val_loader = torch.utils.data.DataLoader(
            valset,
            batch_size=1,
            shuffle=False  # shuffle 测试时是否设置成False batch_size 恒定为1
        )
        return val_loader
예제 #3
0
파일: main.py 프로젝트: JennyGao00/FCRN
def create_loader(args):
    traindir = os.path.join(Path.db_root_dir(args.dataset), 'train')
    if os.path.exists(traindir):
        print('Train dataset "{}" is existed!'.format(traindir))
    else:
        print('Train dataset "{}" is not existed!'.format(traindir))
        exit(-1)

    valdir = os.path.join(Path.db_root_dir(args.dataset), 'val')
    if os.path.exists(traindir):
        print('Train dataset "{}" is existed!'.format(valdir))
    else:
        print('Train dataset "{}" is not existed!'.format(valdir))
        exit(-1)

    if args.dataset == 'kitti':
        train_set = kitti_dataloader.KITTIDataset(traindir, type='train')
        val_set = kitti_dataloader.KITTIDataset(valdir, type='val')

        # sample 3200 pictures for validation from val set
        weights = [1 for i in range(len(val_set))]
        print('weights:', len(weights))
        sampler = torch.utils.data.WeightedRandomSampler(weights,
                                                         num_samples=3200)
    elif args.dataset == 'nyu':
        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')
    else:
        print('no dataset named as ', args.dataset)
        exit(-1)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    if args.dataset == 'kitti':
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=args.batch_size,
                                                 sampler=sampler,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

    return train_loader, val_loader
예제 #4
0
def create_loader(dataset='kitti'):
    root_dir = Path.db_root_dir(dataset)
    if dataset == 'kitti':
        train_set = KittiFolder(root_dir, mode='train', size=(385, 513))
        test_set = KittiFolder(root_dir, mode='test', size=(385, 513))
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=32,
                                                   shuffle=False,
                                                   num_workers=0,
                                                   pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=32,
                                                  shuffle=False,
                                                  num_workers=0,
                                                  pin_memory=True)
        return train_loader, test_loader
    else:
        traindir = os.path.join(root_dir, 'train')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(traindir))
        else:
            print('Train dataset "{}" is not existed!'.format(traindir))
            exit(-1)

        valdir = os.path.join(root_dir, 'val')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(valdir))
        else:
            print('Train dataset "{}" is not existed!'.format(valdir))
            exit(-1)

        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')

        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=4,
                                                   shuffle=False,
                                                   num_workers=0,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=4,
                                                 shuffle=False,
                                                 num_workers=0,
                                                 pin_memory=True)

        return train_loader, val_loader
def create_loader(args):
    #root_dir = Path.db_root_dir(args.dataset)
    if args.dataset == 'kitti':
        # train_set = KittiFolder(root_dir, mode='train', size=(385, 513))
        # test_set = KittiFolder(root_dir, mode='test', size=(385, 513))
        # train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True,
        #                                            num_workers=args.workers, pin_memory=True)
        # test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False,
        #                                           num_workers=args.workers, pin_memory=True)
        # return train_loader, test_loader
        ## TODO implement KITTI
        assert "Not implemented"
    else:
        # traindir = os.path.join(root_dir, 'train')
        # if os.path.exists(traindir):
        #     print('Train dataset "{}" is existed!'.format(traindir))
        # else:
        #     print('Train dataset "{}" is not existed!'.format(traindir))
        #     exit(-1)
        #
        # valdir = os.path.join(root_dir, 'val')
        # if os.path.exists(traindir):
        #     print('Train dataset "{}" is existed!'.format(valdir))
        # else:
        #     print('Train dataset "{}" is not existed!'.format(valdir))
        #     exit(-1)
        # Data loading code
        print("=> creating data loaders...")
        data_dir = '..'
        valdir = os.path.join(data_dir, 'data', args.dataset, 'val')
        traindir = os.path.join(data_dir, 'data', args.dataset, 'train')

        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')

        train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)

        val_loader = torch.utils.data.DataLoader(
            val_set, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)

        return train_loader, val_loader
예제 #6
0
def get_sets(args, traindir, valdir, testdir):
    sampler = None

    if args.dataset == 'kitti':
        train_set = kitti_dataloader.KITTIDataset(traindir, type='train')
        val_set = kitti_dataloader.KITTIDataset(valdir, type='val')
        # sample 3200 pictures for validation from val set
        weights = [1 for i in range(len(val_set))]
        print('weights:', len(weights))
        sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=3200)
    elif args.dataset == 'nyu':
        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')
    elif args.dataset == 'panoptic':
        train_set = panoptic_dataloader.PANOPTICDataset(traindir, type='train')
        val_set = panoptic_dataloader.PANOPTICDataset(valdir, type='val')
        test_set = panoptic_dataloader.PANOPTICDataset(testdir, type='test')
    else:
        print('no dataset named as ', args.dataset)
        exit(-1)
    return sampler, train_set, val_set, test_set
예제 #7
0
def create_loader(args):
    root_dir = Path.db_root_dir(args.dataset)  # --dataset hacker
    if args.dataset == 'kitti':
        train_set = KittiFolder(root_dir, mode='train', size=(385, 513))
        test_set = KittiFolder(root_dir, mode='test', size=(385, 513))
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)
        return train_loader, test_loader
    elif args.dataset == 'hacker':
        # data = 'test.txt' or 'val.txt', transform = None
        train_set = HackerDataloader(root_dir, type='train')
        val_set = HackerDataloader(root_dir, type='val')
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        return train_loader, val_loader
        # raise NotImplementedError
    else:
        traindir = os.path.join(root_dir, 'train')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(traindir))
        else:
            print('Train dataset "{}" is not existed!'.format(traindir))
            exit(-1)

        valdir = os.path.join(root_dir, 'val')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(valdir))
        else:
            print('Train dataset "{}" is not existed!'.format(valdir))
            exit(-1)

        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')

        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

        return train_loader, val_loader
예제 #8
0
def create_loader(args):
    root_dir = Path.db_root_dir(args.dataset)
    if args.dataset == 'kitti':
        train_set = KittiFolder(root_dir, mode='train', size=(385, 513))
        test_set = KittiFolder(root_dir, mode='test', size=(385, 513))
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)
        return train_loader, test_loader
    elif args.dataset == 'nyu':
        traindir = os.path.join(root_dir, 'train')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(traindir))
        else:
            print('Train dataset "{}" is not existed!'.format(traindir))
            exit(-1)

        valdir = os.path.join(root_dir, 'val')
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(valdir))
        else:
            print('Train dataset "{}" is not existed!'.format(valdir))
            exit(-1)

        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')

        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

        return train_loader, val_loader
    elif args.dataset == 'floorplan3d':
        traindir = os.path.join(root_dir)
        if os.path.exists(traindir):
            print('Train dataset "{}" is existed!'.format(traindir))
        else:
            print('Train dataset "{}" is not existed!'.format(traindir))
            exit(-1)

        valdir = os.path.join(root_dir)
        if os.path.exists(traindir):
            print('Valid dataset "{}" is existed!'.format(valdir))
        else:
            print('Valid dataset "{}" is not existed!'.format(valdir))
            exit(-1)

        train_set = floorplan3d_dataloader.Floorplan3DDataset(
            traindir, dataset_type=args.dataset_type, split='train')
        val_set = floorplan3d_dataloader.Floorplan3DDataset(
            valdir, dataset_type=args.dataset_type, split='val')

        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        return train_loader, val_loader
    else:
        raise ValueError("unknown dataset")