Esempio n. 1
0
    def _RawRecordToRecordBatchInternal(self,
                                        batch_size: Optional[int] = None
                                        ) -> beam.PTransform:
        @beam.typehints.with_input_types(List[bytes])
        @beam.typehints.with_output_types(pa.RecordBatch)
        def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection):
            """Returns RecordBatch of csv lines."""

            # Decode raw csv lines to record batches.
            record_batches = (
                raw_records_pcoll
                | "CSVToRecordBatch" >> csv_decoder.CSVToRecordBatch(
                    column_names=self._column_names,
                    delimiter=self._delimiter,
                    skip_blank_lines=self._skip_blank_lines,
                    schema=self._schema,
                    desired_batch_size=batch_size,
                    multivalent_columns=self._multivalent_columns,
                    secondary_delimiter=self._secondary_delimiter,
                    produce_large_types=self._can_produce_large_types,
                    raw_record_column_name=self._raw_record_column_name))

            return record_batches

        return beam.ptransform_fn(_PTransformFn)()
Esempio n. 2
0
    def RawRecordBeamSource(self) -> beam.PTransform:
        """Returns a PTransform that produces a PCollection[bytes].

    Used together with RawRecordToRecordBatch(), it allows getting both the
    PCollection of the raw records and the PCollection of the RecordBatch from
    the same source. For example:

    record_batch = pipeline | tfxio.BeamSource()
    raw_record = pipeline | tfxio.RawRecordBeamSource()

    would result in the files being read twice, while the following would only
    read once:

    raw_record = pipeline | tfxio.RawRecordBeamSource()
    record_batch = raw_record | tfxio.RawRecordToRecordBatch()
    """
        @beam.typehints.with_input_types(Any)
        @beam.typehints.with_output_types(bytes)
        def _PTransformFn(pcoll_or_pipeline: Any):
            return (pcoll_or_pipeline
                    | "ReadRawRecords" >> self._RawRecordBeamSourceInternal()
                    |
                    "CollectRawRecordTelemetry" >> telemetry.ProfileRawRecords(
                        self._telemetry_descriptors, self._logical_format,
                        self._physical_format))

        return beam.ptransform_fn(_PTransformFn)()
Esempio n. 3
0
    def _SerializedExamplesSource(self) -> beam.PTransform:
        @beam.typehints.with_input_types(bytes)
        @beam.typehints.with_output_types(bytes)
        def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection):
            return raw_records_pcoll

        return beam.ptransform_fn(_PTransformFn)()
Esempio n. 4
0
  def RawRecordToRecordBatch(self,
                             batch_size: Optional[int] = None
                            ) -> beam.PTransform:
    """Returns a PTransform that converts raw records to Arrow RecordBatches.

    The input PCollection must be from self.RawRecordBeamSource() (also see
    the documentation for that method).

    Args:
      batch_size: if not None, the `pa.RecordBatch` produced will be of the
        specified size. Otherwise it's automatically tuned by Beam.
    """

    @beam.typehints.with_input_types(bytes)
    @beam.typehints.with_output_types(pa.RecordBatch)
    def _PTransformFn(pcoll: beam.pvalue.PCollection):
      return (pcoll
              | "RawRecordToRecordBatch" >>
              self._RawRecordToRecordBatchInternal(batch_size)
              | "CollectRecordBatchTelemetry" >>
              telemetry.ProfileRecordBatches(self._telemetry_descriptors,
                                             self._logical_format,
                                             self._physical_format))

    return beam.ptransform_fn(_PTransformFn)()
Esempio n. 5
0
    def _RawRecordBeamSourceInternal(self) -> beam.PTransform:
        @beam.typehints.with_input_types(bytes)
        @beam.typehints.with_output_types(bytes)
        def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection):
            return raw_records_pcoll

        return beam.ptransform_fn(_PTransformFn)()
Esempio n. 6
0
  def testOverridableRecordBasedTFXIO(self):
    tmp_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir)
    file1 = os.path.join(tmp_dir, "tfrecord1")
    file1_records = [b"aa", b"bb"]
    _WriteTfRecord(file1, file1_records)

    def _CheckRecords(actual, expected):
      for a, e in zip(actual, expected):
        self.assertDictEqual(a.to_pydict(), e)

    @beam.typehints.with_input_types(Any)
    @beam.typehints.with_output_types(bytes)
    def _RawRecordBeamSource(pipeline: Any):
      return pipeline | beam.io.ReadFromTFRecord(file1 + "*")

    @beam.typehints.with_input_types(bytes)
    @beam.typehints.with_output_types(pa.RecordBatch)
    def _RawRecordsToRecordBatch(pcoll, batch_size):
      batch_size = 1 if not batch_size else batch_size

      class _CreateRBDoFn(beam.DoFn):

        def process(self, examples):
          return [
              pa.RecordBatch.from_arrays([pa.array(examples)], ["column_name"])
          ]

      return (pcoll | beam.BatchElements(batch_size)
              | beam.ParDo(_CreateRBDoFn()))

    tfxio = record_based_tfxio.OverridableRecordBasedTFXIO(
        telemetry_descriptors=None,
        logical_format="tfrecord",
        physical_format="tf_example",
        raw_record_beam_source=beam.ptransform_fn(_RawRecordBeamSource),
        raw_record_to_record_batch=beam.ptransform_fn(_RawRecordsToRecordBatch))

    expected = [{"column_name": [b"aa"]}, {"column_name": [b"bb"]}]
    with beam.Pipeline() as p:
      record_pcoll = p | tfxio.BeamSource()
      beam_test_util.assert_that(
          record_pcoll,
          lambda actual: _CheckRecords(actual, expected))
Esempio n. 7
0
    def BeamSource(self, batch_size: Optional[int] = None) -> beam.PTransform:
        @beam.typehints.with_input_types(Any)
        @beam.typehints.with_output_types(pa.RecordBatch)
        def _PTransformFn(pcoll_or_pipeline: Any):
            """Converts raw records to RecordBatches."""
            return (pcoll_or_pipeline
                    | "RawRecordBeamSource" >> self.RawRecordBeamSource()
                    | "RawRecordToRecordBatch" >>
                    self.RawRecordToRecordBatch(batch_size))

        return beam.ptransform_fn(_PTransformFn)()
Esempio n. 8
0
    def _RawRecordBeamSourceInternal(self) -> beam.PTransform:
        @beam.typehints.with_input_types(beam.Pipeline)
        @beam.typehints.with_output_types(bytes)
        def _PTransformFn(pipeline: beam.pvalue.PCollection):
            return pipeline | "ReadFromTFRecord" >> beam.io.ReadFromTFRecord(
                self._file_pattern,
                coder=beam.coders.BytesCoder(),
                # TODO(b/114938612): Eventually remove this override.
                validate=False)

        return beam.ptransform_fn(_PTransformFn)()
Esempio n. 9
0
    def _RawRecordToRecordBatchInternal(self,
                                        batch_size: Optional[int] = None
                                        ) -> beam.PTransform:
        @beam.typehints.with_input_types(bytes)
        @beam.typehints.with_output_types(pa.RecordBatch)
        def _PTransformFn(raw_record_pcoll: beam.pvalue.PCollection):
            return (raw_record_pcoll
                    | "Batch" >> beam.BatchElements(
                        **batch_util.GetBatchElementsKwargs(batch_size))
                    | "ToRecordBatch" >> beam.Map(_BatchedRecordsToArrow,
                                                  self.raw_record_column_name))

        return beam.ptransform_fn(_PTransformFn)()
Esempio n. 10
0
    def _RawRecordToRecordBatchInternal(
            self, batch_size: Optional[int]) -> beam.PTransform:
        @beam.typehints.with_input_types(bytes)
        @beam.typehints.with_output_types(pa.RecordBatch)
        def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection):
            return (raw_records_pcoll
                    | "BatchElements" >> beam.BatchElements(
                        **batch_util.GetBatchElementsKwargs(batch_size))
                    | "Decode" >> beam.ParDo(
                        _RecordsToRecordBatch(self._saved_decoder_path,
                                              self.raw_record_column_name,
                                              self._can_produce_large_types)))

        return beam.ptransform_fn(_PTransformFn)()
Esempio n. 11
0
    def _split_generators(self, dl_manager, pipeline):
        del dl_manager

        examples = (pipeline | beam.Create(range(1000))
                    | beam.Map(_gen_example))

        # Wrap the pipeline inside a ptransform_fn to add `'label' >> ` to avoid
        # duplicated PTransform nodes names.
        generate_examples = beam.ptransform_fn(self._generate_examples)
        return {
            'train':
            examples | 'train' >> generate_examples(num_examples=1000),
            'test': examples | 'test' >> generate_examples(num_examples=725),
        }
Esempio n. 12
0
    def _RawRecordToRecordBatchInternal(self,
                                        batch_size: Optional[int] = None
                                        ) -> beam.PTransform:
        @beam.typehints.with_input_types(bytes)
        @beam.typehints.with_output_types(pa.RecordBatch)
        def ptransform_fn(raw_records_pcoll: beam.pvalue.PCollection):
            return (raw_records_pcoll
                    | "Batch" >> beam.BatchElements(
                        **batch_util.GetBatchElementsKwargs(batch_size))
                    | "Decode" >> beam.ParDo(
                        _DecodeBatchExamplesDoFn(self._GetSchemaForDecoding(),
                                                 self.raw_record_column_name)))

        return beam.ptransform_fn(ptransform_fn)()
Esempio n. 13
0
    def _RawRecordToRecordBatchInternal(
            self, batch_size: Optional[int]) -> beam.PTransform:
        @beam.typehints.with_input_types(bytes)
        @beam.typehints.with_output_types(pa.RecordBatch)
        def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection):
            return (
                raw_records_pcoll
                | "BatchElements" >> beam.BatchElements(
                    **batch_util.GetBatchElementsKwargs(batch_size))
                | "Decode" >> beam.ParDo(
                    _RecordsToRecordBatch(
                        self._saved_decoder_path, self.telemetry_descriptors,
                        shared.Shared() if self._use_singleton_decoder else
                        None, self.raw_record_column_name,
                        self._record_index_column_name)))

        return beam.ptransform_fn(_PTransformFn)()
Esempio n. 14
0
    def BeamSource(self, batch_size: Optional[int] = None) -> beam.PTransform:
        @beam.typehints.with_input_types(Union[beam.PCollection,
                                               beam.Pipeline])
        @beam.typehints.with_output_types(pa.RecordBatch)
        def _PTransformFn(pcoll_or_pipeline: Any):
            """Reads Parquet tables and converts to RecordBatches."""
            return (pcoll_or_pipeline
                    | "ParquetBeamSource" >> beam.io.ReadFromParquetBatched(
                        file_pattern=self._file_pattern,
                        min_bundle_size=self._min_bundle_size,
                        validate=self._validate,
                        columns=self._column_names) | "ToRecordBatch" >>
                    beam.FlatMap(self._TableToRecordBatch, batch_size)
                    | "CollectRecordBatchTelemetry"
                    >> telemetry.ProfileRecordBatches(
                        self._telemetry_descriptors, _PARQUET_FORMAT,
                        _PARQUET_FORMAT))

        return beam.ptransform_fn(_PTransformFn)()
Esempio n. 15
0
    def normalize_legacy_split_generators(
        self,
        split_generators: Union[Dict[str, SplitGenerator],
                                List[SplitGeneratorLegacy]],
        generator_fn: Callable[..., Any],
        is_beam: bool,
    ) -> Dict[str, SplitGenerator]:
        """Normalize legacy split API into new dict[split_name, generator].

    This function convert the legacy `List[tfds.core.SplitGenerator]` into
    the new `{'split_name': generator}` structure.

    Could be removed if all datasets were updated.

    Args:
      split_generators: Either legacy or new split_generators
      generator_fn: The `GeneratorBasedBuilder._generate_examples` function.
      is_beam: `True` if using legacy `tfds.core.BeamBasedBuilder`

    Returns:
      split_generators: New split generator structure.
    """
        if isinstance(split_generators, dict):  # New structure
            return split_generators
        if isinstance(split_generators, list):  # Legacy structure
            if is_beam:  # Legacy `tfds.core.BeamBasedBuilder`
                beam = lazy_imports_lib.lazy_imports.apache_beam
                generator_fn = beam.ptransform_fn(generator_fn)
                return {
                    s.name: generator_fn(
                        **s.gen_kwargs)  # Create the `beam.PTransform`
                    for s in split_generators
                }
            else:
                return {
                    split_generator.name:
                    generator_fn(**split_generator.gen_kwargs)
                    for split_generator in split_generators
                }
        else:
            raise TypeError(
                f'Invalid `_split_generators` returned value: {split_generators}'
            )
Esempio n. 16
0
 def _RawRecordBeamSourceInternal(self) -> beam.PTransform:
     return (beam.ptransform_fn(lambda x: x)().with_input_types(
         bytes).with_output_types(bytes))
Esempio n. 17
0
 def get_models(self):
     """PTransform that outputs a PCollection with all models in the stub."""
     return beam.ptransform_fn(lambda pcoll: (
         pcoll.pipeline
         | beam.Create(m for m in self._models_by_id.values())))
Esempio n. 18
0
 def _CSVSource(self) -> beam.PTransform:
     return (beam.ptransform_fn(lambda x: x)().with_input_types(
         bytes).with_output_types(bytes))