def setup_loaders(args): """ Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin] input: argument passed by the user return: training data loader, validation data loader loader, train_set """ if args.dataset == 'cityscapes': args.dataset_cls = cityscapes args.train_batch_size = args.bs_mult * args.ngpu if args.bs_mult_val > 0: args.val_batch_size = args.bs_mult_val * args.ngpu else: args.val_batch_size = args.bs_mult * args.ngpu elif args.dataset == 'mapillary': args.dataset_cls = mapillary args.train_batch_size = args.bs_mult * args.ngpu args.val_batch_size = 4 elif args.dataset == 'ade20k': args.dataset_cls = ade20k args.train_batch_size = args.bs_mult * args.ngpu args.val_batch_size = 4 elif args.dataset == 'kitti': args.dataset_cls = kitti args.train_batch_size = args.bs_mult * args.ngpu if args.bs_mult_val > 0: args.val_batch_size = args.bs_mult_val * args.ngpu else: args.val_batch_size = args.bs_mult * args.ngpu elif args.dataset == 'camvid': args.dataset_cls = camvid args.train_batch_size = args.bs_mult * args.ngpu if args.bs_mult_val > 0: args.val_batch_size = args.bs_mult_val * args.ngpu else: args.val_batch_size = args.bs_mult * args.ngpu elif args.dataset == 'null_loader': args.dataset_cls = null_loader args.train_batch_size = args.bs_mult * args.ngpu if args.bs_mult_val > 0: args.val_batch_size = args.bs_mult_val * args.ngpu else: args.val_batch_size = args.bs_mult * args.ngpu else: raise Exception('Dataset {} is not supported'.format(args.dataset)) # Readjust batch size to mini-batch size for apex if args.apex: args.train_batch_size = args.bs_mult args.val_batch_size = args.bs_mult_val args.num_workers = 4 * args.ngpu if args.test_mode: args.num_workers = 1 mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Geometric image transformations train_joint_transform_list = [ joint_transforms.RandomSizeAndCrop(args.crop_size, False, pre_size=args.pre_size, scale_min=args.scale_min, scale_max=args.scale_max, ignore_index=args.dataset_cls.ignore_label), joint_transforms.Resize(args.crop_size), joint_transforms.RandomHorizontallyFlip()] train_joint_transform = joint_transforms.Compose(train_joint_transform_list) # Image appearance transformations train_input_transform = [] if args.color_aug: train_input_transform += [extended_transforms.ColorJitter( brightness=args.color_aug, contrast=args.color_aug, saturation=args.color_aug, hue=args.color_aug)] if args.bblur: train_input_transform += [extended_transforms.RandomBilateralBlur()] elif args.gblur: train_input_transform += [extended_transforms.RandomGaussianBlur()] else: pass train_input_transform += [standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std)] train_input_transform = standard_transforms.Compose(train_input_transform) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() if args.jointwtborder: target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(args.dataset_cls.ignore_label, args.dataset_cls.num_classes) else: target_train_transform = extended_transforms.MaskToTensor() if args.dataset == 'cityscapes': city_mode = 'train' ## Can be trainval city_quality = 'fine' if args.class_uniform_pct: if args.coarse_boost_classes: coarse_boost_classes = \ [int(c) for c in args.coarse_boost_classes.split(',')] else: coarse_boost_classes = None train_set = args.dataset_cls.CityScapesUniform( city_quality, city_mode, args.maxSkip, joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform, dump_images=args.dump_augmentation_images, cv_split=args.cv, class_uniform_pct=args.class_uniform_pct, class_uniform_tile=args.class_uniform_tile, test=args.test_mode, coarse_boost_classes=coarse_boost_classes) else: train_set = args.dataset_cls.CityScapes( city_quality, city_mode, 0, joint_transform=train_joint_transform, transform=train_input_transform, target_transform=target_train_transform, dump_images=args.dump_augmentation_images, cv_split=args.cv) val_set = args.dataset_cls.CityScapes('fine', 'val', 0, transform=val_input_transform, target_transform=target_transform, cv_split=args.cv) elif args.dataset == 'mapillary': eval_size = 1536 val_joint_transform_list = [ joint_transforms.ResizeHeight(eval_size), joint_transforms.CenterCropPad(eval_size, ignore_index=args.dataset_cls.ignore_label)] train_set = args.dataset_cls.Mapillary( 'semantic', 'train', joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform, dump_images=args.dump_augmentation_images, class_uniform_pct=args.class_uniform_pct, class_uniform_tile=args.class_uniform_tile, test=args.test_mode) val_set = args.dataset_cls.Mapillary( 'semantic', 'val', joint_transform_list=val_joint_transform_list, transform=val_input_transform, target_transform=target_transform, test=False) elif args.dataset == 'ade20k': eval_size = 384 val_joint_transform_list = [ joint_transforms.ResizeHeight(eval_size), joint_transforms.CenterCropPad(eval_size)] train_set = args.dataset_cls.ade20k( 'semantic', 'train', joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform, dump_images=args.dump_augmentation_images, class_uniform_pct=args.class_uniform_pct, class_uniform_tile=args.class_uniform_tile, test=args.test_mode) val_set = args.dataset_cls.ade20k( 'semantic', 'val', joint_transform_list=val_joint_transform_list, transform=val_input_transform, target_transform=target_transform, test=False) elif args.dataset == 'kitti': # eval_size_h = 384 # eval_size_w = 1280 # val_joint_transform_list = [ # joint_transforms.ResizeHW(eval_size_h, eval_size_w)] train_set = args.dataset_cls.KITTI( 'semantic', 'train', args.maxSkip, joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform, dump_images=args.dump_augmentation_images, class_uniform_pct=args.class_uniform_pct, class_uniform_tile=args.class_uniform_tile, test=args.test_mode, cv_split=args.cv, scf=args.scf, hardnm=args.hardnm) val_set = args.dataset_cls.KITTI( 'semantic', 'trainval', 0, joint_transform_list=None, transform=val_input_transform, target_transform=target_transform, test=False, cv_split=args.cv, scf=None) elif args.dataset == 'camvid': # eval_size_h = 384 # eval_size_w = 1280 # val_joint_transform_list = [ # joint_transforms.ResizeHW(eval_size_h, eval_size_w)] train_set = args.dataset_cls.CAMVID( 'semantic', 'trainval', args.maxSkip, joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform, dump_images=args.dump_augmentation_images, class_uniform_pct=args.class_uniform_pct, class_uniform_tile=args.class_uniform_tile, test=args.test_mode, cv_split=args.cv, scf=args.scf, hardnm=args.hardnm) val_set = args.dataset_cls.CAMVID( 'semantic', 'test', 0, joint_transform_list=None, transform=val_input_transform, target_transform=target_transform, test=False, cv_split=args.cv, scf=None) elif args.dataset == 'null_loader': train_set = args.dataset_cls.null_loader(args.crop_size) val_set = args.dataset_cls.null_loader(args.crop_size) else: raise Exception('Dataset {} is not supported'.format(args.dataset)) if args.apex: from datasets.sampler import DistributedSampler train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False) val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False) else: train_sampler = None val_sampler = None train_loader = DataLoader(train_set, batch_size=args.train_batch_size, num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler) val_loader = DataLoader(val_set, batch_size=args.val_batch_size, num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler) return train_loader, val_loader, train_set
def setup_loaders(args): """ Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin] input: argument passed by the user return: training data loader, validation data loader loader, train_set """ args.train_batch_size = args.bs_mult * args.ngpu if args.bs_mult_val > 0: args.val_batch_size = args.bs_mult_val * args.ngpu else: args.val_batch_size = args.bs_mult * args.ngpu # Readjust batch size to mini-batch size for syncbn if args.syncbn: args.train_batch_size = args.bs_mult args.val_batch_size = args.bs_mult_val args.num_workers = 8 #1 * args.ngpu if args.test_mode: args.num_workers = 1 train_sets = [] val_sets = [] val_dataset_names = [] if 'cityscapes' in args.dataset: dataset = cityscapes city_mode = args.city_mode #'train' ## Can be trainval city_quality = 'fine' train_joint_transform_list, train_joint_transform = get_train_joint_transform( args, dataset) train_input_transform, val_input_transform = get_input_transforms( args, dataset) target_transform, target_train_transform, target_aux_train_transform = get_target_transforms( args, dataset) if args.class_uniform_pct: if args.coarse_boost_classes: coarse_boost_classes = \ [int(c) for c in args.coarse_boost_classes.split(',')] else: coarse_boost_classes = None train_set = dataset.CityScapesUniform( city_quality, city_mode, args.maxSkip, joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform, target_aux_transform=target_aux_train_transform, dump_images=args.dump_augmentation_images, cv_split=args.cv, class_uniform_pct=args.class_uniform_pct, class_uniform_tile=args.class_uniform_tile, test=args.test_mode, coarse_boost_classes=coarse_boost_classes, image_in=args.image_in) else: train_set = dataset.CityScapes( city_quality, city_mode, 0, joint_transform=train_joint_transform, transform=train_input_transform, target_transform=target_train_transform, target_aux_transform=target_aux_train_transform, dump_images=args.dump_augmentation_images, image_in=args.image_in) val_set = dataset.CityScapes('fine', 'val', 0, transform=val_input_transform, target_transform=target_transform, cv_split=args.cv, image_in=args.image_in) train_sets.append(train_set) val_sets.append(val_set) val_dataset_names.append('cityscapes') if 'bdd100k' in args.dataset: dataset = bdd100k bdd_mode = 'train' ## Can be trainval train_joint_transform_list, train_joint_transform = get_train_joint_transform( args, dataset) train_input_transform, val_input_transform = get_input_transforms( args, dataset) target_transform, target_train_transform, target_aux_train_transform = get_target_transforms( args, dataset) if args.class_uniform_pct: if args.coarse_boost_classes: coarse_boost_classes = \ [int(c) for c in args.coarse_boost_classes.split(',')] else: coarse_boost_classes = None train_set = dataset.BDD100KUniform( bdd_mode, args.maxSkip, joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform, target_aux_transform=target_aux_train_transform, dump_images=args.dump_augmentation_images, cv_split=args.cv, class_uniform_pct=args.class_uniform_pct, class_uniform_tile=args.class_uniform_tile, test=args.test_mode, coarse_boost_classes=coarse_boost_classes, image_in=args.image_in) else: train_set = dataset.BDD100K( bdd_mode, 0, joint_transform=train_joint_transform, transform=train_input_transform, target_transform=target_train_transform, target_aux_transform=target_aux_train_transform, dump_images=args.dump_augmentation_images, cv_split=args.cv, image_in=args.image_in) val_set = dataset.BDD100K('val', 0, transform=val_input_transform, target_transform=target_transform, cv_split=args.cv, image_in=args.image_in) train_sets.append(train_set) val_sets.append(val_set) val_dataset_names.append('bdd100k') if 'gtav' in args.dataset: dataset = gtav gtav_mode = 'train' ## Can be trainval train_joint_transform_list, train_joint_transform = get_train_joint_transform( args, dataset) train_input_transform, val_input_transform = get_input_transforms( args, dataset) target_transform, target_train_transform, target_aux_train_transform = get_target_transforms( args, dataset) if args.class_uniform_pct: if args.coarse_boost_classes: coarse_boost_classes = \ [int(c) for c in args.coarse_boost_classes.split(',')] else: coarse_boost_classes = None train_set = dataset.GTAVUniform( gtav_mode, args.maxSkip, joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform, target_aux_transform=target_aux_train_transform, dump_images=args.dump_augmentation_images, cv_split=args.cv, class_uniform_pct=args.class_uniform_pct, class_uniform_tile=args.class_uniform_tile, test=args.test_mode, coarse_boost_classes=coarse_boost_classes, image_in=args.image_in) else: train_set = gtav.GTAV( gtav_mode, 0, joint_transform=train_joint_transform, transform=train_input_transform, target_transform=target_train_transform, target_aux_transform=target_aux_train_transform, dump_images=args.dump_augmentation_images, cv_split=args.cv, image_in=args.image_in) val_set = gtav.GTAV('val', 0, transform=val_input_transform, target_transform=target_transform, cv_split=args.cv, image_in=args.image_in) train_sets.append(train_set) val_sets.append(val_set) val_dataset_names.append('gtav') if 'synthia' in args.dataset: dataset = synthia synthia_mode = 'train' ## Can be trainval train_joint_transform_list, train_joint_transform = get_train_joint_transform( args, dataset) train_input_transform, val_input_transform = get_input_transforms( args, dataset) target_transform, target_train_transform, target_aux_train_transform = get_target_transforms( args, dataset) if args.class_uniform_pct: if args.coarse_boost_classes: coarse_boost_classes = \ [int(c) for c in args.coarse_boost_classes.split(',')] else: coarse_boost_classes = None train_set = dataset.SynthiaUniform( synthia_mode, args.maxSkip, joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform, target_aux_transform=target_aux_train_transform, dump_images=args.dump_augmentation_images, cv_split=args.cv, class_uniform_pct=args.class_uniform_pct, class_uniform_tile=args.class_uniform_tile, test=args.test_mode, coarse_boost_classes=coarse_boost_classes, image_in=args.image_in) else: train_set = dataset.Synthia( synthia_mode, 0, joint_transform=train_joint_transform, transform=train_input_transform, target_transform=target_train_transform, target_aux_transform=target_aux_train_transform, dump_images=args.dump_augmentation_images, cv_split=args.cv, image_in=args.image_in) val_set = dataset.Synthia('val', 0, transform=val_input_transform, target_transform=target_transform, cv_split=args.cv, image_in=args.image_in) train_sets.append(train_set) val_sets.append(val_set) val_dataset_names.append('synthia') if 'mapillary' in args.dataset: dataset = mapillary train_joint_transform_list, train_joint_transform = get_train_joint_transform( args, dataset) train_input_transform, val_input_transform = get_input_transforms( args, dataset) target_transform, target_train_transform, target_aux_train_transform = get_target_transforms( args, dataset) eval_size = 1536 val_joint_transform_list = [ joint_transforms.ResizeHeight(eval_size), joint_transforms.CenterCropPad(eval_size) ] train_set = dataset.Mapillary( 'semantic', 'train', joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform, target_aux_transform=target_aux_train_transform, image_in=args.image_in, dump_images=args.dump_augmentation_images, class_uniform_pct=args.class_uniform_pct, class_uniform_tile=args.class_uniform_tile, test=args.test_mode) val_set = dataset.Mapillary( 'semantic', 'val', joint_transform_list=val_joint_transform_list, transform=val_input_transform, target_transform=target_transform, image_in=args.image_in, test=False) train_sets.append(train_set) val_sets.append(val_set) val_dataset_names.append('mapillary') if 'null_loader' in args.dataset: train_set = nullloader.nullloader(args.crop_size) val_set = nullloader.nullloader(args.crop_size) train_sets.append(train_set) val_sets.append(val_set) val_dataset_names.append('null_loader') if len(train_sets) == 0: raise Exception('Dataset {} is not supported'.format(args.dataset)) if len(train_sets) != len(args.dataset): raise Exception( 'Something went wrong. Please check your dataset names are valid') # Define new train data set that has all the train sets # Define new val data set that has all the val sets val_loaders = {} if len(args.dataset) != 1: if args.image_uniform_sampling: train_set = ConcatDataset(train_sets) else: train_set = multi_loader.DomainUniformConcatDataset( args, train_sets) for i, val_set in enumerate(val_sets): if args.syncbn: val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False) else: val_sampler = None val_loader = DataLoader(val_set, batch_size=args.val_batch_size, num_workers=args.num_workers // 2, shuffle=False, drop_last=False, sampler=val_sampler) val_loaders[val_dataset_names[i]] = val_loader if args.syncbn: train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False) else: train_sampler = None train_loader = DataLoader(train_set, batch_size=args.train_batch_size, num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler) extra_val_loader = {} for val_dataset in args.val_dataset: extra_val_loader[val_dataset] = create_extra_val_loader( args, val_dataset, val_input_transform, target_transform, val_sampler) covstat_val_loader = {} for val_dataset in args.covstat_val_dataset: covstat_val_loader[val_dataset] = create_covstat_val_loader( args, val_dataset, val_input_transform, target_transform, val_sampler) return train_loader, val_loaders, train_set, extra_val_loader, covstat_val_loader
def setup_loaders(args): """ Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin] input: argument passed by the user return: training data loader, validation data loader loader, train_set """ # TODO add error checking to make sure class exists logx.msg(f'dataset = {args.dataset}') mod = importlib.import_module('datasets.{}'.format(args.dataset)) dataset_cls = getattr(mod, 'Loader') logx.msg(f'ignore_label = {dataset_cls.ignore_label}') update_dataset_cfg(num_classes=dataset_cls.num_classes, ignore_label=dataset_cls.ignore_label) ###################################################################### # Define transformations, augmentations ###################################################################### # Joint transformations that must happen on both image and mask if ',' in args.crop_size: args.crop_size = [int(x) for x in args.crop_size.split(',')] else: args.crop_size = int(args.crop_size) train_joint_transform_list = [ # TODO FIXME: move these hparams into cfg joint_transforms.RandomSizeAndCrop(args.crop_size, False, scale_min=args.scale_min, scale_max=args.scale_max, full_size=args.full_crop_training, pre_size=args.pre_size)] train_joint_transform_list.append( joint_transforms.RandomHorizontallyFlip()) if args.rand_augment is not None: N, M = [int(i) for i in args.rand_augment.split(',')] assert isinstance(N, int) and isinstance(M, int), \ f'Either N {N} or M {M} not integer' train_joint_transform_list.append(RandAugment(N, M)) ###################################################################### # Image only augmentations ###################################################################### train_input_transform = [] if args.color_aug: train_input_transform += [extended_transforms.ColorJitter( brightness=args.color_aug, contrast=args.color_aug, saturation=args.color_aug, hue=args.color_aug)] if args.bblur: train_input_transform += [extended_transforms.RandomBilateralBlur()] elif args.gblur: train_input_transform += [extended_transforms.RandomGaussianBlur()] mean_std = (cfg.DATASET.MEAN, cfg.DATASET.STD) train_input_transform += [standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std)] train_input_transform = standard_transforms.Compose(train_input_transform) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() if args.jointwtborder: target_train_transform = \ extended_transforms.RelaxedBoundaryLossToTensor() else: target_train_transform = extended_transforms.MaskToTensor() if args.eval == 'folder': val_joint_transform_list = None elif 'mapillary' in args.dataset: if args.pre_size is None: eval_size = 2177 else: eval_size = args.pre_size if cfg.DATASET.MAPILLARY_CROP_VAL: val_joint_transform_list = [ joint_transforms.ResizeHeight(eval_size), joint_transforms.CenterCropPad(eval_size)] else: val_joint_transform_list = [ joint_transforms.Scale(eval_size)] else: val_joint_transform_list = None if args.eval is None or args.eval == 'val': val_name = 'val' elif args.eval == 'trn': val_name = 'train' elif args.eval == 'folder': val_name = 'folder' else: raise 'unknown eval mode {}'.format(args.eval) ###################################################################### # Create loaders ###################################################################### val_set = dataset_cls( mode=val_name, joint_transform_list=val_joint_transform_list, img_transform=val_input_transform, label_transform=target_transform, eval_folder=args.eval_folder) update_dataset_inst(dataset_inst=val_set) if args.apex: from datasets.sampler import DistributedSampler val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False) else: val_sampler = None val_loader = DataLoader(val_set, batch_size=args.bs_val, num_workers=args.num_workers // 2, shuffle=False, drop_last=False, sampler=val_sampler) if args.eval is not None: # Don't create train dataloader if eval train_set = None train_loader = None else: train_set = dataset_cls( mode='train', joint_transform_list=train_joint_transform_list, img_transform=train_input_transform, label_transform=target_train_transform) if args.apex: from datasets.sampler import DistributedSampler train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False) train_batch_size = args.bs_trn else: train_sampler = None train_batch_size = args.bs_trn * args.ngpu train_loader = DataLoader(train_set, batch_size=train_batch_size, num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler) return train_loader, val_loader, train_set
def create_extra_val_loader(args, dataset, val_input_transform, target_transform, val_sampler): """ Create extra validation loader Args: args: input config arguments dataset: dataset class object val_input_transform: validation input transforms target_transform: target transforms val_sampler: validation sampler return: validation loaders """ if dataset == 'cityscapes': val_set = cityscapes.CityScapes('fine', 'val', 0, transform=val_input_transform, target_transform=target_transform, cv_split=args.cv, image_in=args.image_in) elif dataset == 'bdd100k': val_set = bdd100k.BDD100K('val', 0, transform=val_input_transform, target_transform=target_transform, cv_split=args.cv, image_in=args.image_in) elif dataset == 'gtav': val_set = gtav.GTAV('val', 0, transform=val_input_transform, target_transform=target_transform, cv_split=args.cv, image_in=args.image_in) elif dataset == 'synthia': val_set = synthia.Synthia('val', 0, transform=val_input_transform, target_transform=target_transform, cv_split=args.cv, image_in=args.image_in) elif dataset == 'mapillary': eval_size = 1536 val_joint_transform_list = [ joint_transforms.ResizeHeight(eval_size), joint_transforms.CenterCropPad(eval_size) ] val_set = mapillary.Mapillary( 'semantic', 'val', joint_transform_list=val_joint_transform_list, transform=val_input_transform, target_transform=target_transform, test=False) elif dataset == 'null_loader': val_set = nullloader.nullloader(args.crop_size) else: raise Exception('Dataset {} is not supported'.format(dataset)) if args.syncbn: from datasets.sampler import DistributedSampler val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False) else: val_sampler = None val_loader = DataLoader(val_set, batch_size=args.val_batch_size, num_workers=args.num_workers // 2, shuffle=False, drop_last=False, sampler=val_sampler) return val_loader
def setup_loaders(args): """ Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin] input: argument passed by the user return: training data loader, validation data loader loader, train_set """ if args.dataset == 'OmniAudio_noBG_Paralleltask': args.dataset_cls = OmniAudio_noBG_Paralleltask args.train_batch_size = 4#args.bs_mult * args.ngpu args.val_batch_size = 4 elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth': args.dataset_cls = OmniAudio_noBG_Paralleltask_depth args.train_batch_size = 4#args.bs_mult * args.ngpu args.val_batch_size = 4 elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth_noSeman': args.dataset_cls = OmniAudio_noBG_Paralleltask_depth_noSeman args.train_batch_size = 4#args.bs_mult * args.ngpu args.val_batch_size = 4 else: raise Exception('Dataset {} is not supported'.format(args.dataset)) # Readjust batch size to mini-batch size for apex if args.apex: args.train_batch_size = args.bs_mult args.val_batch_size = args.bs_mult_val args.num_workers = 4 * args.ngpu if args.test_mode: args.num_workers = 1 mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Geometric image transformations train_joint_transform_list = [ joint_transforms.RandomSizeAndCrop(args.crop_size, False, pre_size=args.pre_size, scale_min=args.scale_min, scale_max=args.scale_max, ignore_index=args.dataset_cls.ignore_label), joint_transforms.Resize(args.crop_size), joint_transforms.RandomHorizontallyFlip()] train_joint_transform = joint_transforms.Compose(train_joint_transform_list) # Image appearance transformations train_input_transform = [] if args.color_aug: train_input_transform += [extended_transforms.ColorJitter( brightness=args.color_aug, contrast=args.color_aug, saturation=args.color_aug, hue=args.color_aug)] if args.bblur: train_input_transform += [extended_transforms.RandomBilateralBlur()] elif args.gblur: train_input_transform += [extended_transforms.RandomGaussianBlur()] else: pass train_input_transform += [standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std)] train_input_transform = standard_transforms.Compose(train_input_transform) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() if args.jointwtborder: target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(args.dataset_cls.ignore_label, args.dataset_cls.num_classes) else: target_train_transform = extended_transforms.MaskToTensor() if args.dataset == 'OmniAudio_noBG_Paralleltask': eval_size = 960 val_joint_transform_list = [ joint_transforms.ResizeHeight(eval_size), joint_transforms.CenterCropPad(eval_size)] train_set = args.dataset_cls.OmniAudio( 'semantic', 'train', joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform) val_set = args.dataset_cls.OmniAudio( 'semantic', 'val', joint_transform_list=val_joint_transform_list, transform=val_input_transform, target_transform=target_transform) elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth': eval_size = 960 val_joint_transform_list = [ joint_transforms.ResizeHeight(eval_size), joint_transforms.CenterCropPad(eval_size)] train_set = args.dataset_cls.OmniAudio( 'semantic', 'train', joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform) val_set = args.dataset_cls.OmniAudio( 'semantic', 'val', joint_transform_list=val_joint_transform_list, transform=val_input_transform, target_transform=target_transform) elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth_noSeman': eval_size = 960 val_joint_transform_list = [ joint_transforms.ResizeHeight(eval_size), joint_transforms.CenterCropPad(eval_size)] train_set = args.dataset_cls.OmniAudio( 'semantic', 'train', joint_transform_list=train_joint_transform_list, transform=train_input_transform, target_transform=target_train_transform) val_set = args.dataset_cls.OmniAudio( 'semantic', 'val', joint_transform_list=val_joint_transform_list, transform=val_input_transform, target_transform=target_transform) elif args.dataset == 'null_loader': train_set = args.dataset_cls.null_loader(args.crop_size) val_set = args.dataset_cls.null_loader(args.crop_size) else: raise Exception('Dataset {} is not supported'.format(args.dataset)) if args.apex: from datasets.sampler import DistributedSampler train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False) val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False) else: train_sampler = None val_sampler = None train_loader = DataLoader(train_set, batch_size=args.train_batch_size, num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler) val_loader = DataLoader(val_set, batch_size=args.val_batch_size, num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler) return train_loader, val_loader, train_set