def create_train_dataloader(configs): """Create dataloader for training""" train_lidar_transforms = OneOf([ Random_Rotation(limit_angle=20., p=1.0), Random_Scaling(scaling_range=(0.95, 1.05), p=1.0) ], p=0.66) train_aug_transforms = Compose([ Horizontal_Flip(p=configs.hflip_prob), Cutout(n_holes=configs.cutout_nholes, ratio=configs.cutout_ratio, fill_value=configs.cutout_fill_value, p=configs.cutout_prob) ], p=1.) train_dataset = KittiDataset(configs.dataset_dir, mode='train', lidar_transforms=train_lidar_transforms, aug_transforms=train_aug_transforms, multiscale=configs.multiscale_training, num_samples=configs.num_samples, mosaic=configs.mosaic, random_padding=configs.random_padding) train_sampler = None if configs.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=(train_sampler is None), pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=train_sampler, collate_fn=train_dataset.collate_fn) return train_dataloader, train_sampler
def create_train_val_dataloader(configs): """Create dataloader for training and validate""" train_transform = Compose([ Random_Crop(max_reduction_percent=0.15, p=1.), Random_HFlip(p=0.5), Random_Rotate(rotation_angle_limit=15, p=0.5), ], p=1.) val_transform = None resize_transform = Resize(new_size=tuple(configs.input_size), p=1.0) train_events_infor, val_events_infor = train_val_data_separation(configs) train_dataset = TTNet_Dataset(train_events_infor, configs.events_dict, configs.input_size, transform=train_transform, resize=resize_transform, num_samples=configs.num_samples) if not configs.no_val: val_dataset = TTNet_Dataset(val_events_infor, configs.events_dict, configs.input_size, transform=val_transform, resize=resize_transform, num_samples=configs.num_samples) if configs.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if not configs.no_val: val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) else: train_sampler = None if not configs.no_val: val_sampler = None train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=(train_sampler is None), pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=train_sampler) if not configs.no_val: val_dataloader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False, pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=val_sampler) else: val_dataloader = None return train_dataloader, val_dataloader, train_sampler
def create_val_dataloader(configs): """Create dataloader for validation""" val_aug_transforms = Compose([ Horizontal_Flip(p=configs.hflip_prob), Cutout(n_holes=configs.cutout_nholes, ratio=configs.cutout_ratio, fill_value=configs.cutout_fill_value, p=configs.cutout_prob) ], p=1.) val_sampler = None val_dataset = KittiDataset(configs.dataset_dir, mode='val', lidar_transforms=None, aug_transforms=val_aug_transforms, multiscale=False, num_samples=configs.num_samples, mosaic=False, random_padding=False) if configs.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) val_dataloader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False, pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=val_sampler, collate_fn=val_dataset.collate_fn) return val_dataloader
if __name__ == '__main__': import cv2 import matplotlib.pyplot as plt from config.config import parse_configs from data_process.ttnet_data_utils import get_events_infor, train_val_data_separation from data_process.transformation import Compose, Random_Crop, Resize, Random_HFlip, Random_Rotate configs = parse_configs() game_list = ['game_1'] dataset_type = 'training' train_events_infor, val_events_infor = train_val_data_separation(configs) print('len(train_events_infor): {}'.format(len(train_events_infor))) # Test transformation transform = Compose([ Random_Crop(max_reduction_percent=0.15, p=1.), Random_HFlip(p=1.), Random_Rotate(rotation_angle_limit=15, p=1.) ], p=1.) resize_transform = Resize(new_size=tuple(configs.input_size), p=1.0) ttnet_dataset = TTNet_Dataset(train_events_infor, configs.events_dict, configs.input_size, transform=transform, resize=resize_transform) print('len(ttnet_dataset): {}'.format(len(ttnet_dataset))) example_index = 100 origin_imgs, resized_imgs, org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg = ttnet_dataset.__getitem__( example_index)