def test_e2e(self, stats_options, expected_stats_pbtxt, expected_schema_pbtxt): tfxio = tf_sequence_example_record.TFSequenceExampleRecord( self._input_file, ['tfdv', 'test']) stats_file = os.path.join(self._output_dir, 'stats') with beam.Pipeline() as p: _ = (p | 'TFXIORead' >> tfxio.BeamSource() | 'GenerateStats' >> tfdv.GenerateStatistics(stats_options) | 'WriteStats' >> tfdv.WriteStatisticsToTFRecord(stats_file)) actual_stats = tfdv.load_statistics(stats_file) test_util.make_dataset_feature_stats_list_proto_equal_fn( self, text_format.Parse( expected_stats_pbtxt, statistics_pb2.DatasetFeatureStatisticsList()))([actual_stats]) actual_schema = tfdv.infer_schema(actual_stats, infer_feature_shape=True) if hasattr(actual_schema, 'generate_legacy_feature_spec'): actual_schema.ClearField('generate_legacy_feature_spec') self._assert_schema_equal( actual_schema, text_format.Parse(expected_schema_pbtxt, schema_pb2.Schema()))
def make_tfxio( file_pattern: OneOrMorePatterns, telemetry_descriptors: List[str], payload_format: Union[str, int], data_view_uri: Optional[str] = None, schema: Optional[schema_pb2.Schema] = None, read_as_raw_records: bool = False, raw_record_column_name: Optional[str] = None, file_format: Optional[Union[str, List[str]]] = None) -> tfxio.TFXIO: """Creates a TFXIO instance that reads `file_pattern`. Args: file_pattern: the file pattern for the TFXIO to access. telemetry_descriptors: A set of descriptors that identify the component that is instantiating the TFXIO. These will be used to construct the namespace to contain metrics for profiling and are therefore expected to be identifiers of the component itself and not individual instances of source use. payload_format: one of the enums from example_gen_pb2.PayloadFormat (may be in string or int form). If None, default to FORMAT_TF_EXAMPLE. data_view_uri: uri to a DataView artifact. A DataView is needed in order to create a TFXIO for certain payload formats. schema: TFMD schema. Note: although optional, some payload formats need a schema in order for all TFXIO interfaces (e.g. TensorAdapter()) to work. Unless you know what you are doing, always supply a schema. read_as_raw_records: If True, ignore the payload type of `examples`. Always use RawTfRecord TFXIO. raw_record_column_name: If provided, the arrow RecordBatch produced by the TFXIO will contain a string column of the given name, and the contents of that column will be the raw records. Note that not all TFXIO supports this option, and an error will be raised in that case. Required if read_as_raw_records == True. file_format: file format string for each file_pattern. Only 'tfrecords_gzip' is supported for now. Returns: a TFXIO instance. """ if not isinstance(payload_format, int): payload_format = example_gen_pb2.PayloadFormat.Value(payload_format) if file_format is not None: if type(file_format) is not type(file_pattern): raise ValueError( f'The type of file_pattern and file_formats should be the same.' f'Given: file_pattern={file_pattern}, file_format={file_format}' ) if isinstance(file_format, list): if len(file_format) != len(file_pattern): raise ValueError( f'The length of file_pattern and file_formats should be the same.' f'Given: file_pattern={file_pattern}, file_format={file_format}' ) else: if any(item != 'tfrecords_gzip' for item in file_format): raise NotImplementedError( f'{file_format} is not supported yet.') else: # file_format is str type. if file_format != 'tfrecords_gzip': raise NotImplementedError( f'{file_format} is not supported yet.') if read_as_raw_records: assert raw_record_column_name is not None, ( 'read_as_raw_records is specified - ' 'must provide raw_record_column_name') return raw_tf_record.RawTfRecordTFXIO( file_pattern=file_pattern, raw_record_column_name=raw_record_column_name, telemetry_descriptors=telemetry_descriptors) if payload_format == example_gen_pb2.PayloadFormat.FORMAT_TF_EXAMPLE: return tf_example_record.TFExampleRecord( file_pattern=file_pattern, schema=schema, raw_record_column_name=raw_record_column_name, telemetry_descriptors=telemetry_descriptors) if (payload_format == example_gen_pb2.PayloadFormat.FORMAT_TF_SEQUENCE_EXAMPLE): return tf_sequence_example_record.TFSequenceExampleRecord( file_pattern=file_pattern, schema=schema, raw_record_column_name=raw_record_column_name, telemetry_descriptors=telemetry_descriptors) if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PROTO: assert data_view_uri is not None, ( 'Accessing FORMAT_PROTO requires a DataView to parse the proto.') return record_to_tensor_tfxio.TFRecordToTensorTFXIO( file_pattern=file_pattern, saved_decoder_path=data_view_uri, telemetry_descriptors=telemetry_descriptors, raw_record_column_name=raw_record_column_name) raise NotImplementedError( 'Unsupport payload format: {}'.format(payload_format))
def _MakeTFXIO(self, schema, raw_record_column_name=None): return tf_sequence_example_record.TFSequenceExampleRecord( self._example_file, schema=schema, raw_record_column_name=raw_record_column_name, telemetry_descriptors=_TELEMETRY_DESCRIPTORS)