def get_train_loader(dataset_dict, local_rank): target_transform = transforms.Compose([ mytransforms.FreeScaleMask((dataset_dict['h'], dataset_dict['w'])), mytransforms.MaskToTensor(), ]) segment_transform = transforms.Compose([ mytransforms.FreeScaleMask( (dataset_dict['h'] // 8, dataset_dict['w'] // 8)), #36 100 mytransforms.MaskToTensor(), ]) img_transform = transforms.Compose([ transforms.Resize((dataset_dict['h'], dataset_dict['w'])), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) simu_transform = mytransforms.Compose2([ mytransforms.RandomRotate(6), mytransforms.RandomUDoffsetLABEL(100), mytransforms.RandomLROffsetLABEL(200) ]) if dataset_dict['name'] == 'CULane': train_dataset = LaneClsDataset(dataset_dict, img_transform=img_transform, target_transform=target_transform, simu_transform=simu_transform, segment_transform=segment_transform) elif dataset_dict['name'] == 'Tusimple': train_dataset = LaneClsDataset(dataset_dict['data_root'], os.path.join(dataset_dict['data_root'], 'train_gt.txt'), img_transform=img_transform, target_transform=target_transform, simu_transform=simu_transform, griding_num=dataset_dict['griding_num'], row_anchor=dataset_dict['row_anchor'], segment_transform=segment_transform, use_aux=dataset_dict['use_aux'], num_lanes=dataset_dict['num_lanes']) else: raise NotImplementedError if local_rank == -1: train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=dataset_dict['batch_size'], shuffle=True, num_workers=dataset_dict['num_workers']) else: num_gpus = torch.cuda.device_count() torch.cuda.set_device(local_rank % num_gpus) torch.distributed.init_process_group(backend='nccl') sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=dataset_dict['batch_size'], sampler=sampler, num_workers=dataset_dict['num_workers']) return train_loader
def get_train_loader(batch_size, data_root, griding_num, dataset, use_aux, distributed): target_transform = transforms.Compose([ mytransforms.FreeScaleMask((288, 800)), mytransforms.MaskToTensor(), ]) segment_transform = transforms.Compose([ mytransforms.FreeScaleMask((36, 100)), mytransforms.MaskToTensor(), ]) img_transform = transforms.Compose([ transforms.Resize((288, 800)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) simu_transform = mytransforms.Compose2([ mytransforms.RandomRotate(6), mytransforms.RandomUDoffsetLABEL(100), mytransforms.RandomLROffsetLABEL(200) ]) if dataset == 'CULane': train_dataset = LaneClsDataset(data_root, os.path.join(data_root, 'list/train_gt.txt'), img_transform=img_transform, target_transform=target_transform, simu_transform = simu_transform, segment_transform=segment_transform, row_anchor = culane_row_anchor, griding_num=griding_num, use_aux=use_aux) cls_num_per_lane = 18 elif dataset == 'Tusimple': train_dataset = LaneClsDataset(data_root, os.path.join(data_root, 'train_gt.txt'), img_transform=img_transform, target_transform=target_transform, simu_transform = simu_transform, griding_num=griding_num, row_anchor = tusimple_row_anchor, segment_transform=segment_transform,use_aux=use_aux) cls_num_per_lane = 56 else: raise NotImplementedError if distributed: sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: sampler = torch.utils.data.RandomSampler(train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler = sampler, num_workers=4) return train_loader, cls_num_per_lane
def get_train_loader(batch_size, data_root, dataset, distributed): img_transform = transforms.Compose([ transforms.Resize((256, 512)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) simu_transform = mytransforms.Compose2([ mytransforms.RandomRotate(6), mytransforms.RandomUDoffsetLABEL(100), mytransforms.RandomLROffsetLABEL(200) ]) if dataset == 'CULane': row_anchor = np.linspace(90, 255, row_num).tolist() train_dataset = LaneClsDataset(data_root, os.path.join(data_root, 'list/train_gt.txt'), img_transform=img_transform, simu_transform = simu_transform, segment_transform=None, row_anchor = row_anchor, griding_num=griding_num, use_aux=use_aux) else: raise NotImplementedError if distributed: sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: sampler = torch.utils.data.RandomSampler(train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler = sampler, num_workers=4) return train_loader
def get_val_loader(batch_size, data_root, griding_num, dataset, use_aux, distributed, num_lanes, cfg): target_transform = transforms.Compose([ # Pyten-20200128-ChangeInputSize # mytransforms.FreeScaleMask((288, 800)), mytransforms.FreeScaleMask((cfg.height, cfg.width)), mytransforms.MaskToTensor(), ]) img_transform = transforms.Compose([ # Pyten-20200129-ChangeInputSize # transforms.Resize((288, 800)), transforms.Resize((cfg.height, cfg.width)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) segment_transform = transforms.Compose([ # Pyten-20200128-ChangeInputSize # mytransforms.FreeScaleMask((36, 100)), mytransforms.FreeScaleMask((cfg.height // 8, cfg.width // 8)), mytransforms.MaskToTensor(), ]) if dataset == 'CULane': val_dataset = LaneClsDataset(data_root, os.path.join(data_root, 'list/val_gt.txt'), img_transform=img_transform, target_transform=target_transform, simu_transform=None, segment_transform=segment_transform, row_anchor=cfg.anchors, griding_num=griding_num, use_aux=use_aux, num_lanes=num_lanes) elif dataset == 'Bdd100k': val_dataset = BddLaneClsDataset(data_root, os.path.join(data_root, 'val.txt'), img_transform=img_transform, target_transform=target_transform, simu_transform=None, griding_num=griding_num, row_anchor=cfg.anchors, segment_transform=segment_transform, use_aux=use_aux, num_lanes=num_lanes, mode="val") elif dataset == 'neolix': val_dataset = LaneClsDataset(data_root, os.path.join(data_root, 'val.txt'), img_transform=img_transform, target_transform=target_transform, simu_transform=None, segment_transform=segment_transform, row_anchor=cfg.anchors, griding_num=griding_num, use_aux=use_aux, num_lanes=num_lanes) elif dataset == 'Tusimple': val_dataset = LaneClsDataset(data_root, os.path.join(data_root, 'train_gt.txt'), img_transform=img_transform, target_transform=target_transform, simu_transform=None, griding_num=griding_num, row_anchor=cfg.anchors, segment_transform=segment_transform, use_aux=use_aux, num_lanes=num_lanes) else: raise NotImplementedError if distributed: sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) else: sampler = torch.utils.data.RandomSampler(val_dataset) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, sampler=sampler, num_workers=4) return val_loader
def get_train_loader(batch_size, data_root, griding_num, dataset, use_aux, distributed, num_lanes, cfg): target_transform = transforms.Compose([ # Pyten-20200128-ChangeInputSize # mytransforms.FreeScaleMask((288, 800)), mytransforms.FreeScaleMask((cfg.height, cfg.width)), mytransforms.MaskToTensor(), ]) segment_transform = transforms.Compose([ # Pyten-20200128-ChangeInputSize # mytransforms.FreeScaleMask((36, 100)), mytransforms.FreeScaleMask((cfg.height // 8, cfg.width // 8)), mytransforms.MaskToTensor(), ]) img_transform = transforms.Compose([ # Pyten-20200128-ChangeInputSize # transforms.Resize((288, 800)), transforms.Resize((cfg.height, cfg.width)), transforms.ToTensor(), # Pyten-20210126-Addnewtransform # transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) # Pyten-20210203-AddAlbumTransform albumtransforms = A.Compose([ A.HorizontalFlip(p=0.5), A.OneOf([ A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.9), A.RandomBrightnessContrast( brightness_limit=0.2, contrast_limit=0.2, p=0.9), ], p=0.8), A.OneOf([ A.MotionBlur(p=0.2), A.MedianBlur(blur_limit=3, p=0.1), A.Blur(blur_limit=3, p=0.1), ], p=0.2), A.OneOf( [ #A.IAAAdditiveGaussianNoise(), A.GaussNoise(), ], p=0.2), A.OneOf( [ A.OpticalDistortion(p=0.3), A.GridDistortion(p=0.1), #A.IAAPiecewiseAffine(p=0.3), ], p=0.2), A.OneOf( [ A.CLAHE(clip_limit=2), #A.IAASharpen(), #A.IAAEmboss(), A.RandomBrightnessContrast(), ], p=0.3), #A.Blur(blur_limit=3, p=0.5), # A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5), #A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5), #A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5), #A.JpegCompression(p=0.5), #A.HueSaturationValue(p=0.5), #A.CLAHE(p=0.5), #A.MotionBlur(p=0.3) ]) simu_transform = mytransforms.Compose2([ #mytransforms.AlbumAug(albumtransforms), mytransforms.RandomRotate(6), mytransforms.RandomUDoffsetLABEL(100), mytransforms.RandomLROffsetLABEL(200) ]) if dataset == 'CULane': if "anchors" not in cfg: cfg.anchors = culane_row_anchor train_dataset = LaneClsDataset(data_root, os.path.join(data_root, 'list/train_gt.txt'), img_transform=img_transform, target_transform=target_transform, simu_transform=simu_transform, segment_transform=segment_transform, row_anchor=cfg.anchors, griding_num=griding_num, use_aux=use_aux, num_lanes=num_lanes) cls_num_per_lane = 18 elif dataset == 'Bdd100k': if "anchors" not in cfg: cfg.anchors = tusimple_row_anchor # culane_row_anchor train_dataset = BddLaneClsDataset( data_root, os.path.join(data_root, 'train.txt'), #'new_train.txt ' #'train.txt' 2000 img_transform=img_transform, target_transform=target_transform, simu_transform=simu_transform, griding_num=griding_num, row_anchor=cfg.anchors, segment_transform=segment_transform, use_aux=use_aux, num_lanes=num_lanes, mode="train") cls_num_per_lane = 56 # 18 elif dataset == 'neolix': if "anchors" not in cfg: cfg.anchors = tusimple_row_anchor train_dataset = LaneClsDataset(data_root, os.path.join(data_root, 'train.txt'), img_transform=img_transform, target_transform=target_transform, simu_transform=simu_transform, segment_transform=segment_transform, row_anchor=cfg.anchors, griding_num=griding_num, use_aux=use_aux, num_lanes=num_lanes) cls_num_per_lane = 56 elif dataset == 'Tusimple': if "anchors" not in cfg: cfg.anchors = tusimple_row_anchor train_dataset = LaneClsDataset(data_root, os.path.join(data_root, 'train_gt.txt'), img_transform=img_transform, target_transform=target_transform, simu_transform=simu_transform, griding_num=griding_num, row_anchor=cfg.anchors, segment_transform=segment_transform, use_aux=use_aux, num_lanes=num_lanes) cls_num_per_lane = 56 else: raise NotImplementedError if distributed: sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: sampler = torch.utils.data.RandomSampler(train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4) return train_loader, cls_num_per_lane