def test_align_with_features(self):
     task = ImageClassification(image_column="input_image",
                                label_column="input_label")
     self.assertEqual(task.label_schema["labels"], ClassLabel)
     task = task.align_with_features(
         Features({"input_label": ClassLabel(names=self.labels)}))
     self.assertEqual(task.label_schema["labels"],
                      ClassLabel(names=self.labels))
 def test_column_mapping(self):
     task = ImageClassification(image_column="input_image",
                                label_column="input_label")
     self.assertDictEqual({
         "input_image": "image",
         "input_label": "labels"
     }, task.column_mapping)
Exemple #3
0
 def _info(self):
     if self.config.name == "full_numbers":
         features = datasets.Features({
             "image":
             datasets.Image(),
             "digits":
             datasets.Sequence({
                 "bbox":
                 datasets.Sequence(datasets.Value("int32"), length=4),
                 "label":
                 datasets.ClassLabel(num_classes=10),
             }),
         })
     else:
         features = datasets.Features({
             "image":
             datasets.Image(),
             "label":
             datasets.ClassLabel(num_classes=10),
         })
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=features,
         supervised_keys=None,
         homepage=_HOMEPAGE,
         license=_LICENSE,
         citation=_CITATION,
         task_templates=[
             ImageClassification(image_column="image", label_column="label")
         ] if self.config.name == "cropped_digits" else None,
     )
Exemple #4
0
 def test_value_error_unique_labels(self):
     with self.assertRaises(ValueError):
         # Add duplicate labels
         labels = self.labels + self.labels[:1]
         task = ImageClassification(image_file_path_column="file_paths",
                                    label_column="input_label",
                                    labels=labels)
         self.assertEqual("image-classification", task.task)
Exemple #5
0
 def test_from_dict(self):
     input_schema = Features({"image": Image()})
     label_schema = Features({"labels": ClassLabel})
     template_dict = {
         "image_column": "input_image",
         "label_column": "input_label",
     }
     task = ImageClassification.from_dict(template_dict)
     self.assertEqual("image-classification", task.task)
     self.assertEqual(input_schema, task.input_schema)
     self.assertEqual(label_schema, task.label_schema)
 def test_from_dict(self):
     input_schema = Features({"image_file_path": Value("string")})
     label_schema = Features(
         {"labels": ClassLabel(names=tuple(self.labels))})
     template_dict = {
         "image_file_path_column": "input_image_file_path",
         "label_column": "input_label",
         "labels": self.labels,
     }
     task = ImageClassification.from_dict(template_dict)
     self.assertEqual("image-classification", task.task)
     self.assertEqual(input_schema, task.input_schema)
     self.assertEqual(label_schema, task.label_schema)
Exemple #7
0
    def _info(self):

        return datasets.DatasetInfo(
            description=_DESCRIPTION,
            features=datasets.Features(self._get_feature_types()),
            supervised_keys=("image", "label"),
            homepage=_HOMEPAGE,
            license=_LICENSE,
            citation=_CITATION,
            task_templates=[
                ImageClassification(image_column="image", label_column="label")
            ],
        )
Exemple #8
0
 def _info(self):
     return ds.DatasetInfo(
         description="",
         citation="",
         homepage="",
         license="",
         features=ds.Features({
             "img": ds.Image(),
             "label": ds.features.ClassLabel(names=_NAMES),
         }),
         supervised_keys=("img", "label"),
         task_templates=ImageClassification(image_column="img",
                                            label_column="label"),
     )
Exemple #9
0
 def _info(self):
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=datasets.Features({
             "img":
             datasets.Image(),
             "label":
             datasets.features.ClassLabel(names=_NAMES),
         }),
         supervised_keys=("img", "label"),
         homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
         citation=_CITATION,
         task_templates=ImageClassification(image_column="img",
                                            label_column="label"),
     )
 def _info(self):
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=datasets.Features(
             {
                 "image_file_path": datasets.Value("string"),
                 "image": datasets.Image(),
                 "labels": datasets.features.ClassLabel(names=_NAMES),
             }
         ),
         supervised_keys=("image", "labels"),
         homepage=_HOMEPAGE,
         citation=_CITATION,
         task_templates=[ImageClassification(image_column="image", label_column="labels")],
     )
Exemple #11
0
 def _info(self):
     assert len(IMAGENET2012_CLASSES) == 1000
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=datasets.Features({
             "image":
             datasets.Image(),
             "label":
             datasets.ClassLabel(names=list(IMAGENET2012_CLASSES.values())),
         }),
         homepage=_HOMEPAGE,
         citation=_CITATION,
         task_templates=[
             ImageClassification(image_column="image", label_column="label")
         ],
     )
Exemple #12
0
 def _info(self):
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=datasets.Features(
             {
                 "img": datasets.Image(),
                 "fine_label": datasets.features.ClassLabel(names=_FINE_LABEL_NAMES),
                 "coarse_label": datasets.features.ClassLabel(names=_COARSE_LABEL_NAMES),
             }
         ),
         supervised_keys=None,  # Probably needs to be fixed.
         homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
         citation=_CITATION,
         task_templates=[
             ImageClassification(image_column="img", label_column="fine_label", labels=_FINE_LABEL_NAMES)
         ],
     )
Exemple #13
0
 def _info(self):
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=datasets.Features({
             "image":
             datasets.Image(),
             "label":
             datasets.ClassLabel(names=_NAMES),
         }),
         supervised_keys=("image", "label"),
         homepage=_HOMEPAGE,
         citation=_CITATION,
         license=_LICENSE,
         task_templates=[
             ImageClassification(image_column="image", label_column="label")
         ],
     )
Exemple #14
0
 def _info(self):
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=datasets.Features({
             "image":
             datasets.Image(),
             "label":
             datasets.features.ClassLabel(
                 names=["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]),
         }),
         supervised_keys=("image", "label"),
         homepage="http://yann.lecun.com/exdb/mnist/",
         citation=_CITATION,
         task_templates=[
             ImageClassification(
                 image_column="image",
                 label_column="label",
             )
         ],
     )
Exemple #15
0
 def _info(self):
     if self.config.name == "raw":
         features = datasets.Features({
             "key_id":
             datasets.Value("string"),
             "word":
             datasets.ClassLabel(names=_NAMES),
             "recognized":
             datasets.Value("bool"),
             "timestamp":
             datasets.Value("timestamp[us, tz=UTC]"),
             "countrycode":
             datasets.Value("string"),
             "drawing":
             datasets.Sequence({
                 "x":
                 datasets.Sequence(datasets.Value("float32")),
                 "y":
                 datasets.Sequence(datasets.Value("float32")),
                 "t":
                 datasets.Sequence(datasets.Value("int32")),
             }),
         })
     elif self.config.name == "preprocessed_simplified_drawings":
         features = datasets.Features({
             "key_id":
             datasets.Value("string"),
             "word":
             datasets.ClassLabel(names=_NAMES),
             "recognized":
             datasets.Value("bool"),
             "timestamp":
             datasets.Value("timestamp[us, tz=UTC]"),
             "countrycode":
             datasets.Value("string"),
             "drawing":
             datasets.Sequence({
                 "x":
                 datasets.Sequence(datasets.Value("uint8")),
                 "y":
                 datasets.Sequence(datasets.Value("uint8")),
             }),
         })
     elif self.config.name == "preprocessed_bitmaps":
         features = datasets.Features({
             "image":
             datasets.Image(),
             "label":
             datasets.ClassLabel(names=_NAMES),
         })
     else:  # sketch_rnn, sketch_rnn_full
         features = datasets.Features({
             "word":
             datasets.ClassLabel(names=_NAMES),
             "drawing":
             datasets.Array2D(shape=(None, 3), dtype="int16"),
         })
     return datasets.DatasetInfo(
         description=_DESCRIPTION,
         features=features,
         homepage=_HOMEPAGE,
         license=_LICENSE,
         citation=_CITATION,
         task_templates=[
             ImageClassification(image_column="image", label_column="label")
         ] if self.config.name == "preprocessed_bitmaps" else None,
     )
Exemple #16
0
    def _split_generators(self, dl_manager):
        if not self.config.data_files:
            raise ValueError(
                f"At least one data file must be specified, but got data_files={self.config.data_files}"
            )

        capture_labels = not self.config.drop_labels and self.config.features is None
        if capture_labels:
            labels = set()

            def capture_labels_for_split(files_or_archives,
                                         downloaded_files_or_dirs):
                if len(downloaded_files_or_dirs) == 0:
                    return
                # The files are separated from the archives at this point, so check the first sample
                # to see if it's a file or a directory and iterate accordingly
                if os.path.isfile(downloaded_files_or_dirs[0]):
                    files, downloaded_files = files_or_archives, downloaded_files_or_dirs
                    for file, downloaded_file in zip(files, downloaded_files):
                        file, downloaded_file = str(file), str(downloaded_file)
                        _, file_ext = os.path.splitext(file)
                        if file_ext.lower() in self.IMAGE_EXTENSIONS:
                            labels.add(os.path.basename(os.path.dirname(file)))
                else:
                    archives, downloaded_dirs = files_or_archives, downloaded_files_or_dirs
                    for archive, downloaded_dir in zip(archives,
                                                       downloaded_dirs):
                        archive, downloaded_file = str(archive), str(
                            downloaded_dir)
                        for downloaded_dir_file in dl_manager.iter_files(
                                downloaded_dir):
                            _, downloaded_dir_file_ext = os.path.splitext(
                                downloaded_dir_file)
                            if downloaded_dir_file_ext in self.IMAGE_EXTENSIONS:
                                labels.add(
                                    os.path.basename(
                                        os.path.dirname(downloaded_dir_file)))

            logger.info("Inferring labels from data files...")

        data_files = self.config.data_files
        splits = []
        if isinstance(data_files, (str, list, tuple)):
            files = data_files
            if isinstance(files, str):
                files = [files]
            files, archives = self._split_files_and_archives(files)
            downloaded_files = dl_manager.download(files)
            downloaded_dirs = dl_manager.download_and_extract(archives)
            if capture_labels:
                capture_labels_for_split(files, downloaded_files)
                capture_labels_for_split(archives, downloaded_dirs)
            splits.append(
                datasets.SplitGenerator(
                    name=datasets.Split.TRAIN,
                    gen_kwargs={
                        "files": [(file, downloaded_file)
                                  for file, downloaded_file in zip(
                                      files, downloaded_files)] +
                        [(None, dl_manager.iter_files(downloaded_dir))
                         for downloaded_dir in downloaded_dirs]
                    },
                ))
        else:
            for split_name, files in data_files.items():
                if isinstance(files, str):
                    files = [files]
                files, archives = self._split_files_and_archives(files)
                downloaded_files = dl_manager.download(files)
                downloaded_dirs = dl_manager.download_and_extract(archives)
                if capture_labels:
                    capture_labels_for_split(files, downloaded_files)
                    capture_labels_for_split(archives, downloaded_dirs)
                splits.append(
                    datasets.SplitGenerator(
                        name=split_name,
                        gen_kwargs={
                            "files": [(file, downloaded_file)
                                      for file, downloaded_file in zip(
                                          files, downloaded_files)] +
                            [(None, dl_manager.iter_files(downloaded_dir))
                             for downloaded_dir in downloaded_dirs]
                        },
                    ))

        # Normally we would do this in _info, but we need to know the labels before building the features
        if capture_labels:
            if not self.config.drop_labels:
                self.info.features = datasets.Features({
                    "image":
                    datasets.Image(),
                    "label":
                    datasets.ClassLabel(names=sorted(labels))
                })
                task_template = ImageClassification(image_column="image",
                                                    label_column="label")
                task_template = task_template.align_with_features(
                    self.info.features)
                self.info.task_templates = [task_template]
            else:
                self.info.features = datasets.Features(
                    {"image": datasets.Image()})

        return splits
Exemple #17
0
    def _split_generators(self, dl_manager):
        if not self.config.data_files:
            raise ValueError(
                f"At least one data file must be specified, but got data_files={self.config.data_files}"
            )

        # Do an early pass if:
        # * `features` are not specified, to infer the class labels
        # * `drop_metadata` is False, to find the metadata files
        do_analyze = (
            self.config.features is None
            and not self.config.drop_labels) or not self.config.drop_metadata
        if do_analyze:
            labels = set()
            metadata_files = collections.defaultdict(list)

            def analyze(files_or_archives, downloaded_files_or_dirs, split):
                if len(downloaded_files_or_dirs) == 0:
                    return
                # The files are separated from the archives at this point, so check the first sample
                # to see if it's a file or a directory and iterate accordingly
                if os.path.isfile(downloaded_files_or_dirs[0]):
                    original_files, downloaded_files = files_or_archives, downloaded_files_or_dirs
                    for original_file, downloaded_file in zip(
                            original_files, downloaded_files):
                        original_file, downloaded_file = str(
                            original_file), str(downloaded_file)
                        _, original_file_ext = os.path.splitext(original_file)
                        if original_file_ext.lower() in self.IMAGE_EXTENSIONS:
                            labels.add(
                                os.path.basename(
                                    os.path.dirname(original_file)))
                        elif os.path.basename(
                                original_file) == self.METADATA_FILENAME:
                            metadata_files[split].append(
                                (original_file, downloaded_file))
                        else:
                            original_file_name = os.path.basename(
                                original_file)
                            logger.debug(
                                f"The file '{original_file_name}' was ignored: it is not an image, and is not {self.METADATA_FILENAME} either."
                            )
                else:
                    archives, downloaded_dirs = files_or_archives, downloaded_files_or_dirs
                    for archive, downloaded_dir in zip(archives,
                                                       downloaded_dirs):
                        archive, downloaded_dir = str(archive), str(
                            downloaded_dir)
                        for downloaded_dir_file in dl_manager.iter_files(
                                downloaded_dir):
                            _, downloaded_dir_file_ext = os.path.splitext(
                                downloaded_dir_file)
                            if downloaded_dir_file_ext in self.IMAGE_EXTENSIONS:
                                labels.add(
                                    os.path.basename(
                                        os.path.dirname(downloaded_dir_file)))
                            elif os.path.basename(downloaded_dir_file
                                                  ) == self.METADATA_FILENAME:
                                metadata_files[split].append(
                                    (None, downloaded_dir_file))
                            else:
                                archive_file_name = os.path.basename(archive)
                                original_file_name = os.path.basename(
                                    downloaded_dir_file)
                                logger.debug(
                                    f"The file '{original_file_name}' from the archive '{archive_file_name}' was ignored: it is not an image, and is not {self.METADATA_FILENAME} either."
                                )

            if not self.config.drop_labels:
                logger.info("Inferring labels from data files...")
            if not self.config.drop_metadata:
                logger.info("Analyzing metadata files...")

        data_files = self.config.data_files
        splits = []
        for split_name, files in data_files.items():
            if isinstance(files, str):
                files = [files]
            files, archives = self._split_files_and_archives(files)
            downloaded_files = dl_manager.download(files)
            downloaded_dirs = dl_manager.download_and_extract(archives)
            if do_analyze:
                analyze(files, downloaded_files, split_name)
                analyze(archives, downloaded_dirs, split_name)
            splits.append(
                datasets.SplitGenerator(
                    name=split_name,
                    gen_kwargs={
                        "files": [(file, downloaded_file)
                                  for file, downloaded_file in zip(
                                      files, downloaded_files)] +
                        [(None, dl_manager.iter_files(downloaded_dir))
                         for downloaded_dir in downloaded_dirs],
                        "metadata_files":
                        metadata_files
                        if not self.config.drop_metadata else None,
                        "split_name":
                        split_name,
                    },
                ))

        if not self.config.drop_metadata and metadata_files:
            # Verify that:
            # * all metadata files have the same set of features
            # * the `file_name` key is one of the metadata keys and is of type string
            features_per_metadata_file: List[Tuple[str,
                                                   datasets.Features]] = []
            for _, downloaded_metadata_file in itertools.chain.from_iterable(
                    metadata_files.values()):
                with open(downloaded_metadata_file, "rb") as f:
                    pa_metadata_table = paj.read_json(f)
                features_per_metadata_file.append(
                    (downloaded_metadata_file,
                     datasets.Features.from_arrow_schema(
                         pa_metadata_table.schema)))
            for downloaded_metadata_file, metadata_features in features_per_metadata_file:
                if metadata_features != features_per_metadata_file[0][1]:
                    raise ValueError(
                        f"Metadata files {downloaded_metadata_file} and {features_per_metadata_file[0][0]} have different features: {features_per_metadata_file[0]} != {metadata_features}"
                    )
            metadata_features = features_per_metadata_file[0][1]
            if "file_name" not in metadata_features:
                raise ValueError(
                    "`file_name` must be present as dictionary key in metadata files"
                )
            if metadata_features["file_name"] != datasets.Value("string"):
                raise ValueError("`file_name` key must be a string")
            del metadata_features["file_name"]
        else:
            metadata_features = None

        # Normally, we would do this in _info, but we need to know the labels and/or metadata
        # before building the features
        if self.config.features is None:
            if not self.config.drop_labels and not metadata_files:
                self.info.features = datasets.Features({
                    "image":
                    datasets.Image(),
                    "label":
                    datasets.ClassLabel(names=sorted(labels))
                })
                task_template = ImageClassification(image_column="image",
                                                    label_column="label")
                task_template = task_template.align_with_features(
                    self.info.features)
                self.info.task_templates = [task_template]
            else:
                self.info.features = datasets.Features(
                    {"image": datasets.Image()})

            if not self.config.drop_metadata and metadata_files:
                # Verify that there are no duplicated keys when compared to the existing features ("image", optionally "label")
                duplicated_keys = set(
                    self.info.features) & set(metadata_features)
                if duplicated_keys:
                    raise ValueError(
                        f"Metadata feature keys {list(duplicated_keys)} are already present as the image features"
                    )
                self.info.features.update(metadata_features)

        return splits