Example #1
0
 def test_write_iterables(self):
     builder = MemoryFile()
     builder.write(B'FOO BAR BAR FOO FOO')
     builder.seek(4)
     builder.write(itertools.repeat(B'X'[0], 7))
     self.assertEqual(builder.getbuffer(), B'FOO XXXXXXX FOO FOO')
     builder.seekset(len(builder))
     builder.write(B' ')
     builder.write(itertools.repeat(B'X'[0], 4))
     self.assertEqual(builder.getbuffer(), B'FOO XXXXXXX FOO FOO XXXX')
Example #2
0
 def _generate_bytes(self, data: ByteString):
     if not self.squeeze:
         yield from self.action(data)
         return
     buffer = MemoryFile(bytearray())
     for item in self.action(data):
         buffer.write(item)
     yield buffer.getbuffer()
Example #3
0
def decompile_buffer(buffer: ByteString, file_name: str) -> ByteString:
    code_objects = {}
    sys_stderr = sys.stderr
    sys.stderr = open(os.devnull, 'w')
    try:
        version, timestamp, magic_int, codez, is_pypy, _, _ = \
            xtpyi._xdis.load.load_module_from_file_object(MemoryFile(buffer), file_name, code_objects)
    finally:
        sys.stderr.close()
        sys.stderr = sys_stderr
    if not isinstance(codez, list):
        codez = [codez]
    errors = ''
    python = ''
    for code in codez:
        for name, engine in {
                'decompyle3': xtpyi._decompyle3,
                'uncompyle6': xtpyi._uncompyle6,
        }.items():
            with io.StringIO(newline='') as output, NoLogging(
                    NoLogging.Mode.ALL):
                try:
                    engine.main.decompile(
                        version,
                        code,
                        output,
                        timestamp=timestamp,
                        code_objects=code_objects,
                        is_pypy=is_pypy,
                        magic_int=magic_int,
                    )
                except Exception as E:
                    errors += '\n'.join(
                        F'# {line}'
                        for line in (F'Error while decompiling with {name}:',
                                     *str(E).splitlines(True)))
                    errors += '\n'
                else:
                    python = output.getvalue()
                    break
    if python:
        return python.encode(xtpyi.codec)
    embedded = bytes(buffer | carve('printable', single=True))
    if len(buffer) - len(embedded) < 0x20:
        return embedded
    disassembly = MemoryFile()
    with io.TextIOWrapper(disassembly, xtpyi.codec, newline='\n') as output:
        output.write(errors)
        output.write('# Generating Disassembly:\n\n')
        for code in codez:
            instructions = list(xtpyi._xdis.std.Bytecode(code))
            width_offset = max(len(str(i.offset)) for i in instructions)
            for i in instructions:
                opname = i.opname.replace('_', '.').lower()
                offset = F'{i.offset:0{width_offset}d}'
                output.write(F'# {offset:>5} {opname:<25} {i.argrepr}\n')
        output.write('\n')
    return disassembly.getbuffer()
Example #4
0
 def test_string_builder(self):
     builder = MemoryFile()
     self.assertTrue(builder.writable())
     builder.write(B'The binary refinery ')
     builder.write(B'refines the finer binaries.')
     builder.seekrel(-1)
     builder.write(B'!')
     self.assertEqual(builder.getbuffer(),
                      B'The binary refinery refines the finer binaries!')
Example #5
0
 def _decompress_mszip(self,
                       reader: StructReader,
                       writer: MemoryFile,
                       target: Optional[int] = None):
     header = bytes(reader.read(2))
     if header != B'CK':
         raise ValueError(
             F'chunk did not begin with CK header, got {header!r} instead')
     decompress = zlib.decompressobj(-zlib.MAX_WBITS,
                                     zdict=writer.getbuffer())
     writer.write(decompress.decompress(reader.read()))
     writer.write(decompress.flush())
Example #6
0
class blz(Unit):
    """
    BriefLZ compression and decompression. The compression algorithm uses a pure Python suffix tree
    implementation: It requires a lot of time & memory.
    """
    def _begin(self, data):
        self._src = StructReader(memoryview(data))
        self._dst = MemoryFile(bytearray())
        return self

    def _reset(self):
        self._src.seek(0)
        self._dst.seek(0)
        self._dst.truncate()
        return self

    def _decompress(self):
        (
            signature,
            version,
            src_count,
            src_crc32,
            dst_count,
            dst_crc32,
        ) = self._src.read_struct('>6L')
        if signature != 0x626C7A1A:
            raise ValueError(F'Invalid BriefLZ signature: {signature:08X}, should be 626C7A1A.')
        if version > 10:
            raise ValueError(F'Invalid version number {version}, should be less than 10.')
        self.log_debug(F'signature: 0x{signature:08X} V{version}')
        self.log_debug(F'src count: 0x{src_count:08X}')
        self.log_debug(F'src crc32: 0x{src_crc32:08X}')
        self.log_debug(F'dst count: 0x{dst_count:08X}')
        self.log_debug(F'dst crc32: 0x{dst_crc32:08X}')
        src = self._src.getbuffer()
        src = src[24:24 + src_count]
        if len(src) < src_count:
            self.log_warn(F'Only {len(src)} bytes in buffer, but header annoucned a length of {src_count}.')
        if src_crc32:
            check = zlib.crc32(src)
            if check != src_crc32:
                self.log_warn(F'Invalid source data CRC {check:08X}, should be {src_crc32:08X}.')
        dst = self._decompress_chunk(dst_count)
        if not dst_crc32:
            return dst
        check = zlib.crc32(dst)
        if check != dst_crc32:
            self.log_warn(F'Invalid result data CRC {check:08X}, should be {dst_crc32:08X}.')
        return dst

    def _decompress_modded(self):
        self._src.seekrel(8)
        total_size = self._src.u64()
        chunk_size = self._src.u64()
        remaining = total_size
        self.log_debug(F'total size: 0x{total_size:016X}')
        self.log_debug(F'chunk size: 0x{chunk_size:016X}')
        while remaining > chunk_size:
            self._decompress_chunk(chunk_size)
            remaining -= chunk_size
        return self._decompress_chunk(remaining)

    def _decompress_chunk(self, size=None):
        bitcount = 0
        bitstore = 0
        decompressed = 1

        def readbit():
            nonlocal bitcount, bitstore
            if not bitcount:
                bitstore = int.from_bytes(self._src.read_exactly(2), 'little')
                bitcount = 0xF
            else:
                bitcount = bitcount - 1
            return (bitstore >> bitcount) & 1

        def readint():
            result = 2 + readbit()
            while readbit():
                result <<= 1
                result += readbit()
            return result

        self._dst.write(self._src.read_exactly(1))

        try:
            while not size or decompressed < size:
                if readbit():
                    length = readint() + 2
                    sector = readint() - 2
                    offset = self._src.read_byte() + 1
                    delta = offset + 0x100 * sector
                    available = self._dst.tell()
                    if delta not in range(available + 1):
                        raise RefineryPartialResult(
                            F'Requested rewind by 0x{delta:08X} bytes with only 0x{available:08X} bytes in output buffer.',
                            partial=self._dst.getvalue())
                    quotient, remainder = divmod(length, delta)
                    replay = memoryview(self._dst.getbuffer())
                    replay = bytes(replay[-delta:] if quotient else replay[-delta:length - delta])
                    replay = quotient * replay + replay[:remainder]
                    self._dst.write(replay)
                    decompressed += length
                else:
                    self._dst.write(self._src.read_exactly(1))
                    decompressed += 1
        except EOF as E:
            raise RefineryPartialResult(str(E), partial=self._dst.getbuffer())
        dst = self._dst.getbuffer()
        if decompressed < size:
            raise RefineryPartialResult(
                F'Attempted to decompress {size} bytes, got only {len(dst)}.', dst)
        if decompressed > size:
            raise RuntimeError('Decompressed buffer contained more bytes than expected.')
        return dst

    def _compress(self):
        from refinery.lib.suffixtree import SuffixTree

        try:
            self.log_info('computing suffix tree')
            tree = SuffixTree(self._src.getbuffer())
        except Exception:
            raise

        bitstore = 0  # The bit stream to be written
        bitcount = 0  # The number of bits in the bit stream
        buffer = MemoryFile(bytearray())

        # Write empty header and first byte of source
        self._dst.write(bytearray(24))
        self._dst.write(self._src.read_exactly(1))

        def writeint(n: int) -> None:
            """
            Write an integer to the bit stream.
            """
            nonlocal bitstore, bitcount
            nbits = n.bit_length()
            if nbits < 2:
                raise ValueError
            # The highest bit is implicitly assumed:
            n ^= 1 << (nbits - 1)
            remaining = nbits - 2
            while remaining:
                remaining -= 1
                bitstore <<= 2
                bitcount += 2
                bitstore |= ((n >> remaining) & 3) | 1
            bitstore <<= 2
            bitcount += 2
            bitstore |= (n & 1) << 1

        src = self._src.getbuffer()
        remaining = len(src) - 1
        self.log_info('compressing data')

        while True:
            cursor = len(src) - remaining
            rest = src[cursor:]
            if bitcount >= 0x10:
                block_count, bitcount = divmod(bitcount, 0x10)
                info_channel = bitstore >> bitcount
                bitstore = info_channel << bitcount ^ bitstore
                # The decompressor will read bits from top to bottom, and each 16 bit block has to be
                # little-endian encoded. The bit stream is encoded top to bottom bit in the bitstore
                # variable, and by encoding it as a big endian integer, the stream is in the correct
                # order. However, we need to swap adjacent bytes to achieve little endian encoding for
                # each of the blocks:
                info_channel = bytearray(info_channel.to_bytes(block_count * 2, 'big'))
                for k in range(block_count):
                    k0 = 2 * k + 0
                    k1 = 2 * k + 1
                    info_channel[k0], info_channel[k1] = info_channel[k1], info_channel[k0]
                info_channel = memoryview(info_channel)
                data_channel = memoryview(buffer.getbuffer())
                self._dst.write(info_channel[:2])
                self._dst.write(data_channel[:-1])
                self._dst.write(info_channel[2:])
                data_channel = bytes(data_channel[-1:])
                buffer.truncate(0)
                store = buffer if bitcount else self._dst
                store.write(data_channel)
            if remaining + bitcount < 0x10:
                buffer = buffer.getbuffer()
                if rest or buffer:
                    bitstore <<= 0x10 - bitcount
                    self._dst.write(bitstore.to_bytes(2, 'little'))
                    self._dst.write(buffer)
                    self._dst.write(rest)
                elif bitcount:
                    raise RuntimeError('Bitbuffer Overflow')
                break
            node = tree.root
            length = 0
            offset = 0
            sector = None
            while node.children and length < len(rest):
                for child in node.children.values():
                    if tree.data[child.start] == rest[length]:
                        node = child
                        break
                if node.start >= cursor:
                    break
                offset = node.start - length
                length = node.end + 1 - offset
            length = min(remaining, length)
            if length >= 4:
                sector, offset = divmod(cursor - offset - 1, 0x100)
            bitcount += 1
            bitstore <<= 1
            if sector is None:
                buffer.write(rest[:1])
                remaining -= 1
                continue
            bitstore |= 1
            buffer.write(bytes((offset,)))
            writeint(length - 2)
            writeint(sector + 2)
            remaining -= length

        self._dst.seek(24)
        dst = self._dst.peek()
        self._dst.seek(0)
        self._dst.write(struct.pack('>6L', 0x626C7A1A, 1, len(dst), zlib.crc32(dst), len(src), zlib.crc32(src)))
        return self._dst.getbuffer()

    def process(self, data):
        self._begin(data)
        partial = None
        try:
            return self._decompress()
        except ValueError as error:
            if isinstance(error, RefineryPartialResult):
                partial = error
            self.log_warn(F'Reverting to modified BriefLZ after decompression error: {error!s}')
            self._reset()

        try:
            return self._decompress_modded()
        except RefineryPartialResult:
            raise
        except Exception as error:
            if not partial:
                raise
            raise partial from error

    def reverse(self, data):
        return self._begin(data)._compress()
Example #7
0
    def _compress(self):
        from refinery.lib.suffixtree import SuffixTree

        try:
            self.log_info('computing suffix tree')
            tree = SuffixTree(self._src.getbuffer())
        except Exception:
            raise

        bitstore = 0  # The bit stream to be written
        bitcount = 0  # The number of bits in the bit stream
        buffer = MemoryFile(bytearray())

        # Write empty header and first byte of source
        self._dst.write(bytearray(24))
        self._dst.write(self._src.read_exactly(1))

        def writeint(n: int) -> None:
            """
            Write an integer to the bit stream.
            """
            nonlocal bitstore, bitcount
            nbits = n.bit_length()
            if nbits < 2:
                raise ValueError
            # The highest bit is implicitly assumed:
            n ^= 1 << (nbits - 1)
            remaining = nbits - 2
            while remaining:
                remaining -= 1
                bitstore <<= 2
                bitcount += 2
                bitstore |= ((n >> remaining) & 3) | 1
            bitstore <<= 2
            bitcount += 2
            bitstore |= (n & 1) << 1

        src = self._src.getbuffer()
        remaining = len(src) - 1
        self.log_info('compressing data')

        while True:
            cursor = len(src) - remaining
            rest = src[cursor:]
            if bitcount >= 0x10:
                block_count, bitcount = divmod(bitcount, 0x10)
                info_channel = bitstore >> bitcount
                bitstore = info_channel << bitcount ^ bitstore
                # The decompressor will read bits from top to bottom, and each 16 bit block has to be
                # little-endian encoded. The bit stream is encoded top to bottom bit in the bitstore
                # variable, and by encoding it as a big endian integer, the stream is in the correct
                # order. However, we need to swap adjacent bytes to achieve little endian encoding for
                # each of the blocks:
                info_channel = bytearray(info_channel.to_bytes(block_count * 2, 'big'))
                for k in range(block_count):
                    k0 = 2 * k + 0
                    k1 = 2 * k + 1
                    info_channel[k0], info_channel[k1] = info_channel[k1], info_channel[k0]
                info_channel = memoryview(info_channel)
                data_channel = memoryview(buffer.getbuffer())
                self._dst.write(info_channel[:2])
                self._dst.write(data_channel[:-1])
                self._dst.write(info_channel[2:])
                data_channel = bytes(data_channel[-1:])
                buffer.truncate(0)
                store = buffer if bitcount else self._dst
                store.write(data_channel)
            if remaining + bitcount < 0x10:
                buffer = buffer.getbuffer()
                if rest or buffer:
                    bitstore <<= 0x10 - bitcount
                    self._dst.write(bitstore.to_bytes(2, 'little'))
                    self._dst.write(buffer)
                    self._dst.write(rest)
                elif bitcount:
                    raise RuntimeError('Bitbuffer Overflow')
                break
            node = tree.root
            length = 0
            offset = 0
            sector = None
            while node.children and length < len(rest):
                for child in node.children.values():
                    if tree.data[child.start] == rest[length]:
                        node = child
                        break
                if node.start >= cursor:
                    break
                offset = node.start - length
                length = node.end + 1 - offset
            length = min(remaining, length)
            if length >= 4:
                sector, offset = divmod(cursor - offset - 1, 0x100)
            bitcount += 1
            bitstore <<= 1
            if sector is None:
                buffer.write(rest[:1])
                remaining -= 1
                continue
            bitstore |= 1
            buffer.write(bytes((offset,)))
            writeint(length - 2)
            writeint(sector + 2)
            remaining -= length

        self._dst.seek(24)
        dst = self._dst.peek()
        self._dst.seek(0)
        self._dst.write(struct.pack('>6L', 0x626C7A1A, 1, len(dst), zlib.crc32(dst), len(src), zlib.crc32(src)))
        return self._dst.getbuffer()
Example #8
0
    def _decompress(self,
                    writer: MemoryFile,
                    reader_: StructReader[bytearray],
                    size: Optional[int] = None):
        index = 1
        base = 8
        literal_bits = None
        literal_offset = None
        flags = BitBufferedReader(reader_, 32)

        while True:
            if size and len(writer) >= size:
                break
            if flags.next():
                b = flags.read(literal_bits) + literal_offset
                b = b & 0xFF
                writer.write_byte(b)
                continue
            if flags.next():
                high = flags.variable_length_integer()
                if (high == 2):
                    match_length = flags.variable_length_integer()
                else:
                    index = ((high - 3) << base) + flags.read(base)
                    match_length = flags.variable_length_integer()
                    if index >= 0x10000:
                        match_length += 3
                    elif index >= 0x37FF:
                        match_length += 2
                    elif index >= 0x27F:
                        match_length += 1
                    elif index <= 127:
                        match_length += 4
                writer.replay(index, match_length)
                continue
            if not flags.next():
                new_index = flags.read(7)
                match_length = 2 + flags.read(2)
                if new_index == 0:
                    if match_length == 2:
                        break
                    base = flags.read(match_length + 1)
                else:
                    index = new_index
                    writer.replay(index, match_length)
                continue
            one_byte_phrase_value = flags.read(4) - 1
            if one_byte_phrase_value == 0:
                writer.write_byte(0)
            elif one_byte_phrase_value > 0:
                b = writer.getbuffer()[-one_byte_phrase_value]
                writer.write_byte(b)
            else:
                if not flags.next():
                    literal_bits = 7 + flags.next()
                    literal_offset = 0
                    if literal_bits != 8:
                        literal_offset = flags.read(8)
                    continue
                while True:
                    for _ in range(0x100):
                        b = flags.read(8)
                        writer.write_byte(b)
                    if not flags.next():
                        break
Example #9
0
    def decompress_stream(self, data: ByteString, LZOv1: bool = False) -> bytearray:
        """
        An implementation of LZO decompression. We use the article
        "[LZO stream format as understood by Linux's LZO decompressor](https://www.kernel.org/doc/html/latest/staging/lzo.html)"
        as a reference since no proper specification is available.
        """
        def integer() -> int:
            length = 0
            while True:
                byte = src.read_byte()
                if byte:
                    return length + byte
                length += 0xFF
                if length > 0x100000:
                    raise LZOError('Too many zeros in integer encoding.')

        def literal(count):
            dst.write(src.read_bytes(count))

        def copy(distance: int, length: int):
            if distance > len(dst):
                raise LZOError(F'Distance {distance} > bufsize {len(dst)}')
            buffer = dst.getbuffer()
            if distance > length:
                start = len(buffer) - distance
                end = start + length
                dst.write(buffer[start:end])
            else:
                block = buffer[-distance:]
                while len(block) < length:
                    block += block[:length - len(block)]
                if len(block) > length:
                    block[length:] = ()
                dst.write(block)

        src = StructReader(memoryview(data))
        dst = MemoryFile()

        state = 0
        first = src.read_byte()

        if first == 0x10:
            raise LZOError('Invalid first stream byte 0x10.')
        elif first <= 0x12:
            src.seekrel(-1)
        elif first <= 0x15:
            state = first - 0x11
            literal(state)
        else:
            state = 4
            literal(first - 0x11)

        while True:
            instruction = src.read_byte()
            if instruction < 0x10:
                if state == 0:
                    length = instruction or integer() + 15
                    state = length + 3
                    if state < 4:
                        raise LZOError('Literal encoding is too short.')
                else:
                    state = instruction & 0b0011
                    D = (instruction & 0b1100) >> 2
                    H = src.read_byte()
                    distance = (H << 2) + D + 1
                    if state >= 4:
                        distance += 0x800
                        length = 3
                    else:
                        length = 2
                    copy(distance, length)
            elif instruction < 0x20:
                L = instruction & 0b0111
                H = instruction & 0b1000
                length = L or integer() + 7
                argument = src.u16()
                state = argument & 3
                distance = (H << 11) + (argument >> 2)
                if not distance:
                    return dst.getbuffer()
                if LZOv1 and distance & 0x803F == 0x803F and length in range(261, 265):
                    raise LZOError('Compressed data contains sequence that is banned in LZOv1.')
                if LZOv1 and distance == 0xBFFF:
                    X = src.read_byte()
                    count = ((X << 3) | L) + 4
                    self.log_debug(F'Writing run of {X} zero bytes according to LZOv1.')
                    dst.write(B'\0' * count)
                else:
                    copy(distance + 0x4000, length + 2)
            elif instruction < 0x40:
                L = instruction & 0b11111
                length = L or integer() + 31
                argument = src.u16()
                state = argument & 3
                distance = (argument >> 2) + 1
                copy(distance, length + 2)
            else:
                if instruction < 0x80:
                    length = 3 + ((instruction >> 5) & 1)
                else:
                    length = 5 + ((instruction >> 5) & 3)
                H = src.read_byte()
                D = (instruction & 0b11100) >> 2
                state = instruction & 3
                distance = (H << 3) + D + 1
                copy(distance, length)
            if state:
                literal(state)