Ejemplo n.º 1
0
def iterate_over_test_chunk(pyarrow_type,
                            column_meta,
                            source_data_generator,
                            expected_data_transformer=None):
    stream = BytesIO()

    assert len(pyarrow_type) == len(column_meta)

    column_size = len(pyarrow_type)
    batch_row_count = 10
    batch_count = 9

    fields = []
    for i in range(column_size):
        fields.append(
            pyarrow.field("column_{}".format(i), pyarrow_type[i], True,
                          column_meta[i]))
    schema = pyarrow.schema(fields)

    expected_data = []
    writer = RecordBatchStreamWriter(stream, schema)

    for i in range(batch_count):
        column_arrays = []
        py_arrays = []
        for j in range(column_size):
            column_data = []
            not_none_cnt = 0
            while not_none_cnt == 0:
                column_data = []
                for _ in range(batch_row_count):
                    data = None if bool(
                        random.getrandbits(1)) else source_data_generator()
                    if data is not None:
                        not_none_cnt += 1
                    column_data.append(data)
            column_arrays.append(column_data)
            py_arrays.append(pyarrow.array(column_data, type=pyarrow_type[j]))

        if expected_data_transformer:
            for i in range(len(column_arrays)):
                column_arrays[i] = [
                    expected_data_transformer(_data)
                    if _data is not None else None
                    for _data in column_arrays[i]
                ]
        expected_data.append(column_arrays)

        column_names = ["column_{}".format(i) for i in range(column_size)]
        rb = RecordBatch.from_arrays(py_arrays, column_names)
        writer.write_batch(rb)

    writer.close()

    # seek stream to begnning so that we can read from stream
    stream.seek(0)
    context = ArrowConverterContext()
    it = PyArrowIterator(None, stream, context, False, False)
    it.init(ROW_UNIT)

    count = 0
    while True:
        try:
            val = next(it)
            for i in range(column_size):
                batch_index = int(count / batch_row_count)
                assert val[i] == expected_data[batch_index][i][
                    count - batch_row_count * batch_index]
            count += 1
        except StopIteration:
            assert count == (batch_count * batch_row_count)
            break
Ejemplo n.º 2
0
 def __init__(self, data):
     session_parameters = {"TIMEZONE": "America/Los_Angeles"}
     self.result_data = PyArrowIterator(
         None, data, ArrowConverterContext(session_parameters), False,
         False)