def _RawRecordToRecordBatchInternal(self, batch_size: Optional[int] = None ) -> beam.PTransform: @beam.typehints.with_input_types(List[bytes]) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection): """Returns RecordBatch of csv lines.""" # Decode raw csv lines to record batches. record_batches = ( raw_records_pcoll | "CSVToRecordBatch" >> csv_decoder.CSVToRecordBatch( column_names=self._column_names, delimiter=self._delimiter, skip_blank_lines=self._skip_blank_lines, schema=self._schema, desired_batch_size=batch_size, multivalent_columns=self._multivalent_columns, secondary_delimiter=self._secondary_delimiter, produce_large_types=self._can_produce_large_types, raw_record_column_name=self._raw_record_column_name)) return record_batches return beam.ptransform_fn(_PTransformFn)()
def RawRecordBeamSource(self) -> beam.PTransform: """Returns a PTransform that produces a PCollection[bytes]. Used together with RawRecordToRecordBatch(), it allows getting both the PCollection of the raw records and the PCollection of the RecordBatch from the same source. For example: record_batch = pipeline | tfxio.BeamSource() raw_record = pipeline | tfxio.RawRecordBeamSource() would result in the files being read twice, while the following would only read once: raw_record = pipeline | tfxio.RawRecordBeamSource() record_batch = raw_record | tfxio.RawRecordToRecordBatch() """ @beam.typehints.with_input_types(Any) @beam.typehints.with_output_types(bytes) def _PTransformFn(pcoll_or_pipeline: Any): return (pcoll_or_pipeline | "ReadRawRecords" >> self._RawRecordBeamSourceInternal() | "CollectRawRecordTelemetry" >> telemetry.ProfileRawRecords( self._telemetry_descriptors, self._logical_format, self._physical_format)) return beam.ptransform_fn(_PTransformFn)()
def _SerializedExamplesSource(self) -> beam.PTransform: @beam.typehints.with_input_types(bytes) @beam.typehints.with_output_types(bytes) def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection): return raw_records_pcoll return beam.ptransform_fn(_PTransformFn)()
def RawRecordToRecordBatch(self, batch_size: Optional[int] = None ) -> beam.PTransform: """Returns a PTransform that converts raw records to Arrow RecordBatches. The input PCollection must be from self.RawRecordBeamSource() (also see the documentation for that method). Args: batch_size: if not None, the `pa.RecordBatch` produced will be of the specified size. Otherwise it's automatically tuned by Beam. """ @beam.typehints.with_input_types(bytes) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(pcoll: beam.pvalue.PCollection): return (pcoll | "RawRecordToRecordBatch" >> self._RawRecordToRecordBatchInternal(batch_size) | "CollectRecordBatchTelemetry" >> telemetry.ProfileRecordBatches(self._telemetry_descriptors, self._logical_format, self._physical_format)) return beam.ptransform_fn(_PTransformFn)()
def _RawRecordBeamSourceInternal(self) -> beam.PTransform: @beam.typehints.with_input_types(bytes) @beam.typehints.with_output_types(bytes) def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection): return raw_records_pcoll return beam.ptransform_fn(_PTransformFn)()
def testOverridableRecordBasedTFXIO(self): tmp_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir) file1 = os.path.join(tmp_dir, "tfrecord1") file1_records = [b"aa", b"bb"] _WriteTfRecord(file1, file1_records) def _CheckRecords(actual, expected): for a, e in zip(actual, expected): self.assertDictEqual(a.to_pydict(), e) @beam.typehints.with_input_types(Any) @beam.typehints.with_output_types(bytes) def _RawRecordBeamSource(pipeline: Any): return pipeline | beam.io.ReadFromTFRecord(file1 + "*") @beam.typehints.with_input_types(bytes) @beam.typehints.with_output_types(pa.RecordBatch) def _RawRecordsToRecordBatch(pcoll, batch_size): batch_size = 1 if not batch_size else batch_size class _CreateRBDoFn(beam.DoFn): def process(self, examples): return [ pa.RecordBatch.from_arrays([pa.array(examples)], ["column_name"]) ] return (pcoll | beam.BatchElements(batch_size) | beam.ParDo(_CreateRBDoFn())) tfxio = record_based_tfxio.OverridableRecordBasedTFXIO( telemetry_descriptors=None, logical_format="tfrecord", physical_format="tf_example", raw_record_beam_source=beam.ptransform_fn(_RawRecordBeamSource), raw_record_to_record_batch=beam.ptransform_fn(_RawRecordsToRecordBatch)) expected = [{"column_name": [b"aa"]}, {"column_name": [b"bb"]}] with beam.Pipeline() as p: record_pcoll = p | tfxio.BeamSource() beam_test_util.assert_that( record_pcoll, lambda actual: _CheckRecords(actual, expected))
def BeamSource(self, batch_size: Optional[int] = None) -> beam.PTransform: @beam.typehints.with_input_types(Any) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(pcoll_or_pipeline: Any): """Converts raw records to RecordBatches.""" return (pcoll_or_pipeline | "RawRecordBeamSource" >> self.RawRecordBeamSource() | "RawRecordToRecordBatch" >> self.RawRecordToRecordBatch(batch_size)) return beam.ptransform_fn(_PTransformFn)()
def _RawRecordBeamSourceInternal(self) -> beam.PTransform: @beam.typehints.with_input_types(beam.Pipeline) @beam.typehints.with_output_types(bytes) def _PTransformFn(pipeline: beam.pvalue.PCollection): return pipeline | "ReadFromTFRecord" >> beam.io.ReadFromTFRecord( self._file_pattern, coder=beam.coders.BytesCoder(), # TODO(b/114938612): Eventually remove this override. validate=False) return beam.ptransform_fn(_PTransformFn)()
def _RawRecordToRecordBatchInternal(self, batch_size: Optional[int] = None ) -> beam.PTransform: @beam.typehints.with_input_types(bytes) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(raw_record_pcoll: beam.pvalue.PCollection): return (raw_record_pcoll | "Batch" >> beam.BatchElements( **batch_util.GetBatchElementsKwargs(batch_size)) | "ToRecordBatch" >> beam.Map(_BatchedRecordsToArrow, self.raw_record_column_name)) return beam.ptransform_fn(_PTransformFn)()
def _RawRecordToRecordBatchInternal( self, batch_size: Optional[int]) -> beam.PTransform: @beam.typehints.with_input_types(bytes) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection): return (raw_records_pcoll | "BatchElements" >> beam.BatchElements( **batch_util.GetBatchElementsKwargs(batch_size)) | "Decode" >> beam.ParDo( _RecordsToRecordBatch(self._saved_decoder_path, self.raw_record_column_name, self._can_produce_large_types))) return beam.ptransform_fn(_PTransformFn)()
def _split_generators(self, dl_manager, pipeline): del dl_manager examples = (pipeline | beam.Create(range(1000)) | beam.Map(_gen_example)) # Wrap the pipeline inside a ptransform_fn to add `'label' >> ` to avoid # duplicated PTransform nodes names. generate_examples = beam.ptransform_fn(self._generate_examples) return { 'train': examples | 'train' >> generate_examples(num_examples=1000), 'test': examples | 'test' >> generate_examples(num_examples=725), }
def _RawRecordToRecordBatchInternal(self, batch_size: Optional[int] = None ) -> beam.PTransform: @beam.typehints.with_input_types(bytes) @beam.typehints.with_output_types(pa.RecordBatch) def ptransform_fn(raw_records_pcoll: beam.pvalue.PCollection): return (raw_records_pcoll | "Batch" >> beam.BatchElements( **batch_util.GetBatchElementsKwargs(batch_size)) | "Decode" >> beam.ParDo( _DecodeBatchExamplesDoFn(self._GetSchemaForDecoding(), self.raw_record_column_name))) return beam.ptransform_fn(ptransform_fn)()
def _RawRecordToRecordBatchInternal( self, batch_size: Optional[int]) -> beam.PTransform: @beam.typehints.with_input_types(bytes) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection): return ( raw_records_pcoll | "BatchElements" >> beam.BatchElements( **batch_util.GetBatchElementsKwargs(batch_size)) | "Decode" >> beam.ParDo( _RecordsToRecordBatch( self._saved_decoder_path, self.telemetry_descriptors, shared.Shared() if self._use_singleton_decoder else None, self.raw_record_column_name, self._record_index_column_name))) return beam.ptransform_fn(_PTransformFn)()
def BeamSource(self, batch_size: Optional[int] = None) -> beam.PTransform: @beam.typehints.with_input_types(Union[beam.PCollection, beam.Pipeline]) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(pcoll_or_pipeline: Any): """Reads Parquet tables and converts to RecordBatches.""" return (pcoll_or_pipeline | "ParquetBeamSource" >> beam.io.ReadFromParquetBatched( file_pattern=self._file_pattern, min_bundle_size=self._min_bundle_size, validate=self._validate, columns=self._column_names) | "ToRecordBatch" >> beam.FlatMap(self._TableToRecordBatch, batch_size) | "CollectRecordBatchTelemetry" >> telemetry.ProfileRecordBatches( self._telemetry_descriptors, _PARQUET_FORMAT, _PARQUET_FORMAT)) return beam.ptransform_fn(_PTransformFn)()
def normalize_legacy_split_generators( self, split_generators: Union[Dict[str, SplitGenerator], List[SplitGeneratorLegacy]], generator_fn: Callable[..., Any], is_beam: bool, ) -> Dict[str, SplitGenerator]: """Normalize legacy split API into new dict[split_name, generator]. This function convert the legacy `List[tfds.core.SplitGenerator]` into the new `{'split_name': generator}` structure. Could be removed if all datasets were updated. Args: split_generators: Either legacy or new split_generators generator_fn: The `GeneratorBasedBuilder._generate_examples` function. is_beam: `True` if using legacy `tfds.core.BeamBasedBuilder` Returns: split_generators: New split generator structure. """ if isinstance(split_generators, dict): # New structure return split_generators if isinstance(split_generators, list): # Legacy structure if is_beam: # Legacy `tfds.core.BeamBasedBuilder` beam = lazy_imports_lib.lazy_imports.apache_beam generator_fn = beam.ptransform_fn(generator_fn) return { s.name: generator_fn( **s.gen_kwargs) # Create the `beam.PTransform` for s in split_generators } else: return { split_generator.name: generator_fn(**split_generator.gen_kwargs) for split_generator in split_generators } else: raise TypeError( f'Invalid `_split_generators` returned value: {split_generators}' )
def _RawRecordBeamSourceInternal(self) -> beam.PTransform: return (beam.ptransform_fn(lambda x: x)().with_input_types( bytes).with_output_types(bytes))
def get_models(self): """PTransform that outputs a PCollection with all models in the stub.""" return beam.ptransform_fn(lambda pcoll: ( pcoll.pipeline | beam.Create(m for m in self._models_by_id.values())))
def _CSVSource(self) -> beam.PTransform: return (beam.ptransform_fn(lambda x: x)().with_input_types( bytes).with_output_types(bytes))