def break_reuse_key(ciphertexts, lang='English', no_of_comparisons=5, alphabet=None, key_space=None, reliability=100.0): """Sentences xored with the same key Args: ciphertexts(list): texts xored with the same key lang(string): key in frequencies dict no_of_comparisons(int): used during comparing by frequencies alphabet(string/None): plaintext space key_space(string/None): key space reliability(float): between 0 and 100, used during comparing by frequencies Returns: list: sorted (by frequencies) list of tuples (key, list(plaintexts)) """ if len(ciphertexts) < 2: log.critical_error("Too less ciphertexts") min_size = min(list(map(len, ciphertexts))) ciphertexts = [one[:min_size] for one in ciphertexts] log.info("Ciphertexts shrinked to {} bytes".format(min_size)) pairs = break_repeated_key(''.join(ciphertexts), lang=lang, no_of_comparisons=no_of_comparisons, key_size=min_size, alphabet=alphabet, key_space=key_space, reliability=reliability) res = [(pair[0], chunks(pair[1], min_size)) for pair in pairs] return res
def test_bleichenbacher_pkcs15(): print("\nTest: Bleichenbacher's PKCS 1.5 Padding Oracle") keys = [key_64, key_256, key_1024] for key in keys: pkcs15_padding_oracle_calls = [0] # must be mutable incremental_blinding = False if key.size < 512: incremental_blinding = True if key.size > 512: plaintext = randint(2, key.n) >> 16 plaintext |= 0x0002 << (key.size - 16) else: plaintext = randint(2, key.n) ciphertext = h2b( subprocess.check_output([ "python", rsa_oracles_path, "encrypt", key.identifier, i2h(plaintext) ]).strip().decode()) msgs_recovered = bleichenbacher_pkcs15( pkcs15_padding_oracle, key.publickey(), ciphertext, incremental_blinding=incremental_blinding, oracle_key=key, pkcs15_padding_oracle_calls=pkcs15_padding_oracle_calls) log.info('For keysize {}: pkcs15_padding_oracle_calls = {}'.format( key.size, pkcs15_padding_oracle_calls[0])) assert msgs_recovered[0] == plaintext key.clear_texts()
def find_prefix_suffix_size(encryption_oracle, block_size=16): """Determine prefix and suffix sizes if ecb mode, sizes must be constant Rarely may fail (if random data that are send unhappily matches prefix/suffix) Args: encryption_oracle(callable) block_size(int) Returns: tuple(int,int): prefix_size, suffix_size """ blocks_to_send = 5 payload = random_bytes(1) * (blocks_to_send * block_size) enc_chunks = chunks(encryption_oracle(payload), block_size) log.debug("Encryption of length {}".format(blocks_to_send * block_size)) log.debug(print_chunks(enc_chunks)) for position_start in range(len(enc_chunks) - 1): if enc_chunks[position_start] == enc_chunks[position_start + 1]: for y in range(2, blocks_to_send - 1): if enc_chunks[position_start] != enc_chunks[position_start + y]: break else: log.success("Controlled payload start at chunk {}".format(position_start)) break else: log.critical_error("Position of controlled chunks not found") log.info('Finding prefix') changed_char = bytes([(payload[0] - 1)%256]) for aligned_bytes in range(block_size): payload_new = payload[:aligned_bytes] + changed_char + payload[aligned_bytes+1:] enc_chunks_new = chunks(encryption_oracle(payload_new), block_size) log.debug(print_chunks(chunks(payload_new, block_size))) log.debug(print_chunks(enc_chunks_new)) if enc_chunks_new[position_start] != enc_chunks[position_start]: prefix_size = position_start*block_size - aligned_bytes log.success("Prefix size: {}".format(prefix_size)) break else: log.critical_error("Size of prefix not found") log.info('Finding suffix') payload = random_bytes(1) * (block_size - (prefix_size % block_size)) # align to block_size encrypted = encryption_oracle(payload) suffix_size = len(encrypted) - len(payload) - prefix_size while True: payload += random_bytes(1) suffix_size -= 1 if len(encryption_oracle(payload)) > len(encrypted): log.success("Suffix size: {}".format(suffix_size)) break else: log.critical_error("Size of suffix not found") return prefix_size, suffix_size
def parity(parity_oracle, key, min_lower_bound=None, max_upper_bound=None): """Given oracle that returns LSB of decrypted ciphertext we can decrypt whole ciphertext parity_oracle function must be implemented Args: parity_oracle(callable) key(RSAKey): contains ciphertexts to decrypt min_lower_bound(None/int) max_upper_bound(None/int) Returns: dict: decrypted ciphertexts update key texts """ recovered = {} for text_no in range(len(key.texts)): if 'cipher' in key.texts[text_no] and 'plain' not in key.texts[text_no]: cipher = key.texts[text_no]['cipher'] log.info("Decrypting {}".format(cipher)) two_encrypted = key.encrypt(2) counter = lower_bound = numerator = 0 upper_bound = key.n denominator = 1 while lower_bound + 1 < upper_bound: cipher = (two_encrypted * cipher) % key.n denominator *= 2 numerator *= 2 counter += 1 if max_upper_bound is not None and upper_bound > max_upper_bound: is_odd = 0 else: # todo: check below if min_lower_bound is not None and lower_bound < min_lower_bound: is_odd = 1 else: is_odd = parity_oracle(cipher) if is_odd: # plaintext > n/(2**counter) numerator += 1 lower_bound = (key.n * numerator) // denominator upper_bound = (key.n * (numerator + 1)) // denominator log.debug("{} {} [{}, {}]".format(counter, is_odd, int(lower_bound), int(upper_bound))) log.debug("{}/{} - {}/{}\n".format(numerator, denominator, numerator + 1, denominator)) log.success("Decrypted: {}".format(i2h(upper_bound))) key.texts[text_no]['plain'] = upper_bound recovered[text_no] = upper_bound return recovered
def faulty(key, padding=None): """Faulty attack against crt-rsa, Boneh-DeMillo-Lipton sp = padding(m)**(d % p-1) % p sq' = padding(m)**(d % q-1) % q <--any error during computation s' = crt(sp, sq') % n <-- broken signature s = crt(sp, sq) % n <-- correct signature p = gcd(s'**e - padding(m), n) p = gcd(s - s', n) Args: key(RSAKey): with at least one broken signature (key.texts[no]['cipher']) and corresponding plaintext (key.texts[no]['plain']), or valid and broken signature padding(None/function): function used before signing message Returns: NoneType/RSAKey: False on failure, recovered private key otherwise """ log.debug("Check signature-message pairs") for pair in key.texts: if 'plain' in pair and 'cipher' in pair: signature = gmpy2.mpz(pair['cipher']) message = pair['plain'] if padding: message = padding(message) p = gmpy2.gcd(gmpy2.pow(signature, key.e) - message, key.n) if p != 1 and p != key.n: log.info("Found p={}".format(p)) new_key = RSAKey.construct(key.n, key.e, p=p, identifier=key.identifier + '-private') new_key.texts = key.texts[:] return new_key log.debug("Check for valid-invalid signatures") signatures = [tmp['cipher'] for tmp in key.texts if 'cipher' in tmp] for pair in itertools.combinations(signatures, 2): p = gmpy2.gcd(pair[0] - pair[1], key.n) if p != 1 and p != key.n: log.info("Found p={}".format(p)) new_key = RSAKey.construct(key.n, key.e, p=p, identifier=key.identifier + '-private') new_key.texts = key.texts[:] return new_key return None
def find_block_size(encryption_oracle, constant=True): """Determine block size if ecb mode Args: encryption_oracle(callable) constant(bool): True if prefix and suffix have constant length Returns: int """ if constant: log.debug("constant == True") payload = bytes(b'A') size = len(encryption_oracle(payload)) while True: payload += bytes(b'A') new_size = len(encryption_oracle(payload)) if new_size > size: log.info("block_size={}".format(new_size - size)) return new_size - size else: log.debug("constant == False") payload = bytes(b'A') max_size = len(encryption_oracle(payload)) possible_sizes = factors(max_size) possible_sizes.add(max_size) blocks_to_send = 5 for block_size in sorted(possible_sizes): """send payload of length x, so at least x-1 blocks should be identical""" payload = random_bytes(1) * (blocks_to_send*block_size) enc_chunks = chunks(encryption_oracle(payload), block_size) for x in range(len(enc_chunks)-1): if enc_chunks[x] == enc_chunks[x+1]: log.debug("Found two identical blocks at {}: {}".format(x, print_chunks(enc_chunks))) for y in range(2, blocks_to_send-1): if enc_chunks[x] != enc_chunks[x+y]: break else: log.info("block_size={}".format(block_size)) return block_size
def fake_ciphertext(new_plaintext, padding_oracle=None, decryption_oracle=None, block_size=16): """Make ciphertext that will decrypt to given plaintext Give padding_oracle or decryption_oracle (or both) Args: new_plaintext(string): with padding padding_oracle(function/None) decryption_oracle(function/None): maximum one block to decrypt block_size(int) Returns: fake_ciphertext(string): fake ciphertext that will decrypt to new_plaintext """ _check_oracles(padding_oracle=padding_oracle, decryption_oracle=decryption_oracle, block_size=block_size) if block_size % 8 != 0: log.critical_error("Incorrect block size: {}".format(block_size)) log.info("Start fake ciphertext") ciphertext = bytes(b'A' * (len(new_plaintext) + block_size)) # prepare blocks blocks = chunks(ciphertext, block_size) new_pl_blocks = chunks(new_plaintext, block_size) if len(new_pl_blocks) != len(blocks) - 1: log.critical_error( "Wrong new plaintext length({}), should be {}".format( len(new_plaintext), block_size * (len(blocks) - 1))) new_ct_blocks = list(blocks) # add known plaintext for count_block in range(len(blocks) - 1, 0, -1): """ Every block, modify block[count_block-1] to set block[count_block] """ log.info("Block no. {}".format(count_block)) ciphertext_to_decrypt = bytes(b''.join(new_ct_blocks[:count_block + 1])) original_plaintext = decrypt(ciphertext_to_decrypt, padding_oracle=padding_oracle, decryption_oracle=decryption_oracle, block_size=block_size, amount=1, is_correct=False) log.info("Set block no. {}".format(count_block)) new_ct_blocks[count_block - 1] = xor(blocks[count_block - 1], original_plaintext, new_pl_blocks[count_block - 1]) fake_ciphertext_res = bytes(b''.join(new_ct_blocks)) log.success("Fake ciphertext(hex): {}".format(b2h(fake_ciphertext_res))) return fake_ciphertext_res
def bleichenbacher_pkcs15(pkcs15_padding_oracle, key, ciphertext=None, incremental_blinding=False, **kwargs): """Given oracle that checks if ciphertext decrypts to some valid plaintext with PKCS1.5 padding we can decrypt whole ciphertext pkcs15_padding_oracle function must be implemented http://archiv.infsec.ethz.ch/education/fs08/secsem/bleichenbacher98.pdf https://www.dsi.unive.it/~focardi/RSA-padding-oracle/#eq5 Note that this attack is very slow. Approximate number of main loop iterations == key's bit length Args: pkcs15_padding_oracle(callable) key(RSAKey): contains ciphertexts to decrypt incremental_blinding(bool): if ciphertext is not pkcs confirming we need to blind it. this may be done using random or incremental values Returns: dict: decrypted ciphertexts update key texts """ def ceil(a, b): return a // b + (a % b > 0) def floor(a, b): return a // b def insert_interval(M, lb, ub): lo, hi = 0, len(M) while lo < hi: mid = (lo + hi) // 2 if M[mid][0] < lb: lo = mid + 1 else: hi = mid # insert it M.insert(lo, (lb, ub)) # lb inside previous interval if lo > 0 and M[lo - 1][1] >= lb: lb = min(lb, M[lo - 1][0]) M[lo] = (lb, M[lo][1]) del M[lo - 1] lo -= 1 # remove covered intervals i = lo + 1 to_remove_first = i to_remove_last = lo while i < len(M) and M[i][0] <= ub: to_remove_last += 1 i += 1 if to_remove_last > lo: new_ub = max(ub, M[to_remove_last][1]) M[lo] = (M[lo][0], new_ub) del M[to_remove_first:to_remove_last + 1] def update_intervals(M, B, s, n): # step 3 M2 = [] for a, b in M: r_min = ceil(a * s - 3 * B + 1, n) r_max = floor(b * s - 2 * B, n) for r in range(r_min, r_max + 1): lb = max(a, ceil(2 * B + r * n, s)) ub = min(b, floor(3 * B - 1 + r * n, s)) insert_interval(M2, lb, ub) del M[:] return M2 def find_si(si_start, si_max=None): si_new = si_start while si_max is None or si_new < si_max: cipheri = (cipher_blinded * gmpy2.powmod(si_new, e, n)) % n if pkcs15_padding_oracle(cipheri, **kwargs): return si_new si_new += 1 return None recovered = {} for text_no, cipher in _prepare_ciphertexts(key=key, ciphertext=ciphertext): log.info("Decrypting {}".format(cipher)) n = key.n e = key.e B = gmpy2.pow(2, key.size - 16) # step 1 log.debug('Blinding the ciphertext (to make it PKCS1.5 confirming)') i = 0 si = 1 cipher_blinded = cipher while not pkcs15_padding_oracle(cipher_blinded, **kwargs): # the paper says to draw it, but seems like incrementation sometimes run faster if incremental_blinding: si += 1 else: si = random.randint(2, 1 << (key.size - 16)) cipher_blinded = (cipher * gmpy2.powmod(si, e, n)) % n Mi = [(2 * B, 3 * B - 1)] s0 = si log.debug('Found s{}: {}'.format(i, hex(si))) i = 1 plaintext = None while plaintext is None: log.debug('len(M{}): {}'.format(i - 1, len(Mi))) if i == 1: # step 2.a si = find_si(si_start=ceil(n, (3 * B))) log.debug('Found s{}: {}'.format(i, hex(si))) elif len(Mi) > 1: # step 2.b si = find_si(si_start=si + 1) elif len(Mi) == 1 and Mi[0][0] != Mi[0][1]: # step 2.c a, b = Mi[0] ri = ceil(2 * (b * si - 2 * B), n) si = None while si is None: si_min = ceil(2 * B + ri * n, b) si_max = ceil(3 * B + ri * n, a) si = find_si(si_start=si_min, si_max=si_max) ri += 1 else: log.error( "Hm, something strange happend. Len(M{}) = {}".format( i, len(Mi))) return None # step 3 Mi = update_intervals(Mi, B, si, n) # step 4 if len(Mi) == 1 and Mi[0][0] == Mi[0][1]: plaintext = Mi[0][0] if s0 != 1: plaintext = (plaintext * invmod(s0, n)) % n log.success("Interval narrowed to one value") log.success("plaintext = {}".format(hex(plaintext))) else: i += 1 recovered[text_no] = plaintext key.texts[text_no]['plain'] = recovered[text_no] return recovered
def bleichenbacher_signature_forgery(key, garbage='suffix', hash_function='sha1'): """Bleichenbacher's signature forgery based on bug in verify implementation Args: key(RSAKey): with small e and at least one plaintext garbage(string): middle: 00 01 ff garbage 00 ASN.1 HASH suffix: 00 01 ff 00 ASN.1 HASH garbage hash_function(string) Returns: dict: forged signatures, signatures[no] == signature(key.texts[no]['plain']) update key texts """ hash_asn1 = { 'md5': bytes( b'\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10' ), 'sha1': bytes(b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'), 'sha256': bytes( b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20' ), 'sha384': bytes( b'\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30' ), 'sha512': bytes( b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40' ) } if garbage not in ['suffix', 'middle']: log.critical_error("Bad garbage position, must be suffix or middle") if hash_function not in list(hash_asn1.keys()): log.critical_error( "Hash function {} not implemented".format(hash_function)) if key.e > 3: log.debug("May not work, because e > 3") signatures = {} if garbage == 'suffix': for text_no in range(len(key.texts)): if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[ text_no]: log.info("Forge for plaintext no {} ({})".format( text_no, key.texts[text_no]['plain'])) hash_callable = getattr(hashlib, hash_function)(i2b( key.texts[text_no] ['plain'])).digest() # hack to call hashlib.hash_function plaintext_prefix = bytes(b'\x00\x01\xff\x00') + hash_asn1[ hash_function] + hash_callable plaintext = plaintext_prefix + bytes( b'\x00' * (key.size // 8 - len(plaintext_prefix))) plaintext = b2i(plaintext) for round_error in range(-5, 5): signature, _ = gmpy2.iroot(plaintext, key.e) signature = int(signature + round_error) test_prefix = i2b(gmpy2.powmod(signature, key.e, key.n), size=key.size)[:len(plaintext_prefix)] if test_prefix == plaintext_prefix: log.info("Got signature: {}".format(signature)) log.debug("signature**e % n == {}".format( i2h(gmpy2.powmod(signature, key.e, key.n), size=key.size))) key.texts[text_no]['cipher'] = signature signatures[text_no] = signature break else: log.error( "Something wrong, can't compute correct signature") return signatures elif garbage == 'middle': for text_no in range(len(key.texts)): if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[ text_no]: log.info("Forge for plaintext no {} ({})".format( text_no, key.texts[text_no]['plain'])) hash_callable = getattr(hashlib, hash_function)(i2b( key.texts[text_no] ['plain'])).digest() # hack to call hashlib.hash_function plaintext_suffix = bytes( b'\x00') + hash_asn1[hash_function] + hash_callable if b2i(plaintext_suffix) & 1 != 1: log.error( "Plaintext suffix is even, can't compute signature") continue # compute suffix signature_suffix = 0b1 for b in range(len(plaintext_suffix) * 8): if (signature_suffix** 3) & (1 << b) != b2i(plaintext_suffix) & (1 << b): signature_suffix |= 1 << b signature_suffix = i2b( signature_suffix)[-len(plaintext_suffix):] # compute prefix while True: plaintext_prefix = bytes(b'\x00\x01\xff') + random_bytes( key.size // 8 - 3) signature_prefix, _ = gmpy2.iroot(b2i(plaintext_prefix), key.e) signature_prefix = i2b( int(signature_prefix), size=key.size)[:-len(signature_suffix)] signature = b2i(signature_prefix + signature_suffix) test_plaintext = i2b(gmpy2.powmod(signature, key.e, key.n), size=key.size) if bytes(b'\x00' ) not in test_plaintext[2:-len(plaintext_suffix)]: if test_plaintext[:3] == plaintext_prefix[:3] and test_plaintext[ -len(plaintext_suffix):] == plaintext_suffix: log.info("Got signature: {}".format(signature)) key.texts[text_no]['cipher'] = signature signatures[text_no] = signature break else: log.error("Something wrong, signature={}," " signature**{}%{} is {}".format( signature, key.e, key.n, [(test_plaintext)])) break return signatures
def blinding(key, signing_oracle=None, decryption_oracle=None): """Perform signature/ciphertext blinding attack Args: key(RSAKey): with at least one plaintext(to sign) or ciphertext(to decrypt) signing_oracle(callable) decryption_oracle(callable) Returns: dict: {index: signature/plaintext, index2: signature/plaintext} update key texts """ if not signing_oracle and not decryption_oracle: log.critical_error("Give one of signing_oracle or decryption_oracle") if signing_oracle and decryption_oracle: log.critical_error( "Give only one of signing_oracle or decryption_oracle") recovered = {} if signing_oracle: log.debug("Have signing_oracle") for text_no in range(len(key.texts)): if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[ text_no]: log.info("Blinding signature of plaintext no {} ({})".format( text_no, i2h(key.texts[text_no]['plain']))) blind = random.randint(2, 100) blind_enc = key.encrypt(blind) blinded_plaintext = (key.texts[text_no]['plain'] * blind_enc) % key.n blinded_signature = signing_oracle(blinded_plaintext) if not blinded_signature: log.critical_error( "Error during call to signing_oracle({})".format( blinded_plaintext)) signature = (invmod(blind, key.n) * blinded_signature) % key.n key.texts[text_no]['cipher'] = signature recovered[text_no] = signature log.success("Signature: {}".format(signature)) if decryption_oracle: log.debug("Have decryption_oracle") for text_no in range(len(key.texts)): if 'cipher' in key.texts[text_no] and 'plain' not in key.texts[ text_no]: log.info("Blinding ciphertext no {} ({})".format( text_no, key.texts[text_no]['cipher'])) blind = random.randint(2, 100) blind_enc = key.encrypt(blind) blinded_ciphertext = (key.texts[text_no]['cipher'] * blind_enc) % key.n blinded_plaintext = decryption_oracle(blinded_ciphertext) if not blinded_plaintext: log.critical_error( "Error during call to decryption_oracle({})".format( blinded_plaintext)) plaintext = (invmod(blind, key.n) * blinded_plaintext) % key.n key.texts[text_no]['plain'] = plaintext recovered[text_no] = plaintext log.success("Plaintext: {}".format(plaintext)) return recovered
def hastad(keys, ciphertexts=None): """Hastad's broadcast attack (small public exponent) Given at least e keys with public exponent equals to e and ciphertexts of the same plaintext, plaintext can be efficiently recovered Args: keys(list): RSAKeys, all with same public exponent e, len(keys) >= e, every key with only one ciphertext ciphertexts(list/None): if not None, use this ciphertexts Returns: NoneType/int: None on failure, recovered plaintext otherwise update keys texts """ e = keys[0].e if len(keys) < e: log.critical_error("Not enough keys, e={}".format(e)) if ciphertexts is None: for key in keys: if len(key.texts) != 1: log.info( "Key have more than one ciphertext, using the first one(key=={})" .format(key.identifier)) if 'cipher' not in key.texts[0]: log.critical_error("key {} doesn't have ciphertext".format( key.identifier)) # prepare ciphertexts and correct_keys lists ciphertexts, modules, correct_keys = [], [], [] for key in keys: # get only first ciphertext (if exists) if key.n not in modules and key.texts[0][ 'cipher'] not in ciphertexts: if key.e == e: modules.append(key.n) correct_keys.append(key) ciphertexts.append(key.texts[0]['cipher']) else: log.info("Key {} have different e(={})".format( key.identifier, key.e)) else: if len(ciphertexts) != len(keys): log.critical_error("len(ciphertexts) != len(keys)") modules = [key.n for key in keys] correct_keys = keys # check if we have enough ciphertexts if len(modules) < e: log.info( "Not enough keys with unique modulus and ciphertext, e={}, len(modules)={}" .format(e, len(modules))) log.info("Checking for simple roots (small_e_msg)") for one_key in correct_keys: recovered_plaintexts = small_e_msg(one_key) if len(recovered_plaintexts) > 0: log.success("Found plaintext: {}".format( recovered_plaintexts[0])) return recovered_plaintexts[0] if len(modules) > e: log.debug("Number of modules/ciphertexts larger than e") modules = modules[:e] ciphertexts = ciphertexts[:e] # actual Hastad result = crt(ciphertexts, modules) plaintext, correct = gmpy2.iroot(result, e) if correct: plaintext = int(plaintext) log.success("Found plaintext: {}".format(plaintext)) for one_key in correct_keys: one_key.texts[0]['plain'] = plaintext return plaintext else: log.debug("Plaintext wasn't {}-th root") log.debug("result (from crt) = {}".format(e, result)) log.debug("plaintext ({}-th root of result) = {}".format(e, plaintext)) return None
def manger(oaep_padding_oracle, key, ciphertext=None, **kwargs): """Given oracle that checks if ciphertext decrypts to some valid plaintext with OAEP padding we can decrypt whole ciphertext oaep_padding_oracle function must be implemented https://iacr.org/archive/crypto2001/21390229.pdf Args: oaep_padding_oracle(callable) key(RSAKey): contains ciphertexts to decrypt Returns: dict: decrypted ciphertexts update key texts """ def ceil(a, b): return a // b + (a % b > 0) def floor(a, b): return a // b recovered = {} for text_no, cipher in _prepare_ciphertexts(key=key, ciphertext=ciphertext): log.info("Decrypting {}".format(cipher)) n = key.n e = key.e B = pow(2, key.size - 8) # step 1 log.debug('step 1') f1 = 2 cipheri = (cipher * gmpy2.powmod(f1, e, n)) % n while oaep_padding_oracle(cipheri, **kwargs): f1 *= 2 cipheri = (cipher * gmpy2.powmod(f1, e, n)) % n log.debug('step 1 done') log.debug('Found f1: {}'.format(hex(f1))) # step 2 log.debug('step 2') f1_half = f1 // 2 f2 = int(floor(n + B, B)) * f1_half cipheri = (cipher * gmpy2.powmod(f2, e, n)) % n while not oaep_padding_oracle(cipheri, **kwargs): f2 += f1_half cipheri = (cipher * gmpy2.powmod(f2, e, n)) % n log.debug('step 2 done') log.debug('Found f2: {}'.format(hex(f2))) # step 3 log.debug('step 3') m_min, m_max = ceil(n, f2), floor(n + B, f2) while m_min < m_max: log.debug(hex(m_max - m_min)) f_tmp = floor(2 * B, m_max - m_min) i = floor(f_tmp * m_min, n) f3 = ceil(i * n, m_min) cipheri = (cipher * gmpy2.powmod(f3, e, n)) % n if oaep_padding_oracle(cipheri, **kwargs): m_max = floor(i * n + B, f3) else: m_min = ceil(i * n + B, f3) log.debug('step 3 done') log.debug('m_min = {}'.format(hex(m_min))) log.debug('m_max = {}'.format(hex(m_max))) plaintext = m_min log.success("plaintext = {}".format(hex(plaintext))) recovered[text_no] = plaintext key.texts[text_no]['plain'] = recovered[text_no] return recovered
def guess_key_size(ciphertext, max_key_size=40): """Given sentence xored with short key, guess key size From: http://trustedsignal.blogspot.com/2015/06/xord-play-normalized-hamming-distance.html Args: ciphertext(string) max_key_size(int) Returns: list: sorted list of tuples (key_size, probability), note that most probable key size not necessary have the largest probability """ if not max_key_size: max_key_size = len(ciphertext)/4 result = {} for key_size in range(1, max_key_size): blocks = re.findall('.' * key_size, ciphertext, re.DOTALL) if len(blocks) < 2: break diff = i = 0 while i < len(blocks) - 1: if len(blocks[i]) != len(blocks[i + 1]): # not full-length block break diff += hamming_distance(blocks[i], blocks[i + 1]) i += 1 result[key_size] = diff / float(i) # average result[key_size] /= float(key_size) # normalize result = sorted(list(result.items()), key=operator.itemgetter(1)) # now part from given link, case one # gcd12 = gcd(result[0][0], result[1][0]) # gcd13 = gcd(result[0][0], result[2][0]) # gcd23 = gcd(result[1][0], result[2][0]) # print gcd12, gcd13, gcd23 # if (gcd12 != 1) and (gcd12 in [x[0] for x in result[:5]]): # if (gcd12 == gcd13 and gcd12 == gcd23) or (gcd12 == result[0][0] or gcd12 == result[1][0]): # #remove key_size == gcd12 from result list and add it to the beginning # for x in result: # if x[0] == gcd12: # result.remove(x) # break # result[0] == (gcd12, 1.0) # from link, case two; yep, black magic it is gcd_frequencies = defaultdict(lambda: 0) for gcd_pairs in itertools.combinations(result[:10], 2): gcd_tmp = gcd(gcd_pairs[0][0], gcd_pairs[1][0]) gcd_frequencies[gcd_tmp] += 1 gcd_frequencies = sorted(list(gcd_frequencies.items()), key=operator.itemgetter(1), reverse=True) key_sizes = [x[0] for x in result[:10]] distances = [x[1] for x in result[:10]] for guessed_most_probable_key_size in gcd_frequencies[:5]: if guessed_most_probable_key_size[0] != 1 and guessed_most_probable_key_size[1] != 0 and \ guessed_most_probable_key_size[0] in key_sizes: gmks_position = result[key_sizes.index(guessed_most_probable_key_size[0])] if gmks_position[1] < max(distances): result.remove(gmks_position) result = [gmks_position] + result log.info("Guessed key size: {}".format(result)) return result
def decrypt(ciphertext, padding_oracle=None, decryption_oracle=None, iv=None, block_size=16, is_correct=True, amount=0, known_plaintext=None, async_calls=False): """Decrypt ciphertext Give padding_oracle or decryption_oracle (or both) Args: ciphertext(string): to decrypt padding_oracle(function/None) decryption_oracle(function/None) iv(string): if not specified, first block of ciphertext is treated as iv block_size(int) is_correct(bool): set if ciphertext will decrypt to something with correct padding amount(int): how much blocks decrypt (counting from last), zero (default) means all known_plaintext(string): with padding, from end (aligned to end of ciphertext) async_calls(bool): make asynchronous calls to oracle (not implemented yet) Returns: plaintext(string): with padding """ _check_oracles(padding_oracle=padding_oracle, decryption_oracle=decryption_oracle, block_size=block_size) if block_size % 8 != 0: log.critical_error("Incorrect block size: {}".format(block_size)) if len(ciphertext) % block_size != 0: log.critical_error("Incorrect ciphertext length: {}".format( len(ciphertext))) if decryption_oracle: if iv: ciphertext = iv + ciphertext blocks = chunks(ciphertext, block_size) plaintext = bytes(b'') for position in range(len(blocks) - 1, 0, -1): plaintext = xor(decryption_oracle(blocks[position]), blocks[position - 1]) + plaintext log.info("Plaintext(hex): {}".format(b2h(plaintext))) if amount != 0 and len(plaintext) == amount * block_size: break log.success("Decrypted(hex): {}".format(b2h(plaintext))) return plaintext log.info("Start cbc padding oracle") log.debug(print_chunks(chunks(ciphertext, block_size))) # prepare blocks blocks = chunks(ciphertext, block_size) if iv: if len(iv) % block_size != 0: log.critical_error("Incorrect iv length: {}".format(len(iv))) log.info("Set iv") blocks.insert(0, iv) if amount != 0: amount = len(blocks) - amount - 1 if amount < 0 or amount >= len(blocks): log.critical_error( "Incorrect amount of blocks to decrypt: {} (have to be in [0,{}]". format(amount, len(blocks) - 1)) log.info("Will decrypt {} block(s)".format(len(blocks) - 1 - amount)) # add known plaintext plaintext = bytes(b'') position_known = 0 chars_decoded = 0 if known_plaintext: is_correct = False plaintext = known_plaintext blocks_decoded = len(plaintext) // block_size chars_decoded = len(plaintext) % block_size if blocks_decoded == len(blocks) - 1: log.debug("Nothing decrypted, known plaintext long enough") return plaintext if blocks_decoded > len(blocks) - 1: log.critical_error( "Too long known plaintext ({} blocks)".format(blocks_decoded)) if blocks_decoded != 0: blocks = blocks[:-blocks_decoded] position_known = chars_decoded log.info("Have known plaintext, skip {} block(s) and {} bytes".format( blocks_decoded, chars_decoded)) # start decryption for count_block in range(len(blocks) - 1, amount, -1): """ Blocks from the last to the second (all except iv) """ log.info("Block no. {}".format(count_block)) payload_prefix = bytes(b''.join(blocks[:count_block - 1])) payload_modify = blocks[count_block - 1] payload_decrypt = blocks[count_block] if chars_decoded != 0: # we know some chars, so modify previous block payload_modify = payload_modify[:-chars_decoded] +\ xor(plaintext[:chars_decoded], payload_modify[-chars_decoded:], bytes([chars_decoded + 1])) chars_decoded = 0 position = block_size - 1 - position_known position_known = 0 while position >= 0: """ Every position in block, from the end """ log.debug("Position: {}".format(position)) found_correct_char = False for guess_char in range(256): modified = payload_modify[:position] + bytes( [guess_char]) + payload_modify[position + 1:] payload = bytes(b''.join( [payload_prefix, modified, payload_decrypt])) iv = payload[:block_size] payload = payload[block_size:] log.debug(print_chunks(chunks(iv + payload, block_size))) correct = padding_oracle(payload=payload, iv=iv) if correct: """ oracle returns True """ padding = block_size - position # sent ciphertext decoded to that padding decrypted_char = bytes( [payload_modify[position] ^ guess_char ^ padding]) if is_correct: """ If we didn't send original ciphertext, then we have found original padding value. Otherwise keep searching and if won't find any other correct char - padding is \x01 """ if guess_char == blocks[-2][-1]: log.debug( "Skip this guess char ({})".format(guess_char)) continue dc = int(decrypted_char[0]) log.info( "Found padding value for correct ciphertext: {}". format(dc)) if dc == 0 or dc > block_size: log.critical_error( "Found bad padding value (given ciphertext may not be correct)" ) plaintext = decrypted_char * dc payload_modify = payload_modify[:-dc] + xor( payload_modify[-dc:], decrypted_char, bytes([dc + 1])) position = position - dc + 1 is_correct = False else: """ abcd efgh ijkl o|guess_char|xy || 1234 5678 9tre qwer - ciphertext what ever itma ybex || xyzw rtua lopo k|\x03|\x03\x03 - plaintext abcd efgh ijkl |guess_char|wxy || 1234 5678 9tre qwer - next round ciphertext some thin gels eheh || xyzw rtua lopo guessing|\x04\x04\x04 - next round plaintext """ if position == block_size - 1: """ if we decrypt first byte, check if we didn't hit other padding than \x01 """ payload = iv + payload payload = payload[:-block_size - 2] + bytes( b'A') + payload[-block_size - 1:] iv = payload[:block_size] payload = payload[block_size:] correct = padding_oracle(payload=payload, iv=iv) if not correct: log.debug("Hit false positive, guess char({})". format(guess_char)) continue payload_modify = payload_modify[:position] + xor( bytes([guess_char]) + payload_modify[position + 1:], bytes([padding]), bytes([padding + 1])) plaintext = decrypted_char + plaintext found_correct_char = True log.debug( "Guessed char(\\x{:02x}), decrypted char(\\x{:02x})". format(guess_char, decrypted_char[0])) log.debug("Plaintext: {}".format(plaintext)) log.info("Plaintext(hex): {}".format(b2h(plaintext))) break position -= 1 if found_correct_char is False: if is_correct: padding = 0x01 payload_modify = payload_modify[:position + 1] + xor( payload_modify[position + 1:], bytes([padding]), bytes([padding + 1])) plaintext = bytes(b"\x01") is_correct = False else: log.critical_error( "Can't find correct padding (oracle function return False 256 times)" ) log.success("Decrypted(hex): {}".format(b2h(plaintext))) return plaintext