def test_invalid_raw_record_column_name(self): input_lines = ['1,2.0,hello', '5,12.34'] schema = text_format.Parse( """ feature { name: "int_feature" type: INT } feature { name: "float_feature" type: FLOAT } feature { name: "str_feature" type: BYTES } """, schema_pb2.Schema()) column_names = ['int_feature', 'float_feature', 'str_feature'] with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises ValueError, 'raw_record_column_name.* is already an existing column.*'): with beam.Pipeline() as p: result = (p | beam.Create(input_lines, reshuffle=False) | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( column_names=column_names, desired_batch_size=1000, raw_record_column_name='int_feature')) beam_test_util.assert_that(result, lambda _: None) with self.assertRaisesRegex( ValueError, 'raw_record_column_name.* is already an existing column.*'): csv_decoder.GetArrowSchema(column_names, schema, raw_record_column_name='int_feature')
def test_csv_to_recordbatch_schema_features_subset_of_column_names(self): input_lines = ['1,2.0,hello', '5,12.34,world'] column_names = ['int_feature', 'float_feature', 'str_feature'] schema = text_format.Parse( """feature { name: "int_feature" type: INT }""", schema_pb2.Schema()) self.assertEqual( csv_decoder.GetArrowSchema(column_names, schema), pa.schema([pa.field('int_feature', pa.large_list(pa.int64()))])) def _check_record_batches(record_batches): self.assertLen(record_batches, 1) self.assertTrue(record_batches[0].equals( pa.RecordBatch.from_arrays( [pa.array([[1], [5]], pa.large_list(pa.int64()))], ['int_feature']))) with beam.Pipeline() as p: record_batches = ( p | 'CreatingPColl' >> beam.Create(input_lines, reshuffle=False) | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( column_names=column_names, delimiter=',', desired_batch_size=1000, schema=schema)) beam_test_util.assert_that(record_batches, _check_record_batches, label='check_record_batches')
def expand(self, lines: beam.pvalue.PCollection): """Decodes the input CSV records into RecordBatches. Args: lines: A PCollection of strings representing the lines in the CSV file. Returns: A PCollection of RecordBatches representing the CSV records. """ return (lines | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( column_names=self._column_names, delimiter=self._delimiter, skip_blank_lines=self._skip_blank_lines, schema=self._schema, desired_batch_size=self._desired_batch_size, multivalent_columns=self._multivalent_columns, secondary_delimiter=self._secondary_delimiter))
def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection): """Returns RecordBatch of csv lines.""" # Decode raw csv lines to record batches. record_batches = ( raw_records_pcoll | "CSVToRecordBatch" >> csv_decoder.CSVToRecordBatch( column_names=self._column_names, delimiter=self._delimiter, skip_blank_lines=self._skip_blank_lines, schema=self._schema, desired_batch_size=batch_size, multivalent_columns=self._multivalent_columns, secondary_delimiter=self._secondary_delimiter, produce_large_types=self._can_produce_large_types, raw_record_column_name=self._raw_record_column_name)) return record_batches
def test_invalid_schema_type(self): input_lines = ['1'] column_names = ['f1'] schema = text_format.Parse( """ feature { name: "struct_feature" type: STRUCT } """, schema_pb2.Schema()) with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises ValueError, '.*Schema contains invalid type: STRUCT.*'): with beam.Pipeline() as p: result = (p | beam.Create(input_lines, reshuffle=False) | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( column_names=column_names, schema=schema, desired_batch_size=1000)) beam_test_util.assert_that(result, lambda _: None)
def test_invalid_schema_missing_column(self): input_lines = ['1,2'] column_names = ['f1', 'f2'] schema = text_format.Parse( """ feature { name: "f1" type: INT value_count { min: 0 max: 2 } } """, schema_pb2.Schema()) with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises ValueError, '.*Schema does not contain column.*'): with beam.Pipeline() as p: result = (p | beam.Create(input_lines, reshuffle=False) | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( column_names=column_names, schema=schema, desired_batch_size=1000)) beam_test_util.assert_that(result, lambda _: None)
def test_parse_csv_lines(self, input_lines, column_names, expected_csv_cells, expected_types, expected_record_batch, skip_blank_lines=False, schema=None, delimiter=',', multivalent_columns=None, secondary_delimiter=None, raw_record_column_name=None): def _check_csv_cells(actual): for i in range(len(actual)): self.assertEqual(expected_csv_cells[i], actual[i][0]) self.assertEqual(input_lines[i], actual[i][1]) def _check_types(actual): self.assertLen(actual, 1) self.assertCountEqual([ csv_decoder.ColumnInfo(n, t) for n, t in zip(column_names, expected_types) ], actual[0]) def _check_record_batches(actual): """Compares a list of pa.RecordBatch.""" if actual: self.assertTrue(actual[0].equals(expected_record_batch)) else: self.assertEqual(expected_record_batch, actual) def _check_arrow_schema(actual): for record_batch in actual: expected_arrow_schema = csv_decoder.GetArrowSchema( column_names, schema, raw_record_column_name) self.assertEqual(record_batch.schema, expected_arrow_schema) with beam.Pipeline() as p: parsed_csv_cells_and_raw_records = ( p | beam.Create(input_lines, reshuffle=False) | beam.ParDo(csv_decoder.ParseCSVLine(delimiter=delimiter))) inferred_types = ( parsed_csv_cells_and_raw_records | beam.Keys() | beam.CombineGlobally( csv_decoder.ColumnTypeInferrer( column_names, skip_blank_lines=skip_blank_lines, multivalent_columns=multivalent_columns, secondary_delimiter=secondary_delimiter))) beam_test_util.assert_that(parsed_csv_cells_and_raw_records, _check_csv_cells, label='check_parsed_csv_cells') beam_test_util.assert_that(inferred_types, _check_types, label='check_types') record_batches = ( parsed_csv_cells_and_raw_records | beam.BatchElements(min_batch_size=1000) | beam.ParDo( csv_decoder.BatchedCSVRowsToRecordBatch( skip_blank_lines=skip_blank_lines, multivalent_columns=multivalent_columns, secondary_delimiter=secondary_delimiter, raw_record_column_name=raw_record_column_name), beam.pvalue.AsSingleton(inferred_types))) beam_test_util.assert_that(record_batches, _check_record_batches, label='check_record_batches') if schema: beam_test_util.assert_that(record_batches, _check_arrow_schema, label='check_arrow_schema') # Testing CSVToRecordBatch with beam.Pipeline() as p: record_batches = ( p | 'CreatingPColl' >> beam.Create(input_lines, reshuffle=False) | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch( column_names=column_names, delimiter=delimiter, skip_blank_lines=skip_blank_lines, desired_batch_size=1000, schema=schema, multivalent_columns=multivalent_columns, secondary_delimiter=secondary_delimiter, raw_record_column_name=raw_record_column_name)) beam_test_util.assert_that(record_batches, _check_record_batches, label='check_record_batches')