def length_extension(old_hash, size, new_message, type='sha1'): """Length extension attack: given hash(secret) and len(secret), compute new_hash and old_padding so that hash(secret+old_padding+new_message) == new_hash Args: old_hash(string): hash of secret value size(int): length of secret (in bytes) new_message(string) type(string): sha1 or md4 Returns: tuple: (new_hash, old_padding+new_message) """ implemented_functions = ['sha1', 'md4'] if type == 'sha1': endian = 'big' hash_function = sha1 elif type == 'md4': endian = 'little' hash_function = md4 else: log.critical_error("Not implemented, type must be one of {}".format( implemented_functions)) return None old_padding = add_md_padding(bytes(b'a' * size), endian=endian)[size:] new_data_size = len(new_message) + len(old_padding) + size new_padding = add_md_padding(bytes(b'a' * new_data_size), endian=endian)[new_data_size:] h = [b2i(x, endian=endian) for x in chunks(old_hash, 4)] return hash_function(new_message, h, new_padding), old_padding + new_message
def compare_by_frequencies(a, b, lang='English', no_of_comparisons=5): """Check which text have more similar letter frequencies (compared to language) todo: add words, diagraphs etc... Args: a(string) b(string) lang(string) no_of_comparisons(int): how much letters compare Returns: int: -1 if a is less similar than b, 0 if equal, 1 if a is more similar """ if lang not in frequencies: log.critical_error("[-] Can't find language {}".format(lang)) language_frequencies = frequencies[lang][:no_of_comparisons] freq = {a: get_frequencies(a), b: get_frequencies(b)} result = {a: 0, b: 0} for freq_tuple in language_frequencies: letter = freq_tuple[0] letter_frequency = freq_tuple[1] for word in [a, b]: if letter not in freq[word]: result[word] += letter_frequency else: result[word] += abs(freq[word][letter] - letter_frequency) if result[a] < result[b]: # less means that is closer to language frequencies return 1 elif result[a] > result[b]: return -1 else: return 0
def break_reuse_key(ciphertexts, lang='English', no_of_comparisons=5, alphabet=None, key_space=None, reliability=100.0): """Sentences xored with the same key Args: ciphertexts(list): texts xored with the same key lang(string): key in frequencies dict no_of_comparisons(int): used during comparing by frequencies alphabet(string/None): plaintext space key_space(string/None): key space reliability(float): between 0 and 100, used during comparing by frequencies Returns: list: sorted (by frequencies) list of tuples (key, list(plaintexts)) """ if len(ciphertexts) < 2: log.critical_error("Too less ciphertexts") min_size = min(list(map(len, ciphertexts))) ciphertexts = [one[:min_size] for one in ciphertexts] log.info("Ciphertexts shrinked to {} bytes".format(min_size)) pairs = break_repeated_key(''.join(ciphertexts), lang=lang, no_of_comparisons=no_of_comparisons, key_size=min_size, alphabet=alphabet, key_space=key_space, reliability=reliability) res = [(pair[0], chunks(pair[1], min_size)) for pair in pairs] return res
def compute_params(s, m=None, a=None, b=None): """Compute parameters and initial seed for LCG prng next_state = a*state + b mod m Args: s(list): subsequent outputs from LCG oracle starting with seed m(int/None) a(int/None) b(int/None) Returns: a, b, m(int) """ if m is None: t = [s[n + 1] - s[n] for n in range(len(s) - 1)] u = [abs(t[n + 2] * t[n] - t[n + 1]**2) for n in range(len(t) - 2)] m = gcd(*u) log.success("m = {}".format(m)) if a is None: if gcd(s[1] - s[0], m) == 1: a = (s[2] - s[1]) * invmod(s[1] - s[0], m) elif gcd(s[2] - s[0], m) == 1: a = (s[3] - s[1]) * invmod(s[2] - s[0], m) else: log.critical_error("a not found") a = a % m log.success("a = {}".format(a)) if b is None: b = (s[1] - s[0] * a) % m log.success("b = {}".format(b)) return a, b, m
def tonelli_shanks(n, p): """Find r such that r^2 = n % p, r2 == p-r""" if legendre(n, p) != 1: log.critical_error("Not a square root") s = 0 q = p - 1 while q & 1 == 0: s += 1 q >>= 1 if s == 1: return pow(n, (p + 1) // 4, p) z = 1 while legendre(z, p) != -1: z += 1 c = pow(z, q, p) r = pow(n, (q + 1) // 2, p) t = pow(n, q, p) m = s while t != 1: i = 1 while i < m: if pow(t, 2**i, p) == 1: break i += 1 b = pow(c, 2**(m - i - 1), p) r = (r * b) % p t = (t * (b**2)) % p c = pow(b, 2, p) m = i return r
def _check_oracles(padding_oracle=None, decryption_oracle=None, block_size=16): """Check if padding or decryption oracle works""" if not padding_oracle and not decryption_oracle: log.critical_error( "Give padding_oracle and/or decryption_oracle functions") if decryption_oracle: try: decryption_oracle(bytes(b'B' * block_size)) except NotImplementedError: log.critical_error("decryption_oracle not implemented") except Exception as e: log.critical_error( "Error in first call to decryption_oracle: {}".format( e.message)) if padding_oracle: try: padding_oracle(payload=bytes(b'B' * block_size), iv=bytes(b'A' * block_size)) except NotImplementedError: log.critical_error("padding_oracle not implemented") except Exception as e: log.critical_error( "Error in first call to padding_oracle: {}".format(e.message))
def fake_ciphertext(new_plaintext, padding_oracle=None, decryption_oracle=None, block_size=16): """Make ciphertext that will decrypt to given plaintext Give padding_oracle or decryption_oracle (or both) Args: new_plaintext(string): with padding padding_oracle(function/None) decryption_oracle(function/None): maximum one block to decrypt block_size(int) Returns: fake_ciphertext(string): fake ciphertext that will decrypt to new_plaintext """ _check_oracles(padding_oracle=padding_oracle, decryption_oracle=decryption_oracle, block_size=block_size) if block_size % 8 != 0: log.critical_error("Incorrect block size: {}".format(block_size)) log.info("Start fake ciphertext") ciphertext = bytes(b'A' * (len(new_plaintext) + block_size)) # prepare blocks blocks = chunks(ciphertext, block_size) new_pl_blocks = chunks(new_plaintext, block_size) if len(new_pl_blocks) != len(blocks) - 1: log.critical_error( "Wrong new plaintext length({}), should be {}".format( len(new_plaintext), block_size * (len(blocks) - 1))) new_ct_blocks = list(blocks) # add known plaintext for count_block in range(len(blocks) - 1, 0, -1): """ Every block, modify block[count_block-1] to set block[count_block] """ log.info("Block no. {}".format(count_block)) ciphertext_to_decrypt = bytes(b''.join(new_ct_blocks[:count_block + 1])) original_plaintext = decrypt(ciphertext_to_decrypt, padding_oracle=padding_oracle, decryption_oracle=decryption_oracle, block_size=block_size, amount=1, is_correct=False) log.info("Set block no. {}".format(count_block)) new_ct_blocks[count_block - 1] = xor(blocks[count_block - 1], original_plaintext, new_pl_blocks[count_block - 1]) fake_ciphertext_res = bytes(b''.join(new_ct_blocks)) log.success("Fake ciphertext(hex): {}".format(b2h(fake_ciphertext_res))) return fake_ciphertext_res
def lcm(*args): """Lowest common multiple""" if len(args) < 2: log.critical_error("Give at least two values") if len(args) == 2: a, b = args return (a * b) // gcd(a, b) else: l = 1 for number in args: l = lcm(l, number) return l
def decrypt(self, ciphertext): """Raw decryption Args: ciphertext(int/string) Returns: pow(ciphertext, d, n) """ if not isinstance(ciphertext, Number): try: ciphertext = b2i(ciphertext) except: log.critical_error( "Ciphertext to decrypt must be number or be convertible to number ({})" .format(ciphertext)) return self.pyrsa_key.decrypt(gmpy2.mpz(ciphertext))
def encrypt(self, plaintext): """Raw encryption Args: plaintext(int/string) Returns: pow(plaintext,e,n) """ if not isinstance(plaintext, Number): try: plaintext = b2i(plaintext) except: log.critical_error( "Plaintext to decrypt must be number or be convertible to number ({})" .format(plaintext)) return self.pyrsa_key.encrypt(int(plaintext), 0)[0]
def gcd(*args): """Greatest common divisor""" if len(args) < 2: log.critical_error("Give at least two values") if len(args) == 2: a, b = args return gmpy2.gcd(a, b) else: d = 0 for number in args: d = gcd(d, number) if d == 1: break return d
def add_plaintext(self, plaintext, position=None): """Args: plaintext(int/string) position(int/None) - position in list where to add, None for new """ if not isinstance(plaintext, Number): try: plaintext = b2i(plaintext) except: log.critical_error( "Plaintext to add must be number or be convertible to number ({})" .format(plaintext)) if position is None: self.texts.append({'plain': plaintext}) else: self.texts[position]['plain'] = plaintext
def egcd(*args): """Extended Euclidean algorithm""" if len(args) < 2: log.critical_error("Give at least two values") if len(args) == 2: a, b = args return gmpy2.gcdext(a, b) else: d, s, t = egcd(args[0], args[1]) coefficients = [s, t] for i in range(2, len(args)): d, s, t = egcd(d, args[i]) for j in range(len(coefficients)): coefficients[j] *= s coefficients.append(t) coefficients.insert(0, d) return coefficients
def crt_non_coprime(a, n): """Solve chinese remainder theorem with general modules Given: x = a % n x = b % m If modules n, m are not comprime, but a = b mod gcd(n, m) then solution can be found The solution will be modulo lcm of modules Args: a(list): remainders n(list): modules Returns: int: solution to crt """ if len(a) != len(n): log.critical_error( "Different number of remainders({}) and modules({})".format( len(a), len(n))) if len(n) < 2: log.critical_error( "Give at least two remainders and modules (got {})".format(len(n))) while len(n) > 1: g, u, _ = egcd(n[0], n[1]) if (a[0] - a[1]) % g != 0: print( 'Not satisfied: gcd(ni, nj) | ai - aj\ngcd({}, {}) | {} - {}'. format(n[0], n[1], a[0], a[1])) return None w = (a[0] - a[1]) // g l = lcm(n[0], n[1]) x = (a[0] - n[0] * u * w) % l n = n[2:] n.insert(0, l) a = a[2:] a.insert(0, x) return int(a[0] % n[0])
def crt(a, n): """Solve chinese remainder theorem from: http://rosettacode.org/wiki/Chinese_remainder_theorem#Python The solution will be modulo product of modules Args: a(list): remainders n(list): modules Returns: int: solution to crt """ if len(a) != len(n): log.critical_error("Different number of remainders({}) and modules({})".format(len(a), len(n))) prod = product(n) sum_crt = 0 for n_i, a_i in zip(n, a): p = prod // n_i sum_crt += a_i * invmod(p, n_i) * p return int(sum_crt % prod)
def egcd(*args): """Extended Euclidean algorithm""" if len(args) < 2: log.critical_error("Give at least two values") if len(args) == 2: a, b = args s0, t0, s1, t1 = 1, 0, 0, 1 while b: q, a, b = a // b, b, a % b s0, s1 = s1, s0 - q * s1 t0, t1 = t1, t0 - q * t1 return a, s0, t0 else: d, s, t = egcd(args[0], args[1]) coefficients = [s, t] for i in range(2, len(args)): d, s, t = egcd(d, args[i]) for j in range(len(coefficients)): coefficients[j] *= s coefficients.append(t) coefficients.insert(0, d) return coefficients
def find_prefix_suffix_size(encryption_oracle, block_size=16): """Determine prefix and suffix sizes if ecb mode, sizes must be constant Rarely may fail (if random data that are send unhappily matches prefix/suffix) Args: encryption_oracle(callable) block_size(int) Returns: tuple(int,int): prefix_size, suffix_size """ blocks_to_send = 5 payload = random_bytes(1) * (blocks_to_send * block_size) enc_chunks = chunks(encryption_oracle(payload), block_size) log.debug("Encryption of length {}".format(blocks_to_send * block_size)) log.debug(print_chunks(enc_chunks)) for position_start in range(len(enc_chunks) - 1): if enc_chunks[position_start] == enc_chunks[position_start + 1]: for y in range(2, blocks_to_send - 1): if enc_chunks[position_start] != enc_chunks[position_start + y]: break else: log.success("Controlled payload start at chunk {}".format(position_start)) break else: log.critical_error("Position of controlled chunks not found") log.info('Finding prefix') changed_char = bytes([(payload[0] - 1)%256]) for aligned_bytes in range(block_size): payload_new = payload[:aligned_bytes] + changed_char + payload[aligned_bytes+1:] enc_chunks_new = chunks(encryption_oracle(payload_new), block_size) log.debug(print_chunks(chunks(payload_new, block_size))) log.debug(print_chunks(enc_chunks_new)) if enc_chunks_new[position_start] != enc_chunks[position_start]: prefix_size = position_start*block_size - aligned_bytes log.success("Prefix size: {}".format(prefix_size)) break else: log.critical_error("Size of prefix not found") log.info('Finding suffix') payload = random_bytes(1) * (block_size - (prefix_size % block_size)) # align to block_size encrypted = encryption_oracle(payload) suffix_size = len(encrypted) - len(payload) - prefix_size while True: payload += random_bytes(1) suffix_size -= 1 if len(encryption_oracle(payload)) > len(encrypted): log.success("Suffix size: {}".format(suffix_size)) break else: log.critical_error("Size of suffix not found") return prefix_size, suffix_size
def bit_flipping(ciphertext, plaintext, wanted, block_size=16): """Given ciphertext and corresponding plaintext (two blocks) we can set first block of ciphertext so that last block of plaintext will be our wanted value Args: ciphertext(string): size == 2*block_size plaintext(string): size == block_size, plaintext of second ciphertext block wanted(string): size == block_size, we want ciphertext last block decrypt to this block_size(int) Returns: string: ciphertext that will decrypt to garbage_block+wanted_last_block """ if len(ciphertext) != 2 * block_size: log.critical_error("Incorrect ciphertext size") if len(plaintext) != block_size: log.critical_error("Incorrect plaintext size") if len(wanted) != block_size: log.critical_error("Incorrect wanted_last_block size") return xor(ciphertext[:block_size], plaintext, wanted) + ciphertext[block_size:]
def decrypt(encryption_oracle, constant=True, block_size=16, prefix_size=None, secret_size=None, alphabet=None): """Given encryption oracle which produce ecb(prefix || our_input || secret), find secret Args: encryption_oracle(callable) constant(bool): True if prefix have constant length (secret must have constant length) block_size(int/None) prefix_size(int/None) secret_size(int/None) alphabet(string): plaintext space Returns: secret(string) """ log.debug("Start decrypt function") if not alphabet: alphabet = bytes(string.printable.encode()) if not block_size: block_size = find_block_size(encryption_oracle, constant) if constant: log.debug("constant == True") if not prefix_size or not secret_size: prefix_size, secret_size = find_prefix_suffix_size(encryption_oracle, block_size) """Start decrypt""" secret = bytes(b'') aligned_bytes = random_bytes(1) * (block_size - (prefix_size % block_size)) if len(aligned_bytes) == block_size: aligned_bytes = bytes(b'') aligned_bytes_suffix = random_bytes(1) * (block_size - (secret_size % block_size)) if len(aligned_bytes_suffix) == block_size: aligned_bytes_suffix = bytes(b'') block_to_find_position = -1 controlled_block_position = (prefix_size+len(aligned_bytes)) // block_size while len(secret) < secret_size: if (len(secret)+1) % block_size == 0: block_to_find_position -= 1 payload = aligned_bytes + aligned_bytes_suffix + random_bytes(1) + secret enc_chunks = chunks(encryption_oracle(payload), block_size) block_to_find = enc_chunks[block_to_find_position] log.debug("To guess at position {}:".format(block_to_find_position)) log.debug("Plain: " + print_chunks(chunks(bytes(b'P'*prefix_size) + payload + bytes(b'S'*secret_size), block_size))) log.debug("Encry: " + print_chunks(enc_chunks)+"\n") for guessed_char in range(256): guessed_char = bytes([guessed_char]) payload = aligned_bytes + add_padding(guessed_char + secret, block_size) enc_chunks = chunks(encryption_oracle(payload), block_size) log.debug("Plain: " + print_chunks(chunks(bytes(b'P'*prefix_size) + payload + bytes(b'S'*secret_size), block_size))) log.debug("Encry: " + print_chunks(enc_chunks)+"\n") if block_to_find == enc_chunks[controlled_block_position]: secret = guessed_char + secret log.debug("Found char, secret={}".format(repr(secret))) break else: log.critical_error("Char not found, try change alphabet. Secret so far: {}".format(repr(secret))) log.success("Secret(hex): {}".format(b2h(secret))) return secret else: log.debug("constant == False")
def decrypt(ciphertext, padding_oracle=None, decryption_oracle=None, iv=None, block_size=16, is_correct=True, amount=0, known_plaintext=None, async_calls=False): """Decrypt ciphertext Give padding_oracle or decryption_oracle (or both) Args: ciphertext(string): to decrypt padding_oracle(function/None) decryption_oracle(function/None) iv(string): if not specified, first block of ciphertext is treated as iv block_size(int) is_correct(bool): set if ciphertext will decrypt to something with correct padding amount(int): how much blocks decrypt (counting from last), zero (default) means all known_plaintext(string): with padding, from end (aligned to end of ciphertext) async_calls(bool): make asynchronous calls to oracle (not implemented yet) Returns: plaintext(string): with padding """ _check_oracles(padding_oracle=padding_oracle, decryption_oracle=decryption_oracle, block_size=block_size) if block_size % 8 != 0: log.critical_error("Incorrect block size: {}".format(block_size)) if len(ciphertext) % block_size != 0: log.critical_error("Incorrect ciphertext length: {}".format( len(ciphertext))) if decryption_oracle: if iv: ciphertext = iv + ciphertext blocks = chunks(ciphertext, block_size) plaintext = bytes(b'') for position in range(len(blocks) - 1, 0, -1): plaintext = xor(decryption_oracle(blocks[position]), blocks[position - 1]) + plaintext log.info("Plaintext(hex): {}".format(b2h(plaintext))) if amount != 0 and len(plaintext) == amount * block_size: break log.success("Decrypted(hex): {}".format(b2h(plaintext))) return plaintext log.info("Start cbc padding oracle") log.debug(print_chunks(chunks(ciphertext, block_size))) # prepare blocks blocks = chunks(ciphertext, block_size) if iv: if len(iv) % block_size != 0: log.critical_error("Incorrect iv length: {}".format(len(iv))) log.info("Set iv") blocks.insert(0, iv) if amount != 0: amount = len(blocks) - amount - 1 if amount < 0 or amount >= len(blocks): log.critical_error( "Incorrect amount of blocks to decrypt: {} (have to be in [0,{}]". format(amount, len(blocks) - 1)) log.info("Will decrypt {} block(s)".format(len(blocks) - 1 - amount)) # add known plaintext plaintext = bytes(b'') position_known = 0 chars_decoded = 0 if known_plaintext: is_correct = False plaintext = known_plaintext blocks_decoded = len(plaintext) // block_size chars_decoded = len(plaintext) % block_size if blocks_decoded == len(blocks) - 1: log.debug("Nothing decrypted, known plaintext long enough") return plaintext if blocks_decoded > len(blocks) - 1: log.critical_error( "Too long known plaintext ({} blocks)".format(blocks_decoded)) if blocks_decoded != 0: blocks = blocks[:-blocks_decoded] position_known = chars_decoded log.info("Have known plaintext, skip {} block(s) and {} bytes".format( blocks_decoded, chars_decoded)) # start decryption for count_block in range(len(blocks) - 1, amount, -1): """ Blocks from the last to the second (all except iv) """ log.info("Block no. {}".format(count_block)) payload_prefix = bytes(b''.join(blocks[:count_block - 1])) payload_modify = blocks[count_block - 1] payload_decrypt = blocks[count_block] if chars_decoded != 0: # we know some chars, so modify previous block payload_modify = payload_modify[:-chars_decoded] +\ xor(plaintext[:chars_decoded], payload_modify[-chars_decoded:], bytes([chars_decoded + 1])) chars_decoded = 0 position = block_size - 1 - position_known position_known = 0 while position >= 0: """ Every position in block, from the end """ log.debug("Position: {}".format(position)) found_correct_char = False for guess_char in range(256): modified = payload_modify[:position] + bytes( [guess_char]) + payload_modify[position + 1:] payload = bytes(b''.join( [payload_prefix, modified, payload_decrypt])) iv = payload[:block_size] payload = payload[block_size:] log.debug(print_chunks(chunks(iv + payload, block_size))) correct = padding_oracle(payload=payload, iv=iv) if correct: """ oracle returns True """ padding = block_size - position # sent ciphertext decoded to that padding decrypted_char = bytes( [payload_modify[position] ^ guess_char ^ padding]) if is_correct: """ If we didn't send original ciphertext, then we have found original padding value. Otherwise keep searching and if won't find any other correct char - padding is \x01 """ if guess_char == blocks[-2][-1]: log.debug( "Skip this guess char ({})".format(guess_char)) continue dc = int(decrypted_char[0]) log.info( "Found padding value for correct ciphertext: {}". format(dc)) if dc == 0 or dc > block_size: log.critical_error( "Found bad padding value (given ciphertext may not be correct)" ) plaintext = decrypted_char * dc payload_modify = payload_modify[:-dc] + xor( payload_modify[-dc:], decrypted_char, bytes([dc + 1])) position = position - dc + 1 is_correct = False else: """ abcd efgh ijkl o|guess_char|xy || 1234 5678 9tre qwer - ciphertext what ever itma ybex || xyzw rtua lopo k|\x03|\x03\x03 - plaintext abcd efgh ijkl |guess_char|wxy || 1234 5678 9tre qwer - next round ciphertext some thin gels eheh || xyzw rtua lopo guessing|\x04\x04\x04 - next round plaintext """ if position == block_size - 1: """ if we decrypt first byte, check if we didn't hit other padding than \x01 """ payload = iv + payload payload = payload[:-block_size - 2] + bytes( b'A') + payload[-block_size - 1:] iv = payload[:block_size] payload = payload[block_size:] correct = padding_oracle(payload=payload, iv=iv) if not correct: log.debug("Hit false positive, guess char({})". format(guess_char)) continue payload_modify = payload_modify[:position] + xor( bytes([guess_char]) + payload_modify[position + 1:], bytes([padding]), bytes([padding + 1])) plaintext = decrypted_char + plaintext found_correct_char = True log.debug( "Guessed char(\\x{:02x}), decrypted char(\\x{:02x})". format(guess_char, decrypted_char[0])) log.debug("Plaintext: {}".format(plaintext)) log.info("Plaintext(hex): {}".format(b2h(plaintext))) break position -= 1 if found_correct_char is False: if is_correct: padding = 0x01 payload_modify = payload_modify[:position + 1] + xor( payload_modify[position + 1:], bytes([padding]), bytes([padding + 1])) plaintext = bytes(b"\x01") is_correct = False else: log.critical_error( "Can't find correct padding (oracle function return False 256 times)" ) log.success("Decrypted(hex): {}".format(b2h(plaintext))) return plaintext
def bleichenbacher_signature_forgery(key, garbage='suffix', hash_function='sha1'): """Bleichenbacher's signature forgery based on bug in verify implementation Args: key(RSAKey): with small e and at least one plaintext garbage(string): middle: 00 01 ff garbage 00 ASN.1 HASH suffix: 00 01 ff 00 ASN.1 HASH garbage hash_function(string) Returns: dict: forged signatures, signatures[no] == signature(key.texts[no]['plain']) update key texts """ hash_asn1 = { 'md5': bytes( b'\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10' ), 'sha1': bytes(b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'), 'sha256': bytes( b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20' ), 'sha384': bytes( b'\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30' ), 'sha512': bytes( b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40' ) } if garbage not in ['suffix', 'middle']: log.critical_error("Bad garbage position, must be suffix or middle") if hash_function not in list(hash_asn1.keys()): log.critical_error( "Hash function {} not implemented".format(hash_function)) if key.e > 3: log.debug("May not work, because e > 3") signatures = {} if garbage == 'suffix': for text_no in range(len(key.texts)): if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[ text_no]: log.info("Forge for plaintext no {} ({})".format( text_no, key.texts[text_no]['plain'])) hash_callable = getattr(hashlib, hash_function)(i2b( key.texts[text_no] ['plain'])).digest() # hack to call hashlib.hash_function plaintext_prefix = bytes(b'\x00\x01\xff\x00') + hash_asn1[ hash_function] + hash_callable plaintext = plaintext_prefix + bytes( b'\x00' * (key.size // 8 - len(plaintext_prefix))) plaintext = b2i(plaintext) for round_error in range(-5, 5): signature, _ = gmpy2.iroot(plaintext, key.e) signature = int(signature + round_error) test_prefix = i2b(gmpy2.powmod(signature, key.e, key.n), size=key.size)[:len(plaintext_prefix)] if test_prefix == plaintext_prefix: log.info("Got signature: {}".format(signature)) log.debug("signature**e % n == {}".format( i2h(gmpy2.powmod(signature, key.e, key.n), size=key.size))) key.texts[text_no]['cipher'] = signature signatures[text_no] = signature break else: log.error( "Something wrong, can't compute correct signature") return signatures elif garbage == 'middle': for text_no in range(len(key.texts)): if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[ text_no]: log.info("Forge for plaintext no {} ({})".format( text_no, key.texts[text_no]['plain'])) hash_callable = getattr(hashlib, hash_function)(i2b( key.texts[text_no] ['plain'])).digest() # hack to call hashlib.hash_function plaintext_suffix = bytes( b'\x00') + hash_asn1[hash_function] + hash_callable if b2i(plaintext_suffix) & 1 != 1: log.error( "Plaintext suffix is even, can't compute signature") continue # compute suffix signature_suffix = 0b1 for b in range(len(plaintext_suffix) * 8): if (signature_suffix** 3) & (1 << b) != b2i(plaintext_suffix) & (1 << b): signature_suffix |= 1 << b signature_suffix = i2b( signature_suffix)[-len(plaintext_suffix):] # compute prefix while True: plaintext_prefix = bytes(b'\x00\x01\xff') + random_bytes( key.size // 8 - 3) signature_prefix, _ = gmpy2.iroot(b2i(plaintext_prefix), key.e) signature_prefix = i2b( int(signature_prefix), size=key.size)[:-len(signature_suffix)] signature = b2i(signature_prefix + signature_suffix) test_plaintext = i2b(gmpy2.powmod(signature, key.e, key.n), size=key.size) if bytes(b'\x00' ) not in test_plaintext[2:-len(plaintext_suffix)]: if test_plaintext[:3] == plaintext_prefix[:3] and test_plaintext[ -len(plaintext_suffix):] == plaintext_suffix: log.info("Got signature: {}".format(signature)) key.texts[text_no]['cipher'] = signature signatures[text_no] = signature break else: log.error("Something wrong, signature={}," " signature**{}%{} is {}".format( signature, key.e, key.n, [(test_plaintext)])) break return signatures
def blinding(key, signing_oracle=None, decryption_oracle=None): """Perform signature/ciphertext blinding attack Args: key(RSAKey): with at least one plaintext(to sign) or ciphertext(to decrypt) signing_oracle(callable) decryption_oracle(callable) Returns: dict: {index: signature/plaintext, index2: signature/plaintext} update key texts """ if not signing_oracle and not decryption_oracle: log.critical_error("Give one of signing_oracle or decryption_oracle") if signing_oracle and decryption_oracle: log.critical_error( "Give only one of signing_oracle or decryption_oracle") recovered = {} if signing_oracle: log.debug("Have signing_oracle") for text_no in range(len(key.texts)): if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[ text_no]: log.info("Blinding signature of plaintext no {} ({})".format( text_no, i2h(key.texts[text_no]['plain']))) blind = random.randint(2, 100) blind_enc = key.encrypt(blind) blinded_plaintext = (key.texts[text_no]['plain'] * blind_enc) % key.n blinded_signature = signing_oracle(blinded_plaintext) if not blinded_signature: log.critical_error( "Error during call to signing_oracle({})".format( blinded_plaintext)) signature = (invmod(blind, key.n) * blinded_signature) % key.n key.texts[text_no]['cipher'] = signature recovered[text_no] = signature log.success("Signature: {}".format(signature)) if decryption_oracle: log.debug("Have decryption_oracle") for text_no in range(len(key.texts)): if 'cipher' in key.texts[text_no] and 'plain' not in key.texts[ text_no]: log.info("Blinding ciphertext no {} ({})".format( text_no, key.texts[text_no]['cipher'])) blind = random.randint(2, 100) blind_enc = key.encrypt(blind) blinded_ciphertext = (key.texts[text_no]['cipher'] * blind_enc) % key.n blinded_plaintext = decryption_oracle(blinded_ciphertext) if not blinded_plaintext: log.critical_error( "Error during call to decryption_oracle({})".format( blinded_plaintext)) plaintext = (invmod(blind, key.n) * blinded_plaintext) % key.n key.texts[text_no]['plain'] = plaintext recovered[text_no] = plaintext log.success("Plaintext: {}".format(plaintext)) return recovered
def hastad(keys, ciphertexts=None): """Hastad's broadcast attack (small public exponent) Given at least e keys with public exponent equals to e and ciphertexts of the same plaintext, plaintext can be efficiently recovered Args: keys(list): RSAKeys, all with same public exponent e, len(keys) >= e, every key with only one ciphertext ciphertexts(list/None): if not None, use this ciphertexts Returns: NoneType/int: None on failure, recovered plaintext otherwise update keys texts """ e = keys[0].e if len(keys) < e: log.critical_error("Not enough keys, e={}".format(e)) if ciphertexts is None: for key in keys: if len(key.texts) != 1: log.info( "Key have more than one ciphertext, using the first one(key=={})" .format(key.identifier)) if 'cipher' not in key.texts[0]: log.critical_error("key {} doesn't have ciphertext".format( key.identifier)) # prepare ciphertexts and correct_keys lists ciphertexts, modules, correct_keys = [], [], [] for key in keys: # get only first ciphertext (if exists) if key.n not in modules and key.texts[0][ 'cipher'] not in ciphertexts: if key.e == e: modules.append(key.n) correct_keys.append(key) ciphertexts.append(key.texts[0]['cipher']) else: log.info("Key {} have different e(={})".format( key.identifier, key.e)) else: if len(ciphertexts) != len(keys): log.critical_error("len(ciphertexts) != len(keys)") modules = [key.n for key in keys] correct_keys = keys # check if we have enough ciphertexts if len(modules) < e: log.info( "Not enough keys with unique modulus and ciphertext, e={}, len(modules)={}" .format(e, len(modules))) log.info("Checking for simple roots (small_e_msg)") for one_key in correct_keys: recovered_plaintexts = small_e_msg(one_key) if len(recovered_plaintexts) > 0: log.success("Found plaintext: {}".format( recovered_plaintexts[0])) return recovered_plaintexts[0] if len(modules) > e: log.debug("Number of modules/ciphertexts larger than e") modules = modules[:e] ciphertexts = ciphertexts[:e] # actual Hastad result = crt(ciphertexts, modules) plaintext, correct = gmpy2.iroot(result, e) if correct: plaintext = int(plaintext) log.success("Found plaintext: {}".format(plaintext)) for one_key in correct_keys: one_key.texts[0]['plain'] = plaintext return plaintext else: log.debug("Plaintext wasn't {}-th root") log.debug("result (from crt) = {}".format(e, result)) log.debug("plaintext ({}-th root of result) = {}".format(e, plaintext)) return None