def test_table_unsafe_casting(): data = [ pa.array(range(5), type=pa.int64()), pa.array([-10, -5, 0, 5, 10], type=pa.int32()), pa.array([1.1, 2.2, 3.3, 4.4, 5.5], type=pa.float64()), pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string()) ] table = pa.Table.from_arrays(data, names=tuple('abcd')) expected_data = [ pa.array(range(5), type=pa.int32()), pa.array([-10, -5, 0, 5, 10], type=pa.int16()), pa.array([1, 2, 3, 4, 5], type=pa.int64()), pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string()) ] expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd')) target_schema = pa.schema([ pa.field('a', pa.int32()), pa.field('b', pa.int16()), pa.field('c', pa.int64()), pa.field('d', pa.string()) ]) with pytest.raises(pa.ArrowInvalid, match='Floating point value truncated'): table.cast(target_schema) casted_table = table.cast(target_schema, safe=False) assert casted_table.equals(expected_table)
def test_struct_array_field(): ty = pa.struct([pa.field('x', pa.int16()), pa.field('y', pa.float32())]) a = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty) x0 = a.field(0) y0 = a.field(1) x1 = a.field(-2) y1 = a.field(-1) x2 = a.field('x') y2 = a.field('y') assert isinstance(x0, pa.lib.Int16Array) assert isinstance(y1, pa.lib.FloatArray) assert x0.equals(pa.array([1, 3, 5], type=pa.int16())) assert y0.equals(pa.array([2.5, 4.5, 6.5], type=pa.float32())) assert x0.equals(x1) assert x0.equals(x2) assert y0.equals(y1) assert y0.equals(y2) for invalid_index in [None, pa.int16()]: with pytest.raises(TypeError): a.field(invalid_index) for invalid_index in [3, -3]: with pytest.raises(IndexError): a.field(invalid_index) for invalid_name in ['z', '']: with pytest.raises(KeyError): a.field(invalid_name)
def make_recordbatch(length): schema = pa.schema([pa.field('f0', pa.int16()), pa.field('f1', pa.int16())]) a0 = pa.array(np.random.randint(0, 255, size=length, dtype=np.int16)) a1 = pa.array(np.random.randint(0, 255, size=length, dtype=np.int16)) batch = pa.RecordBatch.from_arrays([a0, a1], schema) return batch
def test_table_safe_casting(): data = [ pa.array(range(5), type=pa.int64()), pa.array([-10, -5, 0, 5, 10], type=pa.int32()), pa.array([1.0, 2.0, 3.0, 4.0, 5.0], type=pa.float64()), pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string()) ] table = pa.Table.from_arrays(data, names=tuple('abcd')) expected_data = [ pa.array(range(5), type=pa.int32()), pa.array([-10, -5, 0, 5, 10], type=pa.int16()), pa.array([1, 2, 3, 4, 5], type=pa.int64()), pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string()) ] expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd')) target_schema = pa.schema([ pa.field('a', pa.int32()), pa.field('b', pa.int16()), pa.field('c', pa.int64()), pa.field('d', pa.string()) ]) casted_table = table.cast(target_schema) assert casted_table.equals(expected_table)
def test_column_pickle(): arr = pa.chunked_array([[1, 2], [5, 6, 7]], type=pa.int16()) field = pa.field("ints", pa.int16()).add_metadata({b"foo": b"bar"}) col = pa.column(field, arr) result = pickle.loads(pickle.dumps(col)) assert result.equals(col) assert result.data.num_chunks == 2 assert result.field == field
def test_convert_options(): cls = ConvertOptions opts = cls() assert opts.check_utf8 is True opts.check_utf8 = False assert opts.check_utf8 is False assert opts.strings_can_be_null is False opts.strings_can_be_null = True assert opts.strings_can_be_null is True assert opts.column_types == {} # Pass column_types as mapping opts.column_types = {'b': pa.int16(), 'c': pa.float32()} assert opts.column_types == {'b': pa.int16(), 'c': pa.float32()} opts.column_types = {'v': 'int16', 'w': 'null'} assert opts.column_types == {'v': pa.int16(), 'w': pa.null()} # Pass column_types as schema schema = pa.schema([('a', pa.int32()), ('b', pa.string())]) opts.column_types = schema assert opts.column_types == {'a': pa.int32(), 'b': pa.string()} # Pass column_types as sequence opts.column_types = [('x', pa.binary())] assert opts.column_types == {'x': pa.binary()} with pytest.raises(TypeError, match='DataType expected'): opts.column_types = {'a': None} with pytest.raises(TypeError): opts.column_types = 0 assert isinstance(opts.null_values, list) assert '' in opts.null_values assert 'N/A' in opts.null_values opts.null_values = ['xxx', 'yyy'] assert opts.null_values == ['xxx', 'yyy'] assert isinstance(opts.true_values, list) opts.true_values = ['xxx', 'yyy'] assert opts.true_values == ['xxx', 'yyy'] assert isinstance(opts.false_values, list) opts.false_values = ['xxx', 'yyy'] assert opts.false_values == ['xxx', 'yyy'] opts = cls(check_utf8=False, column_types={'a': pa.null()}, null_values=['N', 'nn'], true_values=['T', 'tt'], false_values=['F', 'ff'], strings_can_be_null=True) assert opts.check_utf8 is False assert opts.column_types == {'a': pa.null()} assert opts.null_values == ['N', 'nn'] assert opts.false_values == ['F', 'ff'] assert opts.true_values == ['T', 'tt'] assert opts.strings_can_be_null is True
def test_column_flatten(): ty = pa.struct([pa.field('x', pa.int16()), pa.field('y', pa.float32())]) a = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty) col = pa.Column.from_array('foo', a) x, y = col.flatten() assert x == pa.column('foo.x', pa.array([1, 3, 5], type=pa.int16())) assert y == pa.column('foo.y', pa.array([2.5, 4.5, 6.5], type=pa.float32())) # Empty column a = pa.array([], type=ty) col = pa.Column.from_array('foo', a) x, y = col.flatten() assert x == pa.column('foo.x', pa.array([], type=pa.int16())) assert y == pa.column('foo.y', pa.array([], type=pa.float32()))
def test_cast_integers_unsafe(): # We let NumPy do the unsafe casting unsafe_cases = [ (np.array([50000], dtype='i4'), 'int32', np.array([50000], dtype='i2'), pa.int16()), (np.array([70000], dtype='i4'), 'int32', np.array([70000], dtype='u2'), pa.uint16()), (np.array([-1], dtype='i4'), 'int32', np.array([-1], dtype='u2'), pa.uint16()), (np.array([50000], dtype='u2'), pa.uint16(), np.array([50000], dtype='i2'), pa.int16()) ] for case in unsafe_cases: _check_cast_case(case, safe=False)
def test_empty_cast(): types = [ pa.null(), pa.bool_(), pa.int8(), pa.int16(), pa.int32(), pa.int64(), pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64(), pa.float16(), pa.float32(), pa.float64(), pa.date32(), pa.date64(), pa.binary(), pa.binary(length=4), pa.string(), ] for (t1, t2) in itertools.product(types, types): try: # ARROW-4766: Ensure that supported types conversion don't segfault # on empty arrays of common types pa.array([], type=t1).cast(t2) except pa.lib.ArrowNotImplementedError: continue
def test_buffers_primitive(): a = pa.array([1, 2, None, 4], type=pa.int16()) buffers = a.buffers() assert len(buffers) == 2 null_bitmap = buffers[0].to_pybytes() assert 1 <= len(null_bitmap) <= 64 # XXX this is varying assert bytearray(null_bitmap)[0] == 0b00001011 # Slicing does not affect the buffers but the offset a_sliced = a[1:] buffers = a_sliced.buffers() a_sliced.offset == 1 assert len(buffers) == 2 null_bitmap = buffers[0].to_pybytes() assert 1 <= len(null_bitmap) <= 64 # XXX this is varying assert bytearray(null_bitmap)[0] == 0b00001011 assert struct.unpack('hhxxh', buffers[1].to_pybytes()) == (1, 2, 4) a = pa.array(np.int8([4, 5, 6])) buffers = a.buffers() assert len(buffers) == 2 # No null bitmap from Numpy int array assert buffers[0] is None assert struct.unpack('3b', buffers[1].to_pybytes()) == (4, 5, 6) a = pa.array([b'foo!', None, b'bar!!']) buffers = a.buffers() assert len(buffers) == 3 null_bitmap = buffers[0].to_pybytes() assert bytearray(null_bitmap)[0] == 0b00000101 offsets = buffers[1].to_pybytes() assert struct.unpack('4i', offsets) == (0, 4, 4, 9) values = buffers[2].to_pybytes() assert values == b'foo!bar!!'
def test_buffers_nested(): a = pa.array([[1, 2], None, [3, None, 4, 5]], type=pa.list_(pa.int64())) buffers = a.buffers() assert len(buffers) == 4 # The parent buffers null_bitmap = buffers[0].to_pybytes() assert bytearray(null_bitmap)[0] == 0b00000101 offsets = buffers[1].to_pybytes() assert struct.unpack('4i', offsets) == (0, 2, 2, 6) # The child buffers null_bitmap = buffers[2].to_pybytes() assert bytearray(null_bitmap)[0] == 0b00110111 values = buffers[3].to_pybytes() assert struct.unpack('qqq8xqq', values) == (1, 2, 3, 4, 5) a = pa.array([(42, None), None, (None, 43)], type=pa.struct([pa.field('a', pa.int8()), pa.field('b', pa.int16())])) buffers = a.buffers() assert len(buffers) == 5 # The parent buffer null_bitmap = buffers[0].to_pybytes() assert bytearray(null_bitmap)[0] == 0b00000101 # The child buffers: 'a' null_bitmap = buffers[1].to_pybytes() assert bytearray(null_bitmap)[0] == 0b00000001 values = buffers[2].to_pybytes() assert struct.unpack('bxx', values) == (42,) # The child buffers: 'b' null_bitmap = buffers[3].to_pybytes() assert bytearray(null_bitmap)[0] == 0b00000100 values = buffers[4].to_pybytes() assert struct.unpack('4xh', values) == (43,)
def _from_jvm_int_type(jvm_type): """ Convert a JVM int type to its Python equivalent. Parameters ---------- jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Int Returns ------- typ: pyarrow.DataType """ if jvm_type.isSigned: if jvm_type.bitWidth == 8: return pa.int8() elif jvm_type.bitWidth == 16: return pa.int16() elif jvm_type.bitWidth == 32: return pa.int32() elif jvm_type.bitWidth == 64: return pa.int64() else: if jvm_type.bitWidth == 8: return pa.uint8() elif jvm_type.bitWidth == 16: return pa.uint16() elif jvm_type.bitWidth == 32: return pa.uint32() elif jvm_type.bitWidth == 64: return pa.uint64()
def test_type_to_pandas_dtype(): M8_ns = np.dtype('datetime64[ns]') cases = [ (pa.null(), np.float64), (pa.bool_(), np.bool_), (pa.int8(), np.int8), (pa.int16(), np.int16), (pa.int32(), np.int32), (pa.int64(), np.int64), (pa.uint8(), np.uint8), (pa.uint16(), np.uint16), (pa.uint32(), np.uint32), (pa.uint64(), np.uint64), (pa.float16(), np.float16), (pa.float32(), np.float32), (pa.float64(), np.float64), (pa.date32(), M8_ns), (pa.date64(), M8_ns), (pa.timestamp('ms'), M8_ns), (pa.binary(), np.object_), (pa.binary(12), np.object_), (pa.string(), np.object_), (pa.list_(pa.int8()), np.object_), ] for arrow_type, numpy_type in cases: assert arrow_type.to_pandas_dtype() == numpy_type
def test_recordbatch_basics(): data = [ pa.array(range(5)), pa.array([-10, -5, 0, 5, 10]) ] batch = pa.RecordBatch.from_arrays(data, ['c0', 'c1']) assert not batch.schema.metadata assert len(batch) == 5 assert batch.num_rows == 5 assert batch.num_columns == len(data) assert batch.to_pydict() == OrderedDict([ ('c0', [0, 1, 2, 3, 4]), ('c1', [-10, -5, 0, 5, 10]) ]) with pytest.raises(IndexError): # bounds checking batch[2] # Schema passed explicitly schema = pa.schema([pa.field('c0', pa.int16()), pa.field('c1', pa.int32())], metadata={b'foo': b'bar'}) batch = pa.RecordBatch.from_arrays(data, schema) assert batch.schema == schema
def test_cast_from_null(): in_data = [None] * 3 in_type = pa.null() out_types = [ pa.null(), pa.uint8(), pa.float16(), pa.utf8(), pa.binary(), pa.binary(10), pa.list_(pa.int16()), pa.decimal128(19, 4), pa.timestamp('us'), pa.timestamp('us', tz='UTC'), pa.timestamp('us', tz='Europe/Paris'), pa.struct([pa.field('a', pa.int32()), pa.field('b', pa.list_(pa.int8())), pa.field('c', pa.string())]), ] for out_type in out_types: _check_cast_case((in_data, in_type, in_data, out_type)) out_types = [ pa.dictionary(pa.int32(), pa.string()), pa.union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE), pa.union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), ] in_arr = pa.array(in_data, type=pa.null()) for out_type in out_types: with pytest.raises(NotImplementedError): in_arr.cast(out_type)
def test_orcfile_empty(): from pyarrow import orc f = orc.ORCFile(path_for_orc_example('TestOrcFile.emptyFile')) table = f.read() assert table.num_rows == 0 schema = table.schema expected_schema = pa.schema([ ('boolean1', pa.bool_()), ('byte1', pa.int8()), ('short1', pa.int16()), ('int1', pa.int32()), ('long1', pa.int64()), ('float1', pa.float32()), ('double1', pa.float64()), ('bytes1', pa.binary()), ('string1', pa.string()), ('middle', pa.struct([ ('list', pa.list_(pa.struct([ ('int1', pa.int32()), ('string1', pa.string()), ]))), ])), ('list', pa.list_(pa.struct([ ('int1', pa.int32()), ('string1', pa.string()), ]))), ('map', pa.list_(pa.struct([ ('key', pa.string()), ('value', pa.struct([ ('int1', pa.int32()), ('string1', pa.string()), ])), ]))), ]) assert schema == expected_schema
def test_table_flatten(): ty1 = pa.struct([pa.field('x', pa.int16()), pa.field('y', pa.float32())]) ty2 = pa.struct([pa.field('nest', ty1)]) a = pa.array([(1, 2.5), (3, 4.5)], type=ty1) b = pa.array([((11, 12.5),), ((13, 14.5),)], type=ty2) c = pa.array([False, True], type=pa.bool_()) table = pa.Table.from_arrays([a, b, c], names=['a', 'b', 'c']) t2 = table.flatten() t2._validate() expected = pa.Table.from_arrays([ pa.array([1, 3], type=pa.int16()), pa.array([2.5, 4.5], type=pa.float32()), pa.array([(11, 12.5), (13, 14.5)], type=ty1), c], names=['a.x', 'a.y', 'b.nest', 'c']) assert t2.equals(expected)
def test_type_schema_pickling(): cases = [ pa.int8(), pa.string(), pa.binary(), pa.binary(10), pa.list_(pa.string()), pa.struct([ pa.field('a', 'int8'), pa.field('b', 'string') ]), pa.union([ pa.field('a', pa.int8()), pa.field('b', pa.int16()) ], pa.lib.UnionMode_SPARSE), pa.union([ pa.field('a', pa.int8()), pa.field('b', pa.int16()) ], pa.lib.UnionMode_DENSE), pa.time32('s'), pa.time64('us'), pa.date32(), pa.date64(), pa.timestamp('ms'), pa.timestamp('ns'), pa.decimal128(12, 2), pa.field('a', 'string', metadata={b'foo': b'bar'}) ] for val in cases: roundtripped = pickle.loads(pickle.dumps(val)) assert val == roundtripped fields = [] for i, f in enumerate(cases): if isinstance(f, pa.Field): fields.append(f) else: fields.append(pa.field('_f{}'.format(i), f)) schema = pa.schema(fields, metadata={b'foo': b'bar'}) roundtripped = pickle.loads(pickle.dumps(schema)) assert schema == roundtripped
def dataframe_with_arrays(include_index=False): """ Dataframe with numpy arrays columns of every possible primtive type. Returns ------- df: pandas.DataFrame schema: pyarrow.Schema Arrow schema definition that is in line with the constructed df. """ dtypes = [('i1', pa.int8()), ('i2', pa.int16()), ('i4', pa.int32()), ('i8', pa.int64()), ('u1', pa.uint8()), ('u2', pa.uint16()), ('u4', pa.uint32()), ('u8', pa.uint64()), ('f4', pa.float32()), ('f8', pa.float64())] arrays = OrderedDict() fields = [] for dtype, arrow_dtype in dtypes: fields.append(pa.field(dtype, pa.list_(arrow_dtype))) arrays[dtype] = [ np.arange(10, dtype=dtype), np.arange(5, dtype=dtype), None, np.arange(1, dtype=dtype) ] fields.append(pa.field('str', pa.list_(pa.string()))) arrays['str'] = [ np.array([u"1", u"ä"], dtype="object"), None, np.array([u"1"], dtype="object"), np.array([u"1", u"2", u"3"], dtype="object") ] fields.append(pa.field('datetime64', pa.list_(pa.timestamp('ms')))) arrays['datetime64'] = [ np.array(['2007-07-13T01:23:34.123456789', None, '2010-08-13T05:46:57.437699912'], dtype='datetime64[ms]'), None, None, np.array(['2007-07-13T02', None, '2010-08-13T05:46:57.437699912'], dtype='datetime64[ms]'), ] if include_index: fields.append(pa.field('__index_level_0__', pa.int64())) df = pd.DataFrame(arrays) schema = pa.schema(fields) return df, schema
def test_bit_width(): for ty, expected in [(pa.bool_(), 1), (pa.int8(), 8), (pa.uint32(), 32), (pa.float16(), 16), (pa.decimal128(19, 4), 128), (pa.binary(42), 42 * 8)]: assert ty.bit_width == expected for ty in [pa.binary(), pa.string(), pa.list_(pa.int16())]: with pytest.raises(ValueError, match="fixed width"): ty.bit_width
def test_field_equals(): meta1 = {b'foo': b'bar'} meta2 = {b'bizz': b'bazz'} f1 = pa.field('a', pa.int8(), nullable=True) f2 = pa.field('a', pa.int8(), nullable=True) f3 = pa.field('a', pa.int8(), nullable=False) f4 = pa.field('a', pa.int16(), nullable=False) f5 = pa.field('b', pa.int16(), nullable=False) f6 = pa.field('a', pa.int8(), nullable=True, metadata=meta1) f7 = pa.field('a', pa.int8(), nullable=True, metadata=meta1) f8 = pa.field('a', pa.int8(), nullable=True, metadata=meta2) assert f1.equals(f2) assert f6.equals(f7) assert not f1.equals(f3) assert not f1.equals(f4) assert not f3.equals(f4) assert not f1.equals(f6) assert not f4.equals(f5) assert not f7.equals(f8)
def test_schema_repr_with_dictionaries(): fields = [ pa.field('one', pa.dictionary(pa.int16(), pa.string())), pa.field('two', pa.int32()) ] sch = pa.schema(fields) expected = ( """\ one: dictionary<values=string, indices=int16, ordered=0> two: int32""") assert repr(sch) == expected
def test_type_for_alias(): cases = [ ('i1', pa.int8()), ('int8', pa.int8()), ('i2', pa.int16()), ('int16', pa.int16()), ('i4', pa.int32()), ('int32', pa.int32()), ('i8', pa.int64()), ('int64', pa.int64()), ('u1', pa.uint8()), ('uint8', pa.uint8()), ('u2', pa.uint16()), ('uint16', pa.uint16()), ('u4', pa.uint32()), ('uint32', pa.uint32()), ('u8', pa.uint64()), ('uint64', pa.uint64()), ('f4', pa.float32()), ('float32', pa.float32()), ('f8', pa.float64()), ('float64', pa.float64()), ('date32', pa.date32()), ('date64', pa.date64()), ('string', pa.string()), ('str', pa.string()), ('binary', pa.binary()), ('time32[s]', pa.time32('s')), ('time32[ms]', pa.time32('ms')), ('time64[us]', pa.time64('us')), ('time64[ns]', pa.time64('ns')), ('timestamp[s]', pa.timestamp('s')), ('timestamp[ms]', pa.timestamp('ms')), ('timestamp[us]', pa.timestamp('us')), ('timestamp[ns]', pa.timestamp('ns')), ] for val, expected in cases: assert pa.type_for_alias(val) == expected
def test_struct_array_flatten(): ty = pa.struct([pa.field('x', pa.int16()), pa.field('y', pa.float32())]) a = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty) xs, ys = a.flatten() assert xs.type == pa.int16() assert ys.type == pa.float32() assert xs.to_pylist() == [1, 3, 5] assert ys.to_pylist() == [2.5, 4.5, 6.5] xs, ys = a[1:].flatten() assert xs.to_pylist() == [3, 5] assert ys.to_pylist() == [4.5, 6.5] a = pa.array([(1, 2.5), None, (3, 4.5)], type=ty) xs, ys = a.flatten() assert xs.to_pylist() == [1, None, 3] assert ys.to_pylist() == [2.5, None, 4.5] xs, ys = a[1:].flatten() assert xs.to_pylist() == [None, 3] assert ys.to_pylist() == [None, 4.5] a = pa.array([(1, None), (2, 3.5), (None, 4.5)], type=ty) xs, ys = a.flatten() assert xs.to_pylist() == [1, 2, None] assert ys.to_pylist() == [None, 3.5, 4.5] xs, ys = a[1:].flatten() assert xs.to_pylist() == [2, None] assert ys.to_pylist() == [3.5, 4.5] a = pa.array([(1, None), None, (None, 2.5)], type=ty) xs, ys = a.flatten() assert xs.to_pylist() == [1, None, None] assert ys.to_pylist() == [None, None, 2.5] xs, ys = a[1:].flatten() assert xs.to_pylist() == [None, None] assert ys.to_pylist() == [None, 2.5]
def test_schema_repr_with_dictionaries(): dct = pa.array(['foo', 'bar', 'baz'], type=pa.string()) fields = [ pa.field('one', pa.dictionary(pa.int16(), dct)), pa.field('two', pa.int32()) ] sch = pa.schema(fields) expected = ( """\ one: dictionary<values=string, indices=int16, ordered=0> dictionary: ["foo", "bar", "baz"] two: int32""") assert repr(sch) == expected
def test_struct_value_subscripting(self): ty = pa.struct([pa.field('x', pa.int16()), pa.field('y', pa.float32())]) arr = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty) assert arr[0]['x'] == 1 assert arr[0]['y'] == 2.5 assert arr[1]['x'] == 3 assert arr[1]['y'] == 4.5 assert arr[2]['x'] == 5 assert arr[2]['y'] == 6.5 with pytest.raises(IndexError): arr[4]['non-existent'] with pytest.raises(KeyError): arr[0]['non-existent']
def test_is_integer(): signed_ints = [pa.int8(), pa.int16(), pa.int32(), pa.int64()] unsigned_ints = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()] for t in signed_ints + unsigned_ints: assert types.is_integer(t) for t in signed_ints: assert types.is_signed_integer(t) assert not types.is_unsigned_integer(t) for t in unsigned_ints: assert types.is_unsigned_integer(t) assert not types.is_signed_integer(t) assert not types.is_integer(pa.float32()) assert not types.is_signed_integer(pa.float32())
def test_array_from_buffers(): values_buf = pa.py_buffer(np.int16([4, 5, 6, 7])) nulls_buf = pa.py_buffer(np.uint8([0b00001101])) arr = pa.Array.from_buffers(pa.int16(), 4, [nulls_buf, values_buf]) assert arr.type == pa.int16() assert arr.to_pylist() == [4, None, 6, 7] arr = pa.Array.from_buffers(pa.int16(), 4, [None, values_buf]) assert arr.type == pa.int16() assert arr.to_pylist() == [4, 5, 6, 7] arr = pa.Array.from_buffers(pa.int16(), 3, [nulls_buf, values_buf], offset=1) assert arr.type == pa.int16() assert arr.to_pylist() == [None, 6, 7] with pytest.raises(TypeError): pa.Array.from_buffers(pa.int16(), 3, [u'', u''], offset=1) with pytest.raises(NotImplementedError): pa.Array.from_buffers(pa.list_(pa.int16()), 4, [None, values_buf])
def test_table_cast_to_incompatible_schema(): data = [ pa.array(range(5)), pa.array([-10, -5, 0, 5, 10]), ] table = pa.Table.from_arrays(data, names=tuple('ab')) target_schema1 = pa.schema([ pa.field('A', pa.int32()), pa.field('b', pa.int16()), ]) target_schema2 = pa.schema([ pa.field('a', pa.int32()), ]) message = ("Target schema's field names are not matching the table's " "field names:.*") with pytest.raises(ValueError, match=message): table.cast(target_schema1) with pytest.raises(ValueError, match=message): table.cast(target_schema2)
def test_integer_no_nulls(self): data = {} fields = [] numpy_dtypes = [('i1', A.int8()), ('i2', A.int16()), ('i4', A.int32()), ('i8', A.int64()), ('u1', A.uint8()), ('u2', A.uint16()), ('u4', A.uint32()), ('u8', A.uint64())] num_values = 100 for dtype, arrow_dtype in numpy_dtypes: info = np.iinfo(dtype) values = np.random.randint(info.min, min(info.max, np.iinfo('i8').max), size=num_values) data[dtype] = values.astype(dtype) fields.append(A.Field.from_py(dtype, arrow_dtype)) df = pd.DataFrame(data) schema = A.Schema.from_fields(fields) self._check_pandas_roundtrip(df, expected_schema=schema)
import pandas as pd import pyarrow as pa from pandas.core.dtypes.common import infer_dtype_from_object from pandas.core.dtypes.dtypes import CategoricalDtype, CategoricalDtypeType import cudf from cudf._lib.scalar import Scalar _NA_REP = "<NA>" _np_pa_dtypes = { np.float64: pa.float64(), np.float32: pa.float32(), np.int64: pa.int64(), np.longlong: pa.int64(), np.int32: pa.int32(), np.int16: pa.int16(), np.int8: pa.int8(), np.bool_: pa.int8(), np.uint64: pa.uint64(), np.uint32: pa.uint32(), np.uint16: pa.uint16(), np.uint8: pa.uint8(), np.datetime64: pa.date64(), np.object_: pa.string(), np.str_: pa.string(), } cudf_dtypes_to_pandas_dtypes = { np.dtype("uint8"): pd.UInt8Dtype(), np.dtype("uint16"): pd.UInt16Dtype(), np.dtype("uint32"): pd.UInt32Dtype(),
def test_convert_options(): cls = ConvertOptions opts = cls() check_options_class(cls, check_utf8=[True, False], strings_can_be_null=[False, True], include_columns=[[], ['def', 'abc']], include_missing_columns=[False, True], auto_dict_encode=[False, True], timestamp_parsers=[[], [ISO8601, '%y-%m']]) assert opts.auto_dict_max_cardinality > 0 opts.auto_dict_max_cardinality = 99999 assert opts.auto_dict_max_cardinality == 99999 assert opts.column_types == {} # Pass column_types as mapping opts.column_types = {'b': pa.int16(), 'c': pa.float32()} assert opts.column_types == {'b': pa.int16(), 'c': pa.float32()} opts.column_types = {'v': 'int16', 'w': 'null'} assert opts.column_types == {'v': pa.int16(), 'w': pa.null()} # Pass column_types as schema schema = pa.schema([('a', pa.int32()), ('b', pa.string())]) opts.column_types = schema assert opts.column_types == {'a': pa.int32(), 'b': pa.string()} # Pass column_types as sequence opts.column_types = [('x', pa.binary())] assert opts.column_types == {'x': pa.binary()} with pytest.raises(TypeError, match='DataType expected'): opts.column_types = {'a': None} with pytest.raises(TypeError): opts.column_types = 0 assert isinstance(opts.null_values, list) assert '' in opts.null_values assert 'N/A' in opts.null_values opts.null_values = ['xxx', 'yyy'] assert opts.null_values == ['xxx', 'yyy'] assert isinstance(opts.true_values, list) opts.true_values = ['xxx', 'yyy'] assert opts.true_values == ['xxx', 'yyy'] assert isinstance(opts.false_values, list) opts.false_values = ['xxx', 'yyy'] assert opts.false_values == ['xxx', 'yyy'] assert opts.timestamp_parsers == [] opts.timestamp_parsers = [ISO8601] assert opts.timestamp_parsers == [ISO8601] opts = cls(column_types={'a': pa.null()}, null_values=['N', 'nn'], true_values=['T', 'tt'], false_values=['F', 'ff'], auto_dict_max_cardinality=999, timestamp_parsers=[ISO8601, '%Y-%m-%d']) assert opts.column_types == {'a': pa.null()} assert opts.null_values == ['N', 'nn'] assert opts.false_values == ['F', 'ff'] assert opts.true_values == ['T', 'tt'] assert opts.auto_dict_max_cardinality == 999 assert opts.timestamp_parsers == [ISO8601, '%Y-%m-%d']
cur.execute (query) # Derive PyArrow schema from query result set fields = [] for c in cur.description: if args.debug is True: print (c) ct = c[1] pr = c[4] sc = c[5] if ct is int: if pr == 3: fields.append (pa.field (c[0],pa.int8(),nullable=c[6])) elif pr == 5: fields.append (pa.field (c[0],pa.int16(),nullable=c[6])) elif pr == 10: fields.append (pa.field (c[0],pa.int32(),nullable=c[6])) else: fields.append (pa.field (c[0],pa.int64(),nullable=c[6])) elif ct is decimal.Decimal: fields.append (pa.field (c[0],pa.decimal128(c[4],c[5]),nullable=c[6])) elif ct is float: if pr == 53: fields.append (pa.field (c[0],pa.float32(),nullable=[c[6]])) else: fields.append (pa.field (c[0],pa.float64(),nullable=[c[6]])) elif ct is str: fields.append (pa.field (c[0],pa.string(),nullable=c[6])) elif ct is bytearray: fields.append (pa.field (c[0],pa.binary(),nullable=c[6]))
def to_arrow_type(dt: DataType) -> "pa.DataType": """ Convert Spark data type to pyarrow type """ from distutils.version import LooseVersion import pyarrow as pa if type(dt) == BooleanType: arrow_type = pa.bool_() elif type(dt) == ByteType: arrow_type = pa.int8() elif type(dt) == ShortType: arrow_type = pa.int16() elif type(dt) == IntegerType: arrow_type = pa.int32() elif type(dt) == LongType: arrow_type = pa.int64() elif type(dt) == FloatType: arrow_type = pa.float32() elif type(dt) == DoubleType: arrow_type = pa.float64() elif type(dt) == DecimalType: arrow_type = pa.decimal128(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() elif type(dt) == BinaryType: arrow_type = pa.binary() elif type(dt) == DateType: arrow_type = pa.date32() elif type(dt) == TimestampType: # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') elif type(dt) == TimestampNTZType: arrow_type = pa.timestamp('us', tz=None) elif type(dt) == ArrayType: if type(dt.elementType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) elif type(dt) == MapType: if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): raise TypeError( "MapType is only supported with pyarrow 2.0.0 and above") if type(dt.keyType) in [StructType, TimestampType] or \ type(dt.valueType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.map_(to_arrow_type(dt.keyType), to_arrow_type(dt.valueType)) elif type(dt) == StructType: if any(type(field.dataType) == StructType for field in dt): raise TypeError( "Nested StructType not supported in conversion to Arrow") fields = [ pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) for field in dt ] arrow_type = pa.struct(fields) elif type(dt) == NullType: arrow_type = pa.null() else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type
('bool', [True, False, False, True, True]), ('uint8', np.arange(5)), ('int8', np.arange(5)), ('uint16', np.arange(5)), ('int16', np.arange(5)), ('uint32', np.arange(5)), ('int32', np.arange(5)), ('uint64', np.arange(5, 10)), ('int64', np.arange(5, 10)), ('float', np.arange(0, 0.5, 0.1)), ('double', np.arange(0, 0.5, 0.1)), ('string', ['a', 'b', None, 'ddd', 'ee']), ('binary', [b'a', b'b', b'c', b'ddd', b'ee']), (pa.binary(3), [b'abc', b'bcd', b'cde', b'def', b'efg']), (pa.list_(pa.int8()), [[1, 2], [3, 4], [5, 6], None, [9, 16]]), (pa.large_list(pa.int16()), [[1], [2, 3, 4], [5, 6], None, [9, 16]]), (pa.struct([('a', pa.int8()), ('b', pa.int8())]), [ {'a': 1, 'b': 2}, None, {'a': 3, 'b': 4}, None, {'a': 5, 'b': 6}]), ] exported_functions = [ func for (name, func) in sorted(pc.__dict__.items()) if hasattr(func, '__arrow_compute_function__')] exported_option_classes = [ cls for (name, cls) in sorted(pc.__dict__.items()) if (isinstance(cls, type) and cls is not pc.FunctionOptions and issubclass(cls, pc.FunctionOptions))]
# This ensures that we neither rely on the exact mechanics on how to construct # them using Java code as well as enables us to define them as parameters # without to invoke the JVM. # # The specifications were created using: # # om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() # field = … # Code to instantiate the field # jvm_spec = om.writeValueAsString(field) @pytest.mark.parametrize( 'pa_type,jvm_spec', [ (pa.null(), '{"name":"null"}'), (pa.bool_(), '{"name":"bool"}'), (pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'), (pa.int16(), '{"name":"int","bitWidth":16,"isSigned":true}'), (pa.int32(), '{"name":"int","bitWidth":32,"isSigned":true}'), (pa.int64(), '{"name":"int","bitWidth":64,"isSigned":true}'), (pa.uint8(), '{"name":"int","bitWidth":8,"isSigned":false}'), (pa.uint16(), '{"name":"int","bitWidth":16,"isSigned":false}'), (pa.uint32(), '{"name":"int","bitWidth":32,"isSigned":false}'), (pa.uint64(), '{"name":"int","bitWidth":64,"isSigned":false}'), (pa.float16(), '{"name":"floatingpoint","precision":"HALF"}'), (pa.float32(), '{"name":"floatingpoint","precision":"SINGLE"}'), (pa.float64(), '{"name":"floatingpoint","precision":"DOUBLE"}'), (pa.time32('s'), '{"name":"time","unit":"SECOND","bitWidth":32}'), (pa.time32('ms'), '{"name":"time","unit":"MILLISECOND","bitWidth":32}'), (pa.time64('us'), '{"name":"time","unit":"MICROSECOND","bitWidth":64}'), (pa.time64('ns'), '{"name":"time","unit":"NANOSECOND","bitWidth":64}'),
def test_iterate_over_decimal_chunk(): random.seed(datetime.datetime.now()) precision = random.randint(1, 38) scale = random.randint(0, precision) datatype = None if precision <= 2: datatype = pyarrow.int8() elif precision <= 4: datatype = pyarrow.int16() elif precision <= 9: datatype = pyarrow.int32() elif precision <= 19: datatype = pyarrow.int64() else: datatype = pyarrow.decimal128(precision, scale) def decimal_generator(_precision, _scale): def decimal128_generator(precision, scale): data = [] for _ in range(precision): data.append(str(random.randint(0, 9))) if scale: data.insert(-scale, ".") return decimal.Decimal("".join(data)) def int64_generator(precision): data = random.randint(-9223372036854775808, 9223372036854775807) return int(str(data)[:precision if data >= 0 else precision + 1]) def int32_generator(precision): data = random.randint(-2147483648, 2147483637) return int(str(data)[:precision if data >= 0 else precision + 1]) def int16_generator(precision): data = random.randint(-32768, 32767) return int(str(data)[:precision if data >= 0 else precision + 1]) def int8_generator(precision): data = random.randint(-128, 127) return int(str(data)[:precision if data >= 0 else precision + 1]) if _precision <= 2: return int8_generator(_precision) elif _precision <= 4: return int16_generator(_precision) elif _precision <= 9: return int32_generator(_precision) elif _precision <= 19: return int64_generator(_precision) else: return decimal128_generator(_precision, _scale) def expected_data_transform_decimal(_precision, _scale): def expected_data_transform_decimal_impl(data, precision=_precision, scale=_scale): if precision <= 19: return decimal.Decimal(data).scaleb(-scale) else: return data return expected_data_transform_decimal_impl column_meta = { "logicalType": "FIXED", "precision": str(precision), "scale": str(scale), } iterate_over_test_chunk( [datatype, datatype], [column_meta, column_meta], lambda: decimal_generator(precision, scale), expected_data_transform_decimal(precision, scale), )
def test_fields_weakrefable(): field = pa.field('a', pa.int32()) wr = weakref.ref(field) assert wr() is not None del field assert wr() is None @pytest.mark.parametrize('t,check_func', [(pa.date32(), types.is_date32), (pa.date64(), types.is_date64), (pa.time32('s'), types.is_time32), (pa.time64('ns'), types.is_time64), (pa.int8(), types.is_int8), (pa.int16(), types.is_int16), (pa.int32(), types.is_int32), (pa.int64(), types.is_int64), (pa.uint8(), types.is_uint8), (pa.uint16(), types.is_uint16), (pa.uint32(), types.is_uint32), (pa.uint64(), types.is_uint64), (pa.float16(), types.is_float16), (pa.float32(), types.is_float32), (pa.float64(), types.is_float64)]) def test_exact_primitive_types(t, check_func): assert check_func(t) def test_type_id(): # enum values are not exposed publicly
def test_sql(parameters, db_type): df = get_df() if db_type == "redshift": df.drop(["binary"], axis=1, inplace=True) engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-{db_type}") wr.db.to_sql( df=df, con=engine, name="test_sql", schema=parameters[db_type]["schema"], if_exists="replace", index=False, index_label=None, chunksize=None, method=None, dtype={"iint32": sqlalchemy.types.Integer}, ) df = wr.db.read_sql_query( sql=f"SELECT * FROM {parameters[db_type]['schema']}.test_sql", con=engine) ensure_data_types(df, has_list=False) engine = wr.db.get_engine( db_type=db_type, host=parameters[db_type]["host"], port=parameters[db_type]["port"], database=parameters[db_type]["database"], user=parameters["user"], password=parameters["password"], ) dfs = wr.db.read_sql_query( sql=f"SELECT * FROM {parameters[db_type]['schema']}.test_sql", con=engine, chunksize=1, dtype={ "iint8": pa.int8(), "iint16": pa.int16(), "iint32": pa.int32(), "iint64": pa.int64(), "float": pa.float32(), "double": pa.float64(), "decimal": pa.decimal128(3, 2), "string_object": pa.string(), "string": pa.string(), "date": pa.date32(), "timestamp": pa.timestamp(unit="ns"), "binary": pa.binary(), "category": pa.float64(), }, ) for df in dfs: ensure_data_types(df, has_list=False) if db_type != "redshift": account_id = boto3.client("sts").get_caller_identity().get("Account") engine = wr.catalog.get_engine( connection=f"aws-data-wrangler-{db_type}", catalog_id=account_id) wr.db.to_sql( df=pd.DataFrame({"col0": [1, 2, 3]}, dtype="Int32"), con=engine, name="test_sql", schema=parameters[db_type]["schema"], if_exists="replace", index=True, index_label="index", ) schema = None if db_type == "postgresql": schema = parameters[db_type]["schema"] df = wr.db.read_sql_table(con=engine, table="test_sql", schema=schema, index_col="index") assert len(df.index) == 3 assert len(df.columns) == 1
result = arr.cast('i8') assert result.equals(expected) def test_simple_type_construction(): result = pa.lib.TimestampType() with pytest.raises(TypeError): str(result) @pytest.mark.parametrize( ('type', 'expected'), [(pa.null(), 'empty'), (pa.bool_(), 'bool'), (pa.int8(), 'int8'), (pa.int16(), 'int16'), (pa.int32(), 'int32'), (pa.int64(), 'int64'), (pa.uint8(), 'uint8'), (pa.uint16(), 'uint16'), (pa.uint32(), 'uint32'), (pa.uint64(), 'uint64'), (pa.float16(), 'float16'), (pa.float32(), 'float32'), (pa.float64(), 'float64'), (pa.date32(), 'date'), (pa.date64(), 'date'), (pa.binary(), 'bytes'), (pa.binary(length=4), 'bytes'), (pa.string(), 'unicode'), (pa.list_(pa.list_(pa.int16())), 'list[list[int16]]'), (pa.decimal128(18, 3), 'decimal'), (pa.timestamp('ms'), 'datetime'), (pa.timestamp('us', 'UTC'), 'datetimetz'), (pa.time32('s'), 'time'), (pa.time64('us'), 'time')]) def test_logical_type(type, expected): assert get_logical_type(type) == expected def test_array_uint64_from_py_over_range(): arr = pa.array([2**63], type=pa.uint64())
) from pandas.core.arrays import ExtensionArray from pandas.core.dtypes.dtypes import ExtensionDtype from ._algorithms import all_op, any_op, extract_isnull_bytemap _python_type_map = { pa.null().id: six.text_type, pa.bool_().id: bool, pa.int8().id: int, pa.uint8().id: int, pa.int16().id: int, pa.uint16().id: int, pa.int32().id: int, pa.uint32().id: int, pa.int64().id: int, pa.uint64().id: int, pa.float16().id: float, pa.float32().id: float,
import os import logging import pandas as pd import pyarrow as pa from ztf_dr.utils.jobs import run_jobs from ztf_dr.utils.s3 import s3_uri_bucket, s3_filename_difference, get_s3_path_to_files LC_FIELDS = { 'objectid': pa.int64(), 'filterid': pa.int8(), 'fieldid': pa.int16(), 'rcid': pa.int8(), 'objra': pa.float32(), 'objdec': pa.float32(), 'nepochs': pa.int64(), 'hmjd': pa.list_(pa.float64()), 'mag': pa.list_(pa.float32()), 'magerr': pa.list_(pa.float32()), 'clrcoeff': pa.list_(pa.float32()), 'catflags': pa.list_(pa.int32()) } LC_SCHEMA = pa.schema(LC_FIELDS) def parse_field(field_path: str, output_path: str) -> int: df = pd.read_parquet(field_path) df.to_parquet(output_path, schema=LC_SCHEMA) return 1
import pytest import weakref import numpy as np import pyarrow as pa @pytest.mark.parametrize(['value', 'ty', 'klass', 'deprecated'], [ (False, None, pa.BooleanScalar, pa.BooleanValue), (True, None, pa.BooleanScalar, pa.BooleanValue), (1, None, pa.Int64Scalar, pa.Int64Value), (-1, None, pa.Int64Scalar, pa.Int64Value), (1, pa.int8(), pa.Int8Scalar, pa.Int8Value), (1, pa.uint8(), pa.UInt8Scalar, pa.UInt8Value), (1, pa.int16(), pa.Int16Scalar, pa.Int16Value), (1, pa.uint16(), pa.UInt16Scalar, pa.UInt16Value), (1, pa.int32(), pa.Int32Scalar, pa.Int32Value), (1, pa.uint32(), pa.UInt32Scalar, pa.UInt32Value), (1, pa.int64(), pa.Int64Scalar, pa.Int64Value), (1, pa.uint64(), pa.UInt64Scalar, pa.UInt64Value), (1.0, None, pa.DoubleScalar, pa.DoubleValue), (np.float16(1.0), pa.float16(), pa.HalfFloatScalar, pa.HalfFloatValue), (1.0, pa.float32(), pa.FloatScalar, pa.FloatValue), (decimal.Decimal("1.123"), None, pa.Decimal128Scalar, pa.Decimal128Value), (decimal.Decimal("1.1234567890123456789012345678901234567890"), None, pa.Decimal256Scalar, pa.Decimal256Value), ("string", None, pa.StringScalar, pa.StringValue), (b"bytes", None, pa.BinaryScalar, pa.BinaryValue), ("largestring", pa.large_string(), pa.LargeStringScalar, pa.LargeStringValue),
from mojap_metadata import Metadata from mojap_metadata.converters.arrow_converter import ( ArrowConverter, _extract_bracket_params, ) import pyarrow as pa from mojap_metadata.converters import BaseConverterOptions @pytest.mark.parametrize( argnames="meta_type,arrow_type", argvalues=[ ("bool_", pa.bool_()), ("int8", pa.int8()), ("int16", pa.int16()), ("int32", pa.int32()), ("int64", pa.int64()), ("uint8", pa.uint8()), ("uint16", pa.uint16()), ("uint32", pa.uint32()), ("uint64", pa.uint64()), ("float16", pa.float16()), ("float32", pa.float32()), ("float64", pa.float64()), ("decimal128(38,1)", pa.decimal128(38, 1)), ("decimal128(1,2)", pa.decimal128(1, 2)), ("time32(s)", pa.time32("s")), ("time32(ms)", pa.time32("ms")), ("time64(us)", pa.time64("us")), ("time64(ns)", pa.time64("ns")),
class MisraGriesSketchTest(parameterized.TestCase): @parameterized.named_parameters( ("binary", [b"a", b"a", b"b", b"c", None], pa.binary()), ("large_binary", [b"a", b"a", b"b", b"c"], pa.large_binary()), ("string", ["a", "a", "b", "c", None], pa.string()), ("large_string", ["a", "a", "b", "c"], pa.large_string()), ) def test_add_binary_like(self, values, binary_like_type): expected_counts = [{ "values": b"a", "counts": 2.0 }, { "values": b"b", "counts": 1.0 }, { "values": b"c", "counts": 1.0 }] sketch = _create_basic_sketch(pa.array(values, type=binary_like_type)) estimate = sketch.Estimate() estimate.validate(full=True) self.assertEqual(estimate.to_pylist(), expected_counts) @parameterized.named_parameters( ("int8", [1, 1, 2, 3, None], pa.int8()), ("int16", [1, 1, 2, 3], pa.int16()), ("int32", [1, 1, 2, 3, None], pa.int32()), ("int64", [1, 1, 2, 3], pa.int64()), ("uint8", [1, 1, 2, 3], pa.uint8()), ("uint16", [1, None, 1, 2, 3], pa.uint16()), ("uint32", [1, 1, 2, 3], pa.uint32()), ("uint64", [1, 1, 2, 3, None], pa.uint64()), ) def test_add_integer(self, values, integer_type): expected_counts = [{ "values": b"1", "counts": 2.0 }, { "values": b"2", "counts": 1.0 }, { "values": b"3", "counts": 1.0 }] sketch = _create_basic_sketch(pa.array(values, type=integer_type)) estimate = sketch.Estimate() estimate.validate(full=True) self.assertEqual(estimate.to_pylist(), expected_counts) def test_add_weighted_values(self): items = pa.array(["a", "a", "b", "c"], type=pa.string()) weights = pa.array([4, 3, 2, 1], type=pa.float32()) sketch = _create_basic_sketch(items, weights=weights) expected_counts = [{ "values": b"a", "counts": 7.0 }, { "values": b"b", "counts": 2.0 }, { "values": b"c", "counts": 1.0 }] estimate = sketch.Estimate() estimate.validate(full=True) self.assertEqual(estimate.to_pylist(), expected_counts) def test_add_invalid_weights(self): items = pa.array(["a", "a", "b", "c"], type=pa.string()) weights = pa.array([4, 3, 2, 1], type=pa.int64()) with self.assertRaisesRegex( RuntimeError, "Invalid argument: Weight array must be float type."): _create_basic_sketch(items, weights=weights) def test_add_unsupported_type(self): values = pa.array([True, False], pa.bool_()) sketch = sketches.MisraGriesSketch(_NUM_BUCKETS) with self.assertRaisesRegex(RuntimeError, "Unimplemented: bool"): sketch.AddValues(values) def test_replace_invalid_utf8(self): values1 = pa.array([ b"a", b"\x80", # invalid b"\xC1", # invalid ]) values2 = pa.array([ b"\xc0\x80", # invalid b"a"]) sketch1 = sketches.MisraGriesSketch( _NUM_BUCKETS, invalid_utf8_placeholder=b"<BYTES>") sketch1.AddValues(values1) sketch2 = sketches.MisraGriesSketch( _NUM_BUCKETS, invalid_utf8_placeholder=b"<BYTES>") sketch2.AddValues(values2) serialized1 = sketch1.Serialize() serialized2 = sketch2.Serialize() sketch1 = sketches.MisraGriesSketch.Deserialize(serialized1) sketch2 = sketches.MisraGriesSketch.Deserialize(serialized2) sketch1.AddValues(values2) sketch1.Merge(sketch2) actual = sketch1.Estimate() actual.validate(full=True) self.assertEqual(actual.to_pylist(), [ {"values": b"<BYTES>", "counts": 4.0}, {"values": b"a", "counts": 3.0}, ]) def test_no_replace_invalid_utf8(self): sketch = sketches.MisraGriesSketch( _NUM_BUCKETS) sketch.AddValues(pa.array([b"\x80"])) actual = sketch.Estimate() self.assertEqual(actual.to_pylist(), [ {"values": b"\x80", "counts": 1.0}, ]) def test_large_string_threshold(self): values1 = pa.array(["a", "bbb", "c", "d", "eeff"]) values2 = pa.array(["a", "gghh"]) sketch1 = sketches.MisraGriesSketch( _NUM_BUCKETS, large_string_threshold=2, large_string_placeholder=b"<LARGE>") sketch1.AddValues(values1) sketch2 = sketches.MisraGriesSketch( _NUM_BUCKETS, large_string_threshold=2, large_string_placeholder=b"<LARGE>") sketch2.AddValues(values2) serialized1 = sketch1.Serialize() serialized2 = sketch2.Serialize() sketch1 = sketches.MisraGriesSketch.Deserialize(serialized1) sketch2 = sketches.MisraGriesSketch.Deserialize(serialized2) sketch1.AddValues(values2) sketch1.Merge(sketch2) actual = sketch1.Estimate() actual.validate(full=True) self.assertEqual(actual.to_pylist(), [ {"values": b"<LARGE>", "counts": 4.0}, {"values": b"a", "counts": 3.0}, {"values": b"c", "counts": 1.0}, {"values": b"d", "counts": 1.0}, ]) def test_invalid_large_string_replacing_config(self): with self.assertRaisesRegex( RuntimeError, "Must provide both or neither large_string_threshold and " "large_string_placeholder"): _ = sketches.MisraGriesSketch(_NUM_BUCKETS, large_string_threshold=1024) with self.assertRaisesRegex( RuntimeError, "Must provide both or neither large_string_threshold and " "large_string_placeholder"): _ = sketches.MisraGriesSketch( _NUM_BUCKETS, large_string_placeholder=b"<L>") def test_many_uniques(self): # Test that the tail elements with equal counts are not discarded after # `AddValues` call. sketch = _create_basic_sketch(pa.array(["a", "b", "c", "a"]), num_buckets=2) estimate = sketch.Estimate() estimate.validate(full=True) # Since "b" and "c" have equal counts and neither token has count > 4/2, any # combination is possible. all_counts = [{ "values": b"a", "counts": 2.0 }, { "values": b"b", "counts": 1.0 }, { "values": b"c", "counts": 1.0 }] self.assertIn( tuple(estimate.to_pylist()), list(itertools.combinations(all_counts, 2))) def test_merge(self): sketch1 = _create_basic_sketch(pa.array(["a", "b", "c", "a"])) sketch2 = _create_basic_sketch(pa.array(["d", "a"])) sketch1.Merge(sketch2) estimate = sketch1.Estimate() estimate.validate(full=True) expected_counts = [{ "values": b"a", "counts": 3.0 }, { "values": b"b", "counts": 1.0 }, { "values": b"c", "counts": 1.0 }, { "values": b"d", "counts": 1.0 }] self.assertEqual(estimate.to_pylist(), expected_counts) def test_merge_equal_to_kth_weights(self): # Test that tail elements with equal counts are not discarded after # `Compress` call. sketch1 = _create_basic_sketch( pa.array(["a"] * 5 + ["b"] * 5 + ["c"] * 4 + ["a"] * 4), num_buckets=3) sketch2 = _create_basic_sketch( pa.array(["d"] * 4 + ["a"] * 2), num_buckets=3) sketch1.Merge(sketch2) estimate = sketch1.Estimate() estimate.validate(full=True) # Since "c" and "d" have equal counts, the last entry may be either. expected_counts1 = [{ "values": b"a", "counts": 11.0 }, { "values": b"b", "counts": 5.0 }, { "values": b"c", "counts": 4.0 }] expected_counts2 = expected_counts1.copy() expected_counts2[2] = {"values": b"d", "counts": 4.0} self.assertIn(estimate.to_pylist(), [expected_counts1, expected_counts2]) def test_merge_with_extra_items(self): # Each of these sketches get more values than `num_buckets`. This will # result into removal of less frequent elements from the main buffer and # adding them to a buffer of extra elements. # Here we're testing that merging of sketches having extra elements is # correct and results in a sketch that produces the requested number of # elements. sketch1 = _create_basic_sketch( pa.array(["a"] * 3 + ["b"] * 2 + ["c", "d"]), num_buckets=3) sketch2 = _create_basic_sketch( pa.array(["e"] * 3 + ["f"] * 2 + ["g", "h"]), num_buckets=3) sketch3 = _create_basic_sketch( pa.array(["i"] * 2 + ["j", "k", "l"]), num_buckets=3) sketch1.Merge(sketch2) sketch1.Merge(sketch3) estimate = sketch1.Estimate() estimate.validate(full=True) # Due to large number of unique elements (relative to `num_buckets`), the # total estimated count error is 5. def get_expected_counts(): for least_frequent_item in [b"b", b"f", b"i"]: yield [{ "values": b"a", "counts": 5.0 }, { "values": b"e", "counts": 5.0 }, { "values": least_frequent_item, "counts": 5.0 }] self.assertIn(estimate.to_pylist(), list(get_expected_counts())) def test_picklable(self): sketch = _create_basic_sketch(pa.array(["a", "b", "c", "a"])) pickled = pickle.dumps(sketch, 2) self.assertIsInstance(pickled, bytes) unpickled = pickle.loads(pickled) self.assertIsInstance(unpickled, sketches.MisraGriesSketch) estimate = unpickled.Estimate() estimate.validate(full=True) expected_counts = [{ "values": b"a", "counts": 2.0 }, { "values": b"b", "counts": 1.0 }, { "values": b"c", "counts": 1.0 }] self.assertEqual(estimate.to_pylist(), expected_counts) def test_serialization(self): sketch = _create_basic_sketch(pa.array(["a", "b", "c", "a"])) serialized = sketch.Serialize() self.assertIsInstance(serialized, bytes) deserialized = sketches.MisraGriesSketch.Deserialize(serialized) self.assertIsInstance(deserialized, sketches.MisraGriesSketch) estimate = deserialized.Estimate() estimate.validate(full=True) expected_counts = [{ "values": b"a", "counts": 2.0 }, { "values": b"b", "counts": 1.0 }, { "values": b"c", "counts": 1.0 }] self.assertEqual(estimate.to_pylist(), expected_counts)