예제 #1
0
파일: main.py 프로젝트: JennyGao00/FCRN
def create_loader(args):
    traindir = os.path.join(Path.db_root_dir(args.dataset), 'train')
    if os.path.exists(traindir):
        print('Train dataset "{}" is existed!'.format(traindir))
    else:
        print('Train dataset "{}" is not existed!'.format(traindir))
        exit(-1)

    valdir = os.path.join(Path.db_root_dir(args.dataset), 'val')
    if os.path.exists(traindir):
        print('Train dataset "{}" is existed!'.format(valdir))
    else:
        print('Train dataset "{}" is not existed!'.format(valdir))
        exit(-1)

    if args.dataset == 'kitti':
        train_set = kitti_dataloader.KITTIDataset(traindir, type='train')
        val_set = kitti_dataloader.KITTIDataset(valdir, type='val')

        # sample 3200 pictures for validation from val set
        weights = [1 for i in range(len(val_set))]
        print('weights:', len(weights))
        sampler = torch.utils.data.WeightedRandomSampler(weights,
                                                         num_samples=3200)
    elif args.dataset == 'nyu':
        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')
    else:
        print('no dataset named as ', args.dataset)
        exit(-1)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    if args.dataset == 'kitti':
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=args.batch_size,
                                                 sampler=sampler,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        val_loader = torch.utils.data.DataLoader(val_set,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

    return train_loader, val_loader
예제 #2
0
def get_sets(args, traindir, valdir, testdir):
    sampler = None

    if args.dataset == 'kitti':
        train_set = kitti_dataloader.KITTIDataset(traindir, type='train')
        val_set = kitti_dataloader.KITTIDataset(valdir, type='val')
        # sample 3200 pictures for validation from val set
        weights = [1 for i in range(len(val_set))]
        print('weights:', len(weights))
        sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=3200)
    elif args.dataset == 'nyu':
        train_set = nyu_dataloader.NYUDataset(traindir, type='train')
        val_set = nyu_dataloader.NYUDataset(valdir, type='val')
    elif args.dataset == 'panoptic':
        train_set = panoptic_dataloader.PANOPTICDataset(traindir, type='train')
        val_set = panoptic_dataloader.PANOPTICDataset(valdir, type='val')
        test_set = panoptic_dataloader.PANOPTICDataset(testdir, type='test')
    else:
        print('no dataset named as ', args.dataset)
        exit(-1)
    return sampler, train_set, val_set, test_set
예제 #3
0
def create_loader(args):
    if args.dataset == 'kitti':
        kitti_root = Path.db_root_dir(args.dataset)
        if os.path.exists(kitti_root):
            print('kitti dataset "{}" exists!'.format(kitti_root))
        else:
            print('kitti dataset "{}" doesnt existed!'.format(kitti_root))
            exit(-1)

        train_set = kitti_dataloader.KITTIDataset(
            kitti_root, type='train', model=args.model)
        val_set = kitti_dataloader.KITTIDataset(
            kitti_root, type='test', model=args.model)

    elif args.dataset == 'nyu':
        traindir = os.path.join(Path.db_root_dir(args.dataset), 'train')
        if os.path.exists(traindir):
            print('Train dataset "{}" exits!'.format(traindir))
        else:
            print('Train dataset "{}" doesnt existed!'.format(traindir))
            exit(-1)

        valdir = os.path.join(Path.db_root_dir(args.dataset), 'val')
        if os.path.exists(valdir):
            print('Val dataset "{}" exists!'.format(valdir))
        else:
            print('Val dataset "{}" doesnt existed!'.format(valdir))
            exit(-1)

        train_set = nyu_dataloader.NYUDataset(
            traindir, type='train', model=args.model)
        val_set = nyu_dataloader.NYUDataset(
            valdir, type='val', model=args.model)

    elif args.dataset == 'saved_images':
        if not os.path.exists(args.save_image_dir):
            print('Val dataset "{}" doesnt existed!'.format(args.save_image_dir))
            exit(-1)

        train_set = folder_loader.FolderDataset(
            args.save_image_dir, model=args.model)
        val_set = folder_loader.FolderDataset(
            args.save_image_dir, model=args.model)
    else:
        print('no dataset named as ', args.dataset)
        exit(-1)

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)

    if args.dataset == 'kitti':
        if args.save_image_dir is not None:
            val_loader = torch.utils.data.DataLoader(
                val_set, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
        else:
            val_loader = torch.utils.data.DataLoader(
                val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
    else:
        val_loader = torch.utils.data.DataLoader(
            val_set, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)

    return train_loader, val_loader