def test_query_with_one_field_not_supported(self, sql, expected_output, error): record_batch = pa.RecordBatch.from_arrays([ pa.array([[[10, 100]], [[20, 200]], None, [[30, 300]]], type=pa.list_(pa.list_(pa.int64()))), pa.array([1, 2, None, 3], type=pa.int32()), ], ['f1', 'f2']) if error: with self.assertRaisesRegex( RuntimeError, 'Are you querying any unsupported column?'): query = sql_util.RecordBatchSQLSliceQuery( sql, record_batch.schema) else: query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema) slices = query.Execute(record_batch) self.assertEqual(slices, expected_output)
def _GenerateQueries( schema: pa.Schema) -> List[sql_util.RecordBatchSQLSliceQuery]: result = [] for sql in self._sqls: try: result.append(sql_util.RecordBatchSQLSliceQuery(sql, schema)) except Exception as e: raise RuntimeError(f'Failed to parse sql:\n\n{sql}') from e return result
def _validate_sql(sql_query: Text, schema: schema_pb2.Schema): arrow_schema = example_coder.ExamplesToRecordBatchDecoder( schema.SerializeToString()).ArrowSchema() formatted_query = slicing_util.format_slice_sql_query(sql_query) try: sql_util.RecordBatchSQLSliceQuery(formatted_query, arrow_schema) except Exception as e: # pylint: disable=broad-except raise ValueError('One of the slice SQL query %s raised an exception: %s.' % (sql_query, repr(e)))
def test_query_with_all_fields_not_supported(self, sql, expected_output): record_batch = pa.RecordBatch.from_arrays([ pa.array([[[10, 100]], [[20, 200]], None, [[30, 300]]], type=pa.list_(pa.list_(pa.int64()))), ], ['f1']) query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema) slices = query.Execute(record_batch) self.assertEqual(slices, expected_output)
def test_query_primitive_arrays(self, sql, expected_output): record_batch = pa.RecordBatch.from_arrays([ pa.array([1, 2, None, 3], type=pa.int64()), pa.array([10, 20, None, 30], type=pa.int32()), ], ['f1', 'f2']) query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema) slices = query.Execute(record_batch) self.assertEqual(slices, expected_output)
def test_query_with_all_supported_types(self): record_batch = pa.RecordBatch.from_arrays([ pa.array([[1], [2]], type=pa.list_(pa.int32())), pa.array([[10], [20]], type=pa.list_(pa.int64())), pa.array([[1.1], [2.2]], type=pa.list_(pa.float32())), pa.array([[10.1], [20.2]], type=pa.list_(pa.float64())), pa.array([['a'], ['b']], type=pa.list_(pa.string())), pa.array([['a+'], ['b+']], type=pa.list_(pa.large_string())), pa.array([[b'a_bytes'], [b'b_bytes']], type=pa.list_(pa.binary())), pa.array([[b'a_bytes+'], [b'b_bytes+']], type=pa.list_(pa.large_binary())), ], [ 'int32_list', 'int64_list', 'float32_list', 'float64_list', 'string_list', 'large_string_list', 'binary_list', 'large_binary_list', ]) sql = """ SELECT ARRAY( SELECT STRUCT(int32_list, int64_list, float32_list, float64_list, string_list, large_string_list, binary_list, large_binary_list) FROM example.int32_list, example.int64_list, example.float32_list, example.float64_list, example.string_list, example.large_string_list, example.binary_list, example.large_binary_list ) as slice_key FROM Examples as example;""" query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema) slices = query.Execute(record_batch) self.assertEqual(slices, [[[('int32_list', '1'), ('int64_list', '10'), ('float32_list', '1.1'), ('float64_list', '10.1'), ('string_list', 'a'), ('large_string_list', 'a+'), ('binary_list', 'a_bytes'), ('large_binary_list', 'a_bytes+')]], [[('int32_list', '2'), ('int64_list', '20'), ('float32_list', '2.2'), ('float64_list', '20.2'), ('string_list', 'b'), ('large_string_list', 'b+'), ('binary_list', 'b_bytes'), ('large_binary_list', 'b_bytes+')]]])
def test_query_with_empty_input(self): record_batch = pa.RecordBatch.from_arrays([ pa.array([], type=pa.int64()), ], ['f1']) sql = """SELECT ARRAY(SELECT STRUCT(f1)) as slice_key FROM Examples as example;""" query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema) slices = query.Execute(record_batch) self.assertEqual(slices, [])
def test_query_with_invalid_statement(self, sql, error): record_batch = pa.RecordBatch.from_arrays([ pa.array([[1, 2, 3], [4], None, [5], [], [6], [None], [7]], type=pa.list_(pa.int64())), pa.array([[10, 20, 30], [40], None, None, [], [], [None], [None]], type=pa.list_(pa.int32())), ], ['f1', 'f2']) with self.assertRaisesRegex(RuntimeError, error): _ = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema)
def test_query_list_arrays(self, sql, expected_output): # List of int32 & int64. record_batch = pa.RecordBatch.from_arrays([ pa.array([[1, 2, 3], [4], None, [5], [], [6], [None], [7]], type=pa.list_(pa.int64())), pa.array([[10, 20, 30], [40], None, None, [], [], [None], [None]], type=pa.list_(pa.int32())), ], ['f1', 'f2']) query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch.schema) slices = query.Execute(record_batch) self.assertEqual(slices, expected_output)
def test_query_with_unexpected_record_batch_schema(self): record_batch_1 = pa.RecordBatch.from_arrays([ pa.array([1, 2, 3], type=pa.int64()), ], ['f1']) record_batch_2 = pa.RecordBatch.from_arrays([ pa.array([4, 5, 6], type=pa.int32()), ], ['f1']) sql = """SELECT ARRAY(SELECT STRUCT(f1)) as slice_key FROM Examples as example;""" query = sql_util.RecordBatchSQLSliceQuery(sql, record_batch_1.schema) with self.assertRaisesRegex(RuntimeError, 'Unexpected RecordBatch schema.'): _ = query.Execute(record_batch_2)
def _generate_queries( schema: pa.Schema) -> List[sql_util.RecordBatchSQLSliceQuery]: return [ sql_util.RecordBatchSQLSliceQuery(sql, schema) for sql in self._sqls ]