def _prepare_split(self, split_generator, pipeline): beam = lazy_imports_lib.lazy_imports.apache_beam if not tf.io.gfile.exists(self._data_dir): tf.io.gfile.makedirs(self._data_dir) split_name = split_generator.split_info.name output_prefix = naming.filename_prefix_for_split(self.name, split_name) output_prefix = os.path.join(self._data_dir, output_prefix) # To write examples to disk: fname = "{}-{}.tfrecord".format(self.name, split_name) fpath = os.path.join(self._data_dir, fname) beam_writer = tfrecords_writer.BeamWriter(self._example_specs, fpath, hash_salt=split_name) self._beam_writers[split_name] = beam_writer encode_example = self.info.features.encode_example # Note: We need to wrap the pipeline in a PTransform to avoid re-using the # same label names for each split @beam.ptransform_fn def _build_pcollection(pipeline): """PTransformation which build a single split.""" # Encode the PCollection pcoll_examples = self._build_pcollection( pipeline, **split_generator.gen_kwargs) pcoll_examples |= "Encode" >> beam.Map( lambda key_ex: (key_ex[0], encode_example(key_ex[1]))) return beam_writer.write_from_pcollection(pcoll_examples) # Add the PCollection to the pipeline _ = pipeline | split_name >> _build_pcollection() # pylint: disable=no-value-for-parameter
def _write(self, to_write, path, salt='', disable_shuffling=False, file_format=file_adapters.DEFAULT_FILE_FORMAT): beam = lazy_imports_lib.lazy_imports.apache_beam writer = tfrecords_writer.BeamWriter( 'some spec', path, salt, disable_shuffling=disable_shuffling, file_format=file_format, ) # Here we need to disable type check as `beam.Create` is not capable of # inferring the type of the PCollection elements. options = beam.options.pipeline_options.PipelineOptions( pipeline_type_check=False) with beam.Pipeline(options=options) as pipeline: @beam.ptransform_fn def _build_pcollection(pipeline): pcollection = pipeline | 'Start' >> beam.Create(to_write) return writer.write_from_pcollection(pcollection) _ = pipeline | 'test' >> _build_pcollection() # pylint: disable=no-value-for-parameter return writer.finalize()
def _write(self, to_write, dataset_name: str = 'foo', split: str = 'train', salt='', disable_shuffling=False, file_format=file_adapters.DEFAULT_FILE_FORMAT): filetype_suffix = file_adapters.ADAPTER_FOR_FORMAT[ file_format].FILE_SUFFIX filename_template = naming.ShardedFileTemplate( dataset_name=dataset_name, split=split, filetype_suffix=filetype_suffix, data_dir=self.tmp_dir) beam = lazy_imports_lib.lazy_imports.apache_beam writer = tfrecords_writer.BeamWriter( example_specs='some spec', filename_template=filename_template, hash_salt=salt, disable_shuffling=disable_shuffling, file_format=file_format, ) # Here we need to disable type check as `beam.Create` is not capable of # inferring the type of the PCollection elements. options = beam.options.pipeline_options.PipelineOptions( pipeline_type_check=False) with beam.Pipeline(options=options) as pipeline: @beam.ptransform_fn def _build_pcollection(pipeline): pcollection = pipeline | 'Start' >> beam.Create(to_write) return writer.write_from_pcollection(pcollection) _ = pipeline | 'test' >> _build_pcollection() # pylint: disable=no-value-for-parameter return writer.finalize()
def _build_from_pcollection( self, split_name: str, generator: 'beam.PCollection[KeyExample]', filename_template: naming.ShardedFileTemplate, disable_shuffling: bool, ) -> _SplitInfoFuture: """Split generator for `beam.PCollection`.""" # TODO(tfds): Should try to add support to `max_examples_per_split` beam = lazy_imports_lib.lazy_imports.apache_beam beam_writer = tfrecords_writer.BeamWriter( example_specs=self._features.get_serialized_info(), filename_template=filename_template, hash_salt=split_name, disable_shuffling=disable_shuffling, file_format=self._file_format, ) def _encode_example(key_ex, encode_fn=self._features.encode_example): # We do not access self._features in this function to avoid pickling the # entire class. return key_ex[0], encode_fn(key_ex[1]) # Note: We need to wrap the pipeline in a PTransform to avoid # errors due to duplicated ``>> beam_labels` @beam.ptransform_fn def _encode_pcollection(pipeline): """PTransformation which build a single split.""" pcoll_examples = pipeline | 'Encode' >> beam.Map(_encode_example) return beam_writer.write_from_pcollection(pcoll_examples) # Add the PCollection to the pipeline _ = generator | f'{split_name}_write' >> _encode_pcollection() # pylint: disable=no-value-for-parameter def _resolve_future(): if self._in_contextmanager: raise AssertionError( '`future.result()` should be called after the ' '`maybe_beam_pipeline` contextmanager.') logging.info('Retrieving split info for %s...', split_name) shard_lengths, total_size = beam_writer.finalize() return splits_lib.SplitInfo( name=split_name, shard_lengths=shard_lengths, num_bytes=total_size, filename_template=filename_template, ) return _SplitInfoFuture(_resolve_future)