コード例 #1
0
    def _prepare_split(self, split_generator, pipeline):
        beam = lazy_imports_lib.lazy_imports.apache_beam

        if not tf.io.gfile.exists(self._data_dir):
            tf.io.gfile.makedirs(self._data_dir)

        split_name = split_generator.split_info.name
        output_prefix = naming.filename_prefix_for_split(self.name, split_name)
        output_prefix = os.path.join(self._data_dir, output_prefix)

        # To write examples to disk:
        fname = "{}-{}.tfrecord".format(self.name, split_name)
        fpath = os.path.join(self._data_dir, fname)
        beam_writer = tfrecords_writer.BeamWriter(self._example_specs,
                                                  fpath,
                                                  hash_salt=split_name)
        self._beam_writers[split_name] = beam_writer

        encode_example = self.info.features.encode_example

        # Note: We need to wrap the pipeline in a PTransform to avoid re-using the
        # same label names for each split
        @beam.ptransform_fn
        def _build_pcollection(pipeline):
            """PTransformation which build a single split."""
            # Encode the PCollection
            pcoll_examples = self._build_pcollection(
                pipeline, **split_generator.gen_kwargs)
            pcoll_examples |= "Encode" >> beam.Map(
                lambda key_ex: (key_ex[0], encode_example(key_ex[1])))
            return beam_writer.write_from_pcollection(pcoll_examples)

        # Add the PCollection to the pipeline
        _ = pipeline | split_name >> _build_pcollection()  # pylint: disable=no-value-for-parameter
コード例 #2
0
    def _write(self,
               to_write,
               path,
               salt='',
               disable_shuffling=False,
               file_format=file_adapters.DEFAULT_FILE_FORMAT):
        beam = lazy_imports_lib.lazy_imports.apache_beam
        writer = tfrecords_writer.BeamWriter(
            'some spec',
            path,
            salt,
            disable_shuffling=disable_shuffling,
            file_format=file_format,
        )
        # Here we need to disable type check as `beam.Create` is not capable of
        # inferring the type of the PCollection elements.
        options = beam.options.pipeline_options.PipelineOptions(
            pipeline_type_check=False)
        with beam.Pipeline(options=options) as pipeline:

            @beam.ptransform_fn
            def _build_pcollection(pipeline):
                pcollection = pipeline | 'Start' >> beam.Create(to_write)
                return writer.write_from_pcollection(pcollection)

            _ = pipeline | 'test' >> _build_pcollection()  # pylint: disable=no-value-for-parameter
        return writer.finalize()
コード例 #3
0
    def _write(self,
               to_write,
               dataset_name: str = 'foo',
               split: str = 'train',
               salt='',
               disable_shuffling=False,
               file_format=file_adapters.DEFAULT_FILE_FORMAT):
        filetype_suffix = file_adapters.ADAPTER_FOR_FORMAT[
            file_format].FILE_SUFFIX
        filename_template = naming.ShardedFileTemplate(
            dataset_name=dataset_name,
            split=split,
            filetype_suffix=filetype_suffix,
            data_dir=self.tmp_dir)
        beam = lazy_imports_lib.lazy_imports.apache_beam
        writer = tfrecords_writer.BeamWriter(
            example_specs='some spec',
            filename_template=filename_template,
            hash_salt=salt,
            disable_shuffling=disable_shuffling,
            file_format=file_format,
        )
        # Here we need to disable type check as `beam.Create` is not capable of
        # inferring the type of the PCollection elements.
        options = beam.options.pipeline_options.PipelineOptions(
            pipeline_type_check=False)
        with beam.Pipeline(options=options) as pipeline:

            @beam.ptransform_fn
            def _build_pcollection(pipeline):
                pcollection = pipeline | 'Start' >> beam.Create(to_write)
                return writer.write_from_pcollection(pcollection)

            _ = pipeline | 'test' >> _build_pcollection()  # pylint: disable=no-value-for-parameter
        return writer.finalize()
コード例 #4
0
    def _build_from_pcollection(
        self,
        split_name: str,
        generator: 'beam.PCollection[KeyExample]',
        filename_template: naming.ShardedFileTemplate,
        disable_shuffling: bool,
    ) -> _SplitInfoFuture:
        """Split generator for `beam.PCollection`."""
        # TODO(tfds): Should try to add support to `max_examples_per_split`
        beam = lazy_imports_lib.lazy_imports.apache_beam

        beam_writer = tfrecords_writer.BeamWriter(
            example_specs=self._features.get_serialized_info(),
            filename_template=filename_template,
            hash_salt=split_name,
            disable_shuffling=disable_shuffling,
            file_format=self._file_format,
        )

        def _encode_example(key_ex, encode_fn=self._features.encode_example):
            # We do not access self._features in this function to avoid pickling the
            # entire class.
            return key_ex[0], encode_fn(key_ex[1])

        # Note: We need to wrap the pipeline in a PTransform to avoid
        # errors due to duplicated ``>> beam_labels`
        @beam.ptransform_fn
        def _encode_pcollection(pipeline):
            """PTransformation which build a single split."""
            pcoll_examples = pipeline | 'Encode' >> beam.Map(_encode_example)
            return beam_writer.write_from_pcollection(pcoll_examples)

        # Add the PCollection to the pipeline
        _ = generator | f'{split_name}_write' >> _encode_pcollection()  # pylint: disable=no-value-for-parameter

        def _resolve_future():
            if self._in_contextmanager:
                raise AssertionError(
                    '`future.result()` should be called after the '
                    '`maybe_beam_pipeline` contextmanager.')
            logging.info('Retrieving split info for %s...', split_name)
            shard_lengths, total_size = beam_writer.finalize()
            return splits_lib.SplitInfo(
                name=split_name,
                shard_lengths=shard_lengths,
                num_bytes=total_size,
                filename_template=filename_template,
            )

        return _SplitInfoFuture(_resolve_future)