def process(self, data): pcapkit = self._pcapkit logging.getLogger('pcapkit').disabled = True merge = self.args.merge with VirtualFileSystem() as fs: vf = VirtualFile(fs, data, 'pcap') extraction = pcapkit.extract( fin=vf.path, engine='scapy', store=False, nofile=True, extension=False, tcp=True, strict=True) tcp: list = list(extraction.reassembly.tcp) count, convo = 0, None src_buffer = MemoryFile() dst_buffer = MemoryFile() for stream in tcp: this_convo = Conversation.FromID(stream.id) if this_convo != convo: if count and merge: if src_buffer.tell(): yield self.labelled(src_buffer.getvalue(), **convo.src_to_dst()) src_buffer.truncate(0) if dst_buffer.tell(): yield self.labelled(dst_buffer.getvalue(), **convo.dst_to_src()) dst_buffer.truncate(0) count = count + 1 convo = this_convo for packet in stream.packets: if not merge: yield self.labelled(packet.data, **this_convo.src_to_dst(), stream=count) elif this_convo.src == convo.src: src_buffer.write(packet.data) elif this_convo.dst == convo.src: dst_buffer.write(packet.data) else: raise RuntimeError(F'direction of packet {convo!s} in conversation {count} is unknown')
def _generate_bytes(self, data: ByteString): if not self.squeeze: yield from self.action(data) return buffer = MemoryFile(bytearray()) for item in self.action(data): buffer.write(item) yield buffer.getbuffer()
def _decompress_mszip(self, reader: StructReader, writer: MemoryFile, target: Optional[int] = None): header = bytes(reader.read(2)) if header != B'CK': raise ValueError( F'chunk did not begin with CK header, got {header!r} instead') decompress = zlib.decompressobj(-zlib.MAX_WBITS, zdict=writer.getbuffer()) writer.write(decompress.decompress(reader.read())) writer.write(decompress.flush())
def test_write_iterables(self): builder = MemoryFile() builder.write(B'FOO BAR BAR FOO FOO') builder.seek(4) builder.write(itertools.repeat(B'X'[0], 7)) self.assertEqual(builder.getbuffer(), B'FOO XXXXXXX FOO FOO') builder.seekset(len(builder)) builder.write(B' ') builder.write(itertools.repeat(B'X'[0], 4)) self.assertEqual(builder.getbuffer(), B'FOO XXXXXXX FOO FOO XXXX')
def _generate_chunks(self, parent: Chunk): if not self.squeeze: for item in self.action(parent): yield copy.copy(item).inherit(parent) return it = self.action(parent) try: header = next(it) except StopIteration: return else: header.inherit(parent) buffer = MemoryFile(header) buffer.seek(len(header)) for item in it: header &= item buffer.write(item) yield header
def test_string_builder(self): builder = MemoryFile() self.assertTrue(builder.writable()) builder.write(B'The binary refinery ') builder.write(B'refines the finer binaries.') builder.seekrel(-1) builder.write(B'!') self.assertEqual(builder.getbuffer(), B'The binary refinery refines the finer binaries!')
def _decompress_xpress(self, reader: StructReader, writer: MemoryFile, target: Optional[int] = None) -> bytearray: if target is not None: target += writer.tell() flags = BitBufferedReader(reader) nibble_cache = None while not reader.eof: if target is not None and writer.tell() >= target: return if not flags.next(): writer.write(reader.read(1)) continue offset, length = divmod(reader.u16(), 8) offset += 1 if length == 7: length = nibble_cache if length is None: length_pair = reader.u8() nibble_cache = length_pair >> 4 length = length_pair & 0xF else: nibble_cache = None if length == 15: length = reader.u8() if length == 0xFF: length = reader.u16() or reader.u32() length -= 22 if length < 0: raise RuntimeError( F'Invalid match length of {length} for long delta sequence' ) length += 15 length += 7 length += 3 writer.replay(offset, length)
def format(self, spec: str, codec: str, args: Union[list, tuple], symb: dict, binary: bool, fixup: bool = True, used: Optional[set] = None, escaped: bool = False) -> Union[str, ByteString]: """ Formats a string using Python-like string fomatting syntax. The formatter for `binary` mode is different; each formatting is documented in one of the following two proxy methods: - `refinery.lib.meta.LazyMetaOracle.format_str` - `refinery.lib.meta.LazyMetaOracle.format_bin` """ from refinery.lib.argformats import multibin, ParserError, PythonExpression # prevents circular import symb = symb or {} if used is None: class dummy: def add(self, _): pass used = dummy() if args is None: args = () elif not isinstance(args, (list, tuple)): args = list(args) if fixup: for (store, it) in ( (args, enumerate(args)), (self, self.items()), (symb, symb.items()), ): for key, value in it: with contextlib.suppress(TypeError): if isinstance(value, CustomStringRepresentation): continue store[key] = ByteStringWrapper(value, codec) formatter = string.Formatter() autoindex = 0 if binary: stream = MemoryFile() def putstr(s: str): stream.write(s.encode(codec)) else: stream = StringIO() putstr = stream.write with stream: for prefix, field, modifier, conversion in formatter.parse(spec): output = value = None if prefix: if binary: prefix = prefix.encode(codec) elif escaped: prefix = prefix.encode('raw-unicode-escape').decode( 'unicode-escape') stream.write(prefix) if field is None: continue if not field: if not args: raise LookupError( 'no positional arguments given to formatter') value = args[autoindex] used.add(autoindex) if autoindex < len(args) - 1: autoindex += 1 if binary and conversion: conversion = conversion.lower() if conversion == 'h': value = bytes.fromhex(field) elif conversion == 'q': value = unquote_to_bytes(field) elif conversion == 's': value = field.encode(codec) elif conversion == 'u': value = field.encode('utf-16le') elif conversion == 'a': value = field.encode('latin1') elif conversion == 'e': value = field.encode(codec).decode( 'unicode-escape').encode('latin1') elif field in symb: value = symb[field] used.add(field) if value is None: with contextlib.suppress(ValueError, IndexError): index = int(field, 0) value = args[index] used.add(index) if value is None: with contextlib.suppress(KeyError): value = self[field] used.add(field) if value is None: try: expression = PythonExpression(field, *self, *symb) value = expression(self, **symb) except ParserError: if not self.ghost: raise KeyError(field) putstr(F'{{{field}') if conversion: putstr(F'!{conversion}') if modifier: putstr(F':{modifier}') putstr('}') continue if binary: modifier = modifier.strip() if modifier: expression = self.format(modifier, codec, args, symb, True, False, used) output = multibin(expression.decode(codec), reverse=True, seed=value) elif isbuffer(value): output = value elif not isinstance(value, int): with contextlib.suppress(TypeError): output = bytes(value) if output is None: converter = { 'a': ascii, 's': str, 'r': repr, 'H': lambda b: b.hex().upper(), 'h': lambda b: b.hex(), 'u': lambda b: b.decode('utf-16le'), 'e': lambda b: repr(bytes(b)).lstrip('bBrR')[1:-1], 'q': lambda b: quote_from_bytes(bytes(b)) }.get(conversion) if converter: output = converter(value) elif modifier: output = value elif isinstance(value, CustomStringRepresentation): output = str(value) elif isbuffer(value): output = value.decode('utf8', errors='replace') else: output = value output = output.__format__(modifier) if binary: output = output.encode(codec) stream.write(output) return stream.getvalue()
class blz(Unit): """ BriefLZ compression and decompression. The compression algorithm uses a pure Python suffix tree implementation: It requires a lot of time & memory. """ def _begin(self, data): self._src = StructReader(memoryview(data)) self._dst = MemoryFile(bytearray()) return self def _reset(self): self._src.seek(0) self._dst.seek(0) self._dst.truncate() return self def _decompress(self): ( signature, version, src_count, src_crc32, dst_count, dst_crc32, ) = self._src.read_struct('>6L') if signature != 0x626C7A1A: raise ValueError(F'Invalid BriefLZ signature: {signature:08X}, should be 626C7A1A.') if version > 10: raise ValueError(F'Invalid version number {version}, should be less than 10.') self.log_debug(F'signature: 0x{signature:08X} V{version}') self.log_debug(F'src count: 0x{src_count:08X}') self.log_debug(F'src crc32: 0x{src_crc32:08X}') self.log_debug(F'dst count: 0x{dst_count:08X}') self.log_debug(F'dst crc32: 0x{dst_crc32:08X}') src = self._src.getbuffer() src = src[24:24 + src_count] if len(src) < src_count: self.log_warn(F'Only {len(src)} bytes in buffer, but header annoucned a length of {src_count}.') if src_crc32: check = zlib.crc32(src) if check != src_crc32: self.log_warn(F'Invalid source data CRC {check:08X}, should be {src_crc32:08X}.') dst = self._decompress_chunk(dst_count) if not dst_crc32: return dst check = zlib.crc32(dst) if check != dst_crc32: self.log_warn(F'Invalid result data CRC {check:08X}, should be {dst_crc32:08X}.') return dst def _decompress_modded(self): self._src.seekrel(8) total_size = self._src.u64() chunk_size = self._src.u64() remaining = total_size self.log_debug(F'total size: 0x{total_size:016X}') self.log_debug(F'chunk size: 0x{chunk_size:016X}') while remaining > chunk_size: self._decompress_chunk(chunk_size) remaining -= chunk_size return self._decompress_chunk(remaining) def _decompress_chunk(self, size=None): bitcount = 0 bitstore = 0 decompressed = 1 def readbit(): nonlocal bitcount, bitstore if not bitcount: bitstore = int.from_bytes(self._src.read_exactly(2), 'little') bitcount = 0xF else: bitcount = bitcount - 1 return (bitstore >> bitcount) & 1 def readint(): result = 2 + readbit() while readbit(): result <<= 1 result += readbit() return result self._dst.write(self._src.read_exactly(1)) try: while not size or decompressed < size: if readbit(): length = readint() + 2 sector = readint() - 2 offset = self._src.read_byte() + 1 delta = offset + 0x100 * sector available = self._dst.tell() if delta not in range(available + 1): raise RefineryPartialResult( F'Requested rewind by 0x{delta:08X} bytes with only 0x{available:08X} bytes in output buffer.', partial=self._dst.getvalue()) quotient, remainder = divmod(length, delta) replay = memoryview(self._dst.getbuffer()) replay = bytes(replay[-delta:] if quotient else replay[-delta:length - delta]) replay = quotient * replay + replay[:remainder] self._dst.write(replay) decompressed += length else: self._dst.write(self._src.read_exactly(1)) decompressed += 1 except EOF as E: raise RefineryPartialResult(str(E), partial=self._dst.getbuffer()) dst = self._dst.getbuffer() if decompressed < size: raise RefineryPartialResult( F'Attempted to decompress {size} bytes, got only {len(dst)}.', dst) if decompressed > size: raise RuntimeError('Decompressed buffer contained more bytes than expected.') return dst def _compress(self): from refinery.lib.suffixtree import SuffixTree try: self.log_info('computing suffix tree') tree = SuffixTree(self._src.getbuffer()) except Exception: raise bitstore = 0 # The bit stream to be written bitcount = 0 # The number of bits in the bit stream buffer = MemoryFile(bytearray()) # Write empty header and first byte of source self._dst.write(bytearray(24)) self._dst.write(self._src.read_exactly(1)) def writeint(n: int) -> None: """ Write an integer to the bit stream. """ nonlocal bitstore, bitcount nbits = n.bit_length() if nbits < 2: raise ValueError # The highest bit is implicitly assumed: n ^= 1 << (nbits - 1) remaining = nbits - 2 while remaining: remaining -= 1 bitstore <<= 2 bitcount += 2 bitstore |= ((n >> remaining) & 3) | 1 bitstore <<= 2 bitcount += 2 bitstore |= (n & 1) << 1 src = self._src.getbuffer() remaining = len(src) - 1 self.log_info('compressing data') while True: cursor = len(src) - remaining rest = src[cursor:] if bitcount >= 0x10: block_count, bitcount = divmod(bitcount, 0x10) info_channel = bitstore >> bitcount bitstore = info_channel << bitcount ^ bitstore # The decompressor will read bits from top to bottom, and each 16 bit block has to be # little-endian encoded. The bit stream is encoded top to bottom bit in the bitstore # variable, and by encoding it as a big endian integer, the stream is in the correct # order. However, we need to swap adjacent bytes to achieve little endian encoding for # each of the blocks: info_channel = bytearray(info_channel.to_bytes(block_count * 2, 'big')) for k in range(block_count): k0 = 2 * k + 0 k1 = 2 * k + 1 info_channel[k0], info_channel[k1] = info_channel[k1], info_channel[k0] info_channel = memoryview(info_channel) data_channel = memoryview(buffer.getbuffer()) self._dst.write(info_channel[:2]) self._dst.write(data_channel[:-1]) self._dst.write(info_channel[2:]) data_channel = bytes(data_channel[-1:]) buffer.truncate(0) store = buffer if bitcount else self._dst store.write(data_channel) if remaining + bitcount < 0x10: buffer = buffer.getbuffer() if rest or buffer: bitstore <<= 0x10 - bitcount self._dst.write(bitstore.to_bytes(2, 'little')) self._dst.write(buffer) self._dst.write(rest) elif bitcount: raise RuntimeError('Bitbuffer Overflow') break node = tree.root length = 0 offset = 0 sector = None while node.children and length < len(rest): for child in node.children.values(): if tree.data[child.start] == rest[length]: node = child break if node.start >= cursor: break offset = node.start - length length = node.end + 1 - offset length = min(remaining, length) if length >= 4: sector, offset = divmod(cursor - offset - 1, 0x100) bitcount += 1 bitstore <<= 1 if sector is None: buffer.write(rest[:1]) remaining -= 1 continue bitstore |= 1 buffer.write(bytes((offset,))) writeint(length - 2) writeint(sector + 2) remaining -= length self._dst.seek(24) dst = self._dst.peek() self._dst.seek(0) self._dst.write(struct.pack('>6L', 0x626C7A1A, 1, len(dst), zlib.crc32(dst), len(src), zlib.crc32(src))) return self._dst.getbuffer() def process(self, data): self._begin(data) partial = None try: return self._decompress() except ValueError as error: if isinstance(error, RefineryPartialResult): partial = error self.log_warn(F'Reverting to modified BriefLZ after decompression error: {error!s}') self._reset() try: return self._decompress_modded() except RefineryPartialResult: raise except Exception as error: if not partial: raise raise partial from error def reverse(self, data): return self._begin(data)._compress()
def _compress(self): from refinery.lib.suffixtree import SuffixTree try: self.log_info('computing suffix tree') tree = SuffixTree(self._src.getbuffer()) except Exception: raise bitstore = 0 # The bit stream to be written bitcount = 0 # The number of bits in the bit stream buffer = MemoryFile(bytearray()) # Write empty header and first byte of source self._dst.write(bytearray(24)) self._dst.write(self._src.read_exactly(1)) def writeint(n: int) -> None: """ Write an integer to the bit stream. """ nonlocal bitstore, bitcount nbits = n.bit_length() if nbits < 2: raise ValueError # The highest bit is implicitly assumed: n ^= 1 << (nbits - 1) remaining = nbits - 2 while remaining: remaining -= 1 bitstore <<= 2 bitcount += 2 bitstore |= ((n >> remaining) & 3) | 1 bitstore <<= 2 bitcount += 2 bitstore |= (n & 1) << 1 src = self._src.getbuffer() remaining = len(src) - 1 self.log_info('compressing data') while True: cursor = len(src) - remaining rest = src[cursor:] if bitcount >= 0x10: block_count, bitcount = divmod(bitcount, 0x10) info_channel = bitstore >> bitcount bitstore = info_channel << bitcount ^ bitstore # The decompressor will read bits from top to bottom, and each 16 bit block has to be # little-endian encoded. The bit stream is encoded top to bottom bit in the bitstore # variable, and by encoding it as a big endian integer, the stream is in the correct # order. However, we need to swap adjacent bytes to achieve little endian encoding for # each of the blocks: info_channel = bytearray(info_channel.to_bytes(block_count * 2, 'big')) for k in range(block_count): k0 = 2 * k + 0 k1 = 2 * k + 1 info_channel[k0], info_channel[k1] = info_channel[k1], info_channel[k0] info_channel = memoryview(info_channel) data_channel = memoryview(buffer.getbuffer()) self._dst.write(info_channel[:2]) self._dst.write(data_channel[:-1]) self._dst.write(info_channel[2:]) data_channel = bytes(data_channel[-1:]) buffer.truncate(0) store = buffer if bitcount else self._dst store.write(data_channel) if remaining + bitcount < 0x10: buffer = buffer.getbuffer() if rest or buffer: bitstore <<= 0x10 - bitcount self._dst.write(bitstore.to_bytes(2, 'little')) self._dst.write(buffer) self._dst.write(rest) elif bitcount: raise RuntimeError('Bitbuffer Overflow') break node = tree.root length = 0 offset = 0 sector = None while node.children and length < len(rest): for child in node.children.values(): if tree.data[child.start] == rest[length]: node = child break if node.start >= cursor: break offset = node.start - length length = node.end + 1 - offset length = min(remaining, length) if length >= 4: sector, offset = divmod(cursor - offset - 1, 0x100) bitcount += 1 bitstore <<= 1 if sector is None: buffer.write(rest[:1]) remaining -= 1 continue bitstore |= 1 buffer.write(bytes((offset,))) writeint(length - 2) writeint(sector + 2) remaining -= length self._dst.seek(24) dst = self._dst.peek() self._dst.seek(0) self._dst.write(struct.pack('>6L', 0x626C7A1A, 1, len(dst), zlib.crc32(dst), len(src), zlib.crc32(src))) return self._dst.getbuffer()
def decompress_stream(self, data: ByteString, LZOv1: bool = False) -> bytearray: """ An implementation of LZO decompression. We use the article "[LZO stream format as understood by Linux's LZO decompressor](https://www.kernel.org/doc/html/latest/staging/lzo.html)" as a reference since no proper specification is available. """ def integer() -> int: length = 0 while True: byte = src.read_byte() if byte: return length + byte length += 0xFF if length > 0x100000: raise LZOError('Too many zeros in integer encoding.') def literal(count): dst.write(src.read_bytes(count)) def copy(distance: int, length: int): if distance > len(dst): raise LZOError(F'Distance {distance} > bufsize {len(dst)}') buffer = dst.getbuffer() if distance > length: start = len(buffer) - distance end = start + length dst.write(buffer[start:end]) else: block = buffer[-distance:] while len(block) < length: block += block[:length - len(block)] if len(block) > length: block[length:] = () dst.write(block) src = StructReader(memoryview(data)) dst = MemoryFile() state = 0 first = src.read_byte() if first == 0x10: raise LZOError('Invalid first stream byte 0x10.') elif first <= 0x12: src.seekrel(-1) elif first <= 0x15: state = first - 0x11 literal(state) else: state = 4 literal(first - 0x11) while True: instruction = src.read_byte() if instruction < 0x10: if state == 0: length = instruction or integer() + 15 state = length + 3 if state < 4: raise LZOError('Literal encoding is too short.') else: state = instruction & 0b0011 D = (instruction & 0b1100) >> 2 H = src.read_byte() distance = (H << 2) + D + 1 if state >= 4: distance += 0x800 length = 3 else: length = 2 copy(distance, length) elif instruction < 0x20: L = instruction & 0b0111 H = instruction & 0b1000 length = L or integer() + 7 argument = src.u16() state = argument & 3 distance = (H << 11) + (argument >> 2) if not distance: return dst.getbuffer() if LZOv1 and distance & 0x803F == 0x803F and length in range(261, 265): raise LZOError('Compressed data contains sequence that is banned in LZOv1.') if LZOv1 and distance == 0xBFFF: X = src.read_byte() count = ((X << 3) | L) + 4 self.log_debug(F'Writing run of {X} zero bytes according to LZOv1.') dst.write(B'\0' * count) else: copy(distance + 0x4000, length + 2) elif instruction < 0x40: L = instruction & 0b11111 length = L or integer() + 31 argument = src.u16() state = argument & 3 distance = (argument >> 2) + 1 copy(distance, length + 2) else: if instruction < 0x80: length = 3 + ((instruction >> 5) & 1) else: length = 5 + ((instruction >> 5) & 3) H = src.read_byte() D = (instruction & 0b11100) >> 2 state = instruction & 3 distance = (H << 3) + D + 1 copy(distance, length) if state: literal(state)