コード例 #1
0
def get_train_loader(dataset_dict, local_rank):
    target_transform = transforms.Compose([
        mytransforms.FreeScaleMask((dataset_dict['h'], dataset_dict['w'])),
        mytransforms.MaskToTensor(),
    ])
    segment_transform = transforms.Compose([
        mytransforms.FreeScaleMask(
            (dataset_dict['h'] // 8, dataset_dict['w'] // 8)),  #36 100
        mytransforms.MaskToTensor(),
    ])
    img_transform = transforms.Compose([
        transforms.Resize((dataset_dict['h'], dataset_dict['w'])),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    simu_transform = mytransforms.Compose2([
        mytransforms.RandomRotate(6),
        mytransforms.RandomUDoffsetLABEL(100),
        mytransforms.RandomLROffsetLABEL(200)
    ])
    if dataset_dict['name'] == 'CULane':
        train_dataset = LaneClsDataset(dataset_dict,
                                       img_transform=img_transform,
                                       target_transform=target_transform,
                                       simu_transform=simu_transform,
                                       segment_transform=segment_transform)
    elif dataset_dict['name'] == 'Tusimple':
        train_dataset = LaneClsDataset(dataset_dict['data_root'],
                                       os.path.join(dataset_dict['data_root'],
                                                    'train_gt.txt'),
                                       img_transform=img_transform,
                                       target_transform=target_transform,
                                       simu_transform=simu_transform,
                                       griding_num=dataset_dict['griding_num'],
                                       row_anchor=dataset_dict['row_anchor'],
                                       segment_transform=segment_transform,
                                       use_aux=dataset_dict['use_aux'],
                                       num_lanes=dataset_dict['num_lanes'])
    else:
        raise NotImplementedError

    if local_rank == -1:
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=dataset_dict['batch_size'],
            shuffle=True,
            num_workers=dataset_dict['num_workers'])
    else:
        num_gpus = torch.cuda.device_count()
        torch.cuda.set_device(local_rank % num_gpus)
        torch.distributed.init_process_group(backend='nccl')
        sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=dataset_dict['batch_size'],
            sampler=sampler,
            num_workers=dataset_dict['num_workers'])
    return train_loader
コード例 #2
0
def get_train_loader(batch_size, data_root, griding_num, dataset, use_aux, distributed):
    target_transform = transforms.Compose([
        mytransforms.FreeScaleMask((288, 800)),
        mytransforms.MaskToTensor(),
    ])
    segment_transform = transforms.Compose([
        mytransforms.FreeScaleMask((36, 100)),
        mytransforms.MaskToTensor(),
    ])
    img_transform = transforms.Compose([
        transforms.Resize((288, 800)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    simu_transform = mytransforms.Compose2([
        mytransforms.RandomRotate(6),
        mytransforms.RandomUDoffsetLABEL(100),
        mytransforms.RandomLROffsetLABEL(200)
    ])
    if dataset == 'CULane':
        train_dataset = LaneClsDataset(data_root,
                                           os.path.join(data_root, 'list/train_gt.txt'),
                                           img_transform=img_transform, target_transform=target_transform,
                                           simu_transform = simu_transform,
                                           segment_transform=segment_transform, 
                                           row_anchor = culane_row_anchor,
                                           griding_num=griding_num, use_aux=use_aux)
        cls_num_per_lane = 18

    elif dataset == 'Tusimple':
        train_dataset = LaneClsDataset(data_root,
                                           os.path.join(data_root, 'train_gt.txt'),
                                           img_transform=img_transform, target_transform=target_transform,
                                           simu_transform = simu_transform,
                                           griding_num=griding_num, 
                                           row_anchor = tusimple_row_anchor,
                                           segment_transform=segment_transform,use_aux=use_aux)
        cls_num_per_lane = 56
    else:
        raise NotImplementedError

    if distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler = sampler, num_workers=4)

    return train_loader, cls_num_per_lane
コード例 #3
0
def get_val_loader(batch_size, data_root, griding_num, dataset, use_aux,
                   distributed, num_lanes, cfg):
    target_transform = transforms.Compose([
        # Pyten-20200128-ChangeInputSize
        # mytransforms.FreeScaleMask((288, 800)),
        mytransforms.FreeScaleMask((cfg.height, cfg.width)),
        mytransforms.MaskToTensor(),
    ])
    img_transform = transforms.Compose([
        # Pyten-20200129-ChangeInputSize
        # transforms.Resize((288, 800)),
        transforms.Resize((cfg.height, cfg.width)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    segment_transform = transforms.Compose([
        # Pyten-20200128-ChangeInputSize
        # mytransforms.FreeScaleMask((36, 100)),
        mytransforms.FreeScaleMask((cfg.height // 8, cfg.width // 8)),
        mytransforms.MaskToTensor(),
    ])

    if dataset == 'CULane':
        val_dataset = LaneClsDataset(data_root,
                                     os.path.join(data_root,
                                                  'list/val_gt.txt'),
                                     img_transform=img_transform,
                                     target_transform=target_transform,
                                     simu_transform=None,
                                     segment_transform=segment_transform,
                                     row_anchor=cfg.anchors,
                                     griding_num=griding_num,
                                     use_aux=use_aux,
                                     num_lanes=num_lanes)

    elif dataset == 'Bdd100k':
        val_dataset = BddLaneClsDataset(data_root,
                                        os.path.join(data_root, 'val.txt'),
                                        img_transform=img_transform,
                                        target_transform=target_transform,
                                        simu_transform=None,
                                        griding_num=griding_num,
                                        row_anchor=cfg.anchors,
                                        segment_transform=segment_transform,
                                        use_aux=use_aux,
                                        num_lanes=num_lanes,
                                        mode="val")

    elif dataset == 'neolix':
        val_dataset = LaneClsDataset(data_root,
                                     os.path.join(data_root, 'val.txt'),
                                     img_transform=img_transform,
                                     target_transform=target_transform,
                                     simu_transform=None,
                                     segment_transform=segment_transform,
                                     row_anchor=cfg.anchors,
                                     griding_num=griding_num,
                                     use_aux=use_aux,
                                     num_lanes=num_lanes)

    elif dataset == 'Tusimple':
        val_dataset = LaneClsDataset(data_root,
                                     os.path.join(data_root, 'train_gt.txt'),
                                     img_transform=img_transform,
                                     target_transform=target_transform,
                                     simu_transform=None,
                                     griding_num=griding_num,
                                     row_anchor=cfg.anchors,
                                     segment_transform=segment_transform,
                                     use_aux=use_aux,
                                     num_lanes=num_lanes)
    else:
        raise NotImplementedError

    if distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(val_dataset)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             sampler=sampler,
                                             num_workers=4)

    return val_loader
コード例 #4
0
def get_train_loader(batch_size, data_root, griding_num, dataset, use_aux,
                     distributed, num_lanes, cfg):
    target_transform = transforms.Compose([
        # Pyten-20200128-ChangeInputSize
        # mytransforms.FreeScaleMask((288, 800)),
        mytransforms.FreeScaleMask((cfg.height, cfg.width)),
        mytransforms.MaskToTensor(),
    ])
    segment_transform = transforms.Compose([
        # Pyten-20200128-ChangeInputSize
        # mytransforms.FreeScaleMask((36, 100)),
        mytransforms.FreeScaleMask((cfg.height // 8, cfg.width // 8)),
        mytransforms.MaskToTensor(),
    ])
    img_transform = transforms.Compose([
        # Pyten-20200128-ChangeInputSize
        # transforms.Resize((288, 800)),
        transforms.Resize((cfg.height, cfg.width)),
        transforms.ToTensor(),
        # Pyten-20210126-Addnewtransform
        # transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    # Pyten-20210203-AddAlbumTransform
    albumtransforms = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.OneOf([
            A.HueSaturationValue(hue_shift_limit=0.2,
                                 sat_shift_limit=0.2,
                                 val_shift_limit=0.2,
                                 p=0.9),
            A.RandomBrightnessContrast(
                brightness_limit=0.2, contrast_limit=0.2, p=0.9),
        ],
                p=0.8),
        A.OneOf([
            A.MotionBlur(p=0.2),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.Blur(blur_limit=3, p=0.1),
        ],
                p=0.2),
        A.OneOf(
            [
                #A.IAAAdditiveGaussianNoise(),
                A.GaussNoise(),
            ],
            p=0.2),
        A.OneOf(
            [
                A.OpticalDistortion(p=0.3),
                A.GridDistortion(p=0.1),
                #A.IAAPiecewiseAffine(p=0.3),
            ],
            p=0.2),
        A.OneOf(
            [
                A.CLAHE(clip_limit=2),
                #A.IAASharpen(),
                #A.IAAEmboss(),
                A.RandomBrightnessContrast(),
            ],
            p=0.3),

        #A.Blur(blur_limit=3, p=0.5),
        # A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),
        #A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
        #A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        #A.JpegCompression(p=0.5),
        #A.HueSaturationValue(p=0.5),
        #A.CLAHE(p=0.5),
        #A.MotionBlur(p=0.3)
    ])

    simu_transform = mytransforms.Compose2([
        #mytransforms.AlbumAug(albumtransforms),
        mytransforms.RandomRotate(6),
        mytransforms.RandomUDoffsetLABEL(100),
        mytransforms.RandomLROffsetLABEL(200)
    ])
    if dataset == 'CULane':
        if "anchors" not in cfg:
            cfg.anchors = culane_row_anchor
        train_dataset = LaneClsDataset(data_root,
                                       os.path.join(data_root,
                                                    'list/train_gt.txt'),
                                       img_transform=img_transform,
                                       target_transform=target_transform,
                                       simu_transform=simu_transform,
                                       segment_transform=segment_transform,
                                       row_anchor=cfg.anchors,
                                       griding_num=griding_num,
                                       use_aux=use_aux,
                                       num_lanes=num_lanes)
        cls_num_per_lane = 18
    elif dataset == 'Bdd100k':
        if "anchors" not in cfg:
            cfg.anchors = tusimple_row_anchor  # culane_row_anchor
        train_dataset = BddLaneClsDataset(
            data_root,
            os.path.join(data_root,
                         'train.txt'),  #'new_train.txt ' #'train.txt' 2000
            img_transform=img_transform,
            target_transform=target_transform,
            simu_transform=simu_transform,
            griding_num=griding_num,
            row_anchor=cfg.anchors,
            segment_transform=segment_transform,
            use_aux=use_aux,
            num_lanes=num_lanes,
            mode="train")
        cls_num_per_lane = 56  # 18

    elif dataset == 'neolix':
        if "anchors" not in cfg:
            cfg.anchors = tusimple_row_anchor
        train_dataset = LaneClsDataset(data_root,
                                       os.path.join(data_root, 'train.txt'),
                                       img_transform=img_transform,
                                       target_transform=target_transform,
                                       simu_transform=simu_transform,
                                       segment_transform=segment_transform,
                                       row_anchor=cfg.anchors,
                                       griding_num=griding_num,
                                       use_aux=use_aux,
                                       num_lanes=num_lanes)
        cls_num_per_lane = 56

    elif dataset == 'Tusimple':
        if "anchors" not in cfg:
            cfg.anchors = tusimple_row_anchor
        train_dataset = LaneClsDataset(data_root,
                                       os.path.join(data_root, 'train_gt.txt'),
                                       img_transform=img_transform,
                                       target_transform=target_transform,
                                       simu_transform=simu_transform,
                                       griding_num=griding_num,
                                       row_anchor=cfg.anchors,
                                       segment_transform=segment_transform,
                                       use_aux=use_aux,
                                       num_lanes=num_lanes)
        cls_num_per_lane = 56
    else:
        raise NotImplementedError

    if distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               sampler=sampler,
                                               num_workers=4)

    return train_loader, cls_num_per_lane