def test_beam(tmp_path: pathlib.Path): """Test that `maybe_beam_pipeline` behave as `beam.Pipeline()`.""" builder = testing.DummyMnist() split_builder = split_builder_lib.SplitBuilder( split_dict=builder.info.splits, features=builder.info.features, beam_options=None, beam_runner=None, max_examples_per_split=None, ) path = tmp_path / 'out.txt' with split_builder.maybe_beam_pipeline() as pipeline_proxy: ptransform = (beam.Create(range(9)) | beam.Map(lambda x: x * 10) | beam.Map(_inc_placeholder_counter) | beam.io.WriteToText(os.fspath(path), shard_name_template='')) _ = split_builder.beam_pipeline | ptransform result = pipeline_proxy.result # counters = metrics.get_metrics(result, 'some_namespace').counters mfilter = beam.metrics.MetricsFilter().with_namespaces(['some_namespace']) all_metrics = result.metrics().query(mfilter) counters = all_metrics['counters'] assert counters[0].key.metric.name == 'some_counter' assert counters[0].committed == 9 assert path.read_text() == '\n'.join(str(x * 10) for x in range(9)) + '\n'
def _download_and_prepare( self, dl_manager: download.DownloadManager, download_config: download.DownloadConfig, ) -> None: """Generate all splits and returns the computed split infos.""" split_builder = split_builder_lib.SplitBuilder( split_dict=self.info.splits, features=self.info.features, max_examples_per_split=download_config.max_examples_per_split, beam_options=download_config.beam_options, beam_runner=download_config.beam_runner, ) # Wrap the generation inside a context manager. # If `beam` is used during generation (when a pipeline gets created), # the context manager is equivalent to `with beam.Pipeline()`. # Otherwise, this is a no-op. with split_builder.maybe_beam_pipeline(): # If the signature has a `pipeline` kwargs, create the pipeline now and # forward it to `self._split_generators` signature = inspect.signature(self._split_generators) if "pipeline" in signature.parameters.keys(): optional_pipeline_kwargs = dict(pipeline=split_builder.beam_pipeline) else: optional_pipeline_kwargs = {} split_generators = self._split_generators( # pylint: disable=unexpected-keyword-arg dl_manager, **optional_pipeline_kwargs ) # TODO(tfds): Could be removed one all datasets are migrated. # https://github.com/tensorflow/datasets/issues/2537 # Legacy mode (eventually convert list[SplitGeneratorLegacy] -> dict) split_generators = split_builder.normalize_legacy_split_generators( split_generators=split_generators, generator_fn=self._generate_examples, is_beam=isinstance(self, BeamBasedBuilder), ) # Ensure `all` isn't used as key. _check_split_names(split_generators.keys()) # Start generating data for all splits split_info_futures = [ split_builder.submit_split_generation( # pylint: disable=g-complex-comprehension split_name=split_name, generator=generator, path=self._data_path / f"{self.name}-{split_name}.tfrecord", ) for split_name, generator in utils.tqdm(split_generators.items(), unit=" splits", leave=False) ] # Finalize the splits (after apache beam completed, if it was used) split_infos = [future.result() for future in split_info_futures] # Update the info object with the splits. # TODO(tfds): Should improve the API. split_dict = splits_lib.SplitDict(dataset_name=self.name) for split_info in split_infos: split_dict.add(split_info) self.info.update_splits_if_different(split_dict)
def test_beam(tmp_path: pathlib.Path): """Test that `maybe_beam_pipeline` behave as `beam.Pipeline()`.""" builder = testing.DummyMnist() split_builder = split_builder_lib.SplitBuilder( split_dict=builder.info.splits, features=builder.info.features, beam_options=None, beam_runner=None, max_examples_per_split=None, ) path = tmp_path / 'out.txt' with split_builder.maybe_beam_pipeline(): ptransform = (beam.Create(range(9)) | beam.Map(lambda x: x * 10) | beam.io.WriteToText(os.fspath(path), shard_name_template='')) _ = split_builder.beam_pipeline | ptransform assert path.read_text() == '\n'.join(str(x * 10) for x in range(9)) + '\n'
def _download_and_prepare( self, dl_manager: download.DownloadManager, download_config: download.DownloadConfig, ) -> None: """Generate all splits and returns the computed split infos.""" split_builder = split_builder_lib.SplitBuilder( split_dict=self.info.splits, features=self.info.features, max_examples_per_split=download_config.max_examples_per_split, beam_options=download_config.beam_options, beam_runner=download_config.beam_runner, file_format=self._file_format, ) # Wrap the generation inside a context manager. # If `beam` is used during generation (when a pipeline gets created), # the context manager is equivalent to `with beam.Pipeline()`. # Otherwise, this is a no-op. # By auto-detecting Beam, the user only has to change `_generate_examples` # to go from non-beam to beam dataset: # https://www.tensorflow.org/datasets/beam_datasets#instructions with split_builder.maybe_beam_pipeline(): # If the signature has a `pipeline` kwargs, create the pipeline now and # forward it to `self._split_generators` # We add this magic because the pipeline kwargs is only used by c4 and # we do not want to make the API more verbose for a single advanced case. signature = inspect.signature(self._split_generators) if "pipeline" in signature.parameters.keys(): optional_pipeline_kwargs = dict(pipeline=split_builder.beam_pipeline) else: optional_pipeline_kwargs = {} split_generators = self._split_generators( # pylint: disable=unexpected-keyword-arg dl_manager, **optional_pipeline_kwargs ) # TODO(tfds): Could be removed once all datasets are migrated. # https://github.com/tensorflow/datasets/issues/2537 # Legacy mode (eventually convert list[SplitGeneratorLegacy] -> dict) split_generators = split_builder.normalize_legacy_split_generators( split_generators=split_generators, generator_fn=self._generate_examples, is_beam=isinstance(self, BeamBasedBuilder), ) # Ensure `all` isn't used as key. _check_split_names(split_generators.keys()) # Writer fail if the number of example yield is `0`, so we return here. if download_config.max_examples_per_split == 0: return # Start generating data for all splits path_suffix = file_adapters.ADAPTER_FOR_FORMAT[ self._file_format].FILE_SUFFIX split_info_futures = [ split_builder.submit_split_generation( # pylint: disable=g-complex-comprehension split_name=split_name, generator=generator, path=self.data_path / f"{self.name}-{split_name}.{path_suffix}", ) for split_name, generator in utils.tqdm(split_generators.items(), unit=" splits", leave=False) ] # Finalize the splits (after apache beam completed, if it was used) split_infos = [future.result() for future in split_info_futures] # Update the info object with the splits. split_dict = splits_lib.SplitDict(split_infos, dataset_name=self.name) self.info.set_splits(split_dict)