Esempio n. 1
0
    def test_csv_to_recordbatch_schema_features_subset_of_column_names(self):
        input_lines = ['1,2.0,hello', '5,12.34,world']
        column_names = ['int_feature', 'float_feature', 'str_feature']
        schema = text_format.Parse(
            """feature { name: "int_feature" type: INT }""",
            schema_pb2.Schema())
        self.assertEqual(
            csv_decoder.GetArrowSchema(column_names, schema),
            pa.schema([pa.field('int_feature', pa.large_list(pa.int64()))]))

        def _check_record_batches(record_batches):
            self.assertLen(record_batches, 1)
            self.assertTrue(record_batches[0].equals(
                pa.RecordBatch.from_arrays(
                    [pa.array([[1], [5]], pa.large_list(pa.int64()))],
                    ['int_feature'])))

        with beam.Pipeline() as p:
            record_batches = (
                p
                | 'CreatingPColl' >> beam.Create(input_lines, reshuffle=False)
                | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch(
                    column_names=column_names,
                    delimiter=',',
                    desired_batch_size=1000,
                    schema=schema))
            beam_test_util.assert_that(record_batches,
                                       _check_record_batches,
                                       label='check_record_batches')
Esempio n. 2
0
 def test_invalid_raw_record_column_name(self):
     input_lines = ['1,2.0,hello', '5,12.34']
     schema = text_format.Parse(
         """
           feature {
             name: "int_feature"
             type: INT
           }
           feature {
             name: "float_feature"
             type: FLOAT
           }
           feature {
             name: "str_feature"
             type: BYTES
           }
           """, schema_pb2.Schema())
     column_names = ['int_feature', 'float_feature', 'str_feature']
     with self.assertRaisesRegex(  # pylint: disable=g-error-prone-assert-raises
             ValueError,
             'raw_record_column_name.* is already an existing column.*'):
         with beam.Pipeline() as p:
             result = (p | beam.Create(input_lines, reshuffle=False)
                       | 'CSVToRecordBatch' >> csv_decoder.CSVToRecordBatch(
                           column_names=column_names,
                           desired_batch_size=1000,
                           raw_record_column_name='int_feature'))
             beam_test_util.assert_that(result, lambda _: None)
     with self.assertRaisesRegex(
             ValueError,
             'raw_record_column_name.* is already an existing column.*'):
         csv_decoder.GetArrowSchema(column_names,
                                    schema,
                                    raw_record_column_name='int_feature')
Esempio n. 3
0
 def _ArrowSchemaNoRawRecordColumn(self) -> pa.Schema:
     if not self._schema:
         raise ValueError("TFMD schema not provided. Unable to derive an "
                          "Arrow schema")
     return csv_decoder.GetArrowSchema(
         self._column_names,
         self._schema,
         large_types=self._can_produce_large_types)
Esempio n. 4
0
    def ArrowSchema(self) -> pa.Schema:
        if self._schema is None:
            return self._InferArrowSchema()

        # If the column names are not passed, we default to all column names in the
        # schema.
        columns = self._column_names or [f.name for f in self._schema.feature]

        return csv_decoder.GetArrowSchema(columns, self._schema)
Esempio n. 5
0
 def test_get_arrow_schema_schema_feature_not_subset_of_column_names(self):
   schema = text_format.Parse(
       """
             feature {
               name: "f1"
               type: INT
             }
             feature {
               name: "f2"
               type: INT
             }
             """, schema_pb2.Schema())
   column_names = ['f1']
   with self.assertRaisesRegex(
       ValueError, 'Schema features are not a subset of column names'):
     csv_decoder.GetArrowSchema(column_names, schema)
Esempio n. 6
0
 def test_get_arrow_schema_column_names_invalid(self):
     schema = text_format.Parse(
         """
           feature {
             name: "f1"
             type: INT
             value_count {
               min: 0
               max: 2
             }
           }
           feature {
             name: "f2"
             type: INT
             value_count {
               min: 0
               max: 2
             }
           }
           """, schema_pb2.Schema())
     column_names = ['f1']
     with self.assertRaisesRegex(
             ValueError, 'Column Names.* does not match schema features.*'):
         csv_decoder.GetArrowSchema(column_names, schema)
Esempio n. 7
0
 def _check_arrow_schema(actual):
     for record_batch in actual:
         expected_arrow_schema = csv_decoder.GetArrowSchema(
             column_names, schema, raw_record_column_name)
         self.assertEqual(record_batch.schema, expected_arrow_schema)
Esempio n. 8
0
 def _check_arrow_schema(actual):
     if actual:
         expected_arrow_schema = csv_decoder.GetArrowSchema(
             column_names, schema, raw_record_column_name,
             produce_large_types)
         self.assertEqual(actual[0].schema, expected_arrow_schema)