Example #1
0
    def __init__(self, path, start=0, length=0):
        self._block_compressed = False
        self._decompress = False
        self._sync_seen = False

        self._value_class = None
        self._key_class = None
        self._codec = None

        self._metadata = None

        self._record = DataInputBuffer()
        self._initialize(path, start, length)
Example #2
0
 def decompressInputStream(self, data):
     return DataInputBuffer(self.decompress(data))
Example #3
0
class Reader(object):
    def __init__(self, path, start=0, length=0):
        self._block_compressed = False
        self._decompress = False
        self._sync_seen = False

        self._value_class = None
        self._key_class = None
        self._codec = None

        self._metadata = None

        self._record = DataInputBuffer()
        self._initialize(path, start, length)

    def getStream(self, path):
        return DataInputStream(FileInputStream(path))

    def close(self):
        self._stream.close()

    def getCompressionCodec(self):
        return self._codec

    def getKeyClass(self):
        if not self._key_class:
          self._key_class = hadoopClassFromName(self._key_class_name)
        return self._key_class

    def getKeyClassName(self):
        return hadoopClassName(self.getKeyClass())

    def getValueClass(self):
        if not self._value_class:
          self._value_class = hadoopClassFromName(self._value_class_name)
        return self._value_class

    def getValueClassName(self):
        return hadoopClassName(self.getValueClass())

    def getPosition(self):
        return self._stream.getPos()

    def getMetadata(self):
        return self._metadata

    def isBlockCompressed(self):
        return self._block_compressed

    def isCompressed(self):
        return self._decompress

    def nextRawKey(self):
        if not self._block_compressed:
            record_length = self._readRecordLength()
            if record_length < 0:
                return None

            key_length = self._stream.readInt()
            key = DataInputBuffer(self._stream.read(key_length))
            self._record.reset(self._stream.read(record_length - key_length))
            return key
        else:
            if hasattr(self, '_block_index') and \
               self._block_index < self._record[0]:
                self._sync_seen = False
                records, keys_len, keys, values_len, values = self._record
                key_length = readVInt(keys_len)
                self._block_index += 1
                return DataInputBuffer(keys.read(key_length))

            if self._stream.getPos() >= self._end:
                return None

            # Read Sync
            self._stream.readInt() # -1
            sync_check = self._stream.read(SYNC_HASH_SIZE)
            if sync_check != self._sync:
                raise IOError("File is corrupt")
            self._sync_seen = True

            def _readBuffer():
                length = readVInt(self._stream)
                buf = self._stream.read(length)
                return self._codec.decompressInputStream(buf)

            records = readVInt(self._stream)
            keys_len = _readBuffer()
            keys = _readBuffer()

            values_len = _readBuffer()
            values = _readBuffer()

            self._record = (records, keys_len, keys, values_len, values)
            self._block_index = 1

            key_length = readVInt(keys_len)
            return DataInputBuffer(keys.read(key_length))

    def nextKey(self, key):
        buf = self.nextRawKey()
        if not buf:
          return False
        key.readFields(buf)
        return True

    def nextRawValue(self):
        if not self._block_compressed:
            if self._decompress:
                compress_data = self._record.read(self._record.size())
                return self._codec.decompressInputStream(compress_data)
            else:
                return self._record
        else:
            records, keys_len, keys, values_len, values = self._record
            value_length = readVInt(values_len)
            return DataInputBuffer(values.read(value_length))

    def next(self, key, value):
        more = self.nextKey(key)
        if more:
            self._getCurrentValue(value)
        return more

    def seek(self, position):
        self._stream.seek(position)
        if self._block_compressed:
            self._no_buffered_keys = 0
            self._values_decompressed = True

    def sync(self, position):
        if (position + SYNC_SIZE) > self._end:
            self.seek(self._end)
            return

        if position < self._header_end:
            self._stream.seek(self._header_end)
            self._sync_seen = True
            return

        self.seek(position + 4)
        sync_check = [x for x in self._stream.read(SYNC_HASH_SIZE)]

        i = 0
        while self._stream.getPos() < self._end:
            j = 0
            while j < SYNC_HASH_SIZE:
                if self._sync[j] != sync_check[(i + j) % SYNC_HASH_SIZE]:
                    break
                j += 1

            if j == SYNC_HASH_SIZE:
                self._stream.seek(self._stream.getPos() - SYNC_SIZE)
                return

            sync_check[i % SYNC_HASH_SIZE] = chr(self._stream.readByte())

            i += 1

    def syncSeen(self):
        return self._sync_seen

    def _initialize(self, path, start, length):
        self._stream = self.getStream(path)

        if length == 0:
            self._end = self._stream.getPos() + self._stream.length()
        else:
            self._end = self._stream.getPos() + length

        # Parse Header
        version_block = self._stream.read(len(VERSION))

        if not version_block.startswith(VERSION_PREFIX):
            raise VersionPrefixException(VERSION_PREFIX,
                                         version_block[0:len(VERSION_PREFIX)])

        self._version = version_block[len(VERSION_PREFIX)]
        if self._version > VERSION[len(VERSION_PREFIX)]:
            raise VersionMismatchException(VERSION[len(VERSION_PREFIX)],
                                           self._version)

        if self._version < BLOCK_COMPRESS_VERSION:
            # Same as below, but with UTF8 Deprecated Class
            raise NotImplementedError
        else:
            self._key_class_name = Text.readString(self._stream)
            self._value_class_name = Text.readString(self._stream)

        if ord(self._version) > 2:
            self._decompress = self._stream.readBoolean()
        else:
            self._decompress = False

        if self._version >= BLOCK_COMPRESS_VERSION:
            self._block_compressed = self._stream.readBoolean()
        else:
            self._block_compressed = False

        # setup compression codec
        if self._decompress:
            if self._version >= CUSTOM_COMPRESS_VERSION:
                codec_class = Text.readString(self._stream)
                self._codec = CodecPool().getDecompressor(codec_class)
            else:
                self._codec = CodecPool().getDecompressor()

        self._metadata = Metadata()
        if self._version >= VERSION_WITH_METADATA:
            self._metadata.readFields(self._stream)

        if self._version > 1:
            self._sync = self._stream.read(SYNC_HASH_SIZE)
            self._header_end = self._stream.getPos()

    def _readRecordLength(self):
        if self._stream.getPos() >= self._end:
            return -1

        length = self._stream.readInt()
        if self._version > 1 and self._sync is not None and length == SYNC_ESCAPE:
            sync_check = self._stream.read(SYNC_HASH_SIZE)
            if sync_check != self._sync:
                raise IOError("File is corrupt!")

            self._sync_seen = True
            if self._stream.getPos() >= self._end:
                return -1

            length = self._stream.readInt()
        else:
            self._sync_seen = False

        return length

    def _getCurrentValue(self, value):
        try:
            stream = self.nextRawValue()
            value.readFields(stream)
            if not self._block_compressed:
                assert self._record.size() == 0
        except:
            pass