def test_read_batches_larger_than_a_table_added():
    """Add a single table composed of 10 one row ten_batches_each_with_one_record. Then read-out two batches of 4
    and verify that no more batches can be read"""
    ten_batches_each_with_one_record = [
        _new_record_batch([i]) for i in range(10)
    ]
    table_0_10 = pa.Table.from_batches(ten_batches_each_with_one_record)

    batcher = BatchingTableQueue(4)
    batcher.put(table_0_10)

    assert not batcher.empty()
    next_batch = batcher.get()

    assert 4 == next_batch.num_rows
    np.testing.assert_equal(
        compat_column_data(next_batch.column(0)).to_pylist(), list(range(0,
                                                                         4)))
    np.testing.assert_equal(
        compat_column_data(next_batch.column(1)).to_pylist(), list(range(0,
                                                                         4)))

    assert not batcher.empty()
    next_batch = batcher.get()

    assert 4 == next_batch.num_rows
    np.testing.assert_equal(
        compat_column_data(next_batch.column(0)).to_pylist(), list(range(4,
                                                                         8)))
    np.testing.assert_equal(
        compat_column_data(next_batch.column(1)).to_pylist(), list(range(4,
                                                                         8)))

    assert batcher.empty()
def test_two_tables_of_10_added_reading_5_batches_of_4():
    """Add two tables to batcher and read a batch that covers parts of both tables"""
    table_0_9 = pa.Table.from_batches([_new_record_batch(range(0, 10))])
    table_10_19 = pa.Table.from_batches([_new_record_batch(range(10, 20))])

    batcher = BatchingTableQueue(4)
    assert batcher.empty()

    batcher.put(table_0_9)
    batcher.put(table_10_19)

    for i in range(5):
        assert not batcher.empty()
        next_batch = batcher.get()

        assert (i != 4) == (not batcher.empty())

        assert 4 == next_batch.num_rows
        expected_values = list(range(i * 4, i * 4 + 4))
        np.testing.assert_equal(
            compat_column_data(next_batch.column(0)).to_pylist(),
            expected_values)
        np.testing.assert_equal(
            compat_column_data(next_batch.column(1)).to_pylist(),
            expected_values)
def test_random_table_size_and_random_batch_sizes():
    """Add a random number of rows, then read a random number of batches. Repeat multiple times."""
    batch_size = 5
    input_table_size = 50
    read_iter_count = 1000

    batcher = BatchingTableQueue(batch_size)
    write_seq = 0
    read_seq = 0

    for _ in range(read_iter_count):
        next_batch_size = np.random.randint(0, input_table_size)
        new_batch = _new_record_batch(
            list(range(write_seq, write_seq + next_batch_size)))
        write_seq += next_batch_size

        batcher.put(pa.Table.from_batches([new_batch]))

        next_read = np.random.randint(1, input_table_size // batch_size)
        for _ in range(next_read):
            if not batcher.empty():
                read_batch = batcher.get()
                for value in compat_column_data(read_batch.columns[0]):
                    assert value == read_seq
                    read_seq += 1

    assert read_seq > 0
Beispiel #4
0
    def read_next(self, workers_pool, schema, ngram):
        try:
            assert not ngram, 'ArrowReader does not support ngrams for now'

            result_table = workers_pool.get_results()

            # Convert arrow table columns into numpy. Strings are handled differently since to_pandas() returns
            # numpy array of dtype=object.
            result_dict = dict()
            for column_name, column in compat_table_columns_gen(result_table):
                # Assume we get only one chunk since reader worker reads one rowgroup at a time

                # `to_pandas` works slower when called on the entire `data` rather directly on a chunk.
                if compat_column_data(result_table.column(0)).num_chunks == 1:
                    column_as_pandas = column.data.chunks[0].to_pandas()
                else:
                    column_as_pandas = column.data.to_pandas()

                # pyarrow < 0.15.0 would always return a numpy array. Starting 0.15 we get pandas series, hence we
                # convert it into numpy array
                if isinstance(column_as_pandas, pd.Series):
                    column_as_numpy = column_as_pandas.values
                else:
                    column_as_numpy = column_as_pandas

                if pa.types.is_string(column.type):
                    result_dict[column_name] = column_as_numpy.astype(
                        np.unicode_)
                elif pa.types.is_list(column.type):
                    # Assuming all lists are of the same length, hence we can collate them into a matrix
                    list_of_lists = column_as_numpy
                    try:
                        col_data = np.vstack(list_of_lists.tolist())
                        shape = schema.fields[column_name].shape
                        if len(shape) > 1:
                            col_data = col_data.reshape(
                                (len(list_of_lists), ) + shape)
                        result_dict[column_name] = col_data

                    except ValueError:
                        raise RuntimeError(
                            'Length of all values in column \'{}\' are expected to be the same length. '
                            'Got the following set of lengths: \'{}\''.format(
                                column_name, ', '.join(
                                    str(value.shape[0])
                                    for value in list_of_lists)))
                else:
                    result_dict[column_name] = column_as_numpy

            return schema.make_namedtuple(**result_dict)

        except EmptyResultError:
            raise StopIteration
def test_single_table_of_10_rows_added_and_2_batches_of_4_read():
    """Add a single table composed of a single batch with 0..9 rows into batcher. Then read two batches of 4
    and verify that no more batches can be read"""

    # Table with two columns. Each column with 0..9 sequence
    one_batch_of_10_records = [_new_record_batch(range(0, 10))]
    table_0_10 = pa.Table.from_batches(one_batch_of_10_records)

    batcher = BatchingTableQueue(4)
    assert batcher.empty()

    # Load 10 rows into batcher
    batcher.put(table_0_10)

    # Get first batch of 4
    assert not batcher.empty()
    next_batch = batcher.get()

    assert 4 == next_batch.num_rows
    np.testing.assert_equal(
        compat_column_data(next_batch.column(0)).to_pylist(), list(range(0,
                                                                         4)))
    np.testing.assert_equal(
        compat_column_data(next_batch.column(1)).to_pylist(), list(range(0,
                                                                         4)))

    # Get second batch of 4
    assert not batcher.empty()
    next_batch = batcher.get()

    assert 4 == next_batch.num_rows
    np.testing.assert_equal(
        compat_column_data(next_batch.column(0)).to_pylist(), list(range(4,
                                                                         8)))
    np.testing.assert_equal(
        compat_column_data(next_batch.column(1)).to_pylist(), list(range(4,
                                                                         8)))

    # No more batches available
    assert batcher.empty()