def setUp(self): super(DatasetBuilderTestCase, self).setUp() self.patchers = [] self.builder = self._make_builder() example_dir = self.DATASET_CLASS.code_path.parent / "dummy_data" fake_example_dir = utils.as_path(test_utils.fake_examples_dir()) if self.EXAMPLE_DIR is not None: self.example_dir = utils.as_path(self.EXAMPLE_DIR) example_dir = self.example_dir # Dir to display in the error elif example_dir.exists(): self.example_dir = example_dir else: self.example_dir = fake_example_dir / self.builder.name if not self.example_dir.exists(): err_msg = f"Dummy data not found in: {example_dir}" raise ValueError(err_msg) if self.MOCK_OUT_FORBIDDEN_OS_FUNCTIONS: self._mock_out_forbidden_os_functions() # Track the urls which are downloaded to validate the checksums # The `dl_manager.download` and `dl_manager.download_and_extract` are # patched to record the urls in `_download_urls`. # Calling `dl_manager.download_checksums` stop the url # registration (as checksums are stored remotelly) # `_test_checksums` validates the recorded urls. self._download_urls = set() self._stop_record_download = False
def from_json(cls, value: utils.JsonValue) -> 'DatasetSource': """Imports from JSON.""" if isinstance(value, str): # Single-file dataset ('.../my_dataset.py') path = utils.as_path(value) return cls(root_path=path.parent, filenames=[path.name]) elif isinstance(value, dict): # Multi-file dataset return cls( root_path=utils.as_path(value['root_path']), filenames=value['filenames'], ) else: raise ValueError(f'Invalid input: {value}')
def __init__(self, name, url=None, content=None): url = url or f'http://foo-bar.ch/{name}' content = content or f'content of {name}' self.url = url self.url_info = checksums_lib.UrlInfo( size=len(content), checksum=_sha256(content), filename=name, ) self.file_name = resource_lib.get_dl_fname(url, self.url_info.checksum) self.file_path = utils.as_path(f'/dl_dir/{self.file_name}') self.url_name = resource_lib.get_dl_fname(url, _sha256(url)) self.url_path = utils.as_path(f'/dl_dir/{self.url_name}')
def _compute_split_statistics_beam( *, split_files: _SplitFilesDict, data_dir: utils.ReadWritePath, out_dir: utils.PathLike, ) -> List[split_lib.SplitInfo]: """Compute statistics.""" out_dir = utils.as_path(out_dir) assert out_dir.exists(), f'{out_dir} does not exists' beam = lazy_imports_lib.lazy_imports.apache_beam # Launch the beam pipeline computation runner = None # Create the global pipeline object common for all splits # Disable type_hint as it doesn't works with typing.Protocol beam_options = beam.options.pipeline_options.PipelineOptions() beam_options.view_as( beam.options.pipeline_options.TypeOptions).pipeline_type_check = False with beam.Pipeline(runner=runner, options=beam_options) as pipeline: for split_name, file_infos in split_files.items(): _ = pipeline | split_name >> _process_split( # pylint: disable=no-value-for-parameter data_dir=data_dir, out_dir=out_dir, file_infos=file_infos, # pytype: disable=missing-parameter ) # After the files have been computed return [ _split_info_from_path(out_dir / _out_filename(split_name)) for split_name in split_files ]
def _get_repr_html_ffmpeg(images: List[PilImage]) -> str: """Runs ffmpeg to get the mp4 encoded <video> str.""" # Find number of digits in len to give names. num_digits = len(str(len(images))) + 1 with tempfile.TemporaryDirectory() as video_dir: for i, img in enumerate(images): f = os.path.join(video_dir, f'img{i:0{num_digits}d}.png') img.save(f, format='png') ffmpeg_args = [ '-framerate', str(_VISU_FRAMERATE), '-i', os.path.join(video_dir, f'img%0{num_digits}d.png'), # Using native h264 to encode video stream to H.264 codec # Default encoding does not seems to be supported by chrome. '-vcodec', 'h264', # When outputting H.264, `-pix_fmt yuv420p` maximize compatibility # with bad video players. # Ref: https://trac.ffmpeg.org/wiki/Slideshow '-pix_fmt', 'yuv420p', # Native encoder cannot encode images of small scale # or the the hardware encoder may be busy which raises # Error: cannot create compression session # so allow software encoding # '-allow_sw', '1', ] video_path = utils.as_path(video_dir) / 'output.mp4' ffmpeg_args.append(os.fspath(video_path)) utils.ffmpeg_run(ffmpeg_args) video_str = utils.get_base64(video_path.read_bytes()) return ( f'<video height="{THUMBNAIL_SIZE}" width="175" ' 'controls loop autoplay muted playsinline>' f'<source src="data:video/mp4;base64,{video_str}" type="video/mp4" >' '</video>' )
def download_from_uri(uri: str, dst: utils.ReadWritePath) -> str: """Download the remote dataset code locally to the dst path. Args: uri: Source of the dataset. Can be: * A local/GCS path (e.g. `gs://bucket/datasets/my_dataset/`) * A github source dst: Empty directory on which copying the source Returns: The module mame of the package. """ if uri.startswith('github://'): raise NotImplementedError('Github sources not supported yet') path = utils.as_path(uri) if not path.exists(): raise ValueError(f'Unsuported source: {uri}') # Download the main file python_module = path / f'{path.name}.py' python_module.copy(dst / python_module.name) # TODO(tfds): Should also support download on the extra files (e.g. label.txt, # util module,...) # Add the `__init__` file (dst / '__init__.py').write_text('') return python_module.stem
def download_kaggle_data( competition_or_dataset: str, download_dir: utils.PathLike, ) -> utils.ReadWritePath: """Downloads the kaggle data to the output_dir. Args: competition_or_dataset: Name of the kaggle competition/dataset. download_dir: Path to the TFDS downloads dir. Returns: Path to the dir where the kaggle data was downloaded. """ kaggle_dir = _kaggle_dir_name(competition_or_dataset) download_path = utils.as_path(download_dir) / kaggle_dir # If the dataset has already been downloaded, return the path to it. if download_path.is_dir(): logging.info( 'Dataset %s already downloaded: reusing %s.', competition_or_dataset, download_path, ) return download_path # Otherwise, download the dataset. with utils.incomplete_dir(download_path) as tmp_data_dir: logging.info( 'Downloading %s into %s...', competition_or_dataset, tmp_data_dir, ) _download_competition_or_dataset(competition_or_dataset, tmp_data_dir) return download_path
def test_manually_downloaded(self): """One file is manually downloaded, one not.""" a, b = [Artifact(i) for i in 'ab'] a_manual_path = '/manual_dir/a' # File a is manually downloaded self.fs.add_file(a_manual_path) self.fs.add_file(b.file_path) self.dl_results[b.url] = b.url_info manager = self._get_manager( register_checksums= False, # Register with manual download not supported url_infos={art.url: art.url_info for art in (a, b)}, ) downloads = manager.download({ 'manual': a.url, 'download': b.url, }) expected = { 'manual': utils.as_path(a_manual_path), 'download': b.file_path, } self.assertEqual(downloads, expected)
def code_path(cls) -> ReadOnlyPath: """Returns the path to the file where the Dataset class is located. Note: As the code can be run inside zip file. The returned value is a `ReadOnlyPath` by default. Use `tfds.core.utils.to_write_path()` to cast the path into `ReadWritePath`. Returns: path: pathlib.Path like abstraction """ modules = cls.__module__.split(".") if len(modules) >= 2: # Filter `__main__`, `python my_dataset.py`,... # If the dataset can be loaded from a module, use this to support zipapp. # Note: `utils.resource_path` will return either `zipfile.Path` (for # zipapp) or `pathlib.Path`. try: path = utils.resource_path(modules[0]) except TypeError: # Module is not a package pass else: modules[-1] += ".py" return path.joinpath(*modules[1:]) # Otherwise, fallback to `pathlib.Path`. For non-zipapp, it should be # equivalent to the above return. return utils.as_path(inspect.getfile(cls))
def _sync_extract(self, from_path, method, to_path): """Returns `to_path` once resource has been extracted there.""" to_path_tmp = '%s%s_%s' % (to_path, constants.INCOMPLETE_SUFFIX, uuid.uuid4().hex) path = None dst_path = None # To avoid undefined variable if exception is raised try: for path, handle in iter_archive(from_path, method): path = tf.compat.as_text(path) dst_path = path and os.path.join(to_path_tmp, path) or to_path_tmp _copy(handle, dst_path) except BaseException as err: msg = 'Error while extracting {} to {} (file: {}) : {}'.format( from_path, to_path, path, err) # Check if running on windows if os.name == 'nt' and dst_path and len(dst_path) > 250: msg += ( '\n' 'On windows, path lengths greater than 260 characters may ' 'result in an error. See the doc to remove the limitation: ' 'https://docs.python.org/3/using/windows.html#removing-the-max-path-limitation' ) raise ExtractError(msg) # `tf.io.gfile.Rename(overwrite=True)` doesn't work for non empty # directories, so delete destination first, if it already exists. if tf.io.gfile.exists(to_path): tf.io.gfile.rmtree(to_path) tf.io.gfile.rename(to_path_tmp, to_path) self._pbar_path.update(1) return utils.as_path(to_path)
def test_incomplete_file(tmp_path: pathlib.Path): tmp_path = utils.as_path(tmp_path) filepath = tmp_path / 'test.txt' with py_utils.incomplete_file(filepath) as tmp_filepath: tmp_filepath.write_text('content') assert not filepath.exists() assert filepath.read_text() == 'content' assert not tmp_filepath.exists() # Tmp file is deleted
def __init__(self, path: utils.PathLike): """Contructor. Args: path: Path to the register files containing the mapping namespace -> data_dir """ self._path: utils.ReadOnlyPath = utils.as_path(path)
def _ns2data_dir(self) -> Dict[str, utils.ReadWritePath]: """Mapping `namespace` -> `data_dir`.""" # Lazy-load the namespaces the first requested time. config = toml.loads(self._path.read_text()) return { namespace: utils.as_path(path) for namespace, path in config['Namespaces'].items() }
def __init__(self, path: utils.PathLike): """Contructor. Args: path: Path to the register files containing the list of dataset sources, forwarded to `_PackageIndex` """ self._path = utils.as_path(path)
def _default_cache_dir() -> type_utils.ReadWritePath: """Returns the default cache directory.""" if 'TFDS_CACHE_DIR' in os.environ: path = os.environ['TFDS_CACHE_DIR'] elif 'XDG_CACHE_HOME' in os.environ: path = os.path.join(os.environ['XDG_CACHE_HOME'], 'tensorflow_datasets') else: path = os.path.join('~', '.cache', 'tensorflow_datasets') return utils.as_path(path).expanduser()
def split_infos_from_path( path: utils.PathLike, split_names: List[str], ) -> List[split_lib.SplitInfo]: """Restore the split info from a directory.""" path = utils.as_path(path) return [ _split_info_from_path(path / _out_filename(split_name)) for split_name in split_names ]
def dummy_data(cls) -> utils.ReadOnlyPath: # pylint: disable=no-self-argument """Path to the `dummy_data/` directory.""" if cls is DatasetBuilderTestCase: # Required for build_api_docs return None # pytype: disable=bad-return-type dummy_data_expected = cls.DATASET_CLASS.code_path.parent / "dummy_data" fake_example_dir = utils.as_path(test_utils.fake_examples_dir()) if cls.EXAMPLE_DIR is not None: dummy_data_found = utils.as_path(cls.EXAMPLE_DIR) dummy_data_expected = dummy_data_found # Dir to display in the error elif dummy_data_expected.exists(): dummy_data_found = dummy_data_expected else: dummy_data_found = fake_example_dir / cls.DATASET_CLASS.name if not dummy_data_found.exists(): err_msg = f"Dummy data not found in: {dummy_data_expected}" raise ValueError(err_msg) return dummy_data_found
def mock_cache_path(new_cache_dir: utils.PathLike) -> Iterator[None]: """Mock which overwrite the cache path.""" new_dir = utils.as_path(new_cache_dir) # Use `__wrapped__` to access the original function wrapped inside # `functools.lru_cache` new_cache_path = utils.memoize()(cache.cache_path.__wrapped__) new_module_path = utils.memoize()(cache.module_path.__wrapped__) with mock.patch.object(cache, '_default_cache_dir', return_value=new_dir), \ mock.patch.object(cache, 'cache_path', new_cache_path), \ mock.patch.object(cache, 'module_path', new_module_path): yield
def _checksum_paths() -> Dict[str, type_utils.ReadOnlyPath]: """Returns dict {'dataset_name': 'path/to/checksums/file'}.""" dataset2path = {} for dir_path in _CHECKSUM_DIRS: if isinstance(dir_path, str): dir_path = utils.as_path(dir_path) if not dir_path.exists(): pass for file_path in dir_path.iterdir(): if not file_path.name.endswith(_CHECKSUM_SUFFIX): continue dataset_name = file_path.name[:-len(_CHECKSUM_SUFFIX)] dataset2path[dataset_name] = file_path return dataset2path
def _download(url, tmpdir_path, verify): del verify self.downloaded_urls.append(url) # Record downloader.download() calls # If the name isn't explicitly provided, then it is extracted from the # url. filename = self.dl_fnames.get(url, os.path.basename(url)) # Save the file in the tmp_dir path = os.path.join(tmpdir_path, filename) self.fs.add_file(path) dl_result = downloader.DownloadResult( path=utils.as_path(path), url_info=self.dl_results[url], ) return promise.Promise.resolve(dl_result)
def _read_indices(path): """Returns (files_name, list of index in each file). Args: path: path to index, omitting suffix. """ paths = sorted(tf.io.gfile.glob('%s-*-of-*_index.json' % path)) all_indices = [] for path in paths: json_str = utils.as_path(path).read_text() # parse it back into a proto. shard_index = json.loads(json_str) all_indices.append(list(shard_index['index'])) return [os.path.basename(p) for p in paths], all_indices
def __init__(self, path: utils.PathLike): """Contructor. Args: path: Remote location of the package index (file containing the list of dataset packages) """ super().__init__() self._remote_path: utils.ReadOnlyPath = utils.as_path(path) self._cached_path: utils.ReadOnlyPath = ( cache.cache_path() / 'community-datasets-list.jsonl') # Pre-load the index from the cache if self._cached_path.exists(): self._refresh_from_content(self._cached_path.read_text())
def _sync_file_copy( self, filepath: str, destination_path: str, ) -> DownloadResult: """Downloads the file through `tf.io.gfile` API.""" filename = os.path.basename(filepath) out_path = os.path.join(destination_path, filename) tf.io.gfile.copy(filepath, out_path) url_info = checksums_lib.compute_url_info( out_path, checksum_cls=self._checksumer_cls) self._pbar_dl_size.update_total(url_info.size) self._pbar_dl_size.update(url_info.size) self._pbar_url.update(1) return DownloadResult(path=utils.as_path(out_path), url_info=url_info)
def _get_default_config_name(builder_dir: str, name: str) -> Optional[str]: """Returns the default config of the given dataset, None if not found.""" # Search for the DatasetBuilder generation code try: cls = registered.imported_builder_cls(name) cls = typing.cast(Type[dataset_builder.DatasetBuilder], cls) except registered.DatasetNotFoundError: pass else: # If code found, return the default config if cls.BUILDER_CONFIGS: return cls.BUILDER_CONFIGS[0].name # Otherwise, try to load default config from common metadata return dataset_builder.load_default_config_name(utils.as_path(builder_dir))
def _download_and_cache(package: DatasetPackage) -> _InstalledPackage: """Downloads and installs locally the dataset source. This function install the dataset package in: `<module_path>/<namespace>/<ds_name>/<hash>/...`. Args: package: Package to install. Returns: installed_dataset: The installed dataset package. """ tmp_dir = utils.as_path(tempfile.mkdtemp()) try: # Download the package in a tmp directory dataset_sources_lib.download_from_source( package.source, tmp_dir, ) # Compute the package hash (to install the dataset in a unique dir) package_hash = _compute_dir_hash(tmp_dir) # Add package metadata installed_package = _InstalledPackage( package=package, instalation_date=datetime.datetime.now(), hash=package_hash, ) package_metadata = json.dumps(installed_package.to_json()) (tmp_dir / _METADATA_FILENAME).write_text(package_metadata) # Rename the package to it's final destination installation_path = installed_package.installation_path if installation_path.exists( ): # Package already exists (with same hash) # In the future, we should be smarter to allow overwrite. raise ValueError( f'Package {package} already installed in {installation_path}.') installation_path.parent.mkdir(parents=True, exist_ok=True) tmp_dir.rename(installation_path) finally: # Cleanup the tmp directory if it still exists. if tmp_dir.exists(): tmp_dir.rmtree() return installed_package
def test_write_metadata( tmp_path: pathlib.Path, file_format, ): tmp_path = utils.as_path(tmp_path) src_builder = testing.DummyDataset( data_dir=tmp_path / 'origin', file_format=file_format, ) src_builder.download_and_prepare() dst_dir = tmp_path / 'copy' dst_dir.mkdir() # Copy all the tfrecord files, but not the dataset info for f in src_builder.data_path.iterdir(): if naming.FilenameInfo.is_valid(f.name): f.copy(dst_dir / f.name) metadata_path = dst_dir / 'dataset_info.json' if file_format is None: split_infos = list(src_builder.info.splits.values()) else: split_infos = None # Auto-compute split infos assert not metadata_path.exists() write_metadata_utils.write_metadata( data_dir=dst_dir, features=src_builder.info.features, split_infos=split_infos, description='my test description.') assert metadata_path.exists() # After metadata are written, builder can be restored from the directory builder = read_only_builder.builder_from_directory(dst_dir) assert builder.name == 'dummy_dataset' assert builder.version == '1.0.0' assert set(builder.info.splits) == {'train'} assert builder.info.splits['train'].num_examples == 3 assert builder.info.description == 'my test description.' # Values are the same src_ds = src_builder.as_dataset(split='train') ds = builder.as_dataset(split='train') assert list(src_ds.as_numpy_iterator()) == list(ds.as_numpy_iterator())
def _write_final_shard( self, shardid_examples: Tuple[str, List[Tuple[ int, List[type_utils.KeySerializedExample]]]], ): """Write all examples from multiple buckets into the same shard.""" shard_path, examples_by_bucket = shardid_examples examples = itertools.chain( *[ex[1] for ex in sorted(examples_by_bucket)]) # Write in a tmp file potential race condition if `--xxxxx_enable_backups` # is set and multiple workers try to write to the same file. with utils.incomplete_file(utils.as_path(shard_path)) as tmp_path: record_keys = _write_examples(tmp_path, examples, self._file_format) # If there are no record_keys, skip creating index files. if not record_keys: return _write_index_file(_get_index_path(shard_path), list(record_keys))
def _sync_file_copy( self, filepath: str, destination_path: str, ) -> DownloadResult: filename = os.path.basename(filepath) out_path = os.path.join(destination_path, filename) tf.io.gfile.copy(filepath, out_path) hexdigest, size = utils.read_checksum_digest( out_path, checksum_cls=self._checksumer_cls) return DownloadResult( path=utils.as_path(out_path), url_info=checksums_lib.UrlInfo( checksum=hexdigest, size=size, filename=filename, ), )
def _get_default_config_name(builder_dir: str, name: str) -> Optional[str]: """Returns the default config of the given dataset, None if not found.""" # Search for the DatasetBuilder generation code try: # Warning: The registered dataset may not match the files (e.g. if # the imported datasets has the same name as the generated files while # being 2 differents datasets) cls = registered.imported_builder_cls(name) cls = typing.cast(Type[dataset_builder.DatasetBuilder], cls) except (registered.DatasetNotFoundError, PermissionError): pass else: # If code found, return the default config if cls.BUILDER_CONFIGS: return cls.BUILDER_CONFIGS[0].name # Otherwise, try to load default config from common metadata return dataset_builder.load_default_config_name(utils.as_path(builder_dir))
def compute_split_info( *, data_dir: utils.PathLike, out_dir: Optional[utils.PathLike] = None, ) -> List[split_lib.SplitInfo]: """Compute the split info on the given files. Compute the split info (num shards, num examples,...) metadata required by `tfds.folder_dataset.write_metadata`. See documentation for usage: https://www.tensorflow.org/datasets/external_tfrecord Args: data_dir: Directory containing the `.tfrecord` files (or similar format) out_dir: Output directory on which save the metadata. It should be available from the apache beam workers. If not set, apache beam won't be used (only available with some file formats). Returns: split_infos: The list of `tfds.core.SplitInfo`. """ data_dir = utils.as_path(data_dir) # Auto-detect the splits from the files split_files = _extract_split_files(data_dir) print('Auto-detected splits:') for split_name, file_infos in split_files.items(): print(f' * {split_name}: {file_infos[0].num_shards} shards') # Launch the beam pipeline to compute the split info if out_dir is not None: split_infos = _compute_split_statistics_beam( split_files=split_files, data_dir=data_dir, out_dir=out_dir, ) else: raise NotImplementedError('compute_split_info require out_dir kwargs.') print('Computed split infos: ') pprint.pprint(split_infos) return split_infos