Beispiel #1
0
 def process(self, data):
     dst = bytearray()
     src = StructReader(data)
     while not src.eof:
         copy = src.read_byte()
         for mask in (0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80):
             if src.eof:
                 break
             if not copy & mask:
                 dst.append(src.read_byte())
                 continue
             elif not dst:
                 raise ValueError('copy requested against empty buffer')
             with src.be:
                 match_len = src.read_integer(6) + _MATCH_MIN
                 match_pos = src.read_integer(10)
             if not match_pos or match_pos > len(dst):
                 raise RuntimeError(F'invalid match offset at position {src.tell()}')
             match_pos = len(dst) - match_pos
             while match_len > 0:
                 match = dst[match_pos:match_pos + match_len]
                 dst.extend(match)
                 match_pos += len(match)
                 match_len -= len(match)
     return dst
Beispiel #2
0
 def __init__(self, reader: StructReader):
     reader.bigendian = True
     entry_start_offset = reader.tell()
     self.size_of_entry = reader.i32()
     self.offset = reader.i32()
     self.size_of_compressed_data = reader.i32()
     self.size_od_uncompressed_data = reader.i32()
     self.is_compressed = bool(reader.read_byte())
     entry_type = bytes(reader.read(1))
     name_length = self.size_of_entry - reader.tell() + entry_start_offset
     if name_length > 0x1000:
         raise RuntimeError(
             F'Refusing to process TOC entry with name of size {name_length}.'
         )
     name, *_ = bytes(reader.read(name_length)).partition(B'\0')
     try:
         name = name.decode('utf8', 'backslashreplace')
     except Exception:
         name = None
     if not all(part.isprintable() for part in re.split('\\s*', name)):
         raise RuntimeError(
             'Refusing to process TOC entry with non-printable name.')
     name = name or str(uuid.uuid4())
     if entry_type == B'Z':
         entry_type = B'z'
     try:
         self.type = PiType(entry_type)
     except ValueError:
         xtpyi.logger.error(F'unknown type {entry_type!r} in field {name}')
         self.type = PiType.UNKNOWN
     self.name = name
Beispiel #3
0
 def __init__(self, reader: StructReader):
     with StreamDetour(reader):
         self.code = opc(reader.read_byte())
         self.table: Optional[Dict[int, int]] = None
         try:
             fmt = self.OPC_ARGMAP[self.code]
         except KeyError:
             self.arguments = []
         else:
             self.arguments = list(reader.read_struct(fmt))
         if self.code == opc.newarray:
             self.arguments = [JvBaseType(reader.read_byte())]
         elif self.code in self.OPC_CONSTPOOL:
             try:
                 self.arguments[0] = self.pool[self.arguments[0] - 1]
             except (AttributeError, IndexError):
                 pass
         elif self.code == opc.lookupswitch:
             reader.byte_align(blocksize=4)
             default, npairs = reader.read_struct('LL')
             pairs = reader.read_struct(F'{npairs*2}L')
             self.table = dict(zip(*([iter(pairs)] * 2)))
             self.table[None] = default
         elif self.code == opc.tableswitch:
             reader.byte_align(blocksize=4)
             default, low, high = reader.read_struct('LLL')
             assert low <= high
             offsets = reader.read_struct(F'{high-low+1}L')
             self.table = {k + low: offset for k, offset in enumerate(offsets)}
             self.table[None] = default
         elif self.code == opc.wide:
             argop = opc(reader.get_byte())
             self.arguments = (argop, reader.u16())
             if argop == opc.iinc:
                 self.arguments += reader.i16(),
             else:
                 assert argop in (
                     opc.iload, opc.istore,
                     opc.fload, opc.fstore,
                     opc.aload, opc.astore,
                     opc.lload, opc.lstore,
                     opc.dload, opc.dstore,
                     opc.ret)
         offset = reader.tell()
     self.raw = bytes(reader.read(offset - reader.tell()))
Beispiel #4
0
 def test_bitreader_be(self):
     data = 0b01010_10011101_0100100001_1111_0111101010000101010101010010010111100000101001010101100000001110010111110100111000_101
     size, remainder = divmod(data.bit_length(), 8)
     self.assertEqual(remainder, 7)
     data = memoryview(data.to_bytes(size + 1, 'big'))
     sr = StructReader(data)
     with sr.be:
         self.assertEqual(sr.read_bit(), 0)
         self.assertEqual(sr.read_bit(), 1)
         self.assertEqual(sr.read_bit(), 0)
         self.assertEqual(sr.read_bit(), 1)
         self.assertEqual(sr.read_bit(), 0)
         self.assertEqual(sr.read_byte(), 0b10011101)
         self.assertEqual(sr.read_integer(10), 0b100100001)
         self.assertTrue(all(sr.read_flags(4)))
         self.assertEqual(
             sr.read_integer(82),
             0b0111101010000101010101010010010111100000101001010101100000001110010111110100111000
         )
         self.assertRaises(EOF, sr.u16)
Beispiel #5
0
 def test_bitreader_structured(self):
     items = (
         0b1100101,  # noqa
         -0x1337,  # noqa
         0xDEFACED,  # noqa
         0xC0CAC01A,  # noqa
         -0o1337,  # noqa
         2076.171875,  # noqa
         math.pi  # noqa
     )
     data = struct.pack('<bhiLqfd', *items)
     sr = StructReader(data)
     self.assertEqual(sr.read_nibble(), 0b101)
     self.assertRaises(sr.Unaligned, lambda: sr.read_exactly(2))
     sr.seek(0)
     self.assertEqual(sr.read_byte(), 0b1100101)
     self.assertEqual(sr.i16(), -0x1337)
     self.assertEqual(sr.i32(), 0xDEFACED)
     self.assertEqual(sr.u32(), 0xC0CAC01A)
     self.assertEqual(sr.i64(), -0o1337)
     self.assertAlmostEqual(sr.read_struct('f', True), 2076.171875)
     self.assertAlmostEqual(sr.read_struct('d', True), math.pi)
     self.assertTrue(sr.eof)
Beispiel #6
0
class blz(Unit):
    """
    BriefLZ compression and decompression. The compression algorithm uses a pure Python suffix tree
    implementation: It requires a lot of time & memory.
    """
    def _begin(self, data):
        self._src = StructReader(memoryview(data))
        self._dst = MemoryFile(bytearray())
        return self

    def _reset(self):
        self._src.seek(0)
        self._dst.seek(0)
        self._dst.truncate()
        return self

    def _decompress(self):
        (
            signature,
            version,
            src_count,
            src_crc32,
            dst_count,
            dst_crc32,
        ) = self._src.read_struct('>6L')
        if signature != 0x626C7A1A:
            raise ValueError(F'Invalid BriefLZ signature: {signature:08X}, should be 626C7A1A.')
        if version > 10:
            raise ValueError(F'Invalid version number {version}, should be less than 10.')
        self.log_debug(F'signature: 0x{signature:08X} V{version}')
        self.log_debug(F'src count: 0x{src_count:08X}')
        self.log_debug(F'src crc32: 0x{src_crc32:08X}')
        self.log_debug(F'dst count: 0x{dst_count:08X}')
        self.log_debug(F'dst crc32: 0x{dst_crc32:08X}')
        src = self._src.getbuffer()
        src = src[24:24 + src_count]
        if len(src) < src_count:
            self.log_warn(F'Only {len(src)} bytes in buffer, but header annoucned a length of {src_count}.')
        if src_crc32:
            check = zlib.crc32(src)
            if check != src_crc32:
                self.log_warn(F'Invalid source data CRC {check:08X}, should be {src_crc32:08X}.')
        dst = self._decompress_chunk(dst_count)
        if not dst_crc32:
            return dst
        check = zlib.crc32(dst)
        if check != dst_crc32:
            self.log_warn(F'Invalid result data CRC {check:08X}, should be {dst_crc32:08X}.')
        return dst

    def _decompress_modded(self):
        self._src.seekrel(8)
        total_size = self._src.u64()
        chunk_size = self._src.u64()
        remaining = total_size
        self.log_debug(F'total size: 0x{total_size:016X}')
        self.log_debug(F'chunk size: 0x{chunk_size:016X}')
        while remaining > chunk_size:
            self._decompress_chunk(chunk_size)
            remaining -= chunk_size
        return self._decompress_chunk(remaining)

    def _decompress_chunk(self, size=None):
        bitcount = 0
        bitstore = 0
        decompressed = 1

        def readbit():
            nonlocal bitcount, bitstore
            if not bitcount:
                bitstore = int.from_bytes(self._src.read_exactly(2), 'little')
                bitcount = 0xF
            else:
                bitcount = bitcount - 1
            return (bitstore >> bitcount) & 1

        def readint():
            result = 2 + readbit()
            while readbit():
                result <<= 1
                result += readbit()
            return result

        self._dst.write(self._src.read_exactly(1))

        try:
            while not size or decompressed < size:
                if readbit():
                    length = readint() + 2
                    sector = readint() - 2
                    offset = self._src.read_byte() + 1
                    delta = offset + 0x100 * sector
                    available = self._dst.tell()
                    if delta not in range(available + 1):
                        raise RefineryPartialResult(
                            F'Requested rewind by 0x{delta:08X} bytes with only 0x{available:08X} bytes in output buffer.',
                            partial=self._dst.getvalue())
                    quotient, remainder = divmod(length, delta)
                    replay = memoryview(self._dst.getbuffer())
                    replay = bytes(replay[-delta:] if quotient else replay[-delta:length - delta])
                    replay = quotient * replay + replay[:remainder]
                    self._dst.write(replay)
                    decompressed += length
                else:
                    self._dst.write(self._src.read_exactly(1))
                    decompressed += 1
        except EOF as E:
            raise RefineryPartialResult(str(E), partial=self._dst.getbuffer())
        dst = self._dst.getbuffer()
        if decompressed < size:
            raise RefineryPartialResult(
                F'Attempted to decompress {size} bytes, got only {len(dst)}.', dst)
        if decompressed > size:
            raise RuntimeError('Decompressed buffer contained more bytes than expected.')
        return dst

    def _compress(self):
        from refinery.lib.suffixtree import SuffixTree

        try:
            self.log_info('computing suffix tree')
            tree = SuffixTree(self._src.getbuffer())
        except Exception:
            raise

        bitstore = 0  # The bit stream to be written
        bitcount = 0  # The number of bits in the bit stream
        buffer = MemoryFile(bytearray())

        # Write empty header and first byte of source
        self._dst.write(bytearray(24))
        self._dst.write(self._src.read_exactly(1))

        def writeint(n: int) -> None:
            """
            Write an integer to the bit stream.
            """
            nonlocal bitstore, bitcount
            nbits = n.bit_length()
            if nbits < 2:
                raise ValueError
            # The highest bit is implicitly assumed:
            n ^= 1 << (nbits - 1)
            remaining = nbits - 2
            while remaining:
                remaining -= 1
                bitstore <<= 2
                bitcount += 2
                bitstore |= ((n >> remaining) & 3) | 1
            bitstore <<= 2
            bitcount += 2
            bitstore |= (n & 1) << 1

        src = self._src.getbuffer()
        remaining = len(src) - 1
        self.log_info('compressing data')

        while True:
            cursor = len(src) - remaining
            rest = src[cursor:]
            if bitcount >= 0x10:
                block_count, bitcount = divmod(bitcount, 0x10)
                info_channel = bitstore >> bitcount
                bitstore = info_channel << bitcount ^ bitstore
                # The decompressor will read bits from top to bottom, and each 16 bit block has to be
                # little-endian encoded. The bit stream is encoded top to bottom bit in the bitstore
                # variable, and by encoding it as a big endian integer, the stream is in the correct
                # order. However, we need to swap adjacent bytes to achieve little endian encoding for
                # each of the blocks:
                info_channel = bytearray(info_channel.to_bytes(block_count * 2, 'big'))
                for k in range(block_count):
                    k0 = 2 * k + 0
                    k1 = 2 * k + 1
                    info_channel[k0], info_channel[k1] = info_channel[k1], info_channel[k0]
                info_channel = memoryview(info_channel)
                data_channel = memoryview(buffer.getbuffer())
                self._dst.write(info_channel[:2])
                self._dst.write(data_channel[:-1])
                self._dst.write(info_channel[2:])
                data_channel = bytes(data_channel[-1:])
                buffer.truncate(0)
                store = buffer if bitcount else self._dst
                store.write(data_channel)
            if remaining + bitcount < 0x10:
                buffer = buffer.getbuffer()
                if rest or buffer:
                    bitstore <<= 0x10 - bitcount
                    self._dst.write(bitstore.to_bytes(2, 'little'))
                    self._dst.write(buffer)
                    self._dst.write(rest)
                elif bitcount:
                    raise RuntimeError('Bitbuffer Overflow')
                break
            node = tree.root
            length = 0
            offset = 0
            sector = None
            while node.children and length < len(rest):
                for child in node.children.values():
                    if tree.data[child.start] == rest[length]:
                        node = child
                        break
                if node.start >= cursor:
                    break
                offset = node.start - length
                length = node.end + 1 - offset
            length = min(remaining, length)
            if length >= 4:
                sector, offset = divmod(cursor - offset - 1, 0x100)
            bitcount += 1
            bitstore <<= 1
            if sector is None:
                buffer.write(rest[:1])
                remaining -= 1
                continue
            bitstore |= 1
            buffer.write(bytes((offset,)))
            writeint(length - 2)
            writeint(sector + 2)
            remaining -= length

        self._dst.seek(24)
        dst = self._dst.peek()
        self._dst.seek(0)
        self._dst.write(struct.pack('>6L', 0x626C7A1A, 1, len(dst), zlib.crc32(dst), len(src), zlib.crc32(src)))
        return self._dst.getbuffer()

    def process(self, data):
        self._begin(data)
        partial = None
        try:
            return self._decompress()
        except ValueError as error:
            if isinstance(error, RefineryPartialResult):
                partial = error
            self.log_warn(F'Reverting to modified BriefLZ after decompression error: {error!s}')
            self._reset()

        try:
            return self._decompress_modded()
        except RefineryPartialResult:
            raise
        except Exception as error:
            if not partial:
                raise
            raise partial from error

    def reverse(self, data):
        return self._begin(data)._compress()
Beispiel #7
0
    def _decompress_xpress_huffman(self,
                                   reader: StructReader,
                                   writer: MemoryFile,
                                   target: Optional[int] = None,
                                   max_chunk_size: int = 0x10000) -> None:
        limit = writer.tell()
        if target is not None:
            target += limit

        while not reader.eof:

            if reader.remaining_bytes < XPRESS_NUM_SYMBOLS // 2:
                raise IndexError(
                    F'There are only {reader.remaining_bytes} bytes reamining in the input buffer,'
                    F' but at least {XPRESS_NUM_SYMBOLS//2} are required to read a Huffman table.'
                )

            table = bytearray(
                reader.read_integer(4) for _ in range(XPRESS_NUM_SYMBOLS))
            table = make_huffman_decode_table(table, XPRESS_TABLEBITS,
                                              XPRESS_MAX_CODEWORD_LEN)
            limit = limit + max_chunk_size
            flags = BitBufferedReader(reader, 16)

            while True:
                position = writer.tell()
                if position == target:
                    if reader.remaining_bytes:
                        self.log_info(
                            F'chunk decompressed with {reader.remaining_bytes} bytes remaining in input buffer'
                        )
                    return
                if position >= limit:
                    if position > limit:
                        limit = position
                        self.log_info(
                            F'decompression of one chunk generated more than the limit of {max_chunk_size} bytes'
                        )
                    flags.collect()
                    break
                try:
                    sym = flags.huffman_symbol(table, XPRESS_TABLEBITS,
                                               XPRESS_MAX_CODEWORD_LEN)
                except EOFError:
                    self.log_debug('end of file while reading huffman symbol')
                    break
                if sym < XPRESS_NUM_CHARS:
                    writer.write_byte(sym)
                    continue
                length = sym & 0xF
                offsetlog = (sym >> 4) & 0xF
                flags.collect()
                if reader.eof:
                    break
                offset = (1 << offsetlog) | flags.read(offsetlog)
                if length == 0xF:
                    nudge = reader.read_byte()
                    if nudge < 0xFF:
                        length += nudge
                    else:
                        length = reader.u16() or reader.u32()
                length += XPRESS_MIN_MATCH_LEN
                writer.replay(offset, length)
Beispiel #8
0
    def decompress_stream(self, data: ByteString, LZOv1: bool = False) -> bytearray:
        """
        An implementation of LZO decompression. We use the article
        "[LZO stream format as understood by Linux's LZO decompressor](https://www.kernel.org/doc/html/latest/staging/lzo.html)"
        as a reference since no proper specification is available.
        """
        def integer() -> int:
            length = 0
            while True:
                byte = src.read_byte()
                if byte:
                    return length + byte
                length += 0xFF
                if length > 0x100000:
                    raise LZOError('Too many zeros in integer encoding.')

        def literal(count):
            dst.write(src.read_bytes(count))

        def copy(distance: int, length: int):
            if distance > len(dst):
                raise LZOError(F'Distance {distance} > bufsize {len(dst)}')
            buffer = dst.getbuffer()
            if distance > length:
                start = len(buffer) - distance
                end = start + length
                dst.write(buffer[start:end])
            else:
                block = buffer[-distance:]
                while len(block) < length:
                    block += block[:length - len(block)]
                if len(block) > length:
                    block[length:] = ()
                dst.write(block)

        src = StructReader(memoryview(data))
        dst = MemoryFile()

        state = 0
        first = src.read_byte()

        if first == 0x10:
            raise LZOError('Invalid first stream byte 0x10.')
        elif first <= 0x12:
            src.seekrel(-1)
        elif first <= 0x15:
            state = first - 0x11
            literal(state)
        else:
            state = 4
            literal(first - 0x11)

        while True:
            instruction = src.read_byte()
            if instruction < 0x10:
                if state == 0:
                    length = instruction or integer() + 15
                    state = length + 3
                    if state < 4:
                        raise LZOError('Literal encoding is too short.')
                else:
                    state = instruction & 0b0011
                    D = (instruction & 0b1100) >> 2
                    H = src.read_byte()
                    distance = (H << 2) + D + 1
                    if state >= 4:
                        distance += 0x800
                        length = 3
                    else:
                        length = 2
                    copy(distance, length)
            elif instruction < 0x20:
                L = instruction & 0b0111
                H = instruction & 0b1000
                length = L or integer() + 7
                argument = src.u16()
                state = argument & 3
                distance = (H << 11) + (argument >> 2)
                if not distance:
                    return dst.getbuffer()
                if LZOv1 and distance & 0x803F == 0x803F and length in range(261, 265):
                    raise LZOError('Compressed data contains sequence that is banned in LZOv1.')
                if LZOv1 and distance == 0xBFFF:
                    X = src.read_byte()
                    count = ((X << 3) | L) + 4
                    self.log_debug(F'Writing run of {X} zero bytes according to LZOv1.')
                    dst.write(B'\0' * count)
                else:
                    copy(distance + 0x4000, length + 2)
            elif instruction < 0x40:
                L = instruction & 0b11111
                length = L or integer() + 31
                argument = src.u16()
                state = argument & 3
                distance = (argument >> 2) + 1
                copy(distance, length + 2)
            else:
                if instruction < 0x80:
                    length = 3 + ((instruction >> 5) & 1)
                else:
                    length = 5 + ((instruction >> 5) & 3)
                H = src.read_byte()
                D = (instruction & 0b11100) >> 2
                state = instruction & 3
                distance = (H << 3) + D + 1
                copy(distance, length)
            if state:
                literal(state)
Beispiel #9
0
 def __init__(self, reader: StructReader):
     self.kind = JvMethodHandleRefKind(reader.read_byte())
     self.reference = reader.u16()