def test_csv_to_recordbatch_schema_features_subset_of_column_names(self): input_lines = ['1,2.0,hello', '5,12.34,world'] column_names = ['int_feature', 'float_feature', 'str_feature'] schema = text_format.Parse( """feature { name: "int_feature" type: INT }""", schema_pb2.Schema()) self.assertEqual( csv_decoder.GetArrowSchema(column_names, schema), pa.schema([pa.field('int_feature', pa.large_list(pa.int64()))])) def _check_record_batches(record_batches): self.assertLen(record_batches, 1) self.assertTrue(record_batches[0].equals( pa.RecordBatch.from_arrays( [pa.array([[1], [5]], pa.large_list(pa.int64()))], ['int_feature']))) with beam.Pipeline() as p: record_batches = ( p | 'CreatingPColl' >> beam.Create(input_lines, reshuffle=False) | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( column_names=column_names, delimiter=',', desired_batch_size=1000, schema=schema)) beam_test_util.assert_that(record_batches, _check_record_batches, label='check_record_batches')
def test_invalid_raw_record_column_name(self): input_lines = ['1,2.0,hello', '5,12.34'] schema = text_format.Parse( """ feature { name: "int_feature" type: INT } feature { name: "float_feature" type: FLOAT } feature { name: "str_feature" type: BYTES } """, schema_pb2.Schema()) column_names = ['int_feature', 'float_feature', 'str_feature'] with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises ValueError, 'raw_record_column_name.* is already an existing column.*'): with beam.Pipeline() as p: result = (p | beam.Create(input_lines, reshuffle=False) | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( column_names=column_names, desired_batch_size=1000, raw_record_column_name='int_feature')) beam_test_util.assert_that(result, lambda _: None) with self.assertRaisesRegex( ValueError, 'raw_record_column_name.* is already an existing column.*'): csv_decoder.GetArrowSchema(column_names, schema, raw_record_column_name='int_feature')
def _ArrowSchemaNoRawRecordColumn(self) -> pa.Schema: if not self._schema: raise ValueError("TFMD schema not provided. Unable to derive an " "Arrow schema") return csv_decoder.GetArrowSchema( self._column_names, self._schema, large_types=self._can_produce_large_types)
def ArrowSchema(self) -> pa.Schema: if self._schema is None: return self._InferArrowSchema() # If the column names are not passed, we default to all column names in the # schema. columns = self._column_names or [f.name for f in self._schema.feature] return csv_decoder.GetArrowSchema(columns, self._schema)
def test_get_arrow_schema_schema_feature_not_subset_of_column_names(self): schema = text_format.Parse( """ feature { name: "f1" type: INT } feature { name: "f2" type: INT } """, schema_pb2.Schema()) column_names = ['f1'] with self.assertRaisesRegex( ValueError, 'Schema features are not a subset of column names'): csv_decoder.GetArrowSchema(column_names, schema)
def test_get_arrow_schema_column_names_invalid(self): schema = text_format.Parse( """ feature { name: "f1" type: INT value_count { min: 0 max: 2 } } feature { name: "f2" type: INT value_count { min: 0 max: 2 } } """, schema_pb2.Schema()) column_names = ['f1'] with self.assertRaisesRegex( ValueError, 'Column Names.* does not match schema features.*'): csv_decoder.GetArrowSchema(column_names, schema)
def _check_arrow_schema(actual): for record_batch in actual: expected_arrow_schema = csv_decoder.GetArrowSchema( column_names, schema, raw_record_column_name) self.assertEqual(record_batch.schema, expected_arrow_schema)
def _check_arrow_schema(actual): if actual: expected_arrow_schema = csv_decoder.GetArrowSchema( column_names, schema, raw_record_column_name, produce_large_types) self.assertEqual(actual[0].schema, expected_arrow_schema)