def setUp(self):
        super(DatasetBuilderTestCase, self).setUp()
        self.patchers = []
        self.builder = self._make_builder()

        example_dir = self.DATASET_CLASS.code_path.parent / "dummy_data"
        fake_example_dir = utils.as_path(test_utils.fake_examples_dir())
        if self.EXAMPLE_DIR is not None:
            self.example_dir = utils.as_path(self.EXAMPLE_DIR)
            example_dir = self.example_dir  # Dir to display in the error
        elif example_dir.exists():
            self.example_dir = example_dir
        else:
            self.example_dir = fake_example_dir / self.builder.name

        if not self.example_dir.exists():
            err_msg = f"Dummy data not found in: {example_dir}"
            raise ValueError(err_msg)
        if self.MOCK_OUT_FORBIDDEN_OS_FUNCTIONS:
            self._mock_out_forbidden_os_functions()

        # Track the urls which are downloaded to validate the checksums
        # The `dl_manager.download` and `dl_manager.download_and_extract` are
        # patched to record the urls in `_download_urls`.
        # Calling `dl_manager.download_checksums` stop the url
        # registration (as checksums are stored remotelly)
        # `_test_checksums` validates the recorded urls.
        self._download_urls = set()
        self._stop_record_download = False
Beispiel #2
0
 def from_json(cls, value: utils.JsonValue) -> 'DatasetSource':
     """Imports from JSON."""
     if isinstance(value, str):  # Single-file dataset ('.../my_dataset.py')
         path = utils.as_path(value)
         return cls(root_path=path.parent, filenames=[path.name])
     elif isinstance(value, dict):  # Multi-file dataset
         return cls(
             root_path=utils.as_path(value['root_path']),
             filenames=value['filenames'],
         )
     else:
         raise ValueError(f'Invalid input: {value}')
Beispiel #3
0
 def __init__(self, name, url=None, content=None):
     url = url or f'http://foo-bar.ch/{name}'
     content = content or f'content of {name}'
     self.url = url
     self.url_info = checksums_lib.UrlInfo(
         size=len(content),
         checksum=_sha256(content),
         filename=name,
     )
     self.file_name = resource_lib.get_dl_fname(url, self.url_info.checksum)
     self.file_path = utils.as_path(f'/dl_dir/{self.file_name}')
     self.url_name = resource_lib.get_dl_fname(url, _sha256(url))
     self.url_path = utils.as_path(f'/dl_dir/{self.url_name}')
def _compute_split_statistics_beam(
    *,
    split_files: _SplitFilesDict,
    data_dir: utils.ReadWritePath,
    out_dir: utils.PathLike,
) -> List[split_lib.SplitInfo]:
    """Compute statistics."""
    out_dir = utils.as_path(out_dir)

    assert out_dir.exists(), f'{out_dir} does not exists'

    beam = lazy_imports_lib.lazy_imports.apache_beam

    # Launch the beam pipeline computation
    runner = None
    # Create the global pipeline object common for all splits
    # Disable type_hint as it doesn't works with typing.Protocol
    beam_options = beam.options.pipeline_options.PipelineOptions()
    beam_options.view_as(
        beam.options.pipeline_options.TypeOptions).pipeline_type_check = False
    with beam.Pipeline(runner=runner, options=beam_options) as pipeline:
        for split_name, file_infos in split_files.items():
            _ = pipeline | split_name >> _process_split(  # pylint: disable=no-value-for-parameter
                data_dir=data_dir,
                out_dir=out_dir,
                file_infos=file_infos,  # pytype: disable=missing-parameter
            )

    # After the files have been computed
    return [
        _split_info_from_path(out_dir / _out_filename(split_name))
        for split_name in split_files
    ]
Beispiel #5
0
def _get_repr_html_ffmpeg(images: List[PilImage]) -> str:
  """Runs ffmpeg to get the mp4 encoded <video> str."""
  # Find number of digits in len to give names.
  num_digits = len(str(len(images))) + 1
  with tempfile.TemporaryDirectory() as video_dir:
    for i, img in enumerate(images):
      f = os.path.join(video_dir, f'img{i:0{num_digits}d}.png')
      img.save(f, format='png')

    ffmpeg_args = [
        '-framerate', str(_VISU_FRAMERATE),
        '-i', os.path.join(video_dir, f'img%0{num_digits}d.png'),
        # Using native h264 to encode video stream to H.264 codec
        # Default encoding does not seems to be supported by chrome.
        '-vcodec', 'h264',
        # When outputting H.264, `-pix_fmt yuv420p` maximize compatibility
        # with bad video players.
        # Ref: https://trac.ffmpeg.org/wiki/Slideshow
        '-pix_fmt', 'yuv420p',
        # Native encoder cannot encode images of small scale
        # or the the hardware encoder may be busy which raises
        # Error: cannot create compression session
        # so allow software encoding
        # '-allow_sw', '1',
    ]
    video_path = utils.as_path(video_dir) / 'output.mp4'
    ffmpeg_args.append(os.fspath(video_path))
    utils.ffmpeg_run(ffmpeg_args)
    video_str = utils.get_base64(video_path.read_bytes())
  return (
      f'<video height="{THUMBNAIL_SIZE}" width="175" '
      'controls loop autoplay muted playsinline>'
      f'<source src="data:video/mp4;base64,{video_str}"  type="video/mp4" >'
      '</video>'
  )
def download_from_uri(uri: str, dst: utils.ReadWritePath) -> str:
    """Download the remote dataset code locally to the dst path.

  Args:
    uri: Source of the dataset. Can be:
        * A local/GCS path (e.g. `gs://bucket/datasets/my_dataset/`)
        * A github source
    dst: Empty directory on which copying the source

  Returns:
    The module mame of the package.
  """
    if uri.startswith('github://'):
        raise NotImplementedError('Github sources not supported yet')

    path = utils.as_path(uri)
    if not path.exists():
        raise ValueError(f'Unsuported source: {uri}')

    # Download the main file
    python_module = path / f'{path.name}.py'
    python_module.copy(dst / python_module.name)

    # TODO(tfds): Should also support download on the extra files (e.g. label.txt,
    # util module,...)

    # Add the `__init__` file
    (dst / '__init__.py').write_text('')
    return python_module.stem
Beispiel #7
0
def download_kaggle_data(
    competition_or_dataset: str,
    download_dir: utils.PathLike,
) -> utils.ReadWritePath:
  """Downloads the kaggle data to the output_dir.

  Args:
    competition_or_dataset: Name of the kaggle competition/dataset.
    download_dir: Path to the TFDS downloads dir.

  Returns:
    Path to the dir where the kaggle data was downloaded.
  """
  kaggle_dir = _kaggle_dir_name(competition_or_dataset)
  download_path = utils.as_path(download_dir) / kaggle_dir
  # If the dataset has already been downloaded, return the path to it.
  if download_path.is_dir():
    logging.info(
        'Dataset %s already downloaded: reusing %s.',
        competition_or_dataset,
        download_path,
    )
    return download_path
  # Otherwise, download the dataset.
  with utils.incomplete_dir(download_path) as tmp_data_dir:
    logging.info(
        'Downloading %s into %s...',
        competition_or_dataset,
        tmp_data_dir,
    )
    _download_competition_or_dataset(competition_or_dataset, tmp_data_dir)
  return download_path
Beispiel #8
0
    def test_manually_downloaded(self):
        """One file is manually downloaded, one not."""
        a, b = [Artifact(i) for i in 'ab']

        a_manual_path = '/manual_dir/a'
        # File a is manually downloaded
        self.fs.add_file(a_manual_path)
        self.fs.add_file(b.file_path)

        self.dl_results[b.url] = b.url_info
        manager = self._get_manager(
            register_checksums=
            False,  # Register with manual download not supported
            url_infos={art.url: art.url_info
                       for art in (a, b)},
        )
        downloads = manager.download({
            'manual': a.url,
            'download': b.url,
        })
        expected = {
            'manual': utils.as_path(a_manual_path),
            'download': b.file_path,
        }
        self.assertEqual(downloads, expected)
Beispiel #9
0
  def code_path(cls) -> ReadOnlyPath:
    """Returns the path to the file where the Dataset class is located.

    Note: As the code can be run inside zip file. The returned value is
    a `ReadOnlyPath` by default. Use `tfds.core.utils.to_write_path()` to cast
    the path into `ReadWritePath`.

    Returns:
      path: pathlib.Path like abstraction
    """
    modules = cls.__module__.split(".")
    if len(modules) >= 2:  # Filter `__main__`, `python my_dataset.py`,...
      # If the dataset can be loaded from a module, use this to support zipapp.
      # Note: `utils.resource_path` will return either `zipfile.Path` (for
      # zipapp) or `pathlib.Path`.
      try:
        path = utils.resource_path(modules[0])
      except TypeError:  # Module is not a package
        pass
      else:
        modules[-1] += ".py"
        return path.joinpath(*modules[1:])
    # Otherwise, fallback to `pathlib.Path`. For non-zipapp, it should be
    # equivalent to the above return.
    return utils.as_path(inspect.getfile(cls))
 def _sync_extract(self, from_path, method, to_path):
     """Returns `to_path` once resource has been extracted there."""
     to_path_tmp = '%s%s_%s' % (to_path, constants.INCOMPLETE_SUFFIX,
                                uuid.uuid4().hex)
     path = None
     dst_path = None  # To avoid undefined variable if exception is raised
     try:
         for path, handle in iter_archive(from_path, method):
             path = tf.compat.as_text(path)
             dst_path = path and os.path.join(to_path_tmp,
                                              path) or to_path_tmp
             _copy(handle, dst_path)
     except BaseException as err:
         msg = 'Error while extracting {} to {} (file: {}) : {}'.format(
             from_path, to_path, path, err)
         # Check if running on windows
         if os.name == 'nt' and dst_path and len(dst_path) > 250:
             msg += (
                 '\n'
                 'On windows, path lengths greater than 260 characters may '
                 'result in an error. See the doc to remove the limitation: '
                 'https://docs.python.org/3/using/windows.html#removing-the-max-path-limitation'
             )
         raise ExtractError(msg)
     # `tf.io.gfile.Rename(overwrite=True)` doesn't work for non empty
     # directories, so delete destination first, if it already exists.
     if tf.io.gfile.exists(to_path):
         tf.io.gfile.rmtree(to_path)
     tf.io.gfile.rename(to_path_tmp, to_path)
     self._pbar_path.update(1)
     return utils.as_path(to_path)
Beispiel #11
0
def test_incomplete_file(tmp_path: pathlib.Path):
    tmp_path = utils.as_path(tmp_path)
    filepath = tmp_path / 'test.txt'
    with py_utils.incomplete_file(filepath) as tmp_filepath:
        tmp_filepath.write_text('content')
        assert not filepath.exists()
    assert filepath.read_text() == 'content'
    assert not tmp_filepath.exists()  # Tmp file is deleted
Beispiel #12
0
    def __init__(self, path: utils.PathLike):
        """Contructor.

    Args:
      path: Path to the register files containing the mapping
        namespace -> data_dir
    """
        self._path: utils.ReadOnlyPath = utils.as_path(path)
Beispiel #13
0
 def _ns2data_dir(self) -> Dict[str, utils.ReadWritePath]:
     """Mapping `namespace` -> `data_dir`."""
     # Lazy-load the namespaces the first requested time.
     config = toml.loads(self._path.read_text())
     return {
         namespace: utils.as_path(path)
         for namespace, path in config['Namespaces'].items()
     }
    def __init__(self, path: utils.PathLike):
        """Contructor.

    Args:
      path: Path to the register files containing the list of dataset sources,
        forwarded to `_PackageIndex`
    """
        self._path = utils.as_path(path)
Beispiel #15
0
def _default_cache_dir() -> type_utils.ReadWritePath:
    """Returns the default cache directory."""
    if 'TFDS_CACHE_DIR' in os.environ:
        path = os.environ['TFDS_CACHE_DIR']
    elif 'XDG_CACHE_HOME' in os.environ:
        path = os.path.join(os.environ['XDG_CACHE_HOME'],
                            'tensorflow_datasets')
    else:
        path = os.path.join('~', '.cache', 'tensorflow_datasets')
    return utils.as_path(path).expanduser()
Beispiel #16
0
def split_infos_from_path(
    path: utils.PathLike,
    split_names: List[str],
) -> List[split_lib.SplitInfo]:
    """Restore the split info from a directory."""
    path = utils.as_path(path)
    return [
        _split_info_from_path(path / _out_filename(split_name))
        for split_name in split_names
    ]
    def dummy_data(cls) -> utils.ReadOnlyPath:  # pylint: disable=no-self-argument
        """Path to the `dummy_data/` directory."""
        if cls is DatasetBuilderTestCase:  # Required for build_api_docs
            return None  # pytype: disable=bad-return-type

        dummy_data_expected = cls.DATASET_CLASS.code_path.parent / "dummy_data"
        fake_example_dir = utils.as_path(test_utils.fake_examples_dir())
        if cls.EXAMPLE_DIR is not None:
            dummy_data_found = utils.as_path(cls.EXAMPLE_DIR)
            dummy_data_expected = dummy_data_found  # Dir to display in the error
        elif dummy_data_expected.exists():
            dummy_data_found = dummy_data_expected
        else:
            dummy_data_found = fake_example_dir / cls.DATASET_CLASS.name

        if not dummy_data_found.exists():
            err_msg = f"Dummy data not found in: {dummy_data_expected}"
            raise ValueError(err_msg)
        return dummy_data_found
def mock_cache_path(new_cache_dir: utils.PathLike) -> Iterator[None]:
    """Mock which overwrite the cache path."""
    new_dir = utils.as_path(new_cache_dir)

    # Use `__wrapped__` to access the original function wrapped inside
    # `functools.lru_cache`
    new_cache_path = utils.memoize()(cache.cache_path.__wrapped__)
    new_module_path = utils.memoize()(cache.module_path.__wrapped__)
    with mock.patch.object(cache, '_default_cache_dir', return_value=new_dir), \
         mock.patch.object(cache, 'cache_path', new_cache_path), \
         mock.patch.object(cache, 'module_path', new_module_path):
        yield
Beispiel #19
0
def _checksum_paths() -> Dict[str, type_utils.ReadOnlyPath]:
    """Returns dict {'dataset_name': 'path/to/checksums/file'}."""
    dataset2path = {}
    for dir_path in _CHECKSUM_DIRS:
        if isinstance(dir_path, str):
            dir_path = utils.as_path(dir_path)
        if not dir_path.exists():
            pass
        for file_path in dir_path.iterdir():
            if not file_path.name.endswith(_CHECKSUM_SUFFIX):
                continue
            dataset_name = file_path.name[:-len(_CHECKSUM_SUFFIX)]
            dataset2path[dataset_name] = file_path
    return dataset2path
 def _download(url, tmpdir_path, verify):
   del verify
   self.downloaded_urls.append(url)  # Record downloader.download() calls
   # If the name isn't explicitly provided, then it is extracted from the
   # url.
   filename = self.dl_fnames.get(url, os.path.basename(url))
   # Save the file in the tmp_dir
   path = os.path.join(tmpdir_path, filename)
   self.fs.add_file(path)
   dl_result = downloader.DownloadResult(
       path=utils.as_path(path),
       url_info=self.dl_results[url],
   )
   return promise.Promise.resolve(dl_result)
def _read_indices(path):
    """Returns (files_name, list of index in each file).

  Args:
    path: path to index, omitting suffix.
  """
    paths = sorted(tf.io.gfile.glob('%s-*-of-*_index.json' % path))
    all_indices = []
    for path in paths:
        json_str = utils.as_path(path).read_text()
        # parse it back into a proto.
        shard_index = json.loads(json_str)
        all_indices.append(list(shard_index['index']))
    return [os.path.basename(p) for p in paths], all_indices
    def __init__(self, path: utils.PathLike):
        """Contructor.

    Args:
      path: Remote location of the package index (file containing the list of
        dataset packages)
    """
        super().__init__()
        self._remote_path: utils.ReadOnlyPath = utils.as_path(path)
        self._cached_path: utils.ReadOnlyPath = (
            cache.cache_path() / 'community-datasets-list.jsonl')

        # Pre-load the index from the cache
        if self._cached_path.exists():
            self._refresh_from_content(self._cached_path.read_text())
Beispiel #23
0
 def _sync_file_copy(
     self,
     filepath: str,
     destination_path: str,
 ) -> DownloadResult:
     """Downloads the file through `tf.io.gfile` API."""
     filename = os.path.basename(filepath)
     out_path = os.path.join(destination_path, filename)
     tf.io.gfile.copy(filepath, out_path)
     url_info = checksums_lib.compute_url_info(
         out_path, checksum_cls=self._checksumer_cls)
     self._pbar_dl_size.update_total(url_info.size)
     self._pbar_dl_size.update(url_info.size)
     self._pbar_url.update(1)
     return DownloadResult(path=utils.as_path(out_path), url_info=url_info)
def _get_default_config_name(builder_dir: str, name: str) -> Optional[str]:
    """Returns the default config of the given dataset, None if not found."""
    # Search for the DatasetBuilder generation code
    try:
        cls = registered.imported_builder_cls(name)
        cls = typing.cast(Type[dataset_builder.DatasetBuilder], cls)
    except registered.DatasetNotFoundError:
        pass
    else:
        # If code found, return the default config
        if cls.BUILDER_CONFIGS:
            return cls.BUILDER_CONFIGS[0].name

    # Otherwise, try to load default config from common metadata
    return dataset_builder.load_default_config_name(utils.as_path(builder_dir))
def _download_and_cache(package: DatasetPackage) -> _InstalledPackage:
    """Downloads and installs locally the dataset source.

  This function install the dataset package in:
  `<module_path>/<namespace>/<ds_name>/<hash>/...`.

  Args:
    package: Package to install.

  Returns:
    installed_dataset: The installed dataset package.
  """
    tmp_dir = utils.as_path(tempfile.mkdtemp())
    try:
        # Download the package in a tmp directory
        dataset_sources_lib.download_from_source(
            package.source,
            tmp_dir,
        )

        # Compute the package hash (to install the dataset in a unique dir)
        package_hash = _compute_dir_hash(tmp_dir)

        # Add package metadata
        installed_package = _InstalledPackage(
            package=package,
            instalation_date=datetime.datetime.now(),
            hash=package_hash,
        )
        package_metadata = json.dumps(installed_package.to_json())
        (tmp_dir / _METADATA_FILENAME).write_text(package_metadata)

        # Rename the package to it's final destination
        installation_path = installed_package.installation_path
        if installation_path.exists(
        ):  # Package already exists (with same hash)
            # In the future, we should be smarter to allow overwrite.
            raise ValueError(
                f'Package {package} already installed in {installation_path}.')
        installation_path.parent.mkdir(parents=True, exist_ok=True)
        tmp_dir.rename(installation_path)
    finally:
        # Cleanup the tmp directory if it still exists.
        if tmp_dir.exists():
            tmp_dir.rmtree()

    return installed_package
Beispiel #26
0
def test_write_metadata(
    tmp_path: pathlib.Path,
    file_format,
):
  tmp_path = utils.as_path(tmp_path)

  src_builder = testing.DummyDataset(
      data_dir=tmp_path / 'origin',
      file_format=file_format,
  )
  src_builder.download_and_prepare()

  dst_dir = tmp_path / 'copy'
  dst_dir.mkdir()

  # Copy all the tfrecord files, but not the dataset info
  for f in src_builder.data_path.iterdir():
    if naming.FilenameInfo.is_valid(f.name):
      f.copy(dst_dir / f.name)

  metadata_path = dst_dir / 'dataset_info.json'

  if file_format is None:
    split_infos = list(src_builder.info.splits.values())
  else:
    split_infos = None  # Auto-compute split infos

  assert not metadata_path.exists()
  write_metadata_utils.write_metadata(
      data_dir=dst_dir,
      features=src_builder.info.features,
      split_infos=split_infos,
      description='my test description.')
  assert metadata_path.exists()

  # After metadata are written, builder can be restored from the directory
  builder = read_only_builder.builder_from_directory(dst_dir)
  assert builder.name == 'dummy_dataset'
  assert builder.version == '1.0.0'
  assert set(builder.info.splits) == {'train'}
  assert builder.info.splits['train'].num_examples == 3
  assert builder.info.description == 'my test description.'

  # Values are the same
  src_ds = src_builder.as_dataset(split='train')
  ds = builder.as_dataset(split='train')
  assert list(src_ds.as_numpy_iterator()) == list(ds.as_numpy_iterator())
Beispiel #27
0
 def _write_final_shard(
     self,
     shardid_examples: Tuple[str, List[Tuple[
         int, List[type_utils.KeySerializedExample]]]],
 ):
     """Write all examples from multiple buckets into the same shard."""
     shard_path, examples_by_bucket = shardid_examples
     examples = itertools.chain(
         *[ex[1] for ex in sorted(examples_by_bucket)])
     # Write in a tmp file potential race condition if `--xxxxx_enable_backups`
     # is set and multiple workers try to write to the same file.
     with utils.incomplete_file(utils.as_path(shard_path)) as tmp_path:
         record_keys = _write_examples(tmp_path, examples,
                                       self._file_format)
     # If there are no record_keys, skip creating index files.
     if not record_keys:
         return
     _write_index_file(_get_index_path(shard_path), list(record_keys))
Beispiel #28
0
 def _sync_file_copy(
     self,
     filepath: str,
     destination_path: str,
 ) -> DownloadResult:
     filename = os.path.basename(filepath)
     out_path = os.path.join(destination_path, filename)
     tf.io.gfile.copy(filepath, out_path)
     hexdigest, size = utils.read_checksum_digest(
         out_path, checksum_cls=self._checksumer_cls)
     return DownloadResult(
         path=utils.as_path(out_path),
         url_info=checksums_lib.UrlInfo(
             checksum=hexdigest,
             size=size,
             filename=filename,
         ),
     )
Beispiel #29
0
def _get_default_config_name(builder_dir: str, name: str) -> Optional[str]:
    """Returns the default config of the given dataset, None if not found."""
    # Search for the DatasetBuilder generation code
    try:
        # Warning: The registered dataset may not match the files (e.g. if
        # the imported datasets has the same name as the generated files while
        # being 2 differents datasets)
        cls = registered.imported_builder_cls(name)
        cls = typing.cast(Type[dataset_builder.DatasetBuilder], cls)
    except (registered.DatasetNotFoundError, PermissionError):
        pass
    else:
        # If code found, return the default config
        if cls.BUILDER_CONFIGS:
            return cls.BUILDER_CONFIGS[0].name

    # Otherwise, try to load default config from common metadata
    return dataset_builder.load_default_config_name(utils.as_path(builder_dir))
Beispiel #30
0
def compute_split_info(
    *,
    data_dir: utils.PathLike,
    out_dir: Optional[utils.PathLike] = None,
) -> List[split_lib.SplitInfo]:
    """Compute the split info on the given files.

  Compute the split info (num shards, num examples,...) metadata required
  by `tfds.folder_dataset.write_metadata`.

  See documentation for usage:
  https://www.tensorflow.org/datasets/external_tfrecord

  Args:
    data_dir: Directory containing the `.tfrecord` files (or similar format)
    out_dir: Output directory on which save the metadata. It should be available
      from the apache beam workers. If not set, apache beam won't be used (only
      available with some file formats).

  Returns:
    split_infos: The list of `tfds.core.SplitInfo`.
  """
    data_dir = utils.as_path(data_dir)

    # Auto-detect the splits from the files
    split_files = _extract_split_files(data_dir)
    print('Auto-detected splits:')
    for split_name, file_infos in split_files.items():
        print(f' * {split_name}: {file_infos[0].num_shards} shards')

    # Launch the beam pipeline to compute the split info
    if out_dir is not None:
        split_infos = _compute_split_statistics_beam(
            split_files=split_files,
            data_dir=data_dir,
            out_dir=out_dir,
        )
    else:
        raise NotImplementedError('compute_split_info require out_dir kwargs.')

    print('Computed split infos: ')
    pprint.pprint(split_infos)

    return split_infos