コード例 #1
0
def create_train_val_dataloader(configs):
    """Create dataloader for training and validate"""
    train_aug_transforms = OneOf([
        Random_Rotation(limit_angle=20., p=1.0),
        Random_Scaling(scaling_range=(0.95, 1.05), p=1.0)
    ],
                                 p=0.6)
    train_dataset = KittiDataset(configs.dataset_dir,
                                 split='train',
                                 mode='train',
                                 aug_transforms=train_aug_transforms,
                                 hflip_prob=0.5,
                                 multiscale=configs.multiscale_training,
                                 num_samples=configs.num_samples,
                                 mosaic=configs.mosaic,
                                 random_padding=configs.random_padding)
    train_sampler = None
    if configs.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=configs.batch_size,
                                  shuffle=(train_sampler is None),
                                  pin_memory=configs.pin_memory,
                                  num_workers=configs.num_workers,
                                  sampler=train_sampler,
                                  collate_fn=train_dataset.collate_fn)

    val_sampler = None
    val_dataset = KittiDataset(configs.dataset_dir,
                               split='val',
                               mode='val',
                               aug_transforms=None,
                               hflip_prob=0.,
                               multiscale=False,
                               num_samples=configs.num_samples,
                               mosaic=False,
                               random_padding=False)
    if configs.distributed:
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset, shuffle=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=configs.batch_size,
                                shuffle=False,
                                pin_memory=configs.pin_memory,
                                num_workers=configs.num_workers,
                                sampler=val_sampler,
                                collate_fn=val_dataset.collate_fn)

    return train_dataloader, val_dataloader, train_sampler
コード例 #2
0
def create_train_dataloader(configs):
    """Create dataloader for training"""
    train_lidar_aug = OneOf([
        Random_Rotation(limit_angle=np.pi / 4, p=1.0),
        Random_Scaling(scaling_range=(0.95, 1.05), p=1.0),
        Random_Rotate_Individual_Box(limit_angle=np.pi / 10, p=1.0)
    ],
                            p=0.75)
    train_dataset = KittiDataset(configs,
                                 mode='train',
                                 lidar_aug=train_lidar_aug,
                                 aug_transforms=None,
                                 num_samples=configs.num_samples)
    train_sampler = None
    if configs.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=configs.batch_size,
                                  shuffle=(train_sampler is None),
                                  pin_memory=configs.pin_memory,
                                  num_workers=configs.num_workers,
                                  sampler=train_sampler,
                                  collate_fn=train_dataset.collate_fn)

    return train_dataloader, train_sampler
コード例 #3
0
def create_val_dataloader(configs):
    """Create dataloader for validation"""

    val_sampler = None
    val_dataset = KittiDataset(configs.dataset_dir,
                               split='val',
                               mode='val',
                               aug_transforms=None,
                               hflip_prob=0.,
                               multiscale=False,
                               num_samples=configs.num_samples,
                               mosaic=False,
                               random_padding=False)
    if configs.distributed:
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset, shuffle=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=configs.batch_size,
                                shuffle=False,
                                pin_memory=configs.pin_memory,
                                num_workers=configs.num_workers,
                                sampler=val_sampler,
                                collate_fn=val_dataset.collate_fn)

    return val_dataloader
コード例 #4
0
def create_train_dataloader(configs):
    """Create dataloader for training"""

    train_lidar_transforms = OneOf([
        Random_Rotation(limit_angle=20., p=1.0),
        Random_Scaling(scaling_range=(0.95, 1.05), p=1.0)
    ], p=0.66)

    train_aug_transforms = Compose([
        Horizontal_Flip(p=configs.hflip_prob),
        Cutout(n_holes=configs.cutout_nholes, ratio=configs.cutout_ratio, fill_value=configs.cutout_fill_value,
               p=configs.cutout_prob)
    ], p=1.)

    train_dataset = KittiDataset(configs.dataset_dir, mode='train', lidar_transforms=train_lidar_transforms,
                                 aug_transforms=train_aug_transforms, multiscale=configs.multiscale_training,
                                 num_samples=configs.num_samples, mosaic=configs.mosaic,
                                 random_padding=configs.random_padding)
    train_sampler = None
    if configs.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=(train_sampler is None),
                                  pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=train_sampler,
                                  collate_fn=train_dataset.collate_fn)

    return train_dataloader, train_sampler
コード例 #5
0
def create_train_dataloader(configs, voxel_generator):
    """Create dataloader for training"""
    train_lidar_aug = OneOf([
        Random_Rotation(limit_angle=np.pi / 4, p=1.0),
        Random_Scaling(scaling_range=(0.95, 1.05), p=1.0),
    ],
                            p=0.66)
    train_dataset = KittiDataset(configs,
                                 voxel_generator,
                                 mode='train',
                                 lidar_aug=train_lidar_aug,
                                 hflip_prob=configs.hflip_prob,
                                 num_samples=configs.num_samples)
    train_sampler = None
    if configs.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=configs.batch_size,
                                  shuffle=(train_sampler is None),
                                  pin_memory=configs.pin_memory,
                                  num_workers=configs.num_workers,
                                  sampler=train_sampler,
                                  collate_fn=merge_batch)

    return train_dataloader, train_sampler
コード例 #6
0
def create_anotha_dataloader(configs):
    """Create dataloader for testing on training dataset"""

    print(configs.dataset_dir)
    test_dataset = KittiDataset(configs.dataset_dir,
                                mode='test',
                                lidar_transforms=None,
                                aug_transforms=None,
                                multiscale=False,
                                num_samples=configs.num_samples,
                                mosaic=False,
                                random_padding=False,
                                switcheroo=True)
    test_sampler = None
    if configs.distributed:
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            test_dataset)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=configs.batch_size,
                                 shuffle=False,
                                 pin_memory=configs.pin_memory,
                                 num_workers=configs.num_workers,
                                 sampler=test_sampler)

    return test_dataloader
コード例 #7
0
def create_val_dataloader(configs):
    """Create dataloader for validation"""
    val_sampler = None
    val_dataset = KittiDataset(configs, mode='val', lidar_aug=None, hflip_prob=0., num_samples=configs.num_samples)
    if configs.distributed:
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
    val_dataloader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False,
                                pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=val_sampler)

    return val_dataloader
コード例 #8
0
def create_test_dataset(configs):
    """Create dataloader for testing phase"""
    test_dataset = KittiDataset(configs.dataset_dir,
                                mode='test',
                                lidar_transforms=None,
                                aug_transforms=None,
                                multiscale=False,
                                num_samples=configs.num_samples,
                                mosaic=False,
                                random_padding=False)
    return test_dataset
コード例 #9
0
def create_test_dataloader(configs):
    """Create dataloader for testing phase"""

    test_dataset = KittiDataset(configs, mode='test', lidar_aug=None, hflip_prob=0., num_samples=configs.num_samples)
    test_sampler = None
    if configs.distributed:
        test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
    test_dataloader = DataLoader(test_dataset, batch_size=configs.batch_size, shuffle=False,
                                 pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=test_sampler)

    return test_dataloader
コード例 #10
0
def create_train_val_dataloader(configs):
    """Create dataloader for training and validate"""
    train_dataset = KittiDataset(configs.dataset_dir, split='train', mode='train', data_aug=True,
                                 multiscale=configs.multiscale_training, num_samples=configs.num_samples)
    train_sampler = None
    if configs.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=(train_sampler is None),
                                  pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=train_sampler,
                                  collate_fn=train_dataset.collate_fn)

    val_sampler = None
    val_dataset = KittiDataset(configs.dataset_dir, split='val', mode='val', data_aug=False, multiscale=False,
                               num_samples=configs.num_samples)
    if configs.distributed:
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
    val_dataloader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False,
                                pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=val_sampler,
                                collate_fn=val_dataset.collate_fn)

    return train_dataloader, val_dataloader, train_sampler
コード例 #11
0
def create_val_dataloader(configs):
    """Create dataloader for validation"""
    val_aug_transforms = Compose([
        Horizontal_Flip(p=configs.hflip_prob),
        Cutout(n_holes=configs.cutout_nholes, ratio=configs.cutout_ratio, fill_value=configs.cutout_fill_value,
               p=configs.cutout_prob)
    ], p=1.)
    val_sampler = None
    val_dataset = KittiDataset(configs.dataset_dir, mode='val', lidar_transforms=None,
                               aug_transforms=val_aug_transforms, multiscale=False, num_samples=configs.num_samples,
                               mosaic=False, random_padding=False)
    if configs.distributed:
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
    val_dataloader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False,
                                pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=val_sampler,
                                collate_fn=val_dataset.collate_fn)

    return val_dataloader
コード例 #12
0
def create_test_dataloader(configs):
    """Create dataloader for testing phase"""

    test_dataset = KittiDataset(configs.dataset_dir,
                                mode='test',
                                aug_transforms=None,
                                hflip_prob=0.,
                                multiscale=False,
                                num_samples=configs.num_samples,
                                mosaic=False,
                                random_padding=False)
    test_sampler = None
    if configs.distributed:
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            test_dataset)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=configs.batch_size,
                                 shuffle=False,
                                 pin_memory=configs.pin_memory,
                                 num_workers=configs.num_workers,
                                 sampler=test_sampler)

    return test_dataloader
コード例 #13
0
ファイル: kitti_dataloader.py プロジェクト: zhouleidcc/RTM3D
def create_train_dataloader(configs):
    """Create dataloader for training"""

    train_aug_transforms = album.Compose(
        [album.RandomBrightnessContrast(p=0.5),
         album.GaussNoise(p=0.5)], p=1.)
    train_dataset = KittiDataset(configs,
                                 mode='train',
                                 aug_transforms=train_aug_transforms,
                                 hflip_prob=configs.hflip_prob,
                                 use_left_cam_prob=configs.use_left_cam_prob,
                                 num_samples=configs.num_samples)
    train_sampler = None
    if configs.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=configs.batch_size,
                                  shuffle=(train_sampler is None),
                                  pin_memory=configs.pin_memory,
                                  num_workers=configs.num_workers,
                                  sampler=train_sampler)

    return train_dataloader, train_sampler