示例#1
0
def setup_loaders(args):
    """
    Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin]
    input: argument passed by the user
    return:  training data loader, validation data loader loader,  train_set
    """

    if args.dataset == 'cityscapes':
        args.dataset_cls = cityscapes
        args.train_batch_size = args.bs_mult * args.ngpu
        if args.bs_mult_val > 0:
            args.val_batch_size = args.bs_mult_val * args.ngpu
        else:
            args.val_batch_size = args.bs_mult * args.ngpu
    elif args.dataset == 'mapillary':
        args.dataset_cls = mapillary
        args.train_batch_size = args.bs_mult * args.ngpu
        args.val_batch_size = 4
    elif args.dataset == 'ade20k':
        args.dataset_cls = ade20k
        args.train_batch_size = args.bs_mult * args.ngpu
        args.val_batch_size = 4
    elif args.dataset == 'kitti':
        args.dataset_cls = kitti
        args.train_batch_size = args.bs_mult * args.ngpu
        if args.bs_mult_val > 0:
            args.val_batch_size = args.bs_mult_val * args.ngpu
        else:
            args.val_batch_size = args.bs_mult * args.ngpu
    elif args.dataset == 'camvid':
        args.dataset_cls = camvid
        args.train_batch_size = args.bs_mult * args.ngpu
        if args.bs_mult_val > 0:
            args.val_batch_size = args.bs_mult_val * args.ngpu
        else:
            args.val_batch_size = args.bs_mult * args.ngpu
    elif args.dataset == 'null_loader':
        args.dataset_cls = null_loader
        args.train_batch_size = args.bs_mult * args.ngpu
        if args.bs_mult_val > 0:
            args.val_batch_size = args.bs_mult_val * args.ngpu
        else:
            args.val_batch_size = args.bs_mult * args.ngpu
    else:
        raise Exception('Dataset {} is not supported'.format(args.dataset))

    # Readjust batch size to mini-batch size for apex
    if args.apex:
        args.train_batch_size = args.bs_mult
        args.val_batch_size = args.bs_mult_val

    args.num_workers = 4 * args.ngpu
    if args.test_mode:
        args.num_workers = 1


    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    # Geometric image transformations
    train_joint_transform_list = [
        joint_transforms.RandomSizeAndCrop(args.crop_size,
                                           False,
                                           pre_size=args.pre_size,
                                           scale_min=args.scale_min,
                                           scale_max=args.scale_max,
                                           ignore_index=args.dataset_cls.ignore_label),
        joint_transforms.Resize(args.crop_size),
        joint_transforms.RandomHorizontallyFlip()]
    train_joint_transform = joint_transforms.Compose(train_joint_transform_list)

    # Image appearance transformations
    train_input_transform = []
    if args.color_aug:
        train_input_transform += [extended_transforms.ColorJitter(
            brightness=args.color_aug,
            contrast=args.color_aug,
            saturation=args.color_aug,
            hue=args.color_aug)]

    if args.bblur:
        train_input_transform += [extended_transforms.RandomBilateralBlur()]
    elif args.gblur:
        train_input_transform += [extended_transforms.RandomGaussianBlur()]
    else:
        pass



    train_input_transform += [standard_transforms.ToTensor(),
                              standard_transforms.Normalize(*mean_std)]
    train_input_transform = standard_transforms.Compose(train_input_transform)

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()
    
    if args.jointwtborder:
        target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(args.dataset_cls.ignore_label, 
            args.dataset_cls.num_classes)
    else:
        target_train_transform = extended_transforms.MaskToTensor()

    if args.dataset == 'cityscapes':
        city_mode = 'train' ## Can be trainval
        city_quality = 'fine'
        if args.class_uniform_pct:
            if args.coarse_boost_classes:
                coarse_boost_classes = \
                    [int(c) for c in args.coarse_boost_classes.split(',')]
            else:
                coarse_boost_classes = None
            train_set = args.dataset_cls.CityScapesUniform(
                city_quality, city_mode, args.maxSkip,
                joint_transform_list=train_joint_transform_list,
                transform=train_input_transform,
                target_transform=target_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv,
                class_uniform_pct=args.class_uniform_pct,
                class_uniform_tile=args.class_uniform_tile,
                test=args.test_mode,
                coarse_boost_classes=coarse_boost_classes)
        else:
            train_set = args.dataset_cls.CityScapes(
                city_quality, city_mode, 0, 
                joint_transform=train_joint_transform,
                transform=train_input_transform,
                target_transform=target_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv)

        val_set = args.dataset_cls.CityScapes('fine', 'val', 0, 
                                              transform=val_input_transform,
                                              target_transform=target_transform,
                                              cv_split=args.cv)
    elif args.dataset == 'mapillary':
        eval_size = 1536
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size, ignore_index=args.dataset_cls.ignore_label)]
        train_set = args.dataset_cls.Mapillary(
            'semantic', 'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform,
            dump_images=args.dump_augmentation_images,
            class_uniform_pct=args.class_uniform_pct,
            class_uniform_tile=args.class_uniform_tile,
            test=args.test_mode)
        val_set = args.dataset_cls.Mapillary(
            'semantic', 'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform,
            test=False)
    elif args.dataset == 'ade20k':
        eval_size = 384
        val_joint_transform_list = [
                joint_transforms.ResizeHeight(eval_size),
  		joint_transforms.CenterCropPad(eval_size)]
            
        train_set = args.dataset_cls.ade20k(
            'semantic', 'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform,
            dump_images=args.dump_augmentation_images,
            class_uniform_pct=args.class_uniform_pct,
            class_uniform_tile=args.class_uniform_tile,
            test=args.test_mode)
        val_set = args.dataset_cls.ade20k(
            'semantic', 'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform,
            test=False)
    elif args.dataset == 'kitti':
        # eval_size_h = 384
        # eval_size_w = 1280
        # val_joint_transform_list = [
        #         joint_transforms.ResizeHW(eval_size_h, eval_size_w)]
            
        train_set = args.dataset_cls.KITTI(
            'semantic', 'train', args.maxSkip,
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform,
            dump_images=args.dump_augmentation_images,
            class_uniform_pct=args.class_uniform_pct,
            class_uniform_tile=args.class_uniform_tile,
            test=args.test_mode,
            cv_split=args.cv,
            scf=args.scf,
            hardnm=args.hardnm)
        val_set = args.dataset_cls.KITTI(
            'semantic', 'trainval', 0, 
            joint_transform_list=None,
            transform=val_input_transform,
            target_transform=target_transform,
            test=False,
            cv_split=args.cv,
            scf=None)
    elif args.dataset == 'camvid':
        # eval_size_h = 384
        # eval_size_w = 1280
        # val_joint_transform_list = [
        #         joint_transforms.ResizeHW(eval_size_h, eval_size_w)]
            
        train_set = args.dataset_cls.CAMVID(
            'semantic', 'trainval', args.maxSkip,
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform,
            dump_images=args.dump_augmentation_images,
            class_uniform_pct=args.class_uniform_pct,
            class_uniform_tile=args.class_uniform_tile,
            test=args.test_mode,
            cv_split=args.cv,
            scf=args.scf,
            hardnm=args.hardnm)
        val_set = args.dataset_cls.CAMVID(
            'semantic', 'test', 0, 
            joint_transform_list=None,
            transform=val_input_transform,
            target_transform=target_transform,
            test=False,
            cv_split=args.cv,
            scf=None)

    elif args.dataset == 'null_loader':
        train_set = args.dataset_cls.null_loader(args.crop_size)
        val_set = args.dataset_cls.null_loader(args.crop_size)
    else:
        raise Exception('Dataset {} is not supported'.format(args.dataset))
    
    if args.apex:
        from datasets.sampler import DistributedSampler
        train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False)
        val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False)

    else:
        train_sampler = None
        val_sampler = None

    train_loader = DataLoader(train_set, batch_size=args.train_batch_size,
                              num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)
    val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
                            num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler)

    return train_loader, val_loader,  train_set
示例#2
0
def setup_loaders(args):
    """
    Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin]
    input: argument passed by the user
    return:  training data loader, validation data loader loader,  train_set
    """

    args.train_batch_size = args.bs_mult * args.ngpu
    if args.bs_mult_val > 0:
        args.val_batch_size = args.bs_mult_val * args.ngpu
    else:
        args.val_batch_size = args.bs_mult * args.ngpu

    # Readjust batch size to mini-batch size for syncbn
    if args.syncbn:
        args.train_batch_size = args.bs_mult
        args.val_batch_size = args.bs_mult_val

    args.num_workers = 8  #1 * args.ngpu
    if args.test_mode:
        args.num_workers = 1

    train_sets = []
    val_sets = []
    val_dataset_names = []

    if 'cityscapes' in args.dataset:
        dataset = cityscapes
        city_mode = args.city_mode  #'train' ## Can be trainval
        city_quality = 'fine'
        train_joint_transform_list, train_joint_transform = get_train_joint_transform(
            args, dataset)
        train_input_transform, val_input_transform = get_input_transforms(
            args, dataset)
        target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(
            args, dataset)

        if args.class_uniform_pct:
            if args.coarse_boost_classes:
                coarse_boost_classes = \
                    [int(c) for c in args.coarse_boost_classes.split(',')]
            else:
                coarse_boost_classes = None

            train_set = dataset.CityScapesUniform(
                city_quality,
                city_mode,
                args.maxSkip,
                joint_transform_list=train_joint_transform_list,
                transform=train_input_transform,
                target_transform=target_train_transform,
                target_aux_transform=target_aux_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv,
                class_uniform_pct=args.class_uniform_pct,
                class_uniform_tile=args.class_uniform_tile,
                test=args.test_mode,
                coarse_boost_classes=coarse_boost_classes,
                image_in=args.image_in)
        else:
            train_set = dataset.CityScapes(
                city_quality,
                city_mode,
                0,
                joint_transform=train_joint_transform,
                transform=train_input_transform,
                target_transform=target_train_transform,
                target_aux_transform=target_aux_train_transform,
                dump_images=args.dump_augmentation_images,
                image_in=args.image_in)

        val_set = dataset.CityScapes('fine',
                                     'val',
                                     0,
                                     transform=val_input_transform,
                                     target_transform=target_transform,
                                     cv_split=args.cv,
                                     image_in=args.image_in)
        train_sets.append(train_set)
        val_sets.append(val_set)
        val_dataset_names.append('cityscapes')

    if 'bdd100k' in args.dataset:
        dataset = bdd100k
        bdd_mode = 'train'  ## Can be trainval
        train_joint_transform_list, train_joint_transform = get_train_joint_transform(
            args, dataset)
        train_input_transform, val_input_transform = get_input_transforms(
            args, dataset)
        target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(
            args, dataset)

        if args.class_uniform_pct:
            if args.coarse_boost_classes:
                coarse_boost_classes = \
                    [int(c) for c in args.coarse_boost_classes.split(',')]
            else:
                coarse_boost_classes = None

            train_set = dataset.BDD100KUniform(
                bdd_mode,
                args.maxSkip,
                joint_transform_list=train_joint_transform_list,
                transform=train_input_transform,
                target_transform=target_train_transform,
                target_aux_transform=target_aux_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv,
                class_uniform_pct=args.class_uniform_pct,
                class_uniform_tile=args.class_uniform_tile,
                test=args.test_mode,
                coarse_boost_classes=coarse_boost_classes,
                image_in=args.image_in)
        else:
            train_set = dataset.BDD100K(
                bdd_mode,
                0,
                joint_transform=train_joint_transform,
                transform=train_input_transform,
                target_transform=target_train_transform,
                target_aux_transform=target_aux_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv,
                image_in=args.image_in)

        val_set = dataset.BDD100K('val',
                                  0,
                                  transform=val_input_transform,
                                  target_transform=target_transform,
                                  cv_split=args.cv,
                                  image_in=args.image_in)
        train_sets.append(train_set)
        val_sets.append(val_set)
        val_dataset_names.append('bdd100k')

    if 'gtav' in args.dataset:
        dataset = gtav
        gtav_mode = 'train'  ## Can be trainval
        train_joint_transform_list, train_joint_transform = get_train_joint_transform(
            args, dataset)
        train_input_transform, val_input_transform = get_input_transforms(
            args, dataset)
        target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(
            args, dataset)

        if args.class_uniform_pct:
            if args.coarse_boost_classes:
                coarse_boost_classes = \
                    [int(c) for c in args.coarse_boost_classes.split(',')]
            else:
                coarse_boost_classes = None

            train_set = dataset.GTAVUniform(
                gtav_mode,
                args.maxSkip,
                joint_transform_list=train_joint_transform_list,
                transform=train_input_transform,
                target_transform=target_train_transform,
                target_aux_transform=target_aux_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv,
                class_uniform_pct=args.class_uniform_pct,
                class_uniform_tile=args.class_uniform_tile,
                test=args.test_mode,
                coarse_boost_classes=coarse_boost_classes,
                image_in=args.image_in)
        else:
            train_set = gtav.GTAV(
                gtav_mode,
                0,
                joint_transform=train_joint_transform,
                transform=train_input_transform,
                target_transform=target_train_transform,
                target_aux_transform=target_aux_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv,
                image_in=args.image_in)

        val_set = gtav.GTAV('val',
                            0,
                            transform=val_input_transform,
                            target_transform=target_transform,
                            cv_split=args.cv,
                            image_in=args.image_in)
        train_sets.append(train_set)
        val_sets.append(val_set)
        val_dataset_names.append('gtav')

    if 'synthia' in args.dataset:
        dataset = synthia
        synthia_mode = 'train'  ## Can be trainval
        train_joint_transform_list, train_joint_transform = get_train_joint_transform(
            args, dataset)
        train_input_transform, val_input_transform = get_input_transforms(
            args, dataset)
        target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(
            args, dataset)

        if args.class_uniform_pct:
            if args.coarse_boost_classes:
                coarse_boost_classes = \
                    [int(c) for c in args.coarse_boost_classes.split(',')]
            else:
                coarse_boost_classes = None

            train_set = dataset.SynthiaUniform(
                synthia_mode,
                args.maxSkip,
                joint_transform_list=train_joint_transform_list,
                transform=train_input_transform,
                target_transform=target_train_transform,
                target_aux_transform=target_aux_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv,
                class_uniform_pct=args.class_uniform_pct,
                class_uniform_tile=args.class_uniform_tile,
                test=args.test_mode,
                coarse_boost_classes=coarse_boost_classes,
                image_in=args.image_in)
        else:
            train_set = dataset.Synthia(
                synthia_mode,
                0,
                joint_transform=train_joint_transform,
                transform=train_input_transform,
                target_transform=target_train_transform,
                target_aux_transform=target_aux_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv,
                image_in=args.image_in)

        val_set = dataset.Synthia('val',
                                  0,
                                  transform=val_input_transform,
                                  target_transform=target_transform,
                                  cv_split=args.cv,
                                  image_in=args.image_in)
        train_sets.append(train_set)
        val_sets.append(val_set)
        val_dataset_names.append('synthia')

    if 'mapillary' in args.dataset:
        dataset = mapillary
        train_joint_transform_list, train_joint_transform = get_train_joint_transform(
            args, dataset)
        train_input_transform, val_input_transform = get_input_transforms(
            args, dataset)
        target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(
            args, dataset)

        eval_size = 1536
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size)
        ]

        train_set = dataset.Mapillary(
            'semantic',
            'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform,
            target_aux_transform=target_aux_train_transform,
            image_in=args.image_in,
            dump_images=args.dump_augmentation_images,
            class_uniform_pct=args.class_uniform_pct,
            class_uniform_tile=args.class_uniform_tile,
            test=args.test_mode)
        val_set = dataset.Mapillary(
            'semantic',
            'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform,
            image_in=args.image_in,
            test=False)
        train_sets.append(train_set)
        val_sets.append(val_set)
        val_dataset_names.append('mapillary')

    if 'null_loader' in args.dataset:
        train_set = nullloader.nullloader(args.crop_size)
        val_set = nullloader.nullloader(args.crop_size)

        train_sets.append(train_set)
        val_sets.append(val_set)
        val_dataset_names.append('null_loader')

    if len(train_sets) == 0:
        raise Exception('Dataset {} is not supported'.format(args.dataset))

    if len(train_sets) != len(args.dataset):
        raise Exception(
            'Something went wrong. Please check your dataset names are valid')

    # Define new train data set that has all the train sets
    # Define new val data set that has all the val sets
    val_loaders = {}
    if len(args.dataset) != 1:
        if args.image_uniform_sampling:
            train_set = ConcatDataset(train_sets)
        else:
            train_set = multi_loader.DomainUniformConcatDataset(
                args, train_sets)

    for i, val_set in enumerate(val_sets):
        if args.syncbn:
            val_sampler = DistributedSampler(val_set,
                                             pad=False,
                                             permutation=False,
                                             consecutive_sample=False)
        else:
            val_sampler = None
        val_loader = DataLoader(val_set,
                                batch_size=args.val_batch_size,
                                num_workers=args.num_workers // 2,
                                shuffle=False,
                                drop_last=False,
                                sampler=val_sampler)
        val_loaders[val_dataset_names[i]] = val_loader

    if args.syncbn:
        train_sampler = DistributedSampler(train_set,
                                           pad=True,
                                           permutation=True,
                                           consecutive_sample=False)
    else:
        train_sampler = None

    train_loader = DataLoader(train_set,
                              batch_size=args.train_batch_size,
                              num_workers=args.num_workers,
                              shuffle=(train_sampler is None),
                              drop_last=True,
                              sampler=train_sampler)

    extra_val_loader = {}
    for val_dataset in args.val_dataset:
        extra_val_loader[val_dataset] = create_extra_val_loader(
            args, val_dataset, val_input_transform, target_transform,
            val_sampler)

    covstat_val_loader = {}
    for val_dataset in args.covstat_val_dataset:
        covstat_val_loader[val_dataset] = create_covstat_val_loader(
            args, val_dataset, val_input_transform, target_transform,
            val_sampler)

    return train_loader, val_loaders, train_set, extra_val_loader, covstat_val_loader
示例#3
0
def setup_loaders(args):
    """
    Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin]
    input: argument passed by the user
    return:  training data loader, validation data loader loader,  train_set
    """

    # TODO add error checking to make sure class exists
    logx.msg(f'dataset = {args.dataset}')

    mod = importlib.import_module('datasets.{}'.format(args.dataset))
    dataset_cls = getattr(mod, 'Loader')

    logx.msg(f'ignore_label = {dataset_cls.ignore_label}')

    update_dataset_cfg(num_classes=dataset_cls.num_classes,
                       ignore_label=dataset_cls.ignore_label)

    ######################################################################
    # Define transformations, augmentations
    ######################################################################

    # Joint transformations that must happen on both image and mask
    if ',' in args.crop_size:
        args.crop_size = [int(x) for x in args.crop_size.split(',')]
    else:
        args.crop_size = int(args.crop_size)
    train_joint_transform_list = [
        # TODO FIXME: move these hparams into cfg
        joint_transforms.RandomSizeAndCrop(args.crop_size,
                                           False,
                                           scale_min=args.scale_min,
                                           scale_max=args.scale_max,
                                           full_size=args.full_crop_training,
                                           pre_size=args.pre_size)]
    train_joint_transform_list.append(
        joint_transforms.RandomHorizontallyFlip())

    if args.rand_augment is not None:
        N, M = [int(i) for i in args.rand_augment.split(',')]
        assert isinstance(N, int) and isinstance(M, int), \
            f'Either N {N} or M {M} not integer'
        train_joint_transform_list.append(RandAugment(N, M))

    ######################################################################
    # Image only augmentations
    ######################################################################
    train_input_transform = []

    if args.color_aug:
        train_input_transform += [extended_transforms.ColorJitter(
            brightness=args.color_aug,
            contrast=args.color_aug,
            saturation=args.color_aug,
            hue=args.color_aug)]
    if args.bblur:
        train_input_transform += [extended_transforms.RandomBilateralBlur()]
    elif args.gblur:
        train_input_transform += [extended_transforms.RandomGaussianBlur()]

    mean_std = (cfg.DATASET.MEAN, cfg.DATASET.STD)
    train_input_transform += [standard_transforms.ToTensor(),
                              standard_transforms.Normalize(*mean_std)]
    train_input_transform = standard_transforms.Compose(train_input_transform)

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()

    if args.jointwtborder:
        target_train_transform = \
            extended_transforms.RelaxedBoundaryLossToTensor()
    else:
        target_train_transform = extended_transforms.MaskToTensor()

    if args.eval == 'folder':
        val_joint_transform_list = None
    elif 'mapillary' in args.dataset:
        if args.pre_size is None:
            eval_size = 2177
        else:
            eval_size = args.pre_size
        if cfg.DATASET.MAPILLARY_CROP_VAL:
            val_joint_transform_list = [
                joint_transforms.ResizeHeight(eval_size),
                joint_transforms.CenterCropPad(eval_size)]
        else:
            val_joint_transform_list = [
                joint_transforms.Scale(eval_size)]
    else:
        val_joint_transform_list = None

    if args.eval is None or args.eval == 'val':
        val_name = 'val'
    elif args.eval == 'trn':
        val_name = 'train'
    elif args.eval == 'folder':
        val_name = 'folder'
    else:
        raise 'unknown eval mode {}'.format(args.eval)

    ######################################################################
    # Create loaders
    ######################################################################
    val_set = dataset_cls(
        mode=val_name,
        joint_transform_list=val_joint_transform_list,
        img_transform=val_input_transform,
        label_transform=target_transform,
        eval_folder=args.eval_folder)

    update_dataset_inst(dataset_inst=val_set)

    if args.apex:
        from datasets.sampler import DistributedSampler
        val_sampler = DistributedSampler(val_set, pad=False, permutation=False,
                                         consecutive_sample=False)
    else:
        val_sampler = None

    val_loader = DataLoader(val_set, batch_size=args.bs_val,
                            num_workers=args.num_workers // 2,
                            shuffle=False, drop_last=False,
                            sampler=val_sampler)

    if args.eval is not None:
        # Don't create train dataloader if eval
        train_set = None
        train_loader = None
    else:
        train_set = dataset_cls(
            mode='train',
            joint_transform_list=train_joint_transform_list,
            img_transform=train_input_transform,
            label_transform=target_train_transform)

        if args.apex:
            from datasets.sampler import DistributedSampler
            train_sampler = DistributedSampler(train_set, pad=True,
                                               permutation=True,
                                               consecutive_sample=False)
            train_batch_size = args.bs_trn
        else:
            train_sampler = None
            train_batch_size = args.bs_trn * args.ngpu

        train_loader = DataLoader(train_set, batch_size=train_batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=(train_sampler is None),
                                  drop_last=True, sampler=train_sampler)

    return train_loader, val_loader, train_set
示例#4
0
def create_extra_val_loader(args, dataset, val_input_transform,
                            target_transform, val_sampler):
    """
    Create extra validation loader
    Args:
        args: input config arguments
        dataset: dataset class object
        val_input_transform: validation input transforms
        target_transform: target transforms
        val_sampler: validation sampler

    return: validation loaders
    """
    if dataset == 'cityscapes':
        val_set = cityscapes.CityScapes('fine',
                                        'val',
                                        0,
                                        transform=val_input_transform,
                                        target_transform=target_transform,
                                        cv_split=args.cv,
                                        image_in=args.image_in)
    elif dataset == 'bdd100k':
        val_set = bdd100k.BDD100K('val',
                                  0,
                                  transform=val_input_transform,
                                  target_transform=target_transform,
                                  cv_split=args.cv,
                                  image_in=args.image_in)
    elif dataset == 'gtav':
        val_set = gtav.GTAV('val',
                            0,
                            transform=val_input_transform,
                            target_transform=target_transform,
                            cv_split=args.cv,
                            image_in=args.image_in)
    elif dataset == 'synthia':
        val_set = synthia.Synthia('val',
                                  0,
                                  transform=val_input_transform,
                                  target_transform=target_transform,
                                  cv_split=args.cv,
                                  image_in=args.image_in)
    elif dataset == 'mapillary':
        eval_size = 1536
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size)
        ]
        val_set = mapillary.Mapillary(
            'semantic',
            'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform,
            test=False)
    elif dataset == 'null_loader':
        val_set = nullloader.nullloader(args.crop_size)
    else:
        raise Exception('Dataset {} is not supported'.format(dataset))

    if args.syncbn:
        from datasets.sampler import DistributedSampler
        val_sampler = DistributedSampler(val_set,
                                         pad=False,
                                         permutation=False,
                                         consecutive_sample=False)

    else:
        val_sampler = None

    val_loader = DataLoader(val_set,
                            batch_size=args.val_batch_size,
                            num_workers=args.num_workers // 2,
                            shuffle=False,
                            drop_last=False,
                            sampler=val_sampler)
    return val_loader
def setup_loaders(args):
    """
    Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin]
    input: argument passed by the user
    return:  training data loader, validation data loader loader,  train_set
    """
    if args.dataset == 'OmniAudio_noBG_Paralleltask':
        args.dataset_cls = OmniAudio_noBG_Paralleltask
        args.train_batch_size = 4#args.bs_mult * args.ngpu
        args.val_batch_size = 4
    elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth':
        args.dataset_cls = OmniAudio_noBG_Paralleltask_depth
        args.train_batch_size = 4#args.bs_mult * args.ngpu
        args.val_batch_size = 4
    elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth_noSeman':
        args.dataset_cls = OmniAudio_noBG_Paralleltask_depth_noSeman
        args.train_batch_size = 4#args.bs_mult * args.ngpu
        args.val_batch_size = 4
    else:
        raise Exception('Dataset {} is not supported'.format(args.dataset))

    # Readjust batch size to mini-batch size for apex
    if args.apex:
        args.train_batch_size = args.bs_mult
        args.val_batch_size = args.bs_mult_val

    
    args.num_workers = 4 * args.ngpu
    if args.test_mode:
        args.num_workers = 1


    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    # Geometric image transformations
    train_joint_transform_list = [
        joint_transforms.RandomSizeAndCrop(args.crop_size,
                                           False,
                                           pre_size=args.pre_size,
                                           scale_min=args.scale_min,
                                           scale_max=args.scale_max,
                                           ignore_index=args.dataset_cls.ignore_label),
        joint_transforms.Resize(args.crop_size),
        joint_transforms.RandomHorizontallyFlip()]
    train_joint_transform = joint_transforms.Compose(train_joint_transform_list)

    # Image appearance transformations
    train_input_transform = []
    if args.color_aug:
        train_input_transform += [extended_transforms.ColorJitter(
            brightness=args.color_aug,
            contrast=args.color_aug,
            saturation=args.color_aug,
            hue=args.color_aug)]

    if args.bblur:
        train_input_transform += [extended_transforms.RandomBilateralBlur()]
    elif args.gblur:
        train_input_transform += [extended_transforms.RandomGaussianBlur()]
    else:
        pass



    train_input_transform += [standard_transforms.ToTensor(),
                              standard_transforms.Normalize(*mean_std)]
    train_input_transform = standard_transforms.Compose(train_input_transform)

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()
    
    if args.jointwtborder: 
        target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(args.dataset_cls.ignore_label, 
            args.dataset_cls.num_classes)
    else:
        target_train_transform = extended_transforms.MaskToTensor()

    if args.dataset == 'OmniAudio_noBG_Paralleltask':
        eval_size = 960
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size)]
        train_set = args.dataset_cls.OmniAudio(
            'semantic', 'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform)
        val_set = args.dataset_cls.OmniAudio(
            'semantic', 'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform)

    elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth':
        eval_size = 960
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size)]
        train_set = args.dataset_cls.OmniAudio(
            'semantic', 'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform)
        val_set = args.dataset_cls.OmniAudio(
            'semantic', 'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform)


    elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth_noSeman':
        eval_size = 960
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size)]
        train_set = args.dataset_cls.OmniAudio(
            'semantic', 'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform)
        val_set = args.dataset_cls.OmniAudio(
            'semantic', 'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform)
        
    elif args.dataset == 'null_loader':
        train_set = args.dataset_cls.null_loader(args.crop_size)
        val_set = args.dataset_cls.null_loader(args.crop_size)
    else:
        raise Exception('Dataset {} is not supported'.format(args.dataset))
    
    if args.apex:
        from datasets.sampler import DistributedSampler
        train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False)
        val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False)

    else:
        train_sampler = None
        val_sampler = None

    train_loader = DataLoader(train_set, batch_size=args.train_batch_size,
                              num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)
    val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
                            num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler)

    return train_loader, val_loader,  train_set