示例#1
0
    def test_query_with_one_field_not_supported(self, sql, expected_output,
                                                error):
        record_batch = pa.RecordBatch.from_arrays([
            pa.array([[[10, 100]], [[20, 200]], None, [[30, 300]]],
                     type=pa.list_(pa.list_(pa.int64()))),
            pa.array([1, 2, None, 3], type=pa.int32()),
        ], ['f1', 'f2'])

        if error:
            with self.assertRaisesRegex(
                    RuntimeError, 'Are you querying any unsupported column?'):
                query = sql_util.RecordBatchSQLSliceQuery(
                    sql, record_batch.schema)
        else:
            query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema)
            slices = query.Execute(record_batch)
            self.assertEqual(slices, expected_output)
 def _GenerateQueries(
     schema: pa.Schema) -> List[sql_util.RecordBatchSQLSliceQuery]:
   result = []
   for sql in self._sqls:
     try:
       result.append(sql_util.RecordBatchSQLSliceQuery(sql, schema))
     except Exception as e:
       raise RuntimeError(f'Failed to parse sql:\n\n{sql}') from e
   return result
def _validate_sql(sql_query: Text, schema: schema_pb2.Schema):
  arrow_schema = example_coder.ExamplesToRecordBatchDecoder(
      schema.SerializeToString()).ArrowSchema()
  formatted_query = slicing_util.format_slice_sql_query(sql_query)
  try:
    sql_util.RecordBatchSQLSliceQuery(formatted_query, arrow_schema)
  except Exception as e:  # pylint: disable=broad-except
    raise ValueError('One of the slice SQL query %s raised an exception: %s.'
                     % (sql_query, repr(e)))
示例#4
0
    def test_query_with_all_fields_not_supported(self, sql, expected_output):
        record_batch = pa.RecordBatch.from_arrays([
            pa.array([[[10, 100]], [[20, 200]], None, [[30, 300]]],
                     type=pa.list_(pa.list_(pa.int64()))),
        ], ['f1'])

        query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema)
        slices = query.Execute(record_batch)
        self.assertEqual(slices, expected_output)
示例#5
0
    def test_query_primitive_arrays(self, sql, expected_output):
        record_batch = pa.RecordBatch.from_arrays([
            pa.array([1, 2, None, 3], type=pa.int64()),
            pa.array([10, 20, None, 30], type=pa.int32()),
        ], ['f1', 'f2'])

        query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema)
        slices = query.Execute(record_batch)
        self.assertEqual(slices, expected_output)
示例#6
0
 def test_query_with_all_supported_types(self):
     record_batch = pa.RecordBatch.from_arrays([
         pa.array([[1], [2]], type=pa.list_(pa.int32())),
         pa.array([[10], [20]], type=pa.list_(pa.int64())),
         pa.array([[1.1], [2.2]], type=pa.list_(pa.float32())),
         pa.array([[10.1], [20.2]], type=pa.list_(pa.float64())),
         pa.array([['a'], ['b']], type=pa.list_(pa.string())),
         pa.array([['a+'], ['b+']], type=pa.list_(pa.large_string())),
         pa.array([[b'a_bytes'], [b'b_bytes']], type=pa.list_(pa.binary())),
         pa.array([[b'a_bytes+'], [b'b_bytes+']],
                  type=pa.list_(pa.large_binary())),
     ], [
         'int32_list',
         'int64_list',
         'float32_list',
         'float64_list',
         'string_list',
         'large_string_list',
         'binary_list',
         'large_binary_list',
     ])
     sql = """
   SELECT
     ARRAY(
       SELECT
         STRUCT(int32_list, int64_list,
           float32_list, float64_list,
           string_list, large_string_list,
           binary_list, large_binary_list)
       FROM
         example.int32_list,
         example.int64_list,
         example.float32_list,
         example.float64_list,
         example.string_list,
         example.large_string_list,
         example.binary_list,
         example.large_binary_list
     ) as slice_key
   FROM Examples as example;"""
     query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema)
     slices = query.Execute(record_batch)
     self.assertEqual(slices, [[[('int32_list', '1'), ('int64_list', '10'),
                                 ('float32_list', '1.1'),
                                 ('float64_list', '10.1'),
                                 ('string_list', 'a'),
                                 ('large_string_list', 'a+'),
                                 ('binary_list', 'a_bytes'),
                                 ('large_binary_list', 'a_bytes+')]],
                               [[('int32_list', '2'), ('int64_list', '20'),
                                 ('float32_list', '2.2'),
                                 ('float64_list', '20.2'),
                                 ('string_list', 'b'),
                                 ('large_string_list', 'b+'),
                                 ('binary_list', 'b_bytes'),
                                 ('large_binary_list', 'b_bytes+')]]])
示例#7
0
    def test_query_with_empty_input(self):
        record_batch = pa.RecordBatch.from_arrays([
            pa.array([], type=pa.int64()),
        ], ['f1'])
        sql = """SELECT ARRAY(SELECT STRUCT(f1)) as slice_key
          FROM Examples as example;"""

        query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema)
        slices = query.Execute(record_batch)
        self.assertEqual(slices, [])
示例#8
0
    def test_query_with_invalid_statement(self, sql, error):
        record_batch = pa.RecordBatch.from_arrays([
            pa.array([[1, 2, 3], [4], None, [5], [], [6], [None], [7]],
                     type=pa.list_(pa.int64())),
            pa.array([[10, 20, 30], [40], None, None, [], [], [None], [None]],
                     type=pa.list_(pa.int32())),
        ], ['f1', 'f2'])

        with self.assertRaisesRegex(RuntimeError, error):
            _ = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema)
示例#9
0
    def test_query_list_arrays(self, sql, expected_output):
        # List of int32 & int64.
        record_batch = pa.RecordBatch.from_arrays([
            pa.array([[1, 2, 3], [4], None, [5], [], [6], [None], [7]],
                     type=pa.list_(pa.int64())),
            pa.array([[10, 20, 30], [40], None, None, [], [], [None], [None]],
                     type=pa.list_(pa.int32())),
        ], ['f1', 'f2'])

        query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema)
        slices = query.Execute(record_batch)
        self.assertEqual(slices, expected_output)
示例#10
0
    def test_query_with_unexpected_record_batch_schema(self):
        record_batch_1 = pa.RecordBatch.from_arrays([
            pa.array([1, 2, 3], type=pa.int64()),
        ], ['f1'])
        record_batch_2 = pa.RecordBatch.from_arrays([
            pa.array([4, 5, 6], type=pa.int32()),
        ], ['f1'])
        sql = """SELECT ARRAY(SELECT STRUCT(f1)) as slice_key
          FROM Examples as example;"""

        query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch_1.schema)
        with self.assertRaisesRegex(RuntimeError,
                                    'Unexpected RecordBatch schema.'):
            _ = query.Execute(record_batch_2)
示例#11
0
 def _generate_queries(
         schema: pa.Schema) -> List[sql_util.RecordBatchSQLSliceQuery]:
     return [
         sql_util.RecordBatchSQLSliceQuery(sql, schema)
         for sql in self._sqls
     ]