def create_test_loader(
    args,
    use_batched_inference: bool = False
) -> Tuple[torch.utils.data.dataloader.DataLoader, List[Tuple[str, str]]]:
    """Create a Pytorch dataloader from a dataroot and list of relative paths.
	
	Args:
	    args: CfgNode object
	    use_batched_inference: whether to process images in batch mode
	
	Returns:
	    test_loader
	    data_list: list of 2-tuples (relative rgb path, relative label path)
	"""
    preprocess_imgs_in_loader = True if use_batched_inference else False

    if preprocess_imgs_in_loader:
        # resize and normalize images in advance
        mean, std = get_imagenet_mean_std()
        test_transform = transform.Compose([
            transform.ResizeShort(args.base_size),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ])
    else:
        # no resizing on the fly using OpenCV and also normalize images on the fly
        test_transform = transform.Compose([transform.ToTensor()])
    test_data = dataset.SemData(split=args.split,
                                data_root=args.data_root,
                                data_list=args.test_list,
                                transform=test_transform)

    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    data_list = test_data.data_list

    # limit batch size to 1 if not performing batched inference
    batch_size = args.batch_size_val if use_batched_inference else 1

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)
    return test_loader, data_list
Ejemplo n.º 2
0
    def create_test_loader(self):
        """
			Create a Pytorch dataloader from a dataroot and list of 
			relative paths.
		"""
        test_transform = transform.Compose([transform.ToTensor()])
        test_data = dataset.SemData(split=self.args.split,
                                    data_root=self.args.data_root,
                                    data_list=self.args.test_list,
                                    transform=test_transform)

        index_start = self.args.index_start
        if self.args.index_step == 0:
            index_end = len(test_data.data_list)
        else:
            index_end = min(index_start + args.index_step,
                            len(test_data.data_list))
        test_data.data_list = test_data.data_list[index_start:index_end]
        self.data_list = test_data.data_list
        test_loader = torch.utils.data.DataLoader(
            test_data,
            batch_size=1,
            shuffle=False,
            num_workers=self.args.workers,
            pin_memory=True)
        return test_loader
Ejemplo n.º 3
0
def create_test_loader(
    args
) -> Tuple[torch.utils.data.dataloader.DataLoader, List[Tuple[str, str]]]:
    """
		Create a Pytorch dataloader from a dataroot and list of 
		relative paths.

		Args:

		Returns:
		-	test_loader
		-	data_list: list of 2-tuples (relative rgb path, relative label path)
	"""
    test_transform = transform.Compose([transform.ToTensor()])
    test_data = dataset.SemData(split=args.split,
                                data_root=args.data_root,
                                data_list=args.test_list,
                                transform=test_transform)

    index_start = args.index_start
    if args.index_step == 0:
        index_end = len(test_data.data_list)
    else:
        index_end = min(index_start + args.index_step,
                        len(test_data.data_list))
    test_data.data_list = test_data.data_list[index_start:index_end]
    data_list = test_data.data_list
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)
    return test_loader, data_list
Ejemplo n.º 4
0
def get_train_transform_list(args, split: str):
    """ Return the input data transform for training (w/ data augmentations)
        Args:
        -   args:
        -   split

        Return:
        -   List of transforms
    """
    from mseg_semantic.utils.normalization_utils import get_imagenet_mean_std
    from mseg_semantic.utils import transform


    mean, std = get_imagenet_mean_std()
    if split == 'train':
        transform_list = [
            transform.ResizeShort(args.short_size),
            transform.RandScale([args.scale_min, args.scale_max]),
            transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
            transform.RandomGaussianBlur(),
            transform.RandomHorizontalFlip(),
            transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ]
    elif split == 'val':
        transform_list = [
            transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ]
    else:
        print('Unknown split. Quitting ...')
        quit()

    if len(args.dataset) > 1 and args.universal:
        transform_list += [ToFlatLabel(args.tc, args.dataset_name)]
    elif args.universal:
        transform_list += [ToFlatLabel(args.tc, args.dataset[0])]

    return transform.Compose(transform_list)