示例#1
0
    def __init__(
            self,
            root: str,
            split: str = "train",
            version: str = "2017",
            resize: Tuple[int] = (300, 300),
            augmentations: Callable[[Image, FloatTensor, LongTensor], Tuple[Image, FloatTensor, LongTensor]] = None,
            keep_crowd: bool = False
    ):
        super().__init__(resize, augmentations)

        self.root = os.path.expanduser(root)
        verify_str_arg(split, "split", ("train", "val"))
        self.keep_crowd = keep_crowd

        self.logger = log_utils.get_master_logger("CocoDetection")

        # path to image folder, e.g. coco_root/train2017
        self.image_folder = os.path.join(self.root, f"{split}{version}")

        # path to annotation, e.g. coco_root/annotations/instances_train2017
        annotation_fp = os.path.join(self.root, "annotations", f"instances_{split}{version}.json")

        self.logger.info("Parsing COCO %s dataset...", split)
        self._init_dataset(annotation_fp)
        self.logger.info("Parsing COCO %s dataset done", split)
    def verify_mode_type(self, split, image_mode, image_type):
        image_mode = self.verify_mode(image_mode)

        if image_mode == "gtFine":
            valid_splits = ("train", "test", "val")
            valid_types = ("_instanceIds.png", "instance", "_labelIds.png",
                           "semantic", "_color.png", "color", "_polygons.json",
                           "polygon")
        elif image_mode == "gtCoarse":
            valid_splits = ("train", "train_extra", "val")
            valid_types = ("_instanceIds.png", "instance", "_labelIds.png",
                           "semantic", "_color.png", "color", "_polygons.json",
                           "polygon")
        elif image_mode == "leftImg8bit":
            valid_splits = ("train", "train_extra", "test", "val")
            valid_types = (
                "_leftImg8bit.png",
                None,
            )
        for i in range(len(valid_types) // 2):
            if image_type == valid_types[i * 2 + 1]:
                image_type = valid_types[i * 2]
                break

        msg = ("Unknown value '{}' for argument split if image_mode is '{}'. "
               "Valid values are {{{}}}.")
        msg = msg.format(split, image_mode, iterable_to_str(valid_splits))
        verify_str_arg(split, "split", valid_splits, msg)

        msg = (
            "Unknown value '{}' for argument image_type if image_mode is '{}'. "
            "Valid values are {{{}}}.")
        msg = msg.format(image_type, image_mode, iterable_to_str(valid_types))
        verify_str_arg(image_type, "image_type", valid_types, msg)
        return image_mode, image_type
示例#3
0
    def __init__(self, root, split='train', transform=None, target_transform=None, transforms=None):
        super(Deepfashion, self).__init__(root, transforms, transform, target_transform)
        self.images_dir = os.path.join(self.root, 'img')
        self.targets_dir = os.path.join(self.root, 'lbl')
        self.split = split
        self.images = []
        self.targets = []

        valid_modes = ("train", "test", "val")
        msg = ("Unknown value '{}' for argument split if mode is '{}'. "
               "Valid values are {{{}}}.")
        msg = msg.format(split, mode, iterable_to_str(valid_modes))
        verify_str_arg(split, "split", valid_modes, msg)

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):

            image_dir_zip = os.path.join(self.root, '{}'.format('img.zip'))
            target_dir_zip = os.path.join(self.root, '{}'.format('lbl.zip'))

            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
                extract_archive(from_path=image_dir_zip, to_path=self.root)
                extract_archive(from_path=target_dir_zip, to_path=self.root)
            else:
                raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                                   ' specified "split" and "mode" are inside the "root" directory')

        data_list = pd.read_csv(os.path.join(self.root, 'list_eval_partition.txt'), sep='\t', skiprows=1)
        data_list = data_list[data_list['evaluation_status'] == self.split]
        for image_path in data_list['image_name']:
            target_path = 'lbl/' + '/'.join(image_path.split('/')[1:])
            self.images.append(image_path)
            self.targets.append(target_path)
示例#4
0
    def __init__(
        self,
        root: str,
        split: str,
        image_set: str,
        view: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ):
        super().__init__(os.path.join(root, self.base_folder),
                         transform=transform,
                         target_transform=target_transform)

        self.image_set = verify_str_arg(image_set.lower(), "image_set",
                                        self.file_dict.keys())
        images_dir, self.filename, self.md5 = self.file_dict[self.image_set]

        self.view = verify_str_arg(view.lower(), "view", ["people", "pairs"])
        self.split = verify_str_arg(split.lower(), "split",
                                    ["10fold", "train", "test"])
        self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
        self.data: List[Any] = []

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted. You can use download=True to download it"
            )

        self.images_dir = os.path.join(self.root, images_dir)
示例#5
0
    def __init__(self, root, split='train', mode='fine', target_type='instance',
                 transform=None, target_transform=None, transforms=None):
        super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
        self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
        self.targets_dir = os.path.join(self.root, self.mode, split)
        self.target_type = target_type
        self.split = split
        self.images = []
        self.targets = []

        verify_str_arg(mode, "mode", ("fine", "coarse"))
        if mode == "fine":
            valid_modes = ("train", "test", "val")
        else:
            valid_modes = ("train", "train_extra", "val")
        msg = ("Unknown value '{}' for argument split if mode is '{}'. "
               "Valid values are {{{}}}.")
        msg = msg.format(split, mode, iterable_to_str(valid_modes))
        verify_str_arg(split, "split", valid_modes, msg)

        if not isinstance(target_type, list):
            self.target_type = [target_type]
        [verify_str_arg(value, "target_type",
                        ("instance", "semantic", "polygon", "color"))
         for value in self.target_type]

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):

            if split == 'train_extra':
                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip'))
            else:
                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip'))

            if self.mode == 'gtFine':
                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip'))
            elif self.mode == 'gtCoarse':
                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip'))

            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
                extract_archive(from_path=image_dir_zip, to_path=self.root)
                extract_archive(from_path=target_dir_zip, to_path=self.root)
            else:
                raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                                   ' specified "split" and "mode" are inside the "root" directory')

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
                target_types = []
                for t in self.target_type:
                    target_name = file_name
                    target_types.append(os.path.join(target_dir, target_name))

                self.images.append(os.path.join(img_dir, file_name))
                self.targets.append(target_types)
示例#6
0
 def verify_mode(self, image_mode):
     valid_modes = ("gt_color", "color",
                    "gt_labelIds", "semantic",
                    "gt_labelTrainIds", "semantic_train",
                    "RGB", "RGB",)
     for i in range(len(valid_modes)//2):
         if image_mode == valid_modes[i*2+1]:
             image_mode = valid_modes[i*2]
             break
     verify_str_arg(image_mode, "image_mode", valid_modes)
     return image_mode
示例#7
0
    def _verify_classes(self, classes):
        categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
                      'conference_room', 'dining_room', 'kitchen',
                      'living_room', 'restaurant', 'tower']
        dset_opts = ['train', 'val', 'test']

        try:
            verify_str_arg(classes, "classes", dset_opts)
            if classes == 'test':
                classes = [classes]
            else:
                classes = [c + '_' + classes for c in categories]
        except ValueError:
            if not isinstance(classes, Iterable):
                msg = ("Expected type str or Iterable for argument classes, "
                       "but got type {}.")
                raise ValueError(msg.format(type(classes)))

            classes = list(classes)
            msg_fmtstr = ("Expected type str for elements in argument classes, "
                          "but got type {}.")
            for c in classes:
                verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
                c_short = c.split('_')
                category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]

                msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
                msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
                verify_str_arg(category, valid_values=categories, custom_msg=msg)

                msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
                verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)

        return classes
 def verify_mode(self, image_mode):
     valid_modes = (
         "gtFine",
         "gtFine",
         "gtCoarse",
         "gtCoarse",
         "leftImg8bit",
         None,
     )
     for i in range(len(valid_modes) // 2):
         if image_mode == valid_modes[i * 2 + 1]:
             image_mode = valid_modes[i * 2]
             break
     verify_str_arg(image_mode, "image_mode", valid_modes)
     return image_mode
    def __init__(self,
                 root,
                 base_folder,
                 split='train',
                 transform=None,
                 target_transform=None,
                 download=False,
                 color_distortion=False,
                 col=False):
        super(TinyImageNet, self).__init__(root,
                                           transform=transform,
                                           target_transform=target_transform)

        os.makedirs(root, exist_ok=True)
        self.root = root
        self.base_folder = base_folder
        self.color_distortion = color_distortion
        self.col = col
        self.dataset_path = os.path.join(root, self.base_folder)
        self.loader = default_loader
        self.split = verify_str_arg(split, "split", (
            "train",
            "val",
        ))

        _, class_to_idx = find_classes(
            os.path.join(self.dataset_path, 'wnids.txt'))

        self.data = make_dataset(self.root, self.base_folder, self.split,
                                 class_to_idx)
    def __init__(self, root, split='train', download=False, **kwargs):
        """Use as torchvision.datasets.ImageNet."""
        root = self.root = os.path.expanduser(root)
        self.split = verify_str_arg(split, "split", ("train", "val"))

        try:
            wnid_to_classes = load_meta_file(self.root)[0]
        except RuntimeError:
            torchvision.datasets.imagenet.META_FILE = os.path.join(
                os.path.expanduser('~/data/'), 'meta.bin')
            try:
                wnid_to_classes = load_meta_file(self.root)[0]
            except RuntimeError:
                self.parse_archives()
                wnid_to_classes = load_meta_file(self.root)[0]

        torchvision.datasets.ImageFolder.__init__(self, self.split_folder,
                                                  **kwargs)
        self.root = root

        self.wnids = self.classes
        self.wnid_to_idx = self.class_to_idx
        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
        self.class_to_idx = {
            cls: idx
            for idx, clss in enumerate(self.classes) for cls in clss
        }
        """Scrub class names to be a single string."""
        scrubbed_names = []
        for name in self.classes:
            if isinstance(name, tuple):
                scrubbed_names.append(name[0])
            else:
                scrubbed_names.append(name)
        self.classes = scrubbed_names
示例#11
0
    def __init__(self, root, split='train', download=None, **kwargs):
        if download is True:
            msg = ("The dataset is no longer publicly accessible. You need to "
                   "download the archives externally and place them in the "
                   "root directory.")
            raise RuntimeError(msg)
        elif download is False:
            msg = ("The use of the download flag is deprecated, since the "
                   "dataset is no longer publicly accessible.")
            warnings.warn(msg, RuntimeWarning)

        root = self.root = os.path.expanduser(root)
        self.split = verify_str_arg(split, "split", ("train", "val"))

        self.parse_archives()
        wnid_to_classes = load_meta_file(self.root)[0]

        super(ImageNet, self).__init__(self.split_folder, **kwargs)
        self.root = root

        self.wnids = self.classes
        self.wnid_to_idx = self.class_to_idx
        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
        self.class_to_idx = {
            cls: idx
            for idx, clss in enumerate(self.classes) for cls in clss
        }
示例#12
0
    def __init__(
        self,
        root,
        split="train",
        target_type="attr",
        transform=None,
        download=False,
    ):
        import pandas

        super(CelebAHQ, self).__init__(root,
                                       transform=transform,
                                       target_transform=None)

        self.split = split
        if isinstance(target_type, list):
            self.target_type = target_type
        else:
            self.target_type = [target_type]

        if not self.target_type and self.target_transform is not None:
            raise RuntimeError(
                "target_transform is specified but target_type is empty")

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted." +
                               " You can use download=True to download it")

        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
        split = split_map[verify_str_arg(split.lower(), "split",
                                         ("train", "valid", "test", "all"))]

        fn = partial(os.path.join, self.root, self.base_folder)
        splits = pandas.read_csv(
            fn("list_eval_partition.txt"),
            delim_whitespace=True,
            header=None,
            index_col=0,
        )
        index = pandas.read_csv(
            fn("CelebAMask-HQ", "CelebA-HQ-to-CelebA-mapping.txt"),
            delim_whitespace=True,
            header=0,
            usecols=["idx", "orig_idx"],
        )

        splits = index["orig_idx"].apply(lambda i: splits.iloc[i])
        index = index["idx"]

        mask = slice(None) if split is None else (splits[1] == split)

        self.filename = index[mask].apply(lambda s: str(s) + ".jpg").values
示例#13
0
    def __init__(self,
                 root,
                 base_folder,
                 split='train',
                 transform=None,
                 target_transform=None,
                 download=False):
        super(TinyImageNet, self).__init__(root,
                                           transform=transform,
                                           target_transform=target_transform)

        os.makedirs(root, exist_ok=True)
        self.dataset_path = os.path.join(root, self.base_folder)
        self.loader = default_loader
        self.split = verify_str_arg(split, "split", (
            "train",
            "val",
        ))
        '''
        if self._check_integrity():
            print('Files already downloaded and verified.')
        elif download:
            self._download()
        else:
            raise RuntimeError(
                'Dataset not found. You can use download=True to download it.')
        if not os.path.isdir(self.dataset_path):
            print('Extracting...')
            extract_archive(os.path.join(root, self.filename))
        '''
        _, class_to_idx = find_classes(
            os.path.join(self.dataset_path, 'wnids.txt'))

        self.data = make_dataset(self.root, self.base_folder, self.split,
                                 class_to_idx)
    def __init__(self, root='./data/ImageNet', split='val', **kwargs):
        root = self.root = os.path.expanduser(root)
        self.split = verify_str_arg(split, "split", ("train", "val"))

        if not os.path.exists(self.split_folder):
            os.mkdir(self.split_folder)

        wnid_to_classes = self.load_label_file()
        mapped_wnid_to_idx, mapped_idx_to_wnid = self.load_mapping()
        targets = self.load_ground_truth()
        target_wnids = [mapped_idx_to_wnid[idx] for idx in targets]
        self.wnids = list(wnid_to_classes.keys())
        self.classes = list(wnid_to_classes.values())

        alphbetical_wnid_to_idx = {wnid: i for i, wnid in enumerate(sorted(self.wnids))}
        imgs = self.parse_image_tar(target_wnids, alphbetical_wnid_to_idx)

        super(ImageNetLoader, self).__init__(self.split_folder, **kwargs)

        self.classes = [cls for wnid, clss in wnid_to_classes.items() for cls in clss]
        self.wnid_to_idx = alphbetical_wnid_to_idx

        self.class_to_idx = {cls: idx
                             for wnid, idx in alphbetical_wnid_to_idx.items() if wnid in wnid_to_classes
                             for cls in wnid_to_classes[wnid]}

        self.samples = imgs
        self.targets = targets
        self.imgs = imgs
示例#15
0
 def _verify_str_arg(
     value: str,
     arg: Optional[str] = None,
     valid_values: Optional[Collection[str]] = None,
     *,
     custom_msg: Optional[str] = None,
 ) -> str:
     return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg)
示例#16
0
文件: pix2pix.py 项目: mwufi/test
    def __init__(self,
                 root,
                 dataset_name='facades',
                 split='train',
                 download=True,
                 **kwargs):
        root = os.path.expanduser(root)
        self.root = os.path.join('.', root)

        self.dataset_name = verify_str_arg(dataset_name,
                                           Pix2Pix_Datasets.keys())
        self.split = verify_str_arg(split, ['train', 'val'])

        if download:
            self.download()

        super(Pix2Pix, self).__init__(self.root, **kwargs)
示例#17
0
 def __init__(self,
              root,
              split="train",
              include_sar=False,
              transforms=None):
     super().__init__(pathlib.Path(root), transforms=transforms)
     verify_str_arg(split, "split", self.splits)
     if split == "test":
         self.city_names = self.test_list
     elif split == "train":
         self.city_names = self.train_list
     self.datatypes = ["rgb", "dem", "seg"]
     if include_sar:
         self.file_list = self._get_file_list("*sar.tif")
         self.datatypes.append("sar")
     else:
         self.file_list = self._get_file_list("*rgb.jp2")
示例#18
0
文件: mnist.py 项目: wbw520/scouter
 def __init__(self, root, what=None, compat=True, train=True, **kwargs):
     if what is None:
         what = 'train' if train else 'test'
     self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
     self.compat = compat
     self.data_file = what + '.pt'
     self.training_file = self.data_file
     self.test_file = self.data_file
     super(QMNIST, self).__init__(root, train, **kwargs)
示例#19
0
    def __init__(self,
                 root: str,
                 split: str = "trainval",
                 version: str = "2007",
                 resize: Optional[Tuple[int]] = (300, 300),
                 augmentations: Callable[[Image, Dict[str, Any]],
                                         Tuple[Image, Dict[str, Any]]] = None,
                 keep_difficult: bool = False,
                 make_partial: List[int] = None):
        """
        Args:
            root: root to voc path which contains [`VOC2007`  `VOC2012`] folders
            split: split of dataset, e.g. `train`, `val`, `test` and `trainval`. Warning: VOC2012 has no
                test split
            version: `2007` or `2012`
            resize: all images will be resized to given size. If `None`, all images will not be resized
            make_partial: only keep objects with given classes, w.r.t `self.CLASSES`. If `None`, all
                objects will be perserved. For multitask learning which one task has 10 classes and the
                other task has others classes
        """
        super().__init__(resize, augmentations)
        self.keep_difficult = keep_difficult

        verify_str_arg(version, "version", ("2007", "2012"))
        verify_str_arg(split, "split", ("train", "val", "test", "trainval"))

        self.logger = log_utils.get_master_logger("VOCDetection")
        self.version = version
        self.split = split
        # parse folders
        self.root = os.path.expanduser(os.path.join(root, f"VOC{version}"))
        # read split file
        split_fp = os.path.join(self.root, "ImageSets", "Main", f"{split}.txt")
        if not os.path.isfile(split_fp):
            raise FileNotFoundError(
                f"`{split_fp}` is not found, note that there is no `test.txt` for VOC-2012"
            )
        with open(split_fp, "r") as f:
            self.file_names = [x.strip() for x in f.readlines()]

        self.logger.info("Parsing VOC%s %s dataset...", version, split)
        self._init_dataset(make_partial)
        self.logger.info("Parsing VOC%s %s dataset done", version, split)
示例#20
0
    def __init__(self, 
                 root,
                 split="train",
                 target_type="attr",
                 transform=None,
                 target_transform=None,
                 download=False
                 ):
        """
        """

        self.root = root
        self.split = split
        self.target_type = target_type
        self.transform = transform
        self.target_transform = target_transform

        if isinstance(target_type, list):
            self.target_type = target_type
        else:
            self.target_type = [target_type]

        if not self.target_type and self.target_transform is not None:
            raise RuntimeError('target_transform is specified but target_type is empty')

        if download:
            self.download_from_kaggle()

        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
        
        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]

        fn = partial(os.path.join, self.root)
        splits = pd.read_csv(fn("list_eval_partition.csv"), delim_whitespace=False, header=0, index_col=0)
        # This file is not available in Kaggle
        # identity = pd.read_csv(fn("identity_CelebA.csv"), delim_whitespace=True, header=None, index_col=0)
        bbox = pd.read_csv(fn("list_bbox_celeba.csv"), delim_whitespace=False, header=0, index_col=0)
        landmarks_align = pd.read_csv(fn("list_landmarks_align_celeba.csv"), delim_whitespace=False, header=0, index_col=0)
        attr = pd.read_csv(fn("list_attr_celeba.csv"), delim_whitespace=False, header=0, index_col=0)

        mask = slice(None) if split_ is None else (splits['partition'] == split_)

        self.filename = splits[mask].index.values
        # self.identity = torch.as_tensor(identity[mask].values)
        self.bbox = torch.as_tensor(bbox[mask].values)
        self.landmarks_align = torch.as_tensor(landmarks_align[mask].values)
        self.attr = torch.as_tensor(attr[mask].values)
        self.attr = (self.attr + 1) // 2  # map from {-1, 1} to {0, 1}
        self.attr_names = list(attr.columns)
示例#21
0
    def __init__(self, root, split='train', download=False, **kwargs):
        self.data_root = os.path.expanduser(root)
        self.split = verify_str_arg(split, "split", self.splits)

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
        super().__init__(self.split_folder, **kwargs)
示例#22
0
    def __init__(
        self,
        root: str,
        split: str = "train",
        mode: str = "fine",
        target_type: Union[List[str], str] = "semantic",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
    ) -> None:
        super(Cityscapes, self).__init__(root, transforms, transform,
                                         target_transform)
        self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
        self.targets_dir = os.path.join(self.root, self.mode, split)
        self.target_type = target_type
        self.split = split
        self.images = []

        self.targets = []

        verify_str_arg(mode, "mode", ("fine", "coarse"))
        if mode == "fine":
            valid_modes = ("train", "test", "val")
        else:
            valid_modes = ("train", "train_extra", "val")
        msg = ("Unknown value '{}' for argument split if mode is '{}'. "
               "Valid values are {{{}}}.")
        msg = msg.format(split, mode, iterable_to_str(valid_modes))
        verify_str_arg(split, "split", valid_modes, msg)

        if not isinstance(target_type, list):
            self.target_type = [target_type]
        [
            verify_str_arg(value, "target_type",
                           ("instance", "semantic", "polygon", "color"))
            for value in self.target_type
        ]

        self.images = self._get_files('image', self.split)
        self.targets = self._get_files('label', self.split)
示例#23
0
    def __init__(self,
                 root,
                 year='2012',
                 image_set='train',
                 download=False,
                 transform=None,
                 target_transform=None,
                 transforms=None):
        super(VOCClassification, self).__init__(root, transforms, transform,
                                                target_transform)
        self.year = year
        self.url = DATASET_YEAR_DICT[year]['url']
        self.filename = DATASET_YEAR_DICT[year]['filename']
        self.md5 = DATASET_YEAR_DICT[year]['md5']
        self.image_set = verify_str_arg(image_set, "image_set",
                                        ("train", "trainval", "val"))

        base_dir = DATASET_YEAR_DICT[year]['base_dir']
        voc_root = os.path.join(self.root, base_dir)
        image_dir = os.path.join(voc_root, 'JPEGImages')
        annotation_dir = os.path.join(voc_root, 'Annotations')

        if download:
            download_extract(self.url, self.root, self.filename, self.md5)

        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        splits_dir = os.path.join(voc_root, 'ImageSets/Main')

        split_fs = [
            os.path.join(splits_dir, c + '_' + image_set.rstrip('\n') + '.txt')
            for c in VOC2012_CLASSES
        ]

        file_names = []
        self.object_nums = []
        for split_f in split_fs:
            with open(os.path.join(split_f), "r") as f:
                for line in f.readlines():
                    file_name, object_num = line.strip().split()
                    file_names.append(file_name)
                    self.object_nums.append(int(object_num))

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.annotations = [
            os.path.join(annotation_dir, x + ".xml") for x in file_names
        ]
        if image_set == 'val':
            self.images = self.images[::6]
            self.annotations = self.annotations[::6]
        assert (len(self.images) == len(self.annotations))
示例#24
0
    def __init__(self, root, split='train', folds=None, transform=None,
                 target_transform=None, download=False):
        super(STL10, self).__init__(root, transform=transform,
                                    target_transform=target_transform)
        self.split = verify_str_arg(split, "split", self.splits)
        self.folds = self._verify_folds(folds)

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(
                'Dataset not found or corrupted. '
                'You can use download=True to download it')

        # now load the picked numpy arrays
        if self.split == 'train':
            self.data, self.labels = self.__loadfile(
                self.train_list[0][0], self.train_list[1][0])
            self.__load_folds(folds)

        elif self.split == 'train+unlabeled':
            self.data, self.labels = self.__loadfile(
                self.train_list[0][0], self.train_list[1][0])
            self.__load_folds(folds)
            unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
            self.data = np.concatenate((self.data, unlabeled_data))
            self.labels = np.concatenate(
                (self.labels, np.asarray([-1] * unlabeled_data.shape[0])))

        elif self.split == 'train+test+unlabeled':
            self.data, self.labels = self.__loadfile(
                self.train_list[0][0], self.train_list[1][0])
            self.__load_folds(folds)
            self.data_test, self.labels_test = self.__loadfile(
                self.test_list[0][0], self.test_list[1][0])
            unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
            self.data = np.concatenate((self.data, self.data_test, unlabeled_data))
            self.labels = np.concatenate(
                (self.labels, self.labels_test, np.asarray([-1] * unlabeled_data.shape[0])))

        elif self.split == 'unlabeled':
            self.data, _ = self.__loadfile(self.train_list[2][0])
            self.labels = np.asarray([-1] * self.data.shape[0])
        else:  # self.split == 'test':
            self.data, self.labels = self.__loadfile(
                self.test_list[0][0], self.test_list[1][0])

        class_file = os.path.join(
            self.root, self.base_folder, self.class_names_file)
        if os.path.isfile(class_file):
            with open(class_file) as f:
                self.classes = f.read().splitlines()
示例#25
0
    def __init__(self,
                 root,
                 token='',
                 split='train',
                 download=False,
                 **kwargs):
        root = self.root = os.path.expanduser(root)
        self.split = verify_str_arg(split, "split", ("train", "val"))

        self.archive_dict = {
            'train': {
                'url':
                'http://www.image-net.org/challenges/LSVRC/2012/{}/ILSVRC2012_img_train.tar'
                .format(token),
                'md5':
                '1d675b47d978889d74fa0da5fadfb00e',
            },
            'val': {
                'url':
                'http://www.image-net.org/challenges/LSVRC/2012/{}/ILSVRC2012_img_val.tar'
                .format(token),
                'md5':
                '29b22e2961454d5413ddabcf34fc5622',
            },
            'devkit': {
                'url':
                'http://www.image-net.org/challenges/LSVRC/2012/{}/ILSVRC2012_devkit_t12.tar.gz'
                .format(token),
                'md5':
                'fa75699e90414af021442c21a62c3abf',
            }
        }

        if download:
            if len(token) == 0:
                raise ValueError(
                    "ImageNet token is empty. Please obtain permission token from the official website."
                )

            self.download()
        wnid_to_classes = self._load_meta_file()[0]

        super(ImageNet, self).__init__(self.split_folder, **kwargs)
        self.root = root

        self.wnids = self.classes
        self.wnid_to_idx = self.class_to_idx
        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
        self.class_to_idx = {
            cls: idx
            for idx, clss in enumerate(self.classes) for cls in clss
        }
示例#26
0
 def verify_mode(self, image_mode):
     valid_modes = (
         "gtFine",
         "gtFine",
         "gtCoarse",
         "gtCoarse",
         "leftImg8bit",
         "clear",
         "leftImg8bit_foggy",
         "foggy",
         "leftImg8bit_foggyDBF",
         "foggyDBF",
         "leftImg8bit_transmittance",
         "transmittance",
         "leftImg8bit_transmittanceDBF",
         "transmittanceDBF",
     )
     for i in range(len(valid_modes) // 2):
         if image_mode == valid_modes[i * 2 + 1]:
             image_mode = valid_modes[i * 2]
             break
     verify_str_arg(image_mode, "image_mode", valid_modes)
     return image_mode
示例#27
0
 def verify_mode_split(self, split, image_mode):
     image_mode = self.verify_mode(image_mode)
     
     if image_mode == "gt_color":
         valid_splits = ("testv1", "testv2")
     elif image_mode == "gt_labelIds":
         valid_splits = ("testv1", "testv2")
     elif image_mode == "gt_labelTrainIds":
         valid_splits = ("testv1", "testv2")
     elif image_mode == "RGB":
         valid_splits = ("light", "medium", "testv1", "testv2")
     
     msg = ("Unknown value '{}' for argument split if image_mode is '{}'. "
            "Valid values are {{{}}}.")
     msg = msg.format(split, image_mode, iterable_to_str(valid_splits))
     verify_str_arg(split, "split", valid_splits, msg)
     
     return image_mode
 
 # def verify_dataset(self):
     if not os.path.isdir(os.path.join(self.root, self.images_dir)):
         if self.image_mode == 'gtFine':
             image_dir_zip = os.path.join(self.root, '{}_trainvaltest.zip'.format(self.image_mode))
         elif self.image_mode == 'gtCoarse':
             image_dir_zip = os.path.join(self.root, '{}.zip'.format(self.image_mode))
         else:
             if split == 'train_extra':
                 image_dir_zip = os.path.join(self.root, '{}_trainextra.zip'.format(self.image_mode))
             else:
                 image_dir_zip = os.path.join(self.root, '{}_trainvaltest.zip'.format(self.image_mode))
             
         if os.path.isfile(image_dir_zip):
             extract_archive(from_path=image_dir_zip, to_path=self.root)
             extract_archive(from_path=target_dir_zip, to_path=self.root)
         else:
             raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                                ' specified "split" and "image_mode" are inside the "root" directory')
    def __init__(self,
                 root,
                 target_type="category",
                 transform=None,
                 target_transform=None,
                 download=False):
        super(Caltech101, self).__init__(os.path.join(root, 'caltech101'),
                                         transform=transform,
                                         target_transform=target_transform)
        os.makedirs(self.root, exist_ok=True)
        if not isinstance(target_type, list):
            target_type = [target_type]
        self.target_type = [
            verify_str_arg(t, "target_type", ("category", "annotation"))
            for t in target_type
        ]

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        self.categories = sorted(
            os.listdir(os.path.join(self.root, "101_ObjectCategories")))
        self.categories.remove("BACKGROUND_Google")  # this is not a real class

        # For some reason, the category names in "101_ObjectCategories" and
        # "Annotations" do not always match. This is a manual map between the
        # two. Defaults to using same name, since most names are fine.
        name_map = {
            "Faces": "Faces_2",
            "Faces_easy": "Faces_3",
            "Motorbikes": "Motorbikes_16",
            "airplanes": "Airplanes_Side_2"
        }
        self.annotation_categories = list(
            map(lambda x: name_map[x]
                if x in name_map else x, self.categories))

        self.index = []
        self.y = []
        for (i, c) in enumerate(self.categories):
            n = len(
                os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])
示例#29
0
    def __init__(self,
                 root,
                 year='2012',
                 image_set='train',
                 download=False,
                 transform=None,
                 target_transform=None,
                 transforms=None):
        super(VOCDetection, self).__init__(root, transforms, transform,
                                           target_transform)
        self.year = year
        if year == "2007" and image_set == "test":
            year = "2007-test"
        self.url = DATASET_YEAR_DICT[year]['url']
        self.filename = DATASET_YEAR_DICT[year]['filename']
        self.md5 = DATASET_YEAR_DICT[year]['md5']
        valid_sets = ["train", "trainval", "val"]
        if year == "2007-test":
            valid_sets.append("test")
        self.image_set = verify_str_arg(image_set, "image_set", valid_sets)

        base_dir = DATASET_YEAR_DICT[year]['base_dir']
        voc_root = os.path.join(self.root, base_dir)
        image_dir = os.path.join(voc_root, 'JPEGImages')
        annotation_dir = os.path.join(voc_root, 'Annotations')

        if download:
            download_extract(self.url, self.root, self.filename, self.md5)

        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        splits_dir = os.path.join(voc_root, 'ImageSets/Main')

        split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')

        with open(os.path.join(split_f), "r") as f:
            file_names = [x.strip() for x in f.readlines()]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.annotations = [
            os.path.join(annotation_dir, x + ".xml") for x in file_names
        ]
        self.images, self.annotations = self.filter_img(
            self.images, self.annotations)
        assert (len(self.images) == len(self.annotations))
示例#30
0
    def __init__(
        self, root, split="train", transform=None, target_transform=None, download=False, in_memory=False, **kwargs
    ):

        super(TinyImageNet, self).__init__(root, transform=transform, target_transform=target_transform)
        self.root = os.path.expanduser(root)
        self.root_final = os.path.join(self.root, self.base_folder, "tiny-imagenet-200")

        self.split = verify_str_arg(
            split, "split", ("train", "val", "test")
        )  # training set , validation set or test set
        self.split_dir = os.path.join(self.root_final, self.split)

        if download:
            self._download()

        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")

        self.image_paths = sorted(glob.iglob(os.path.join(self.split_dir, "**", "*.%s" % EXTENSION), recursive=True))

        self.transform = transform
        self.target_transform = target_transform
        self.in_memory = in_memory

        self.classes = {}  # fname - label number mapping
        self.images = []  # used for in-memory processing

        # build class label - number mapping
        with open(os.path.join(self.root_final, CLASS_LIST_FILE), "r") as fp:
            self.classes_names = sorted([text.strip() for text in fp.readlines()])
        self.class_to_idx = {text: i for i, text in enumerate(self.classes_names)}

        if self.split == "train":
            for class_text, i in self.class_to_idx.items():
                for cnt in range(NUM_IMAGES_PER_CLASS):
                    self.classes["%s_%d.%s" % (class_text, cnt, EXTENSION)] = i
        elif self.split == "val":
            with open(os.path.join(self.split_dir, VAL_ANNOTATION_FILE), "r") as fp:
                for line in fp.readlines():
                    terms = line.split("\t")
                    file_name, class_text = terms[0], terms[1]
                    self.classes[file_name] = self.class_to_idx[class_text]

        # read all images into torch tensor in memory to minimize disk IO overhead
        if self.in_memory:
            self.images = [self.read_image(path) for path in self.image_paths]