예제 #1
0
 def from_json(cls, value: utils.JsonValue) -> 'DatasetSource':
     """Imports from JSON."""
     if isinstance(value, str):  # Single-file dataset ('.../my_dataset.py')
         path = epath.Path(value)
         return cls(root_path=path.parent, filenames=[path.name])
     elif isinstance(value, dict):  # Multi-file dataset
         return cls(
             root_path=epath.Path(value['root_path']),
             filenames=value['filenames'],
         )
     else:
         raise ValueError(f'Invalid input: {value}')
예제 #2
0
def download_gcs_dataset(dataset_name: epath.PathLike,
                         local_dataset_dir: epath.PathLike,
                         max_simultaneous_downloads: int = 25,
                         root_dir: Optional[str] = None):
    """Downloads prepared GCS dataset to local dataset directory."""
    if root_dir:
        gcs_folder = epath.Path(root_dir) / dataset_name
    else:
        gcs_folder = epath.Path(GCS_ROOT_DIR) / GCS_DATASETS_DIR / dataset_name

    download_gcs_folder(gcs_folder=gcs_folder,
                        local_folder=local_dataset_dir,
                        max_simultaneous_downloads=max_simultaneous_downloads)
예제 #3
0
 def __init__(self, name, url=None, content=None):
     url = url or f'http://foo-bar.ch/{name}'
     content = content or f'content of {name}'
     self.url = url
     self.url_info = checksums_lib.UrlInfo(
         size=len(content),
         checksum=_sha256(content),
         filename=name,
     )
     self.file_name = resource_lib.get_dl_fname(url, self.url_info.checksum)
     self.file_path = epath.Path(f'/dl_dir/{self.file_name}')
     self.url_name = resource_lib.get_dl_fname(url, _sha256(url))
     self.url_path = epath.Path(f'/dl_dir/{self.url_name}')
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    puddles = arenas.get_arena(FLAGS.arena_name)
    pw = puddle_world.PuddleWorld(puddles=puddles,
                                  goal_position=geometry.Point((1.0, 1.0)))

    dpw = pw_utils.DiscretizedPuddleWorld(pw, FLAGS.num_bins)
    num_states = dpw.num_states

    # We want to avoid rounding errors when calculating shards,
    # so we use fractions.
    percent_work_to_complete = fractions.Fraction(1, FLAGS.num_shards)
    start_idx = int(FLAGS.shard_idx * percent_work_to_complete * num_states)
    end_idx = int(
        (FLAGS.shard_idx + 1) * percent_work_to_complete * num_states)
    logging.info('start idx: %d, end idx: %d', start_idx, end_idx)

    result_matrices = list()

    # TODO(joshgreaves): utils has helpful functions for generating rollouts.
    for start_state in range(start_idx, end_idx):
        logging.info('Starting iteration %d', start_state)
        result_matrix = np.zeros(
            (FLAGS.num_rollouts_per_start_state, num_states), dtype=np.float32)

        for i in range(FLAGS.num_rollouts_per_start_state):
            current_gamma = 1.0

            s = dpw.sample_state_in_bin(start_state)

            for _ in range(FLAGS.rollout_length):
                action = random.randrange(pw_utils.NUM_ACTIONS)
                transition = dpw.transition(s, action)
                s = transition.next_state

                result_matrix[i, s.bin_idx] += current_gamma
                current_gamma *= FLAGS.gamma

        result_matrices.append(np.mean(result_matrix, axis=0))

    # Before saving, make sure the path exists.
    output_dir = epath.Path(FLAGS.output_dir)
    output_dir.mkdir(exist_ok=True)

    if FLAGS.shard_idx == 0:
        # Write some metadata to make analysis easier at the end.
        metadata = {
            'arena_name': FLAGS.arena_name,
            'num_bins': FLAGS.num_bins,
            'num_shards': FLAGS.num_shards,
        }
        json_file_path = output_dir / 'metadata.json'
        with json_file_path.open('w') as f:
            json.dump(metadata, f)

    file_path = output_dir / f'sr_{start_idx}-{end_idx}.np'
    with file_path.open('wb') as f:
        np.save(f, np.stack(result_matrices, axis=0))
예제 #5
0
    def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir):
        try:
            for future in self._commit_futures:
                for f in future:
                    f.result()

            current_process = jax.process_index()
            logging.info(
                'Commit to storage layer has completed by process: %s',
                current_process)

            # All processes will wait at the barrier. When all processes are at the
            # barrier, the barrier will be satisfied. If not, then it will timeout.
            self._client.wait_at_barrier(self._final_ckpt_dir,
                                         self._timeout_in_ms)
            logging.info('Finished waiting at barrier for process %s',
                         current_process)

            if current_process == 0:
                logging.info('Renaming %s to %s', temp_checkpoint_dir,
                             final_checkpoint_dir)
                epath.Path(temp_checkpoint_dir).rename(final_checkpoint_dir)
                logging.info('Finished saving GDA checkpoint to `%s`.',
                             final_checkpoint_dir)
                self._client.key_value_set(_get_key(self._final_ckpt_dir),
                                           _CHECKPOINT_SUCCESS)
        except Exception as e:
            self._exception = e
예제 #6
0
def download_kaggle_data(
    competition_or_dataset: str,
    download_dir: epath.PathLike,
) -> epath.Path:
    """Downloads the kaggle data to the output_dir.

  Args:
    competition_or_dataset: Name of the kaggle competition/dataset.
    download_dir: Path to the TFDS downloads dir.

  Returns:
    Path to the dir where the kaggle data was downloaded.
  """
    kaggle_dir = _kaggle_dir_name(competition_or_dataset)
    download_path = epath.Path(download_dir) / kaggle_dir
    # If the dataset has already been downloaded, return the path to it.
    if download_path.is_dir():
        logging.info(
            'Dataset %s already downloaded: reusing %s.',
            competition_or_dataset,
            download_path,
        )
        return download_path
    # Otherwise, download the dataset.
    with utils.incomplete_dir(download_path) as tmp_data_dir:
        logging.info(
            'Downloading %s into %s...',
            competition_or_dataset,
            tmp_data_dir,
        )
        _download_competition_or_dataset(competition_or_dataset, tmp_data_dir)
    return download_path
예제 #7
0
파일: naming.py 프로젝트: suvarnak/datasets
def filepattern_for_dataset_split(
    *,
    dataset_name: str,
    split: str,
    data_dir: str,
    filetype_suffix: Optional[str] = None,
    num_shards: Optional[int] = None,
) -> str:
  """Returns the file pattern for the given dataset.

  TODO(tfds): remove this by start using ShardedFileTemplate

  Args:
    dataset_name: Name of the dataset
    split: Name of the requested split
    data_dir: The base folder that contains the dataset.
    filetype_suffix: Optional suffix, e.g. tfrecord
    num_shards: Optional argument. If specified, will return file@num_shards
      notation, otherwise file*.
  """
  template = ShardedFileTemplate(
      data_dir=epath.Path(data_dir),
      dataset_name=dataset_name,
      split=split,
      filetype_suffix=filetype_suffix)
  return os.fspath(template.sharded_filepaths_pattern(num_shards=num_shards))
예제 #8
0
def _compute_split_statistics_beam(
    *,
    split_files: _SplitFilesDict,
    out_dir: epath.PathLike,
    filename_template: naming.ShardedFileTemplate,
) -> List[split_lib.SplitInfo]:
    """Compute statistics."""
    out_dir = epath.Path(out_dir)

    assert out_dir.exists(), f'{out_dir} does not exists'

    beam = lazy_imports_lib.lazy_imports.apache_beam

    # Launch the beam pipeline computation
    runner = None
    # Create the global pipeline object common for all splits
    # Disable type_hint as it doesn't works with typing.Protocol
    beam_options = beam.options.pipeline_options.PipelineOptions()
    beam_options.view_as(
        beam.options.pipeline_options.TypeOptions).pipeline_type_check = False
    with beam.Pipeline(runner=runner, options=beam_options) as pipeline:
        for split_name, file_infos in split_files.items():
            _ = pipeline | split_name >> _process_split(  # pylint: disable=no-value-for-parameter
                filename_template=filename_template,
                out_dir=out_dir,
                file_infos=file_infos,  # pytype: disable=missing-parameter
            )

    # After the files have been computed
    return [
        _split_info_from_path(
            filename_template.replace(data_dir=out_dir, split=split))
        for split in split_files
    ]
예제 #9
0
    def test_manually_downloaded(self):
        """One file is manually downloaded, one not."""
        a, b = [Artifact(i) for i in 'ab']

        a_manual_path = '/manual_dir/a'
        # File a is manually downloaded
        self.fs.add_file(a_manual_path)
        self.fs.add_file(b.file_path)

        self.dl_results[b.url] = b.url_info
        manager = self._get_manager(
            register_checksums=
            False,  # Register with manual download not supported
            url_infos={art.url: art.url_info
                       for art in (a, b)},
        )
        downloads = manager.download({
            'manual': a.url,
            'download': b.url,
        })
        expected = {
            'manual': epath.Path(a_manual_path),
            'download': b.file_path,
        }
        self.assertEqual(downloads, expected)
예제 #10
0
def test_sharded_file_template_no_template_incomplete():
    builder_dir = epath.Path('/my/path')
    template_without_split = naming.ShardedFileTemplate(
        data_dir=builder_dir,
        dataset_name='imagenet',
        filetype_suffix='riegeli')
    with pytest.raises(KeyError):
        template_without_split.sharded_filepath(shard_index=12, num_shards=100)
예제 #11
0
    def __init__(self, path: epath.PathLike):
        """Contructor.

    Args:
      path: Path to the register files containing the mapping namespace ->
        data_dir
    """
        self._path: epath.Path = epath.Path(path)
예제 #12
0
def test_incomplete_file(tmp_path: pathlib.Path):
    tmp_path = epath.Path(tmp_path)
    filepath = tmp_path / 'test.txt'
    with py_utils.incomplete_file(filepath) as tmp_filepath:
        tmp_filepath.write_text('content')
        assert not filepath.exists()
    assert filepath.read_text() == 'content'
    assert not tmp_filepath.exists()  # Tmp file is deleted
예제 #13
0
파일: naming.py 프로젝트: suvarnak/datasets
 def __post_init__(self):
   self.data_dir = epath.Path(self.data_dir)
   if self.split is not None and not self.split:
     raise ValueError(f'Split must be a non-empty string: {self}')
   if self.filetype_suffix is not None and not self.filetype_suffix:
     raise ValueError(f'Filetype suffix must be a non-empty string: {self}')
   if not self.template:
     self.template = DEFAULT_FILENAME_TEMPLATE
예제 #14
0
  def __init__(self, path: epath.PathLike):
    """Contructor.

    Args:
      path: Path to the register files containing the list of dataset sources,
        forwarded to `_PackageIndex`
    """
    self._path = epath.Path(path)
예제 #15
0
def test_sharded_file_template_no_template():
    builder_dir = epath.Path('/my/path')
    template = naming.ShardedFileTemplate(data_dir=builder_dir,
                                          dataset_name='imagenet',
                                          filetype_suffix='riegeli',
                                          split='test')
    assert os.fspath(template.sharded_filepath(
        shard_index=12,
        num_shards=100)) == '/my/path/imagenet-test.riegeli-00012-of-00100'
예제 #16
0
def main(args: argparse.Namespace):
    catalog_dir = args.catalog_dir or os.path.join(
        tfds.core.utils.tfds_write_path(),
        'docs',
        'community_catalog',
    )

    catalog_dir = epath.Path(catalog_dir)
    build_and_save_community_catalog(catalog_dir=catalog_dir)
예제 #17
0
def _default_cache_dir() -> epath.Path:
  """Returns the default cache directory."""
  if 'TFDS_CACHE_DIR' in os.environ:
    path = os.environ['TFDS_CACHE_DIR']
  elif 'XDG_CACHE_HOME' in os.environ:
    path = os.path.join(os.environ['XDG_CACHE_HOME'], 'tensorflow_datasets')
  else:
    path = os.path.join('~', '.cache', 'tensorflow_datasets')
  return epath.Path(path).expanduser()
예제 #18
0
def main(args: argparse.Namespace):
    catalog_dir = args.catalog_dir or os.path.join(
        tfds.core.utils.tfds_write_path(),
        'docs',
        'community_catalog',
    )

    options = _Options(catalog_dir=epath.Path(catalog_dir),
                       local_cache=args.local_cache or None)
    build_and_save_community_catalog(options=options)
예제 #19
0
def test_get_default_config_name_permission_error():
  # Raise populated error message in case of PermissionError
  builder_dir = epath.Path('builder/dir')
  error_msg = f'Permission error when accessing: {builder_dir}'
  with _assert_raises(error_msg):
    with mock.patch.object(
        registered, 'imported_builder_cls', side_effect=PermissionError):
      actual_default_config_name = read_only_builder._get_default_config_name(
          builder_dir=builder_dir, name='name')
      assert actual_default_config_name is None
예제 #20
0
    def dummy_data(cls) -> epath.Path:  # pylint: disable=no-self-argument
        """Path to the `dummy_data/` directory."""
        if cls is DatasetBuilderTestCase:  # Required for build_api_docs
            return None  # pytype: disable=bad-return-type

        dummy_data_expected = cls.DATASET_CLASS.code_path.parent / "dummy_data"
        fake_example_dir = epath.Path(test_utils.fake_examples_dir())
        if cls.EXAMPLE_DIR is not None:
            dummy_data_found = epath.Path(cls.EXAMPLE_DIR)
            dummy_data_expected = dummy_data_found  # Dir to display in the error
        elif dummy_data_expected.exists():
            dummy_data_found = dummy_data_expected
        else:
            dummy_data_found = fake_example_dir / cls.DATASET_CLASS.name

        if not dummy_data_found.exists():
            err_msg = f"Dummy data not found in: {dummy_data_expected}"
            raise ValueError(err_msg)
        return dummy_data_found
예제 #21
0
def gcs_path(*relative_path: epath.PathLike) -> epath.Path:
    """Returns the GCS URI path.

  Args:
    *relative_path: Eventual relative path in the bucket.

  Returns:
    path: The GCS uri.
  """
    return epath.Path(GCS_ROOT_DIR).joinpath(*relative_path)
예제 #22
0
def test_sharded_file_template_shard_index():
  builder_dir = epath.Path('/my/path')
  template = naming.ShardedFileTemplate(
      template='data/mnist-train.tfrecord-{SHARD_INDEX}', data_dir=builder_dir)
  assert os.fspath(template.sharded_filepath(
      shard_index=12,
      num_shards=100)) == '/my/path/data/mnist-train.tfrecord-00012'
  assert os.fspath(template.sharded_filepaths_pattern()
                  ) == '/my/path/data/mnist-train.tfrecord*'
  assert os.fspath(template.sharded_filepaths_pattern(
      num_shards=100)) == '/my/path/data/mnist-train.tfrecord@100'
예제 #23
0
def test_sharded_file_template_template_and_properties():
  builder_dir = epath.Path('/my/path')
  template = naming.ShardedFileTemplate(
      template='data/mnist-{SPLIT}.{FILEFORMAT}-{SHARD_INDEX}',
      data_dir=builder_dir,
      # dataset_name is ignored because the template doesn't have {DATASET}
      dataset_name='imagenet',
      filetype_suffix='riegeli',
      split='test')
  assert os.fspath(template.sharded_filepath(
      shard_index=12,
      num_shards=100)) == '/my/path/data/mnist-test.riegeli-00012'
예제 #24
0
def test_sharded_file_template_sharded_filepath_shard_x_of_y_more_digits():
  builder_dir = epath.Path('/my/path')
  template = naming.ShardedFileTemplate(
      template='data/{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}',
      data_dir=builder_dir,
      dataset_name='mnist',
      filetype_suffix='tfrecord',
      split='train',
  )
  assert os.fspath(
      template.sharded_filepath(shard_index=12, num_shards=1234567)
  ) == '/my/path/data/mnist-train.tfrecord-0000012-of-1234567'
예제 #25
0
def mock_cache_path(new_cache_dir: epath.PathLike) -> Iterator[None]:
  """Mock which overwrite the cache path."""
  new_dir = epath.Path(new_cache_dir)

  # Use `__wrapped__` to access the original function wrapped inside
  # `functools.lru_cache`
  new_cache_path = utils.memoize()(cache.cache_path.__wrapped__)
  new_module_path = utils.memoize()(cache.module_path.__wrapped__)
  with mock.patch.object(cache, '_default_cache_dir', return_value=new_dir), \
       mock.patch.object(cache, 'cache_path', new_cache_path), \
       mock.patch.object(cache, 'module_path', new_module_path):
    yield
예제 #26
0
def create_synthetic_experiment(
    config):
  """Creates a synthetic experiment with finite matrices."""
  key = jax.random.PRNGKey(config.seed)
  phi_key, psi_key, key = jax.random.split(key, 3)

  Phi = jax.random.normal(phi_key, (config.S, config.d), dtype=jnp.float32)

  if config.use_mnist:
    Psi = get_mnist_data()
  else:
    Psi = jax.random.normal(psi_key, (config.S, config.T), dtype=jnp.float32)
    if config.rescale_psi == 'linear':
      Psi = generate_psi_linear(Psi)
    elif config.rescale_psi == 'exp':
      Psi = generate_psi_exp(Psi)

  sample_states = functools.partial(
      sample_discrete_states,
      num_states=config.S,
      sample_with_replacement=config.sample_with_replacement)
  eval_states = jnp.arange(config.S)

  compute_phi = lambda phi, states: phi[states, :]
  params = Phi

  def compute_psi(
      states, tasks = None):
    if tasks is None:
      return Psi[states, :]
    return Psi[states, tasks]

  if config.svd_path:
    logging.info('Loading SVD from %s', config.svd_path)
    with epath.Path(config.svd_path).open('rb') as f:
      left_svd = np.load(f)
      optimal_subspace = left_svd[:, :config.d]
  else:
    Psi = compute_psi(eval_states, None)
    optimal_subspace = compute_optimal_subspace(Psi, config.d)

  return SyntheticExperiment(
      compute_phi=compute_phi,
      compute_psi=compute_psi,
      sample_states=sample_states,
      eval_states=eval_states,
      optimal_subspace=optimal_subspace,
      params=params,
      key=key
  )
예제 #27
0
def test_sharded_file_template_sharded_filepath_shard_x_of_y():
  builder_dir = epath.Path('/my/path')
  template_explicit = naming.ShardedFileTemplate(
      template='data/mnist-train.tfrecord-{SHARD_INDEX}-of-{NUM_SHARDS}',
      data_dir=builder_dir)
  assert os.fspath(
      template_explicit.sharded_filepath(shard_index=12, num_shards=100)
  ) == '/my/path/data/mnist-train.tfrecord-00012-of-00100'

  template = naming.ShardedFileTemplate(
      template='data/mnist-train.tfrecord-{SHARD_X_OF_Y}', data_dir=builder_dir)
  assert os.fspath(template.sharded_filepath(
      shard_index=12,
      num_shards=100)) == '/my/path/data/mnist-train.tfrecord-00012-of-00100'
예제 #28
0
def _checksum_paths() -> Dict[str, epath.Path]:
    """Returns dict {'dataset_name': 'path/to/checksums/file'}."""
    dataset2path = {}
    for dir_path in _CHECKSUM_DIRS:
        if isinstance(dir_path, str):
            dir_path = epath.Path(dir_path)
        if not dir_path.exists():
            pass
        for file_path in dir_path.iterdir():
            if not file_path.name.endswith(_CHECKSUM_SUFFIX):
                continue
            dataset_name = file_path.name[:-len(_CHECKSUM_SUFFIX)]
            dataset2path[dataset_name] = file_path
    return dataset2path
예제 #29
0
def _read_indices(path):
  """Returns (files_name, list of index in each file).

  Args:
    path: path to index, omitting suffix.
  """
  paths = sorted(tf.io.gfile.glob('%s-*-of-*_index.json' % path))
  all_indices = []
  for path in paths:
    json_str = epath.Path(path).read_text()
    # parse it back into a proto.
    shard_index = json.loads(json_str)
    all_indices.append(list(shard_index['index']))
  return [os.path.basename(p) for p in paths], all_indices
예제 #30
0
    def __init__(self, path: epath.PathLike):
        """Contructor.

    Args:
      path: Remote location of the package index (file containing the list of
        dataset packages)
    """
        super().__init__()
        self._remote_path: epath.Path = epath.Path(path)
        self._cached_path: epath.Path = (cache.cache_path() /
                                         'community-datasets-list.jsonl')

        # Pre-load the index from the cache
        if self._cached_path.exists():
            self._refresh_from_content(self._cached_path.read_text())