def test_align_with_features(self): task = ImageClassification(image_column="input_image", label_column="input_label") self.assertEqual(task.label_schema["labels"], ClassLabel) task = task.align_with_features( Features({"input_label": ClassLabel(names=self.labels)})) self.assertEqual(task.label_schema["labels"], ClassLabel(names=self.labels))
def _split_generators(self, dl_manager): if not self.config.data_files: raise ValueError( f"At least one data file must be specified, but got data_files={self.config.data_files}" ) capture_labels = not self.config.drop_labels and self.config.features is None if capture_labels: labels = set() def capture_labels_for_split(files_or_archives, downloaded_files_or_dirs): if len(downloaded_files_or_dirs) == 0: return # The files are separated from the archives at this point, so check the first sample # to see if it's a file or a directory and iterate accordingly if os.path.isfile(downloaded_files_or_dirs[0]): files, downloaded_files = files_or_archives, downloaded_files_or_dirs for file, downloaded_file in zip(files, downloaded_files): file, downloaded_file = str(file), str(downloaded_file) _, file_ext = os.path.splitext(file) if file_ext.lower() in self.IMAGE_EXTENSIONS: labels.add(os.path.basename(os.path.dirname(file))) else: archives, downloaded_dirs = files_or_archives, downloaded_files_or_dirs for archive, downloaded_dir in zip(archives, downloaded_dirs): archive, downloaded_file = str(archive), str( downloaded_dir) for downloaded_dir_file in dl_manager.iter_files( downloaded_dir): _, downloaded_dir_file_ext = os.path.splitext( downloaded_dir_file) if downloaded_dir_file_ext in self.IMAGE_EXTENSIONS: labels.add( os.path.basename( os.path.dirname(downloaded_dir_file))) logger.info("Inferring labels from data files...") data_files = self.config.data_files splits = [] if isinstance(data_files, (str, list, tuple)): files = data_files if isinstance(files, str): files = [files] files, archives = self._split_files_and_archives(files) downloaded_files = dl_manager.download(files) downloaded_dirs = dl_manager.download_and_extract(archives) if capture_labels: capture_labels_for_split(files, downloaded_files) capture_labels_for_split(archives, downloaded_dirs) splits.append( datasets.SplitGenerator( name=datasets.Split.TRAIN, gen_kwargs={ "files": [(file, downloaded_file) for file, downloaded_file in zip( files, downloaded_files)] + [(None, dl_manager.iter_files(downloaded_dir)) for downloaded_dir in downloaded_dirs] }, )) else: for split_name, files in data_files.items(): if isinstance(files, str): files = [files] files, archives = self._split_files_and_archives(files) downloaded_files = dl_manager.download(files) downloaded_dirs = dl_manager.download_and_extract(archives) if capture_labels: capture_labels_for_split(files, downloaded_files) capture_labels_for_split(archives, downloaded_dirs) splits.append( datasets.SplitGenerator( name=split_name, gen_kwargs={ "files": [(file, downloaded_file) for file, downloaded_file in zip( files, downloaded_files)] + [(None, dl_manager.iter_files(downloaded_dir)) for downloaded_dir in downloaded_dirs] }, )) # Normally we would do this in _info, but we need to know the labels before building the features if capture_labels: if not self.config.drop_labels: self.info.features = datasets.Features({ "image": datasets.Image(), "label": datasets.ClassLabel(names=sorted(labels)) }) task_template = ImageClassification(image_column="image", label_column="label") task_template = task_template.align_with_features( self.info.features) self.info.task_templates = [task_template] else: self.info.features = datasets.Features( {"image": datasets.Image()}) return splits
def _split_generators(self, dl_manager): if not self.config.data_files: raise ValueError( f"At least one data file must be specified, but got data_files={self.config.data_files}" ) # Do an early pass if: # * `features` are not specified, to infer the class labels # * `drop_metadata` is False, to find the metadata files do_analyze = ( self.config.features is None and not self.config.drop_labels) or not self.config.drop_metadata if do_analyze: labels = set() metadata_files = collections.defaultdict(list) def analyze(files_or_archives, downloaded_files_or_dirs, split): if len(downloaded_files_or_dirs) == 0: return # The files are separated from the archives at this point, so check the first sample # to see if it's a file or a directory and iterate accordingly if os.path.isfile(downloaded_files_or_dirs[0]): original_files, downloaded_files = files_or_archives, downloaded_files_or_dirs for original_file, downloaded_file in zip( original_files, downloaded_files): original_file, downloaded_file = str( original_file), str(downloaded_file) _, original_file_ext = os.path.splitext(original_file) if original_file_ext.lower() in self.IMAGE_EXTENSIONS: labels.add( os.path.basename( os.path.dirname(original_file))) elif os.path.basename( original_file) == self.METADATA_FILENAME: metadata_files[split].append( (original_file, downloaded_file)) else: original_file_name = os.path.basename( original_file) logger.debug( f"The file '{original_file_name}' was ignored: it is not an image, and is not {self.METADATA_FILENAME} either." ) else: archives, downloaded_dirs = files_or_archives, downloaded_files_or_dirs for archive, downloaded_dir in zip(archives, downloaded_dirs): archive, downloaded_dir = str(archive), str( downloaded_dir) for downloaded_dir_file in dl_manager.iter_files( downloaded_dir): _, downloaded_dir_file_ext = os.path.splitext( downloaded_dir_file) if downloaded_dir_file_ext in self.IMAGE_EXTENSIONS: labels.add( os.path.basename( os.path.dirname(downloaded_dir_file))) elif os.path.basename(downloaded_dir_file ) == self.METADATA_FILENAME: metadata_files[split].append( (None, downloaded_dir_file)) else: archive_file_name = os.path.basename(archive) original_file_name = os.path.basename( downloaded_dir_file) logger.debug( f"The file '{original_file_name}' from the archive '{archive_file_name}' was ignored: it is not an image, and is not {self.METADATA_FILENAME} either." ) if not self.config.drop_labels: logger.info("Inferring labels from data files...") if not self.config.drop_metadata: logger.info("Analyzing metadata files...") data_files = self.config.data_files splits = [] for split_name, files in data_files.items(): if isinstance(files, str): files = [files] files, archives = self._split_files_and_archives(files) downloaded_files = dl_manager.download(files) downloaded_dirs = dl_manager.download_and_extract(archives) if do_analyze: analyze(files, downloaded_files, split_name) analyze(archives, downloaded_dirs, split_name) splits.append( datasets.SplitGenerator( name=split_name, gen_kwargs={ "files": [(file, downloaded_file) for file, downloaded_file in zip( files, downloaded_files)] + [(None, dl_manager.iter_files(downloaded_dir)) for downloaded_dir in downloaded_dirs], "metadata_files": metadata_files if not self.config.drop_metadata else None, "split_name": split_name, }, )) if not self.config.drop_metadata and metadata_files: # Verify that: # * all metadata files have the same set of features # * the `file_name` key is one of the metadata keys and is of type string features_per_metadata_file: List[Tuple[str, datasets.Features]] = [] for _, downloaded_metadata_file in itertools.chain.from_iterable( metadata_files.values()): with open(downloaded_metadata_file, "rb") as f: pa_metadata_table = paj.read_json(f) features_per_metadata_file.append( (downloaded_metadata_file, datasets.Features.from_arrow_schema( pa_metadata_table.schema))) for downloaded_metadata_file, metadata_features in features_per_metadata_file: if metadata_features != features_per_metadata_file[0][1]: raise ValueError( f"Metadata files {downloaded_metadata_file} and {features_per_metadata_file[0][0]} have different features: {features_per_metadata_file[0]} != {metadata_features}" ) metadata_features = features_per_metadata_file[0][1] if "file_name" not in metadata_features: raise ValueError( "`file_name` must be present as dictionary key in metadata files" ) if metadata_features["file_name"] != datasets.Value("string"): raise ValueError("`file_name` key must be a string") del metadata_features["file_name"] else: metadata_features = None # Normally, we would do this in _info, but we need to know the labels and/or metadata # before building the features if self.config.features is None: if not self.config.drop_labels and not metadata_files: self.info.features = datasets.Features({ "image": datasets.Image(), "label": datasets.ClassLabel(names=sorted(labels)) }) task_template = ImageClassification(image_column="image", label_column="label") task_template = task_template.align_with_features( self.info.features) self.info.task_templates = [task_template] else: self.info.features = datasets.Features( {"image": datasets.Image()}) if not self.config.drop_metadata and metadata_files: # Verify that there are no duplicated keys when compared to the existing features ("image", optionally "label") duplicated_keys = set( self.info.features) & set(metadata_features) if duplicated_keys: raise ValueError( f"Metadata feature keys {list(duplicated_keys)} are already present as the image features" ) self.info.features.update(metadata_features) return splits