예제 #1
0
    def __init__(self, task, mode, dump_imgs=False):
        assert task in {'binary', 'type', 'parts'}
        assert mode in {'train', 'val', 'test'}
        self.task = task
        self.mode = mode
        self.dump_imgs = dump_imgs
        self.size = 768

        if mode == 'train':
            idx_range = range(150)
        elif mode == 'val':
            idx_range = range(150, 175)
        else:
            idx_range = range(175, 225)
        #idx_range = range(175) if mode=='train' else range(175,225)
        self.img_paths = []
        self.mask_paths = []
        for i in range(1, 9):
            img_list = [
                os.path.join(root, 'instrument_dataset_%d' % i,
                             'left_frames/frame{:03d}.png'.format(j))
                for j in idx_range
            ]
            self.img_paths.extend(img_list)
            mask_list = [
                os.path.join(root, 'instrument_dataset_%d' % i,
                             'ground_truth/%s_labels' % task,
                             'frame{:03d}.png'.format(j)) for j in idx_range
            ]
            self.mask_paths.extend(mask_list)

        joint_transform_list = [
            joint_transforms.Resize(self.size),
            joint_transforms.RandomHorizontallyFlip()
        ]
        self.train_joint_transform = joint_transforms.Compose(
            joint_transform_list)
        self.val_joint_transform = joint_transforms.Resize(self.size)
        train_input_transform = []
        train_input_transform += [
            extended_transforms.ColorJitter(brightness=0.25,
                                            contrast=0.25,
                                            saturation=0.25,
                                            hue=0.25)
        ]
        train_input_transform += [extended_transforms.RandomGaussianBlur()]
        train_input_transform += [
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*mean_std)
        ]
        self.train_input_transform = standard_transforms.Compose(
            train_input_transform)

        self.val_input_transform = standard_transforms.Compose([
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*mean_std)
        ])
        self.target_transform = extended_transforms.MaskToTensor()
예제 #2
0
def get_target_transforms(args, dataset):
    """
    Get target transforms
    Args:
        args: input config arguments
        dataset: dataset class object

    return: target_transform, target_train_transform, target_aux_train_transform
    """

    target_transform = extended_transforms.MaskToTensor()
    if args.jointwtborder:
        target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(
            dataset.ignore_label, dataset.num_classes)
    else:
        target_train_transform = extended_transforms.MaskToTensor()

    target_aux_train_transform = extended_transforms.MaskToTensor()

    return target_transform, target_train_transform, target_aux_train_transform
예제 #3
0
def setup_loader():
    """
    Setup Data Loaders
    """
    val_input_transform = transforms.ToTensor()
    target_transform = extended_transforms.MaskToTensor()

    if args.dataset == 'cityscapes':
        args.dataset_cls = cityscapes
        eval_mode_pooling = False
        eval_scales = None
        if args.inference_mode == 'pooling':
            eval_mode_pooling = True
            eval_scales = args.scales
        test_set = args.dataset_cls.CityScapes(args.mode, args.split,
                                               transform=val_input_transform,
                                               target_transform=target_transform,
                                               cv_split=args.cv_split,
                                               eval_mode=eval_mode_pooling,
                                               eval_scales=eval_scales,
                                               eval_flip=not args.no_flip,
                                               )
    elif args.dataset == 'kitti':
        args.dataset_cls = kitti
        test_set = args.dataset_cls.KITTI(args.mode, args.split,
                                         transform=val_input_transform,
                                         target_transform=target_transform,
                                         cv_split=args.cv_split)

    elif args.dataset == 'kitti_trav':
        args.dataset_cls = kitti_trav
        test_set = args.dataset_cls.KITTI_trav(args.mode, args.split,
                                         transform=val_input_transform,
                                         target_transform=target_transform,
                                         cv_split=args.cv_split)

    elif args.dataset == 'kitti_semantic':
        args.dataset_cls = kitti_semantic
        test_set = args.dataset_cls.KITTI_Semantic(args.mode, args.split,
                                         transform=val_input_transform,
                                         target_transform=target_transform,
                                         cv_split=args.cv_split)
    else:
        raise NameError('-------------Not Supported Currently-------------')

#    if args.split_count > 1:
#        test_set.split_dataset(args.split_index, args.split_count)

    batch_size = 1

    test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=1,
                             shuffle=False, pin_memory=False, drop_last=False)

    return test_loader
예제 #4
0
def setup_loader():
    """
    Setup Data Loaders
    """
    val_input_transform = transforms.ToTensor()
    target_transform = extended_transforms.MaskToTensor()
    val_joint_transform_list = [joint_transforms.Resize(args.resize_scale)]
    if args.dataset == 'iSAID':
        args.dataset_cls = iSAID
        test_set = args.dataset_cls.ISAIDDataset(
            args.mode,
            args.split,
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform)
    elif args.dataset == 'Posdam':
        args.dataset_cls = Posdam
        test_set = args.dataset_cls.POSDAMDataset(
            args.mode,
            args.split,
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform)
    elif args.dataset == 'Vaihingen':
        args.dataset_cls = Vaihingen
        test_set = args.dataset_cls.VAIHINGENDataset(
            args.mode,
            args.split,
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform)
    else:
        raise NameError('-------------Not Supported Currently-------------')

    if args.split_count > 1:
        test_set.split_dataset(args.split_index, args.split_count)

    batch_size = 1
    if args.inference_mode == 'pooling':
        batch_size = args.batch_size

    test_loader = DataLoader(test_set,
                             batch_size=batch_size,
                             num_workers=args.num_workers,
                             shuffle=False,
                             pin_memory=False,
                             drop_last=False)

    return test_loader
예제 #5
0
def setup_loader():
    """
    Setup Data Loaders
    """
    val_input_transform = transforms.ToTensor()
    target_transform = extended_transforms.MaskToTensor()

    if args.dataset == 'cityscapes':
        args.dataset_cls = cityscapes
        eval_scales = None
        if args.inference_mode == 'pooling':
            eval_mode = 'pooling'
            eval_scales = args.scales
        elif args.inference_mode == 'sliding':
            eval_mode = 'sliding'
        else:
            raise Exception(f"Not implemented inference mode: {args.inference_mode}")

        test_set = args.dataset_cls.CityScapes(args.mode, args.split, 0,
                                               transform=val_input_transform,
                                               target_transform=target_transform,
                                               cv_split=0,#args.cv_split,
                                               eval_mode=eval_mode,
                                               eval_scales=eval_scales,
                                               eval_flip=not args.no_flip,
                                               image_in=args.image_in
                                               )
    else:
        raise NameError('-------------Not Supported Currently-------------')

    if args.split_count > 1:
        test_set.split_dataset(args.split_index, args.split_count)

    batch_size = 1

    if args.syncbn:
        from datasets.sampler import DistributedSampler
        test_sampler = DistributedSampler(test_set, pad=False, permutation=False, consecutive_sample=False)
    else:
        test_sampler = None

    test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=args.num_workers,
                             shuffle=False, drop_last=False, sampler = test_sampler)

    return test_loader
예제 #6
0
def get_transforms(args, mode="train"):

    # image mask transform
    if mode == "train":
        joint_transform_list = []
        # joint_transform_list += [joint_transforms.RandomSizeAndCrop(args.crop_size, crop_nopad=False, p=0.5,
        #                                               scale_min=args.scale_min, scale_max=args.scale_max)]

        # TODO add another joint transforms
        joint_transform_list += [joint_transforms.RandomHorizontallyFlip()
                                 ]  # default percent is 0.5
        joint_transform_list += [joint_transforms.RandomRotate90(p=0.5)]
        joint_transform_list += [
            joint_transforms.RandomZoomIn(sizes=[256, 288, 320],
                                          out_size=256,
                                          p=0.5)
        ]

        # image transform
        input_transform = []
        input_transform += [
            extended_transforms.ColorJitter(brightness=0.25,
                                            contrast=0.10,
                                            saturation=0,
                                            hue=0,
                                            p=0.5)
        ]
        # input_transform += [extended_transforms.RandomGaussianBlur(p=0.5)]

        mean_std = (cfg.DATASET.MEAN, cfg.DATASET.STD)
        input_transform += [
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*mean_std)
        ]
        input_transform = standard_transforms.Compose(input_transform)

        # label transform
        label_transform = extended_transforms.MaskToTensor()

    else:
        joint_transform_list = None
        input_transform = None
        label_transform = None

    return joint_transform_list, input_transform, label_transform
예제 #7
0
    def __getitem__(self, index):
        if len(self.imgs_uniform[index]) == 2:
            img_path, mask_path = self.imgs_uniform[index]
            centroid = None
            class_id = None
        else:
            img_path, mask_path, centroid, class_id = self.imgs_uniform[index]
        img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
        img_name = os.path.splitext(os.path.basename(img_path))[0]

        mask = np.array(mask)
        mask_copy = mask.copy()
        for k, v in id_to_ignore_or_group.items():
            mask_copy[mask == k] = v
        mask = Image.fromarray(mask_copy.astype(np.uint8))

        # Image Transformations
        if self.joint_transform_list is not None:
            for idx, xform in enumerate(self.joint_transform_list):
                if idx == 0 and centroid is not None:
                    # HACK! Assume the first transform accepts a centroid
                    img, mask = xform(img, mask, centroid)
                else:
                    img, mask = xform(img, mask)

        if self.dump_images:
            outdir = 'dump_imgs_{}'.format(self.mode)
            os.makedirs(outdir, exist_ok=True)
            if centroid is not None:
                dump_img_name = self.id2name[class_id] + '_' + img_name
            else:
                dump_img_name = img_name
            out_img_fn = os.path.join(outdir, dump_img_name + '.png')
            out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
            mask_img = colorize_mask(np.array(mask))
            img.save(out_img_fn)
            mask_img.save(out_msk_fn)

        if self.transform is not None:
            img = self.transform(img)

        rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        img_gt = transforms.Normalize(*rgb_mean_std)(img)
        if self.image_in:
            eps = 1e-5
            rgb_mean_std = ([
                torch.mean(img[0]),
                torch.mean(img[1]),
                torch.mean(img[2])
            ], [
                torch.std(img[0]) + eps,
                torch.std(img[1]) + eps,
                torch.std(img[2]) + eps
            ])
        img = transforms.Normalize(*rgb_mean_std)(img)

        if self.target_aux_transform is not None:
            mask_aux = self.target_aux_transform(mask)
        else:
            mask_aux = torch.tensor([0])
        if self.target_transform is not None:
            mask = self.target_transform(mask)

        mask = extended_transforms.MaskToTensor()(mask)
        return img, mask, img_name, mask_aux
예제 #8
0
def setup_loaders(args):
    '''
    input: argument passed by the user
    return:  training data loader, validation data loader loader,  train_set
    '''

    if args.dataset == 'cityscapes':
        args.dataset_cls = cityscapes
        args.train_batch_size = args.bs_mult * args.ngpu
        if args.bs_mult_val > 0:
            args.val_batch_size = args.bs_mult_val * args.ngpu
        else:
            args.val_batch_size = args.bs_mult * args.ngpu
    else:
        raise

    args.num_workers = 4 * args.ngpu
    if args.test_mode:
        args.num_workers = 0 #1

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    # Geometric image transformations
    train_joint_transform_list = [
        joint_transforms.RandomSizeAndCrop(args.crop_size,
                                           False,
                                           pre_size=args.pre_size,
                                           scale_min=args.scale_min,
                                           scale_max=args.scale_max,
                                           ignore_index=args.dataset_cls.ignore_label),
        joint_transforms.Resize(args.crop_size),
        joint_transforms.RandomHorizontallyFlip()]
 
    #if args.rotate:
    #    train_joint_transform_list += [joint_transforms.RandomRotate(args.rotate)]

    train_joint_transform = joint_transforms.Compose(train_joint_transform_list)

    # Image appearance transformations
    train_input_transform = []
    if args.color_aug:
        train_input_transform += [extended_transforms.ColorJitter(
            brightness=args.color_aug,
            contrast=args.color_aug,
            saturation=args.color_aug,
            hue=args.color_aug)]

    if args.bblur:
        train_input_transform += [extended_transforms.RandomBilateralBlur()]
    elif args.gblur:
        train_input_transform += [extended_transforms.RandomGaussianBlur()]
    else:
        pass

    train_input_transform += [standard_transforms.ToTensor(),
                              standard_transforms.Normalize(*mean_std)]
    train_input_transform = standard_transforms.Compose(train_input_transform)

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()
    
    target_train_transform = extended_transforms.MaskToTensor()

    if args.dataset == 'cityscapes':
        city_mode = 'train' ## Can be trainval
        city_quality = 'fine'
        train_set = args.dataset_cls.CityScapes(
            city_quality, city_mode, 0, 
            joint_transform=train_joint_transform,
            transform=train_input_transform,
            target_transform=target_train_transform,
            dump_images=args.dump_augmentation_images,
            cv_split=args.cv)
        val_set = args.dataset_cls.CityScapes('fine', 'val', 0, 
                                              transform=val_input_transform,
                                              target_transform=target_transform,
                                              cv_split=args.cv)
    else:
        raise
    
    train_sampler = None
    val_sampler = None

    train_loader = DataLoader(train_set, batch_size=args.train_batch_size,
                              num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)
    val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
                            num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler)

    return train_loader, val_loader,  train_set
예제 #9
0
def setup_loaders(args):
    """
    Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin]
    input: argument passed by the user
    return:  training data loader, validation data loader loader,  train_set
    """

    if args.dataset == 'cityscapes':
        args.dataset_cls = cityscapes
        args.train_batch_size = args.bs_mult * args.ngpu
        if args.bs_mult_val > 0:
            args.val_batch_size = args.bs_mult_val * args.ngpu
        else:
            args.val_batch_size = args.bs_mult * args.ngpu
    elif args.dataset == 'mapillary':
        args.dataset_cls = mapillary
        args.train_batch_size = args.bs_mult * args.ngpu
        args.val_batch_size = 4
    elif args.dataset == 'ade20k':
        args.dataset_cls = ade20k
        args.train_batch_size = args.bs_mult * args.ngpu
        args.val_batch_size = 4
    elif args.dataset == 'kitti':
        args.dataset_cls = kitti
        args.train_batch_size = args.bs_mult * args.ngpu
        if args.bs_mult_val > 0:
            args.val_batch_size = args.bs_mult_val * args.ngpu
        else:
            args.val_batch_size = args.bs_mult * args.ngpu
    elif args.dataset == 'camvid':
        args.dataset_cls = camvid
        args.train_batch_size = args.bs_mult * args.ngpu
        if args.bs_mult_val > 0:
            args.val_batch_size = args.bs_mult_val * args.ngpu
        else:
            args.val_batch_size = args.bs_mult * args.ngpu
    elif args.dataset == 'null_loader':
        args.dataset_cls = null_loader
        args.train_batch_size = args.bs_mult * args.ngpu
        if args.bs_mult_val > 0:
            args.val_batch_size = args.bs_mult_val * args.ngpu
        else:
            args.val_batch_size = args.bs_mult * args.ngpu
    else:
        raise Exception('Dataset {} is not supported'.format(args.dataset))

    # Readjust batch size to mini-batch size for apex
    if args.apex:
        args.train_batch_size = args.bs_mult
        args.val_batch_size = args.bs_mult_val

    args.num_workers = 4 * args.ngpu
    if args.test_mode:
        args.num_workers = 1


    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    # Geometric image transformations
    train_joint_transform_list = [
        joint_transforms.RandomSizeAndCrop(args.crop_size,
                                           False,
                                           pre_size=args.pre_size,
                                           scale_min=args.scale_min,
                                           scale_max=args.scale_max,
                                           ignore_index=args.dataset_cls.ignore_label),
        joint_transforms.Resize(args.crop_size),
        joint_transforms.RandomHorizontallyFlip()]
    train_joint_transform = joint_transforms.Compose(train_joint_transform_list)

    # Image appearance transformations
    train_input_transform = []
    if args.color_aug:
        train_input_transform += [extended_transforms.ColorJitter(
            brightness=args.color_aug,
            contrast=args.color_aug,
            saturation=args.color_aug,
            hue=args.color_aug)]

    if args.bblur:
        train_input_transform += [extended_transforms.RandomBilateralBlur()]
    elif args.gblur:
        train_input_transform += [extended_transforms.RandomGaussianBlur()]
    else:
        pass



    train_input_transform += [standard_transforms.ToTensor(),
                              standard_transforms.Normalize(*mean_std)]
    train_input_transform = standard_transforms.Compose(train_input_transform)

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()
    
    if args.jointwtborder:
        target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(args.dataset_cls.ignore_label, 
            args.dataset_cls.num_classes)
    else:
        target_train_transform = extended_transforms.MaskToTensor()

    if args.dataset == 'cityscapes':
        city_mode = 'train' ## Can be trainval
        city_quality = 'fine'
        if args.class_uniform_pct:
            if args.coarse_boost_classes:
                coarse_boost_classes = \
                    [int(c) for c in args.coarse_boost_classes.split(',')]
            else:
                coarse_boost_classes = None
            train_set = args.dataset_cls.CityScapesUniform(
                city_quality, city_mode, args.maxSkip,
                joint_transform_list=train_joint_transform_list,
                transform=train_input_transform,
                target_transform=target_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv,
                class_uniform_pct=args.class_uniform_pct,
                class_uniform_tile=args.class_uniform_tile,
                test=args.test_mode,
                coarse_boost_classes=coarse_boost_classes)
        else:
            train_set = args.dataset_cls.CityScapes(
                city_quality, city_mode, 0, 
                joint_transform=train_joint_transform,
                transform=train_input_transform,
                target_transform=target_train_transform,
                dump_images=args.dump_augmentation_images,
                cv_split=args.cv)

        val_set = args.dataset_cls.CityScapes('fine', 'val', 0, 
                                              transform=val_input_transform,
                                              target_transform=target_transform,
                                              cv_split=args.cv)
    elif args.dataset == 'mapillary':
        eval_size = 1536
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size, ignore_index=args.dataset_cls.ignore_label)]
        train_set = args.dataset_cls.Mapillary(
            'semantic', 'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform,
            dump_images=args.dump_augmentation_images,
            class_uniform_pct=args.class_uniform_pct,
            class_uniform_tile=args.class_uniform_tile,
            test=args.test_mode)
        val_set = args.dataset_cls.Mapillary(
            'semantic', 'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform,
            test=False)
    elif args.dataset == 'ade20k':
        eval_size = 384
        val_joint_transform_list = [
                joint_transforms.ResizeHeight(eval_size),
  		joint_transforms.CenterCropPad(eval_size)]
            
        train_set = args.dataset_cls.ade20k(
            'semantic', 'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform,
            dump_images=args.dump_augmentation_images,
            class_uniform_pct=args.class_uniform_pct,
            class_uniform_tile=args.class_uniform_tile,
            test=args.test_mode)
        val_set = args.dataset_cls.ade20k(
            'semantic', 'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform,
            test=False)
    elif args.dataset == 'kitti':
        # eval_size_h = 384
        # eval_size_w = 1280
        # val_joint_transform_list = [
        #         joint_transforms.ResizeHW(eval_size_h, eval_size_w)]
            
        train_set = args.dataset_cls.KITTI(
            'semantic', 'train', args.maxSkip,
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform,
            dump_images=args.dump_augmentation_images,
            class_uniform_pct=args.class_uniform_pct,
            class_uniform_tile=args.class_uniform_tile,
            test=args.test_mode,
            cv_split=args.cv,
            scf=args.scf,
            hardnm=args.hardnm)
        val_set = args.dataset_cls.KITTI(
            'semantic', 'trainval', 0, 
            joint_transform_list=None,
            transform=val_input_transform,
            target_transform=target_transform,
            test=False,
            cv_split=args.cv,
            scf=None)
    elif args.dataset == 'camvid':
        # eval_size_h = 384
        # eval_size_w = 1280
        # val_joint_transform_list = [
        #         joint_transforms.ResizeHW(eval_size_h, eval_size_w)]
            
        train_set = args.dataset_cls.CAMVID(
            'semantic', 'trainval', args.maxSkip,
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform,
            dump_images=args.dump_augmentation_images,
            class_uniform_pct=args.class_uniform_pct,
            class_uniform_tile=args.class_uniform_tile,
            test=args.test_mode,
            cv_split=args.cv,
            scf=args.scf,
            hardnm=args.hardnm)
        val_set = args.dataset_cls.CAMVID(
            'semantic', 'test', 0, 
            joint_transform_list=None,
            transform=val_input_transform,
            target_transform=target_transform,
            test=False,
            cv_split=args.cv,
            scf=None)

    elif args.dataset == 'null_loader':
        train_set = args.dataset_cls.null_loader(args.crop_size)
        val_set = args.dataset_cls.null_loader(args.crop_size)
    else:
        raise Exception('Dataset {} is not supported'.format(args.dataset))
    
    if args.apex:
        from datasets.sampler import DistributedSampler
        train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False)
        val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False)

    else:
        train_sampler = None
        val_sampler = None

    train_loader = DataLoader(train_set, batch_size=args.train_batch_size,
                              num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)
    val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
                            num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler)

    return train_loader, val_loader,  train_set
예제 #10
0
def setup_loaders(args):
    """
    Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin]
    input: argument passed by the user
    return:  training data loader, validation data loader loader,  train_set
    """

    # TODO add error checking to make sure class exists
    logx.msg(f'dataset = {args.dataset}')

    mod = importlib.import_module('datasets.{}'.format(args.dataset))
    dataset_cls = getattr(mod, 'Loader')

    logx.msg(f'ignore_label = {dataset_cls.ignore_label}')

    update_dataset_cfg(num_classes=dataset_cls.num_classes,
                       ignore_label=dataset_cls.ignore_label)

    ######################################################################
    # Define transformations, augmentations
    ######################################################################

    # Joint transformations that must happen on both image and mask
    if ',' in args.crop_size:
        args.crop_size = [int(x) for x in args.crop_size.split(',')]
    else:
        args.crop_size = int(args.crop_size)
    train_joint_transform_list = [
        # TODO FIXME: move these hparams into cfg
        joint_transforms.RandomSizeAndCrop(args.crop_size,
                                           False,
                                           scale_min=args.scale_min,
                                           scale_max=args.scale_max,
                                           full_size=args.full_crop_training,
                                           pre_size=args.pre_size)]
    train_joint_transform_list.append(
        joint_transforms.RandomHorizontallyFlip())

    if args.rand_augment is not None:
        N, M = [int(i) for i in args.rand_augment.split(',')]
        assert isinstance(N, int) and isinstance(M, int), \
            f'Either N {N} or M {M} not integer'
        train_joint_transform_list.append(RandAugment(N, M))

    ######################################################################
    # Image only augmentations
    ######################################################################
    train_input_transform = []

    if args.color_aug:
        train_input_transform += [extended_transforms.ColorJitter(
            brightness=args.color_aug,
            contrast=args.color_aug,
            saturation=args.color_aug,
            hue=args.color_aug)]
    if args.bblur:
        train_input_transform += [extended_transforms.RandomBilateralBlur()]
    elif args.gblur:
        train_input_transform += [extended_transforms.RandomGaussianBlur()]

    mean_std = (cfg.DATASET.MEAN, cfg.DATASET.STD)
    train_input_transform += [standard_transforms.ToTensor(),
                              standard_transforms.Normalize(*mean_std)]
    train_input_transform = standard_transforms.Compose(train_input_transform)

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()

    if args.jointwtborder:
        target_train_transform = \
            extended_transforms.RelaxedBoundaryLossToTensor()
    else:
        target_train_transform = extended_transforms.MaskToTensor()

    if args.eval == 'folder':
        val_joint_transform_list = None
    elif 'mapillary' in args.dataset:
        if args.pre_size is None:
            eval_size = 2177
        else:
            eval_size = args.pre_size
        if cfg.DATASET.MAPILLARY_CROP_VAL:
            val_joint_transform_list = [
                joint_transforms.ResizeHeight(eval_size),
                joint_transforms.CenterCropPad(eval_size)]
        else:
            val_joint_transform_list = [
                joint_transforms.Scale(eval_size)]
    else:
        val_joint_transform_list = None

    if args.eval is None or args.eval == 'val':
        val_name = 'val'
    elif args.eval == 'trn':
        val_name = 'train'
    elif args.eval == 'folder':
        val_name = 'folder'
    else:
        raise 'unknown eval mode {}'.format(args.eval)

    ######################################################################
    # Create loaders
    ######################################################################
    val_set = dataset_cls(
        mode=val_name,
        joint_transform_list=val_joint_transform_list,
        img_transform=val_input_transform,
        label_transform=target_transform,
        eval_folder=args.eval_folder)

    update_dataset_inst(dataset_inst=val_set)

    if args.apex:
        from datasets.sampler import DistributedSampler
        val_sampler = DistributedSampler(val_set, pad=False, permutation=False,
                                         consecutive_sample=False)
    else:
        val_sampler = None

    val_loader = DataLoader(val_set, batch_size=args.bs_val,
                            num_workers=args.num_workers // 2,
                            shuffle=False, drop_last=False,
                            sampler=val_sampler)

    if args.eval is not None:
        # Don't create train dataloader if eval
        train_set = None
        train_loader = None
    else:
        train_set = dataset_cls(
            mode='train',
            joint_transform_list=train_joint_transform_list,
            img_transform=train_input_transform,
            label_transform=target_train_transform)

        if args.apex:
            from datasets.sampler import DistributedSampler
            train_sampler = DistributedSampler(train_set, pad=True,
                                               permutation=True,
                                               consecutive_sample=False)
            train_batch_size = args.bs_trn
        else:
            train_sampler = None
            train_batch_size = args.bs_trn * args.ngpu

        train_loader = DataLoader(train_set, batch_size=train_batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=(train_sampler is None),
                                  drop_last=True, sampler=train_sampler)

    return train_loader, val_loader, train_set
예제 #11
0
    def __getitem__(self, index):
        try:

            data = copy.deepcopy(self.data_list[index])
            # if self.eval_mode:
            #     print(data["img_mask"])
            # im = cv2.imread(data['img_path'], 1 if self.img_mode != 'GRAY' else 0)
            # if self.img_mode == 'RGB':
            #     im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
            label = Image.open(data["img_mask"])
            label = np.array(label)

            label = self.id2trainId(label)
            mask = Image.fromarray(label.astype(np.uint8))
            im = Image.open(data['img_path']).convert('RGB')
            if self.eval_mode:
                return self._eval_get_item(im, label, self.eval_scales, self.eval_flip,data)


            # Geometric image transformations
            train_joint_transform_list = [
                joint_transforms.RandomSizeAndCrop(self.crop_size,
                                                   False,
                                                   pre_size=self.pre_size,
                                                   scale_min=self.scale_min,
                                                   scale_max=self.scale_max,
                                                   ignore_index=self.ignore_label),
                joint_transforms.Resize(self.crop_size),
                joint_transforms.RandomHorizontallyFlip()]

            if self.rotate:
                train_joint_transform_list += [joint_transforms.RandomRotate(self.rotate)]

            train_joint_transform = joint_transforms.Compose(train_joint_transform_list)

            ## Image appearance transformations
            train_input_transform = []
            if self.color_aug:
                train_input_transform += [extended_transforms.ColorJitter(
                    brightness=self.color_aug,
                    contrast=self.color_aug,
                    saturation=self.color_aug,
                    hue=self.color_aug)]

            if self.bblur:
                train_input_transform += [extended_transforms.RandomBilateralBlur()]
            elif self.bblur:
                train_input_transform += [extended_transforms.RandomGaussianBlur()]
            else:
                pass
            train_input_transform = transforms.Compose(train_input_transform)
            target_transform = extended_transforms.MaskToTensor()

            target_train_transform = extended_transforms.MaskToTensor()

            # Image Transformations
            # if train_joint_transform is not None:  # train_joint_transform
            img, mask = train_joint_transform(im, mask)
            # if train_input_transform is not None:  # train_input_transform
            img = train_input_transform(img)
            # if target_train_transform is not None:
            mask = target_train_transform(mask)

            if self.dump_images:
                outdir = '/data/SSSSdump_imgs_/'
                os.makedirs(outdir, exist_ok=True)
                out_img_fn = os.path.join(outdir, '{}.png'.format(self.i))
                out_msk_fn = os.path.join(outdir, '{}s_mask.png'.format(self.i))
                print(out_img_fn)
                self.i+=1
                mask_img = colorize_mask(np.array(mask))
                img.save(out_img_fn)
                mask_img.save(out_msk_fn)
            dict2 = {'img': img,'label':mask}
            data.update(dict2)
            if self.transform:
                data['img'] = self.transform(data['img'])
            # # print(image.shape, label.shape)
            # rdata['img'] = image
            # rdata['label'] = label
            # rdata['img_name'] = ['img_name']
            return data
        except:
            return self.__getitem__(np.random.randint(self.__len__()))
def setup_loaders(args):
    """
    Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin]
    input: argument passed by the user
    return:  training data loader, validation data loader loader,  train_set
    """
    if args.dataset == 'OmniAudio_noBG_Paralleltask':
        args.dataset_cls = OmniAudio_noBG_Paralleltask
        args.train_batch_size = 4#args.bs_mult * args.ngpu
        args.val_batch_size = 4
    elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth':
        args.dataset_cls = OmniAudio_noBG_Paralleltask_depth
        args.train_batch_size = 4#args.bs_mult * args.ngpu
        args.val_batch_size = 4
    elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth_noSeman':
        args.dataset_cls = OmniAudio_noBG_Paralleltask_depth_noSeman
        args.train_batch_size = 4#args.bs_mult * args.ngpu
        args.val_batch_size = 4
    else:
        raise Exception('Dataset {} is not supported'.format(args.dataset))

    # Readjust batch size to mini-batch size for apex
    if args.apex:
        args.train_batch_size = args.bs_mult
        args.val_batch_size = args.bs_mult_val

    
    args.num_workers = 4 * args.ngpu
    if args.test_mode:
        args.num_workers = 1


    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    # Geometric image transformations
    train_joint_transform_list = [
        joint_transforms.RandomSizeAndCrop(args.crop_size,
                                           False,
                                           pre_size=args.pre_size,
                                           scale_min=args.scale_min,
                                           scale_max=args.scale_max,
                                           ignore_index=args.dataset_cls.ignore_label),
        joint_transforms.Resize(args.crop_size),
        joint_transforms.RandomHorizontallyFlip()]
    train_joint_transform = joint_transforms.Compose(train_joint_transform_list)

    # Image appearance transformations
    train_input_transform = []
    if args.color_aug:
        train_input_transform += [extended_transforms.ColorJitter(
            brightness=args.color_aug,
            contrast=args.color_aug,
            saturation=args.color_aug,
            hue=args.color_aug)]

    if args.bblur:
        train_input_transform += [extended_transforms.RandomBilateralBlur()]
    elif args.gblur:
        train_input_transform += [extended_transforms.RandomGaussianBlur()]
    else:
        pass



    train_input_transform += [standard_transforms.ToTensor(),
                              standard_transforms.Normalize(*mean_std)]
    train_input_transform = standard_transforms.Compose(train_input_transform)

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()
    
    if args.jointwtborder: 
        target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(args.dataset_cls.ignore_label, 
            args.dataset_cls.num_classes)
    else:
        target_train_transform = extended_transforms.MaskToTensor()

    if args.dataset == 'OmniAudio_noBG_Paralleltask':
        eval_size = 960
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size)]
        train_set = args.dataset_cls.OmniAudio(
            'semantic', 'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform)
        val_set = args.dataset_cls.OmniAudio(
            'semantic', 'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform)

    elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth':
        eval_size = 960
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size)]
        train_set = args.dataset_cls.OmniAudio(
            'semantic', 'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform)
        val_set = args.dataset_cls.OmniAudio(
            'semantic', 'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform)


    elif args.dataset == 'OmniAudio_noBG_Paralleltask_depth_noSeman':
        eval_size = 960
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size)]
        train_set = args.dataset_cls.OmniAudio(
            'semantic', 'train',
            joint_transform_list=train_joint_transform_list,
            transform=train_input_transform,
            target_transform=target_train_transform)
        val_set = args.dataset_cls.OmniAudio(
            'semantic', 'val',
            joint_transform_list=val_joint_transform_list,
            transform=val_input_transform,
            target_transform=target_transform)
        
    elif args.dataset == 'null_loader':
        train_set = args.dataset_cls.null_loader(args.crop_size)
        val_set = args.dataset_cls.null_loader(args.crop_size)
    else:
        raise Exception('Dataset {} is not supported'.format(args.dataset))
    
    if args.apex:
        from datasets.sampler import DistributedSampler
        train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False)
        val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False)

    else:
        train_sampler = None
        val_sampler = None

    train_loader = DataLoader(train_set, batch_size=args.train_batch_size,
                              num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)
    val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
                            num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler)

    return train_loader, val_loader,  train_set