Beispiel #1
0
 def __init__(self, reader: StructReader):
     with StreamDetour(reader):
         self.code = opc(reader.read_byte())
         self.table: Optional[Dict[int, int]] = None
         try:
             fmt = self.OPC_ARGMAP[self.code]
         except KeyError:
             self.arguments = []
         else:
             self.arguments = list(reader.read_struct(fmt))
         if self.code == opc.newarray:
             self.arguments = [JvBaseType(reader.read_byte())]
         elif self.code in self.OPC_CONSTPOOL:
             try:
                 self.arguments[0] = self.pool[self.arguments[0] - 1]
             except (AttributeError, IndexError):
                 pass
         elif self.code == opc.lookupswitch:
             reader.byte_align(blocksize=4)
             default, npairs = reader.read_struct('LL')
             pairs = reader.read_struct(F'{npairs*2}L')
             self.table = dict(zip(*([iter(pairs)] * 2)))
             self.table[None] = default
         elif self.code == opc.tableswitch:
             reader.byte_align(blocksize=4)
             default, low, high = reader.read_struct('LLL')
             assert low <= high
             offsets = reader.read_struct(F'{high-low+1}L')
             self.table = {k + low: offset for k, offset in enumerate(offsets)}
             self.table[None] = default
         elif self.code == opc.wide:
             argop = opc(reader.get_byte())
             self.arguments = (argop, reader.u16())
             if argop == opc.iinc:
                 self.arguments += reader.i16(),
             else:
                 assert argop in (
                     opc.iload, opc.istore,
                     opc.fload, opc.fstore,
                     opc.aload, opc.astore,
                     opc.lload, opc.lstore,
                     opc.dload, opc.dstore,
                     opc.ret)
         offset = reader.tell()
     self.raw = bytes(reader.read(offset - reader.tell()))
Beispiel #2
0
 def test_bitreader_le(self):
     data = 0b10010100111010100100001111101_11_00000000_0101010101010010010111100000101001010101100000001110010111110100_111_000_100
     size, remainder = divmod(data.bit_length(), 8)
     self.assertEqual(remainder, 0)
     data = memoryview(data.to_bytes(size, 'little'))
     sr = StructReader(data)
     self.assertEqual(sr.read_integer(3), 0b100)
     self.assertEqual(sr.read_integer(3), 0b000)
     self.assertEqual(sr.read_integer(3), 0b111)
     self.assertEqual(
         sr.u64(),
         0b0101010101010010010111100000101001010101100000001110010111110100)
     self.assertFalse(any(sr.read_flags(8, reverse=True)))
     self.assertEqual(sr.read_bit(), 1)
     self.assertRaises(ValueError, lambda: sr.read_struct(''))
     self.assertEqual(sr.read_bit(), 1)
     self.assertEqual(sr.read_integer(29), 0b10010100111010100100001111101)
     self.assertTrue(sr.eof)
Beispiel #3
0
 def test_bitreader_structured(self):
     items = (
         0b1100101,  # noqa
         -0x1337,  # noqa
         0xDEFACED,  # noqa
         0xC0CAC01A,  # noqa
         -0o1337,  # noqa
         2076.171875,  # noqa
         math.pi  # noqa
     )
     data = struct.pack('<bhiLqfd', *items)
     sr = StructReader(data)
     self.assertEqual(sr.read_nibble(), 0b101)
     self.assertRaises(sr.Unaligned, lambda: sr.read_exactly(2))
     sr.seek(0)
     self.assertEqual(sr.read_byte(), 0b1100101)
     self.assertEqual(sr.i16(), -0x1337)
     self.assertEqual(sr.i32(), 0xDEFACED)
     self.assertEqual(sr.u32(), 0xC0CAC01A)
     self.assertEqual(sr.i64(), -0o1337)
     self.assertAlmostEqual(sr.read_struct('f', True), 2076.171875)
     self.assertAlmostEqual(sr.read_struct('d', True), math.pi)
     self.assertTrue(sr.eof)
Beispiel #4
0
class blz(Unit):
    """
    BriefLZ compression and decompression. The compression algorithm uses a pure Python suffix tree
    implementation: It requires a lot of time & memory.
    """
    def _begin(self, data):
        self._src = StructReader(memoryview(data))
        self._dst = MemoryFile(bytearray())
        return self

    def _reset(self):
        self._src.seek(0)
        self._dst.seek(0)
        self._dst.truncate()
        return self

    def _decompress(self):
        (
            signature,
            version,
            src_count,
            src_crc32,
            dst_count,
            dst_crc32,
        ) = self._src.read_struct('>6L')
        if signature != 0x626C7A1A:
            raise ValueError(F'Invalid BriefLZ signature: {signature:08X}, should be 626C7A1A.')
        if version > 10:
            raise ValueError(F'Invalid version number {version}, should be less than 10.')
        self.log_debug(F'signature: 0x{signature:08X} V{version}')
        self.log_debug(F'src count: 0x{src_count:08X}')
        self.log_debug(F'src crc32: 0x{src_crc32:08X}')
        self.log_debug(F'dst count: 0x{dst_count:08X}')
        self.log_debug(F'dst crc32: 0x{dst_crc32:08X}')
        src = self._src.getbuffer()
        src = src[24:24 + src_count]
        if len(src) < src_count:
            self.log_warn(F'Only {len(src)} bytes in buffer, but header annoucned a length of {src_count}.')
        if src_crc32:
            check = zlib.crc32(src)
            if check != src_crc32:
                self.log_warn(F'Invalid source data CRC {check:08X}, should be {src_crc32:08X}.')
        dst = self._decompress_chunk(dst_count)
        if not dst_crc32:
            return dst
        check = zlib.crc32(dst)
        if check != dst_crc32:
            self.log_warn(F'Invalid result data CRC {check:08X}, should be {dst_crc32:08X}.')
        return dst

    def _decompress_modded(self):
        self._src.seekrel(8)
        total_size = self._src.u64()
        chunk_size = self._src.u64()
        remaining = total_size
        self.log_debug(F'total size: 0x{total_size:016X}')
        self.log_debug(F'chunk size: 0x{chunk_size:016X}')
        while remaining > chunk_size:
            self._decompress_chunk(chunk_size)
            remaining -= chunk_size
        return self._decompress_chunk(remaining)

    def _decompress_chunk(self, size=None):
        bitcount = 0
        bitstore = 0
        decompressed = 1

        def readbit():
            nonlocal bitcount, bitstore
            if not bitcount:
                bitstore = int.from_bytes(self._src.read_exactly(2), 'little')
                bitcount = 0xF
            else:
                bitcount = bitcount - 1
            return (bitstore >> bitcount) & 1

        def readint():
            result = 2 + readbit()
            while readbit():
                result <<= 1
                result += readbit()
            return result

        self._dst.write(self._src.read_exactly(1))

        try:
            while not size or decompressed < size:
                if readbit():
                    length = readint() + 2
                    sector = readint() - 2
                    offset = self._src.read_byte() + 1
                    delta = offset + 0x100 * sector
                    available = self._dst.tell()
                    if delta not in range(available + 1):
                        raise RefineryPartialResult(
                            F'Requested rewind by 0x{delta:08X} bytes with only 0x{available:08X} bytes in output buffer.',
                            partial=self._dst.getvalue())
                    quotient, remainder = divmod(length, delta)
                    replay = memoryview(self._dst.getbuffer())
                    replay = bytes(replay[-delta:] if quotient else replay[-delta:length - delta])
                    replay = quotient * replay + replay[:remainder]
                    self._dst.write(replay)
                    decompressed += length
                else:
                    self._dst.write(self._src.read_exactly(1))
                    decompressed += 1
        except EOF as E:
            raise RefineryPartialResult(str(E), partial=self._dst.getbuffer())
        dst = self._dst.getbuffer()
        if decompressed < size:
            raise RefineryPartialResult(
                F'Attempted to decompress {size} bytes, got only {len(dst)}.', dst)
        if decompressed > size:
            raise RuntimeError('Decompressed buffer contained more bytes than expected.')
        return dst

    def _compress(self):
        from refinery.lib.suffixtree import SuffixTree

        try:
            self.log_info('computing suffix tree')
            tree = SuffixTree(self._src.getbuffer())
        except Exception:
            raise

        bitstore = 0  # The bit stream to be written
        bitcount = 0  # The number of bits in the bit stream
        buffer = MemoryFile(bytearray())

        # Write empty header and first byte of source
        self._dst.write(bytearray(24))
        self._dst.write(self._src.read_exactly(1))

        def writeint(n: int) -> None:
            """
            Write an integer to the bit stream.
            """
            nonlocal bitstore, bitcount
            nbits = n.bit_length()
            if nbits < 2:
                raise ValueError
            # The highest bit is implicitly assumed:
            n ^= 1 << (nbits - 1)
            remaining = nbits - 2
            while remaining:
                remaining -= 1
                bitstore <<= 2
                bitcount += 2
                bitstore |= ((n >> remaining) & 3) | 1
            bitstore <<= 2
            bitcount += 2
            bitstore |= (n & 1) << 1

        src = self._src.getbuffer()
        remaining = len(src) - 1
        self.log_info('compressing data')

        while True:
            cursor = len(src) - remaining
            rest = src[cursor:]
            if bitcount >= 0x10:
                block_count, bitcount = divmod(bitcount, 0x10)
                info_channel = bitstore >> bitcount
                bitstore = info_channel << bitcount ^ bitstore
                # The decompressor will read bits from top to bottom, and each 16 bit block has to be
                # little-endian encoded. The bit stream is encoded top to bottom bit in the bitstore
                # variable, and by encoding it as a big endian integer, the stream is in the correct
                # order. However, we need to swap adjacent bytes to achieve little endian encoding for
                # each of the blocks:
                info_channel = bytearray(info_channel.to_bytes(block_count * 2, 'big'))
                for k in range(block_count):
                    k0 = 2 * k + 0
                    k1 = 2 * k + 1
                    info_channel[k0], info_channel[k1] = info_channel[k1], info_channel[k0]
                info_channel = memoryview(info_channel)
                data_channel = memoryview(buffer.getbuffer())
                self._dst.write(info_channel[:2])
                self._dst.write(data_channel[:-1])
                self._dst.write(info_channel[2:])
                data_channel = bytes(data_channel[-1:])
                buffer.truncate(0)
                store = buffer if bitcount else self._dst
                store.write(data_channel)
            if remaining + bitcount < 0x10:
                buffer = buffer.getbuffer()
                if rest or buffer:
                    bitstore <<= 0x10 - bitcount
                    self._dst.write(bitstore.to_bytes(2, 'little'))
                    self._dst.write(buffer)
                    self._dst.write(rest)
                elif bitcount:
                    raise RuntimeError('Bitbuffer Overflow')
                break
            node = tree.root
            length = 0
            offset = 0
            sector = None
            while node.children and length < len(rest):
                for child in node.children.values():
                    if tree.data[child.start] == rest[length]:
                        node = child
                        break
                if node.start >= cursor:
                    break
                offset = node.start - length
                length = node.end + 1 - offset
            length = min(remaining, length)
            if length >= 4:
                sector, offset = divmod(cursor - offset - 1, 0x100)
            bitcount += 1
            bitstore <<= 1
            if sector is None:
                buffer.write(rest[:1])
                remaining -= 1
                continue
            bitstore |= 1
            buffer.write(bytes((offset,)))
            writeint(length - 2)
            writeint(sector + 2)
            remaining -= length

        self._dst.seek(24)
        dst = self._dst.peek()
        self._dst.seek(0)
        self._dst.write(struct.pack('>6L', 0x626C7A1A, 1, len(dst), zlib.crc32(dst), len(src), zlib.crc32(src)))
        return self._dst.getbuffer()

    def process(self, data):
        self._begin(data)
        partial = None
        try:
            return self._decompress()
        except ValueError as error:
            if isinstance(error, RefineryPartialResult):
                partial = error
            self.log_warn(F'Reverting to modified BriefLZ after decompression error: {error!s}')
            self._reset()

        try:
            return self._decompress_modded()
        except RefineryPartialResult:
            raise
        except Exception as error:
            if not partial:
                raise
            raise partial from error

    def reverse(self, data):
        return self._begin(data)._compress()
Beispiel #5
0
    def process(self, data: bytearray):
        formatter = string.Formatter()
        until = self.args.until
        until = until and PythonExpression(until, all_variables_allowed=True)
        reader = StructReader(memoryview(data))
        mainspec = self.args.spec
        byteorder = mainspec[:1]
        if byteorder in '<!=@>':
            mainspec = mainspec[1:]
        else:
            byteorder = '='

        def fixorder(spec):
            if spec[0] not in '<!=@>':
                spec = byteorder + spec
            return spec

        it = itertools.count() if self.args.multi else (0, )
        for index in it:

            if reader.eof:
                break
            if index >= self.args.count:
                break

            meta = metavars(data, ghost=True)
            meta['index'] = index
            args = []
            last = None
            checkpoint = reader.tell()

            try:
                for prefix, name, spec, conversion in formatter.parse(
                        mainspec):
                    if prefix:
                        args.extend(reader.read_struct(fixorder(prefix)))
                    if name is None:
                        continue
                    if conversion:
                        reader.byte_align(
                            PythonExpression.evaluate(conversion, meta))
                    if spec:
                        spec = meta.format_str(spec, self.codec, args)
                    if spec != '':
                        try:
                            spec = PythonExpression.evaluate(spec, meta)
                        except ParserError:
                            pass
                    if spec == '':
                        last = value = reader.read()
                    elif isinstance(spec, int):
                        last = value = reader.read_bytes(spec)
                    else:
                        value = reader.read_struct(fixorder(spec))
                        if not value:
                            self.log_warn(F'field {name} was empty, ignoring.')
                            continue
                        if len(value) > 1:
                            self.log_info(
                                F'parsing field {name} produced {len(value)} items reading a tuple'
                            )
                        else:
                            value = value[0]

                    args.append(value)

                    if name == _SHARP:
                        raise ValueError(
                            'Extracting a field with name # is forbidden.')
                    elif name.isdecimal():
                        index = int(name)
                        limit = len(args) - 1
                        if index > limit:
                            self.log_warn(
                                F'cannot assign index field {name}, the highest index is {limit}'
                            )
                        else:
                            args[index] = value
                        continue
                    elif name:
                        meta[name] = value

                if until and not until(meta):
                    self.log_info(
                        F'the expression ({until}) evaluated to zero; aborting.'
                    )
                    break

                with StreamDetour(reader, checkpoint) as detour:
                    full = reader.read(detour.cursor - checkpoint)
                if last is None:
                    last = full

                outputs = []

                for template in self.args.outputs:
                    used = set()
                    outputs.append(
                        meta.format(template,
                                    self.codec, [full, *args], {_SHARP: last},
                                    True,
                                    used=used))
                    for key in used:
                        meta.pop(key, None)

                for output in outputs:
                    chunk = self.labelled(output, **meta)
                    chunk.set_next_batch(index)
                    yield chunk

            except EOF:
                leftover = repr(SizeInt(len(reader) - checkpoint)).strip()
                self.log_info(F'discarding {leftover} left in buffer')
                break