Exemple #1
0
    def process(self, data):
        pcapkit = self._pcapkit
        logging.getLogger('pcapkit').disabled = True
        merge = self.args.merge

        with VirtualFileSystem() as fs:
            vf = VirtualFile(fs, data, 'pcap')
            extraction = pcapkit.extract(
                fin=vf.path, engine='scapy', store=False, nofile=True, extension=False, tcp=True, strict=True)
            tcp: list = list(extraction.reassembly.tcp)

        count, convo = 0, None
        src_buffer = MemoryFile()
        dst_buffer = MemoryFile()
        for stream in tcp:
            this_convo = Conversation.FromID(stream.id)
            if this_convo != convo:
                if count and merge:
                    if src_buffer.tell():
                        yield self.labelled(src_buffer.getvalue(), **convo.src_to_dst())
                        src_buffer.truncate(0)
                    if dst_buffer.tell():
                        yield self.labelled(dst_buffer.getvalue(), **convo.dst_to_src())
                        dst_buffer.truncate(0)
                count = count + 1
                convo = this_convo
            for packet in stream.packets:
                if not merge:
                    yield self.labelled(packet.data, **this_convo.src_to_dst(), stream=count)
                elif this_convo.src == convo.src:
                    src_buffer.write(packet.data)
                elif this_convo.dst == convo.src:
                    dst_buffer.write(packet.data)
                else:
                    raise RuntimeError(F'direction of packet {convo!s} in conversation {count} is unknown')
Exemple #2
0
 def _generate_bytes(self, data: ByteString):
     if not self.squeeze:
         yield from self.action(data)
         return
     buffer = MemoryFile(bytearray())
     for item in self.action(data):
         buffer.write(item)
     yield buffer.getbuffer()
Exemple #3
0
 def _decompress_mszip(self,
                       reader: StructReader,
                       writer: MemoryFile,
                       target: Optional[int] = None):
     header = bytes(reader.read(2))
     if header != B'CK':
         raise ValueError(
             F'chunk did not begin with CK header, got {header!r} instead')
     decompress = zlib.decompressobj(-zlib.MAX_WBITS,
                                     zdict=writer.getbuffer())
     writer.write(decompress.decompress(reader.read()))
     writer.write(decompress.flush())
Exemple #4
0
 def test_write_iterables(self):
     builder = MemoryFile()
     builder.write(B'FOO BAR BAR FOO FOO')
     builder.seek(4)
     builder.write(itertools.repeat(B'X'[0], 7))
     self.assertEqual(builder.getbuffer(), B'FOO XXXXXXX FOO FOO')
     builder.seekset(len(builder))
     builder.write(B' ')
     builder.write(itertools.repeat(B'X'[0], 4))
     self.assertEqual(builder.getbuffer(), B'FOO XXXXXXX FOO FOO XXXX')
Exemple #5
0
 def _generate_chunks(self, parent: Chunk):
     if not self.squeeze:
         for item in self.action(parent):
             yield copy.copy(item).inherit(parent)
         return
     it = self.action(parent)
     try:
         header = next(it)
     except StopIteration:
         return
     else:
         header.inherit(parent)
         buffer = MemoryFile(header)
         buffer.seek(len(header))
     for item in it:
         header &= item
         buffer.write(item)
     yield header
Exemple #6
0
 def test_string_builder(self):
     builder = MemoryFile()
     self.assertTrue(builder.writable())
     builder.write(B'The binary refinery ')
     builder.write(B'refines the finer binaries.')
     builder.seekrel(-1)
     builder.write(B'!')
     self.assertEqual(builder.getbuffer(),
                      B'The binary refinery refines the finer binaries!')
Exemple #7
0
 def _decompress_xpress(self,
                        reader: StructReader,
                        writer: MemoryFile,
                        target: Optional[int] = None) -> bytearray:
     if target is not None:
         target += writer.tell()
     flags = BitBufferedReader(reader)
     nibble_cache = None
     while not reader.eof:
         if target is not None and writer.tell() >= target:
             return
         if not flags.next():
             writer.write(reader.read(1))
             continue
         offset, length = divmod(reader.u16(), 8)
         offset += 1
         if length == 7:
             length = nibble_cache
             if length is None:
                 length_pair = reader.u8()
                 nibble_cache = length_pair >> 4
                 length = length_pair & 0xF
             else:
                 nibble_cache = None
             if length == 15:
                 length = reader.u8()
                 if length == 0xFF:
                     length = reader.u16() or reader.u32()
                     length -= 22
                     if length < 0:
                         raise RuntimeError(
                             F'Invalid match length of {length} for long delta sequence'
                         )
                 length += 15
             length += 7
         length += 3
         writer.replay(offset, length)
Exemple #8
0
    def format(self,
               spec: str,
               codec: str,
               args: Union[list, tuple],
               symb: dict,
               binary: bool,
               fixup: bool = True,
               used: Optional[set] = None,
               escaped: bool = False) -> Union[str, ByteString]:
        """
        Formats a string using Python-like string fomatting syntax. The formatter for `binary`
        mode is different; each formatting is documented in one of the following two proxy methods:

        - `refinery.lib.meta.LazyMetaOracle.format_str`
        - `refinery.lib.meta.LazyMetaOracle.format_bin`
        """
        from refinery.lib.argformats import multibin, ParserError, PythonExpression
        # prevents circular import

        symb = symb or {}

        if used is None:

            class dummy:
                def add(self, _):
                    pass

            used = dummy()

        if args is None:
            args = ()
        elif not isinstance(args, (list, tuple)):
            args = list(args)

        if fixup:
            for (store, it) in (
                (args, enumerate(args)),
                (self, self.items()),
                (symb, symb.items()),
            ):
                for key, value in it:
                    with contextlib.suppress(TypeError):
                        if isinstance(value, CustomStringRepresentation):
                            continue
                        store[key] = ByteStringWrapper(value, codec)

        formatter = string.Formatter()
        autoindex = 0

        if binary:
            stream = MemoryFile()

            def putstr(s: str):
                stream.write(s.encode(codec))
        else:
            stream = StringIO()
            putstr = stream.write

        with stream:
            for prefix, field, modifier, conversion in formatter.parse(spec):
                output = value = None
                if prefix:
                    if binary:
                        prefix = prefix.encode(codec)
                    elif escaped:
                        prefix = prefix.encode('raw-unicode-escape').decode(
                            'unicode-escape')
                    stream.write(prefix)
                if field is None:
                    continue
                if not field:
                    if not args:
                        raise LookupError(
                            'no positional arguments given to formatter')
                    value = args[autoindex]
                    used.add(autoindex)
                    if autoindex < len(args) - 1:
                        autoindex += 1
                if binary and conversion:
                    conversion = conversion.lower()
                    if conversion == 'h':
                        value = bytes.fromhex(field)
                    elif conversion == 'q':
                        value = unquote_to_bytes(field)
                    elif conversion == 's':
                        value = field.encode(codec)
                    elif conversion == 'u':
                        value = field.encode('utf-16le')
                    elif conversion == 'a':
                        value = field.encode('latin1')
                    elif conversion == 'e':
                        value = field.encode(codec).decode(
                            'unicode-escape').encode('latin1')
                elif field in symb:
                    value = symb[field]
                    used.add(field)
                if value is None:
                    with contextlib.suppress(ValueError, IndexError):
                        index = int(field, 0)
                        value = args[index]
                        used.add(index)
                if value is None:
                    with contextlib.suppress(KeyError):
                        value = self[field]
                        used.add(field)
                if value is None:
                    try:
                        expression = PythonExpression(field, *self, *symb)
                        value = expression(self, **symb)
                    except ParserError:
                        if not self.ghost:
                            raise KeyError(field)
                        putstr(F'{{{field}')
                        if conversion:
                            putstr(F'!{conversion}')
                        if modifier:
                            putstr(F':{modifier}')
                        putstr('}')
                        continue
                if binary:
                    modifier = modifier.strip()
                    if modifier:
                        expression = self.format(modifier, codec, args, symb,
                                                 True, False, used)
                        output = multibin(expression.decode(codec),
                                          reverse=True,
                                          seed=value)
                    elif isbuffer(value):
                        output = value
                    elif not isinstance(value, int):
                        with contextlib.suppress(TypeError):
                            output = bytes(value)
                if output is None:
                    converter = {
                        'a': ascii,
                        's': str,
                        'r': repr,
                        'H': lambda b: b.hex().upper(),
                        'h': lambda b: b.hex(),
                        'u': lambda b: b.decode('utf-16le'),
                        'e': lambda b: repr(bytes(b)).lstrip('bBrR')[1:-1],
                        'q': lambda b: quote_from_bytes(bytes(b))
                    }.get(conversion)
                    if converter:
                        output = converter(value)
                    elif modifier:
                        output = value
                    elif isinstance(value, CustomStringRepresentation):
                        output = str(value)
                    elif isbuffer(value):
                        output = value.decode('utf8', errors='replace')
                    else:
                        output = value
                    output = output.__format__(modifier)
                    if binary:
                        output = output.encode(codec)
                stream.write(output)
            return stream.getvalue()
Exemple #9
0
class blz(Unit):
    """
    BriefLZ compression and decompression. The compression algorithm uses a pure Python suffix tree
    implementation: It requires a lot of time & memory.
    """
    def _begin(self, data):
        self._src = StructReader(memoryview(data))
        self._dst = MemoryFile(bytearray())
        return self

    def _reset(self):
        self._src.seek(0)
        self._dst.seek(0)
        self._dst.truncate()
        return self

    def _decompress(self):
        (
            signature,
            version,
            src_count,
            src_crc32,
            dst_count,
            dst_crc32,
        ) = self._src.read_struct('>6L')
        if signature != 0x626C7A1A:
            raise ValueError(F'Invalid BriefLZ signature: {signature:08X}, should be 626C7A1A.')
        if version > 10:
            raise ValueError(F'Invalid version number {version}, should be less than 10.')
        self.log_debug(F'signature: 0x{signature:08X} V{version}')
        self.log_debug(F'src count: 0x{src_count:08X}')
        self.log_debug(F'src crc32: 0x{src_crc32:08X}')
        self.log_debug(F'dst count: 0x{dst_count:08X}')
        self.log_debug(F'dst crc32: 0x{dst_crc32:08X}')
        src = self._src.getbuffer()
        src = src[24:24 + src_count]
        if len(src) < src_count:
            self.log_warn(F'Only {len(src)} bytes in buffer, but header annoucned a length of {src_count}.')
        if src_crc32:
            check = zlib.crc32(src)
            if check != src_crc32:
                self.log_warn(F'Invalid source data CRC {check:08X}, should be {src_crc32:08X}.')
        dst = self._decompress_chunk(dst_count)
        if not dst_crc32:
            return dst
        check = zlib.crc32(dst)
        if check != dst_crc32:
            self.log_warn(F'Invalid result data CRC {check:08X}, should be {dst_crc32:08X}.')
        return dst

    def _decompress_modded(self):
        self._src.seekrel(8)
        total_size = self._src.u64()
        chunk_size = self._src.u64()
        remaining = total_size
        self.log_debug(F'total size: 0x{total_size:016X}')
        self.log_debug(F'chunk size: 0x{chunk_size:016X}')
        while remaining > chunk_size:
            self._decompress_chunk(chunk_size)
            remaining -= chunk_size
        return self._decompress_chunk(remaining)

    def _decompress_chunk(self, size=None):
        bitcount = 0
        bitstore = 0
        decompressed = 1

        def readbit():
            nonlocal bitcount, bitstore
            if not bitcount:
                bitstore = int.from_bytes(self._src.read_exactly(2), 'little')
                bitcount = 0xF
            else:
                bitcount = bitcount - 1
            return (bitstore >> bitcount) & 1

        def readint():
            result = 2 + readbit()
            while readbit():
                result <<= 1
                result += readbit()
            return result

        self._dst.write(self._src.read_exactly(1))

        try:
            while not size or decompressed < size:
                if readbit():
                    length = readint() + 2
                    sector = readint() - 2
                    offset = self._src.read_byte() + 1
                    delta = offset + 0x100 * sector
                    available = self._dst.tell()
                    if delta not in range(available + 1):
                        raise RefineryPartialResult(
                            F'Requested rewind by 0x{delta:08X} bytes with only 0x{available:08X} bytes in output buffer.',
                            partial=self._dst.getvalue())
                    quotient, remainder = divmod(length, delta)
                    replay = memoryview(self._dst.getbuffer())
                    replay = bytes(replay[-delta:] if quotient else replay[-delta:length - delta])
                    replay = quotient * replay + replay[:remainder]
                    self._dst.write(replay)
                    decompressed += length
                else:
                    self._dst.write(self._src.read_exactly(1))
                    decompressed += 1
        except EOF as E:
            raise RefineryPartialResult(str(E), partial=self._dst.getbuffer())
        dst = self._dst.getbuffer()
        if decompressed < size:
            raise RefineryPartialResult(
                F'Attempted to decompress {size} bytes, got only {len(dst)}.', dst)
        if decompressed > size:
            raise RuntimeError('Decompressed buffer contained more bytes than expected.')
        return dst

    def _compress(self):
        from refinery.lib.suffixtree import SuffixTree

        try:
            self.log_info('computing suffix tree')
            tree = SuffixTree(self._src.getbuffer())
        except Exception:
            raise

        bitstore = 0  # The bit stream to be written
        bitcount = 0  # The number of bits in the bit stream
        buffer = MemoryFile(bytearray())

        # Write empty header and first byte of source
        self._dst.write(bytearray(24))
        self._dst.write(self._src.read_exactly(1))

        def writeint(n: int) -> None:
            """
            Write an integer to the bit stream.
            """
            nonlocal bitstore, bitcount
            nbits = n.bit_length()
            if nbits < 2:
                raise ValueError
            # The highest bit is implicitly assumed:
            n ^= 1 << (nbits - 1)
            remaining = nbits - 2
            while remaining:
                remaining -= 1
                bitstore <<= 2
                bitcount += 2
                bitstore |= ((n >> remaining) & 3) | 1
            bitstore <<= 2
            bitcount += 2
            bitstore |= (n & 1) << 1

        src = self._src.getbuffer()
        remaining = len(src) - 1
        self.log_info('compressing data')

        while True:
            cursor = len(src) - remaining
            rest = src[cursor:]
            if bitcount >= 0x10:
                block_count, bitcount = divmod(bitcount, 0x10)
                info_channel = bitstore >> bitcount
                bitstore = info_channel << bitcount ^ bitstore
                # The decompressor will read bits from top to bottom, and each 16 bit block has to be
                # little-endian encoded. The bit stream is encoded top to bottom bit in the bitstore
                # variable, and by encoding it as a big endian integer, the stream is in the correct
                # order. However, we need to swap adjacent bytes to achieve little endian encoding for
                # each of the blocks:
                info_channel = bytearray(info_channel.to_bytes(block_count * 2, 'big'))
                for k in range(block_count):
                    k0 = 2 * k + 0
                    k1 = 2 * k + 1
                    info_channel[k0], info_channel[k1] = info_channel[k1], info_channel[k0]
                info_channel = memoryview(info_channel)
                data_channel = memoryview(buffer.getbuffer())
                self._dst.write(info_channel[:2])
                self._dst.write(data_channel[:-1])
                self._dst.write(info_channel[2:])
                data_channel = bytes(data_channel[-1:])
                buffer.truncate(0)
                store = buffer if bitcount else self._dst
                store.write(data_channel)
            if remaining + bitcount < 0x10:
                buffer = buffer.getbuffer()
                if rest or buffer:
                    bitstore <<= 0x10 - bitcount
                    self._dst.write(bitstore.to_bytes(2, 'little'))
                    self._dst.write(buffer)
                    self._dst.write(rest)
                elif bitcount:
                    raise RuntimeError('Bitbuffer Overflow')
                break
            node = tree.root
            length = 0
            offset = 0
            sector = None
            while node.children and length < len(rest):
                for child in node.children.values():
                    if tree.data[child.start] == rest[length]:
                        node = child
                        break
                if node.start >= cursor:
                    break
                offset = node.start - length
                length = node.end + 1 - offset
            length = min(remaining, length)
            if length >= 4:
                sector, offset = divmod(cursor - offset - 1, 0x100)
            bitcount += 1
            bitstore <<= 1
            if sector is None:
                buffer.write(rest[:1])
                remaining -= 1
                continue
            bitstore |= 1
            buffer.write(bytes((offset,)))
            writeint(length - 2)
            writeint(sector + 2)
            remaining -= length

        self._dst.seek(24)
        dst = self._dst.peek()
        self._dst.seek(0)
        self._dst.write(struct.pack('>6L', 0x626C7A1A, 1, len(dst), zlib.crc32(dst), len(src), zlib.crc32(src)))
        return self._dst.getbuffer()

    def process(self, data):
        self._begin(data)
        partial = None
        try:
            return self._decompress()
        except ValueError as error:
            if isinstance(error, RefineryPartialResult):
                partial = error
            self.log_warn(F'Reverting to modified BriefLZ after decompression error: {error!s}')
            self._reset()

        try:
            return self._decompress_modded()
        except RefineryPartialResult:
            raise
        except Exception as error:
            if not partial:
                raise
            raise partial from error

    def reverse(self, data):
        return self._begin(data)._compress()
Exemple #10
0
    def _compress(self):
        from refinery.lib.suffixtree import SuffixTree

        try:
            self.log_info('computing suffix tree')
            tree = SuffixTree(self._src.getbuffer())
        except Exception:
            raise

        bitstore = 0  # The bit stream to be written
        bitcount = 0  # The number of bits in the bit stream
        buffer = MemoryFile(bytearray())

        # Write empty header and first byte of source
        self._dst.write(bytearray(24))
        self._dst.write(self._src.read_exactly(1))

        def writeint(n: int) -> None:
            """
            Write an integer to the bit stream.
            """
            nonlocal bitstore, bitcount
            nbits = n.bit_length()
            if nbits < 2:
                raise ValueError
            # The highest bit is implicitly assumed:
            n ^= 1 << (nbits - 1)
            remaining = nbits - 2
            while remaining:
                remaining -= 1
                bitstore <<= 2
                bitcount += 2
                bitstore |= ((n >> remaining) & 3) | 1
            bitstore <<= 2
            bitcount += 2
            bitstore |= (n & 1) << 1

        src = self._src.getbuffer()
        remaining = len(src) - 1
        self.log_info('compressing data')

        while True:
            cursor = len(src) - remaining
            rest = src[cursor:]
            if bitcount >= 0x10:
                block_count, bitcount = divmod(bitcount, 0x10)
                info_channel = bitstore >> bitcount
                bitstore = info_channel << bitcount ^ bitstore
                # The decompressor will read bits from top to bottom, and each 16 bit block has to be
                # little-endian encoded. The bit stream is encoded top to bottom bit in the bitstore
                # variable, and by encoding it as a big endian integer, the stream is in the correct
                # order. However, we need to swap adjacent bytes to achieve little endian encoding for
                # each of the blocks:
                info_channel = bytearray(info_channel.to_bytes(block_count * 2, 'big'))
                for k in range(block_count):
                    k0 = 2 * k + 0
                    k1 = 2 * k + 1
                    info_channel[k0], info_channel[k1] = info_channel[k1], info_channel[k0]
                info_channel = memoryview(info_channel)
                data_channel = memoryview(buffer.getbuffer())
                self._dst.write(info_channel[:2])
                self._dst.write(data_channel[:-1])
                self._dst.write(info_channel[2:])
                data_channel = bytes(data_channel[-1:])
                buffer.truncate(0)
                store = buffer if bitcount else self._dst
                store.write(data_channel)
            if remaining + bitcount < 0x10:
                buffer = buffer.getbuffer()
                if rest or buffer:
                    bitstore <<= 0x10 - bitcount
                    self._dst.write(bitstore.to_bytes(2, 'little'))
                    self._dst.write(buffer)
                    self._dst.write(rest)
                elif bitcount:
                    raise RuntimeError('Bitbuffer Overflow')
                break
            node = tree.root
            length = 0
            offset = 0
            sector = None
            while node.children and length < len(rest):
                for child in node.children.values():
                    if tree.data[child.start] == rest[length]:
                        node = child
                        break
                if node.start >= cursor:
                    break
                offset = node.start - length
                length = node.end + 1 - offset
            length = min(remaining, length)
            if length >= 4:
                sector, offset = divmod(cursor - offset - 1, 0x100)
            bitcount += 1
            bitstore <<= 1
            if sector is None:
                buffer.write(rest[:1])
                remaining -= 1
                continue
            bitstore |= 1
            buffer.write(bytes((offset,)))
            writeint(length - 2)
            writeint(sector + 2)
            remaining -= length

        self._dst.seek(24)
        dst = self._dst.peek()
        self._dst.seek(0)
        self._dst.write(struct.pack('>6L', 0x626C7A1A, 1, len(dst), zlib.crc32(dst), len(src), zlib.crc32(src)))
        return self._dst.getbuffer()
Exemple #11
0
    def decompress_stream(self, data: ByteString, LZOv1: bool = False) -> bytearray:
        """
        An implementation of LZO decompression. We use the article
        "[LZO stream format as understood by Linux's LZO decompressor](https://www.kernel.org/doc/html/latest/staging/lzo.html)"
        as a reference since no proper specification is available.
        """
        def integer() -> int:
            length = 0
            while True:
                byte = src.read_byte()
                if byte:
                    return length + byte
                length += 0xFF
                if length > 0x100000:
                    raise LZOError('Too many zeros in integer encoding.')

        def literal(count):
            dst.write(src.read_bytes(count))

        def copy(distance: int, length: int):
            if distance > len(dst):
                raise LZOError(F'Distance {distance} > bufsize {len(dst)}')
            buffer = dst.getbuffer()
            if distance > length:
                start = len(buffer) - distance
                end = start + length
                dst.write(buffer[start:end])
            else:
                block = buffer[-distance:]
                while len(block) < length:
                    block += block[:length - len(block)]
                if len(block) > length:
                    block[length:] = ()
                dst.write(block)

        src = StructReader(memoryview(data))
        dst = MemoryFile()

        state = 0
        first = src.read_byte()

        if first == 0x10:
            raise LZOError('Invalid first stream byte 0x10.')
        elif first <= 0x12:
            src.seekrel(-1)
        elif first <= 0x15:
            state = first - 0x11
            literal(state)
        else:
            state = 4
            literal(first - 0x11)

        while True:
            instruction = src.read_byte()
            if instruction < 0x10:
                if state == 0:
                    length = instruction or integer() + 15
                    state = length + 3
                    if state < 4:
                        raise LZOError('Literal encoding is too short.')
                else:
                    state = instruction & 0b0011
                    D = (instruction & 0b1100) >> 2
                    H = src.read_byte()
                    distance = (H << 2) + D + 1
                    if state >= 4:
                        distance += 0x800
                        length = 3
                    else:
                        length = 2
                    copy(distance, length)
            elif instruction < 0x20:
                L = instruction & 0b0111
                H = instruction & 0b1000
                length = L or integer() + 7
                argument = src.u16()
                state = argument & 3
                distance = (H << 11) + (argument >> 2)
                if not distance:
                    return dst.getbuffer()
                if LZOv1 and distance & 0x803F == 0x803F and length in range(261, 265):
                    raise LZOError('Compressed data contains sequence that is banned in LZOv1.')
                if LZOv1 and distance == 0xBFFF:
                    X = src.read_byte()
                    count = ((X << 3) | L) + 4
                    self.log_debug(F'Writing run of {X} zero bytes according to LZOv1.')
                    dst.write(B'\0' * count)
                else:
                    copy(distance + 0x4000, length + 2)
            elif instruction < 0x40:
                L = instruction & 0b11111
                length = L or integer() + 31
                argument = src.u16()
                state = argument & 3
                distance = (argument >> 2) + 1
                copy(distance, length + 2)
            else:
                if instruction < 0x80:
                    length = 3 + ((instruction >> 5) & 1)
                else:
                    length = 5 + ((instruction >> 5) & 3)
                H = src.read_byte()
                D = (instruction & 0b11100) >> 2
                state = instruction & 3
                distance = (H << 3) + D + 1
                copy(distance, length)
            if state:
                literal(state)