def prepare_fixed_decimal(data, schema): if not isinstance(data, decimal.Decimal): return data scale = schema['scale'] size = schema['size'] # based on https://github.com/apache/avro/pull/82/ sign, digits, exp = data.as_tuple() if -exp > scale: raise ValueError('Scale provided in schema does not match the decimal') delta = exp + scale if delta > 0: digits = digits + (0, ) * delta unscaled_datum = 0 for digit in digits: unscaled_datum = (unscaled_datum * 10) + digit # 2.6 support if not hasattr(unscaled_datum, 'bit_length'): bits_req = len(bin(abs(unscaled_datum))) - 2 else: bits_req = unscaled_datum.bit_length() + 1 size_in_bits = size * 8 offset_bits = size_in_bits - bits_req mask = 2**size_in_bits - 1 bit = 1 for i in range(bits_req): mask ^= bit bit <<= 1 if bits_req < 8: bytes_req = 1 else: bytes_req = bits_req // 8 if bits_req % 8 != 0: bytes_req += 1 tmp = MemoryIO() if sign: unscaled_datum = (1 << bits_req) - unscaled_datum unscaled_datum = mask | unscaled_datum for index in range(size - 1, -1, -1): bits_to_write = unscaled_datum >> (8 * index) tmp.write(mk_bits(bits_to_write & 0xff)) else: for i in range(offset_bits // 8): tmp.write(mk_bits(0)) for index in range(bytes_req - 1, -1, -1): bits_to_write = unscaled_datum >> (8 * index) tmp.write(mk_bits(bits_to_write & 0xff)) return tmp.getvalue()
class Writer(object): def __init__(self, fo, schema, codec='null', sync_interval=1000 * SYNC_SIZE, metadata=None, validator=None): self.fo = fo self.schema = schema self.validate_fn = validate if validator is True else validator self.sync_marker = urandom(SYNC_SIZE) self.io = MemoryIO() self.block_count = 0 self.metadata = metadata or {} self.metadata['avro.codec'] = codec self.metadata['avro.schema'] = json.dumps(schema) self.sync_interval = sync_interval try: self.block_writer = BLOCK_WRITERS[codec] except KeyError: raise ValueError('unrecognized codec: %r' % codec) write_header(self.fo, self.metadata, self.sync_marker) acquaint_schema(self.schema) def dump(self): write_long(self.fo, self.block_count) self.block_writer(self.fo, self.io.getvalue()) self.fo.write(self.sync_marker) self.io.truncate(0) self.io.seek(0, SEEK_SET) self.block_count = 0 def write(self, record): if self.validate_fn: self.validate_fn(record, self.schema) write_data(self.io, record, self.schema) self.block_count += 1 if self.io.tell() >= self.sync_interval: self.dump() def flush(self): if self.io.tell() or self.block_count > 0: self.dump() self.fo.flush()
def prepare_bytes_decimal(data, schema): if not isinstance(data, decimal.Decimal): return data scale = schema['scale'] # based on https://github.com/apache/avro/pull/82/ sign, digits, exp = data.as_tuple() if -exp > scale: raise AssertionError( 'Scale provided in schema does not match the decimal') delta = exp + scale if delta > 0: digits = digits + (0, ) * delta unscaled_datum = 0 for digit in digits: unscaled_datum = (unscaled_datum * 10) + digit # 2.6 support if not hasattr(unscaled_datum, 'bit_length'): bits_req = len(bin(abs(unscaled_datum))) - 2 else: bits_req = unscaled_datum.bit_length() + 1 if sign: unscaled_datum = (1 << bits_req) - unscaled_datum bytes_req = bits_req // 8 padding_bits = ~((1 << bits_req) - 1) if sign else 0 packed_bits = padding_bits | unscaled_datum bytes_req += 1 if (bytes_req << 3) < bits_req else 0 tmp = MemoryIO() for index in range(bytes_req - 1, -1, -1): bits_to_write = packed_bits >> (8 * index) tmp.write(mk_bits(bits_to_write & 0xff)) return tmp.getvalue()