def test_align_with_features(self):
     task = ImageClassification(image_column="input_image",
                                label_column="input_label")
     self.assertEqual(task.label_schema["labels"], ClassLabel)
     task = task.align_with_features(
         Features({"input_label": ClassLabel(names=self.labels)}))
     self.assertEqual(task.label_schema["labels"],
                      ClassLabel(names=self.labels))
Example #2
0
    def _split_generators(self, dl_manager):
        if not self.config.data_files:
            raise ValueError(
                f"At least one data file must be specified, but got data_files={self.config.data_files}"
            )

        capture_labels = not self.config.drop_labels and self.config.features is None
        if capture_labels:
            labels = set()

            def capture_labels_for_split(files_or_archives,
                                         downloaded_files_or_dirs):
                if len(downloaded_files_or_dirs) == 0:
                    return
                # The files are separated from the archives at this point, so check the first sample
                # to see if it's a file or a directory and iterate accordingly
                if os.path.isfile(downloaded_files_or_dirs[0]):
                    files, downloaded_files = files_or_archives, downloaded_files_or_dirs
                    for file, downloaded_file in zip(files, downloaded_files):
                        file, downloaded_file = str(file), str(downloaded_file)
                        _, file_ext = os.path.splitext(file)
                        if file_ext.lower() in self.IMAGE_EXTENSIONS:
                            labels.add(os.path.basename(os.path.dirname(file)))
                else:
                    archives, downloaded_dirs = files_or_archives, downloaded_files_or_dirs
                    for archive, downloaded_dir in zip(archives,
                                                       downloaded_dirs):
                        archive, downloaded_file = str(archive), str(
                            downloaded_dir)
                        for downloaded_dir_file in dl_manager.iter_files(
                                downloaded_dir):
                            _, downloaded_dir_file_ext = os.path.splitext(
                                downloaded_dir_file)
                            if downloaded_dir_file_ext in self.IMAGE_EXTENSIONS:
                                labels.add(
                                    os.path.basename(
                                        os.path.dirname(downloaded_dir_file)))

            logger.info("Inferring labels from data files...")

        data_files = self.config.data_files
        splits = []
        if isinstance(data_files, (str, list, tuple)):
            files = data_files
            if isinstance(files, str):
                files = [files]
            files, archives = self._split_files_and_archives(files)
            downloaded_files = dl_manager.download(files)
            downloaded_dirs = dl_manager.download_and_extract(archives)
            if capture_labels:
                capture_labels_for_split(files, downloaded_files)
                capture_labels_for_split(archives, downloaded_dirs)
            splits.append(
                datasets.SplitGenerator(
                    name=datasets.Split.TRAIN,
                    gen_kwargs={
                        "files": [(file, downloaded_file)
                                  for file, downloaded_file in zip(
                                      files, downloaded_files)] +
                        [(None, dl_manager.iter_files(downloaded_dir))
                         for downloaded_dir in downloaded_dirs]
                    },
                ))
        else:
            for split_name, files in data_files.items():
                if isinstance(files, str):
                    files = [files]
                files, archives = self._split_files_and_archives(files)
                downloaded_files = dl_manager.download(files)
                downloaded_dirs = dl_manager.download_and_extract(archives)
                if capture_labels:
                    capture_labels_for_split(files, downloaded_files)
                    capture_labels_for_split(archives, downloaded_dirs)
                splits.append(
                    datasets.SplitGenerator(
                        name=split_name,
                        gen_kwargs={
                            "files": [(file, downloaded_file)
                                      for file, downloaded_file in zip(
                                          files, downloaded_files)] +
                            [(None, dl_manager.iter_files(downloaded_dir))
                             for downloaded_dir in downloaded_dirs]
                        },
                    ))

        # Normally we would do this in _info, but we need to know the labels before building the features
        if capture_labels:
            if not self.config.drop_labels:
                self.info.features = datasets.Features({
                    "image":
                    datasets.Image(),
                    "label":
                    datasets.ClassLabel(names=sorted(labels))
                })
                task_template = ImageClassification(image_column="image",
                                                    label_column="label")
                task_template = task_template.align_with_features(
                    self.info.features)
                self.info.task_templates = [task_template]
            else:
                self.info.features = datasets.Features(
                    {"image": datasets.Image()})

        return splits
Example #3
0
    def _split_generators(self, dl_manager):
        if not self.config.data_files:
            raise ValueError(
                f"At least one data file must be specified, but got data_files={self.config.data_files}"
            )

        # Do an early pass if:
        # * `features` are not specified, to infer the class labels
        # * `drop_metadata` is False, to find the metadata files
        do_analyze = (
            self.config.features is None
            and not self.config.drop_labels) or not self.config.drop_metadata
        if do_analyze:
            labels = set()
            metadata_files = collections.defaultdict(list)

            def analyze(files_or_archives, downloaded_files_or_dirs, split):
                if len(downloaded_files_or_dirs) == 0:
                    return
                # The files are separated from the archives at this point, so check the first sample
                # to see if it's a file or a directory and iterate accordingly
                if os.path.isfile(downloaded_files_or_dirs[0]):
                    original_files, downloaded_files = files_or_archives, downloaded_files_or_dirs
                    for original_file, downloaded_file in zip(
                            original_files, downloaded_files):
                        original_file, downloaded_file = str(
                            original_file), str(downloaded_file)
                        _, original_file_ext = os.path.splitext(original_file)
                        if original_file_ext.lower() in self.IMAGE_EXTENSIONS:
                            labels.add(
                                os.path.basename(
                                    os.path.dirname(original_file)))
                        elif os.path.basename(
                                original_file) == self.METADATA_FILENAME:
                            metadata_files[split].append(
                                (original_file, downloaded_file))
                        else:
                            original_file_name = os.path.basename(
                                original_file)
                            logger.debug(
                                f"The file '{original_file_name}' was ignored: it is not an image, and is not {self.METADATA_FILENAME} either."
                            )
                else:
                    archives, downloaded_dirs = files_or_archives, downloaded_files_or_dirs
                    for archive, downloaded_dir in zip(archives,
                                                       downloaded_dirs):
                        archive, downloaded_dir = str(archive), str(
                            downloaded_dir)
                        for downloaded_dir_file in dl_manager.iter_files(
                                downloaded_dir):
                            _, downloaded_dir_file_ext = os.path.splitext(
                                downloaded_dir_file)
                            if downloaded_dir_file_ext in self.IMAGE_EXTENSIONS:
                                labels.add(
                                    os.path.basename(
                                        os.path.dirname(downloaded_dir_file)))
                            elif os.path.basename(downloaded_dir_file
                                                  ) == self.METADATA_FILENAME:
                                metadata_files[split].append(
                                    (None, downloaded_dir_file))
                            else:
                                archive_file_name = os.path.basename(archive)
                                original_file_name = os.path.basename(
                                    downloaded_dir_file)
                                logger.debug(
                                    f"The file '{original_file_name}' from the archive '{archive_file_name}' was ignored: it is not an image, and is not {self.METADATA_FILENAME} either."
                                )

            if not self.config.drop_labels:
                logger.info("Inferring labels from data files...")
            if not self.config.drop_metadata:
                logger.info("Analyzing metadata files...")

        data_files = self.config.data_files
        splits = []
        for split_name, files in data_files.items():
            if isinstance(files, str):
                files = [files]
            files, archives = self._split_files_and_archives(files)
            downloaded_files = dl_manager.download(files)
            downloaded_dirs = dl_manager.download_and_extract(archives)
            if do_analyze:
                analyze(files, downloaded_files, split_name)
                analyze(archives, downloaded_dirs, split_name)
            splits.append(
                datasets.SplitGenerator(
                    name=split_name,
                    gen_kwargs={
                        "files": [(file, downloaded_file)
                                  for file, downloaded_file in zip(
                                      files, downloaded_files)] +
                        [(None, dl_manager.iter_files(downloaded_dir))
                         for downloaded_dir in downloaded_dirs],
                        "metadata_files":
                        metadata_files
                        if not self.config.drop_metadata else None,
                        "split_name":
                        split_name,
                    },
                ))

        if not self.config.drop_metadata and metadata_files:
            # Verify that:
            # * all metadata files have the same set of features
            # * the `file_name` key is one of the metadata keys and is of type string
            features_per_metadata_file: List[Tuple[str,
                                                   datasets.Features]] = []
            for _, downloaded_metadata_file in itertools.chain.from_iterable(
                    metadata_files.values()):
                with open(downloaded_metadata_file, "rb") as f:
                    pa_metadata_table = paj.read_json(f)
                features_per_metadata_file.append(
                    (downloaded_metadata_file,
                     datasets.Features.from_arrow_schema(
                         pa_metadata_table.schema)))
            for downloaded_metadata_file, metadata_features in features_per_metadata_file:
                if metadata_features != features_per_metadata_file[0][1]:
                    raise ValueError(
                        f"Metadata files {downloaded_metadata_file} and {features_per_metadata_file[0][0]} have different features: {features_per_metadata_file[0]} != {metadata_features}"
                    )
            metadata_features = features_per_metadata_file[0][1]
            if "file_name" not in metadata_features:
                raise ValueError(
                    "`file_name` must be present as dictionary key in metadata files"
                )
            if metadata_features["file_name"] != datasets.Value("string"):
                raise ValueError("`file_name` key must be a string")
            del metadata_features["file_name"]
        else:
            metadata_features = None

        # Normally, we would do this in _info, but we need to know the labels and/or metadata
        # before building the features
        if self.config.features is None:
            if not self.config.drop_labels and not metadata_files:
                self.info.features = datasets.Features({
                    "image":
                    datasets.Image(),
                    "label":
                    datasets.ClassLabel(names=sorted(labels))
                })
                task_template = ImageClassification(image_column="image",
                                                    label_column="label")
                task_template = task_template.align_with_features(
                    self.info.features)
                self.info.task_templates = [task_template]
            else:
                self.info.features = datasets.Features(
                    {"image": datasets.Image()})

            if not self.config.drop_metadata and metadata_files:
                # Verify that there are no duplicated keys when compared to the existing features ("image", optionally "label")
                duplicated_keys = set(
                    self.info.features) & set(metadata_features)
                if duplicated_keys:
                    raise ValueError(
                        f"Metadata feature keys {list(duplicated_keys)} are already present as the image features"
                    )
                self.info.features.update(metadata_features)

        return splits