def small_e_msg(key, ciphertexts=None, max_times=100): """If both e and plaintext are small, ciphertext may exceed modulus only a little Args: key(RSAKey): with small e, at least one ciphertext ciphertexts(list) max_times(int): how many times plaintext**e exceeded modulus maximally Returns: list: recovered plaintexts """ ciphertexts = get_mutable_texts(key, ciphertexts) recovered = [] for ciphertext in ciphertexts: log.debug("Find msg for ciphertext {}".format(ciphertext)) times = 0 for k in range(max_times): msg, is_correct = gmpy2.iroot(ciphertext + times, key.e) if is_correct and gmpy2.powmod(msg, key.e, key.n) == ciphertext: msg = int(msg) log.success("Found msg: {}, times=={}".format( i2h(msg), times // key.n)) recovered.append(msg) break times += key.n return recovered
def common_primes(keys): """Find common prime in keys modules Args: keys(list): RSAKeys Returns: list: RSAKeys for which factorization of n was found """ priv_keys = [] for pair in itertools.combinations(keys, 2): prime = gmpy2.gcd(pair[0].n, pair[1].n) if prime != 1: log.success("Found common prime in: {}, {}".format( pair[0].identifier, pair[1].identifier)) for key_no in range(2): if pair[key_no] not in priv_keys: d = int( invmod(pair[key_no].e, (prime - 1) * (pair[key_no].n // prime - 1))) new_key = RSAKey.construct( int(pair[key_no].n), int(pair[key_no].e), int(d), identifier=pair[key_no].identifier + '-private') new_key.texts = pair[key_no].texts[:] priv_keys.append(new_key) else: log.debug("Key {} already in priv_keys".format( pair[key_no].identifier)) return priv_keys
def _prepare_ciphertexts(key=None, ciphertext=None, ciphertexts=None): """Helper function used in various rsa functions. Update key (if provided) and yield ciphertexts to crack Yields: tuple: (text_no, ciphertext_to_crack) """ if ciphertexts is None: ciphertexts = [] if ciphertext is not None: ciphertexts.append(ciphertext) if key is None: for cipher in ciphertexts: yield None, cipher else: for cipher in ciphertexts: matching_texts = [(text_no, text) for text_no, text in enumerate(key.texts) if 'cipher' in text and text['cipher'] == cipher] if len(matching_texts) == 0: key.add_ciphertext(cipher) yield len(key.texts) - 1, key.texts[len(key.texts) - 1]['cipher'] else: assert len(matching_texts) == 1 text_no, text = matching_texts if 'plain' in text: log.success( "Plaintext for ciphertext {cipher} already known: {plain}!" .format(**text)) else: yield text_no, text[text_no]['cipher']
def fake_ciphertext(new_plaintext, padding_oracle=None, decryption_oracle=None, block_size=16): """Make ciphertext that will decrypt to given plaintext Give padding_oracle or decryption_oracle (or both) Args: new_plaintext(string): with padding padding_oracle(function/None) decryption_oracle(function/None): maximum one block to decrypt block_size(int) Returns: fake_ciphertext(string): fake ciphertext that will decrypt to new_plaintext """ _check_oracles(padding_oracle=padding_oracle, decryption_oracle=decryption_oracle, block_size=block_size) if block_size % 8 != 0: log.critical_error("Incorrect block size: {}".format(block_size)) log.info("Start fake ciphertext") ciphertext = bytes(b'A' * (len(new_plaintext) + block_size)) # prepare blocks blocks = chunks(ciphertext, block_size) new_pl_blocks = chunks(new_plaintext, block_size) if len(new_pl_blocks) != len(blocks) - 1: log.critical_error( "Wrong new plaintext length({}), should be {}".format( len(new_plaintext), block_size * (len(blocks) - 1))) new_ct_blocks = list(blocks) # add known plaintext for count_block in range(len(blocks) - 1, 0, -1): """ Every block, modify block[count_block-1] to set block[count_block] """ log.info("Block no. {}".format(count_block)) ciphertext_to_decrypt = bytes(b''.join(new_ct_blocks[:count_block + 1])) original_plaintext = decrypt(ciphertext_to_decrypt, padding_oracle=padding_oracle, decryption_oracle=decryption_oracle, block_size=block_size, amount=1, is_correct=False) log.info("Set block no. {}".format(count_block)) new_ct_blocks[count_block - 1] = xor(blocks[count_block - 1], original_plaintext, new_pl_blocks[count_block - 1]) fake_ciphertext_res = bytes(b''.join(new_ct_blocks)) log.success("Fake ciphertext(hex): {}".format(b2h(fake_ciphertext_res))) return fake_ciphertext_res
def parity(parity_oracle, key, min_lower_bound=None, max_upper_bound=None): """Given oracle that returns LSB of decrypted ciphertext we can decrypt whole ciphertext parity_oracle function must be implemented Args: parity_oracle(callable) key(RSAKey): contains ciphertexts to decrypt min_lower_bound(None/int) max_upper_bound(None/int) Returns: dict: decrypted ciphertexts update key texts """ recovered = {} for text_no in range(len(key.texts)): if 'cipher' in key.texts[text_no] and 'plain' not in key.texts[text_no]: cipher = key.texts[text_no]['cipher'] log.info("Decrypting {}".format(cipher)) two_encrypted = key.encrypt(2) counter = lower_bound = numerator = 0 upper_bound = key.n denominator = 1 while lower_bound + 1 < upper_bound: cipher = (two_encrypted * cipher) % key.n denominator *= 2 numerator *= 2 counter += 1 if max_upper_bound is not None and upper_bound > max_upper_bound: is_odd = 0 else: # todo: check below if min_lower_bound is not None and lower_bound < min_lower_bound: is_odd = 1 else: is_odd = parity_oracle(cipher) if is_odd: # plaintext > n/(2**counter) numerator += 1 lower_bound = (key.n * numerator) // denominator upper_bound = (key.n * (numerator + 1)) // denominator log.debug("{} {} [{}, {}]".format(counter, is_odd, int(lower_bound), int(upper_bound))) log.debug("{}/{} - {}/{}\n".format(numerator, denominator, numerator + 1, denominator)) log.success("Decrypted: {}".format(i2h(upper_bound))) key.texts[text_no]['plain'] = upper_bound recovered[text_no] = upper_bound return recovered
def iv_as_key(ciphertext, plaintext, padding_oracle=None, decryption_oracle=None, block_size=16): """If iv is used as key, we can recover it using decryption oracle, padding oracle or known plaintext (if first ciphertext block is repeated) Args: ciphertext(string): first block must be AES.encrypt(iv xor plaintext[0]) plaintext(string): with padding padding_oracle(function/None) decryption_oracle(function/None) block_size(int) Returns: string: key (== iv) """ key = None ciphertext = chunks(ciphertext, block_size) plaintext = chunks(plaintext, block_size) try: position_second = ciphertext.index(ciphertext[0], 1) log.debug("Position of the same block as the first is {}".format( position_second)) key = xor(plaintext[0], ciphertext[position_second - 1], plaintext[position_second]) except ValueError: log.debug( "first ciphertext block is not repeated, will use decryption/padding oracle" ) if key is None: iv = bytes(b'A' * block_size) iv_xor_plaintext0 = decrypt(ciphertext[0], padding_oracle=padding_oracle, decryption_oracle=decryption_oracle, iv=iv, block_size=block_size, is_correct=False, amount=1) iv_xor_plaintext0 = xor(iv_xor_plaintext0, iv) key = xor(iv_xor_plaintext0, plaintext[0]) log.success("Key(hex): {}".format(b2h(key))) return key
def compute_params(s, m=None, a=None, b=None): """Compute parameters and initial seed for LCG prng next_state = a*state + b mod m Args: s(list): subsequent outputs from LCG oracle starting with seed m(int/None) a(int/None) b(int/None) Returns: a, b, m(int) """ if m is None: t = [s[n + 1] - s[n] for n in range(len(s) - 1)] u = [abs(t[n + 2] * t[n] - t[n + 1]**2) for n in range(len(t) - 2)] m = gcd(*u) log.success("m = {}".format(m)) if a is None: if gcd(s[1] - s[0], m) == 1: a = (s[2] - s[1]) * invmod(s[1] - s[0], m) elif gcd(s[2] - s[0], m) == 1: a = (s[3] - s[1]) * invmod(s[2] - s[0], m) else: log.critical_error("a not found") a = a % m log.success("a = {}".format(a)) if b is None: b = (s[1] - s[0] * a) % m log.success("b = {}".format(b)) return a, b, m
def test_manger(): keys = [key_64, key_256, key_1024, key_2048] for key in keys: manger_padding_oracle_calls = [0] plaintext = randint(2, key.n) >> 8 ciphertext = h2b( subprocess.check_output([ "python", rsa_oracles_path, "encrypt", key.identifier, i2h(plaintext) ]).strip().decode()) msgs_recovered = manger( oaep_padding_oracle, key.publickey(), ciphertext, oracle_key=key, manger_padding_oracle_calls=manger_padding_oracle_calls) log.success('For keysize {}: oaep_padding_oracle_calls = {}'.format( key.size, manger_padding_oracle_calls[0])) assert msgs_recovered[0] == plaintext key.clear_texts()
def find_prefix_suffix_size(encryption_oracle, block_size=16): """Determine prefix and suffix sizes if ecb mode, sizes must be constant Rarely may fail (if random data that are send unhappily matches prefix/suffix) Args: encryption_oracle(callable) block_size(int) Returns: tuple(int,int): prefix_size, suffix_size """ blocks_to_send = 5 payload = random_bytes(1) * (blocks_to_send * block_size) enc_chunks = chunks(encryption_oracle(payload), block_size) log.debug("Encryption of length {}".format(blocks_to_send * block_size)) log.debug(print_chunks(enc_chunks)) for position_start in range(len(enc_chunks) - 1): if enc_chunks[position_start] == enc_chunks[position_start + 1]: for y in range(2, blocks_to_send - 1): if enc_chunks[position_start] != enc_chunks[position_start + y]: break else: log.success("Controlled payload start at chunk {}".format(position_start)) break else: log.critical_error("Position of controlled chunks not found") log.info('Finding prefix') changed_char = bytes([(payload[0] - 1)%256]) for aligned_bytes in range(block_size): payload_new = payload[:aligned_bytes] + changed_char + payload[aligned_bytes+1:] enc_chunks_new = chunks(encryption_oracle(payload_new), block_size) log.debug(print_chunks(chunks(payload_new, block_size))) log.debug(print_chunks(enc_chunks_new)) if enc_chunks_new[position_start] != enc_chunks[position_start]: prefix_size = position_start*block_size - aligned_bytes log.success("Prefix size: {}".format(prefix_size)) break else: log.critical_error("Size of prefix not found") log.info('Finding suffix') payload = random_bytes(1) * (block_size - (prefix_size % block_size)) # align to block_size encrypted = encryption_oracle(payload) suffix_size = len(encrypted) - len(payload) - prefix_size while True: payload += random_bytes(1) suffix_size -= 1 if len(encryption_oracle(payload)) > len(encrypted): log.success("Suffix size: {}".format(suffix_size)) break else: log.critical_error("Size of suffix not found") return prefix_size, suffix_size
def decrypt(encryption_oracle, constant=True, block_size=16, prefix_size=None, secret_size=None, alphabet=None): """Given encryption oracle which produce ecb(prefix || our_input || secret), find secret Args: encryption_oracle(callable) constant(bool): True if prefix have constant length (secret must have constant length) block_size(int/None) prefix_size(int/None) secret_size(int/None) alphabet(string): plaintext space Returns: secret(string) """ log.debug("Start decrypt function") if not alphabet: alphabet = bytes(string.printable.encode()) if not block_size: block_size = find_block_size(encryption_oracle, constant) if constant: log.debug("constant == True") if not prefix_size or not secret_size: prefix_size, secret_size = find_prefix_suffix_size(encryption_oracle, block_size) """Start decrypt""" secret = bytes(b'') aligned_bytes = random_bytes(1) * (block_size - (prefix_size % block_size)) if len(aligned_bytes) == block_size: aligned_bytes = bytes(b'') aligned_bytes_suffix = random_bytes(1) * (block_size - (secret_size % block_size)) if len(aligned_bytes_suffix) == block_size: aligned_bytes_suffix = bytes(b'') block_to_find_position = -1 controlled_block_position = (prefix_size+len(aligned_bytes)) // block_size while len(secret) < secret_size: if (len(secret)+1) % block_size == 0: block_to_find_position -= 1 payload = aligned_bytes + aligned_bytes_suffix + random_bytes(1) + secret enc_chunks = chunks(encryption_oracle(payload), block_size) block_to_find = enc_chunks[block_to_find_position] log.debug("To guess at position {}:".format(block_to_find_position)) log.debug("Plain: " + print_chunks(chunks(bytes(b'P'*prefix_size) + payload + bytes(b'S'*secret_size), block_size))) log.debug("Encry: " + print_chunks(enc_chunks)+"\n") for guessed_char in range(256): guessed_char = bytes([guessed_char]) payload = aligned_bytes + add_padding(guessed_char + secret, block_size) enc_chunks = chunks(encryption_oracle(payload), block_size) log.debug("Plain: " + print_chunks(chunks(bytes(b'P'*prefix_size) + payload + bytes(b'S'*secret_size), block_size))) log.debug("Encry: " + print_chunks(enc_chunks)+"\n") if block_to_find == enc_chunks[controlled_block_position]: secret = guessed_char + secret log.debug("Found char, secret={}".format(repr(secret))) break else: log.critical_error("Char not found, try change alphabet. Secret so far: {}".format(repr(secret))) log.success("Secret(hex): {}".format(b2h(secret))) return secret else: log.debug("constant == False")
def bleichenbacher_pkcs15(pkcs15_padding_oracle, key, ciphertext=None, incremental_blinding=False, **kwargs): """Given oracle that checks if ciphertext decrypts to some valid plaintext with PKCS1.5 padding we can decrypt whole ciphertext pkcs15_padding_oracle function must be implemented http://archiv.infsec.ethz.ch/education/fs08/secsem/bleichenbacher98.pdf https://www.dsi.unive.it/~focardi/RSA-padding-oracle/#eq5 Note that this attack is very slow. Approximate number of main loop iterations == key's bit length Args: pkcs15_padding_oracle(callable) key(RSAKey): contains ciphertexts to decrypt incremental_blinding(bool): if ciphertext is not pkcs confirming we need to blind it. this may be done using random or incremental values Returns: dict: decrypted ciphertexts update key texts """ def ceil(a, b): return a // b + (a % b > 0) def floor(a, b): return a // b def insert_interval(M, lb, ub): lo, hi = 0, len(M) while lo < hi: mid = (lo + hi) // 2 if M[mid][0] < lb: lo = mid + 1 else: hi = mid # insert it M.insert(lo, (lb, ub)) # lb inside previous interval if lo > 0 and M[lo - 1][1] >= lb: lb = min(lb, M[lo - 1][0]) M[lo] = (lb, M[lo][1]) del M[lo - 1] lo -= 1 # remove covered intervals i = lo + 1 to_remove_first = i to_remove_last = lo while i < len(M) and M[i][0] <= ub: to_remove_last += 1 i += 1 if to_remove_last > lo: new_ub = max(ub, M[to_remove_last][1]) M[lo] = (M[lo][0], new_ub) del M[to_remove_first:to_remove_last + 1] def update_intervals(M, B, s, n): # step 3 M2 = [] for a, b in M: r_min = ceil(a * s - 3 * B + 1, n) r_max = floor(b * s - 2 * B, n) for r in range(r_min, r_max + 1): lb = max(a, ceil(2 * B + r * n, s)) ub = min(b, floor(3 * B - 1 + r * n, s)) insert_interval(M2, lb, ub) del M[:] return M2 def find_si(si_start, si_max=None): si_new = si_start while si_max is None or si_new < si_max: cipheri = (cipher_blinded * gmpy2.powmod(si_new, e, n)) % n if pkcs15_padding_oracle(cipheri, **kwargs): return si_new si_new += 1 return None recovered = {} for text_no, cipher in _prepare_ciphertexts(key=key, ciphertext=ciphertext): log.info("Decrypting {}".format(cipher)) n = key.n e = key.e B = gmpy2.pow(2, key.size - 16) # step 1 log.debug('Blinding the ciphertext (to make it PKCS1.5 confirming)') i = 0 si = 1 cipher_blinded = cipher while not pkcs15_padding_oracle(cipher_blinded, **kwargs): # the paper says to draw it, but seems like incrementation sometimes run faster if incremental_blinding: si += 1 else: si = random.randint(2, 1 << (key.size - 16)) cipher_blinded = (cipher * gmpy2.powmod(si, e, n)) % n Mi = [(2 * B, 3 * B - 1)] s0 = si log.debug('Found s{}: {}'.format(i, hex(si))) i = 1 plaintext = None while plaintext is None: log.debug('len(M{}): {}'.format(i - 1, len(Mi))) if i == 1: # step 2.a si = find_si(si_start=ceil(n, (3 * B))) log.debug('Found s{}: {}'.format(i, hex(si))) elif len(Mi) > 1: # step 2.b si = find_si(si_start=si + 1) elif len(Mi) == 1 and Mi[0][0] != Mi[0][1]: # step 2.c a, b = Mi[0] ri = ceil(2 * (b * si - 2 * B), n) si = None while si is None: si_min = ceil(2 * B + ri * n, b) si_max = ceil(3 * B + ri * n, a) si = find_si(si_start=si_min, si_max=si_max) ri += 1 else: log.error( "Hm, something strange happend. Len(M{}) = {}".format( i, len(Mi))) return None # step 3 Mi = update_intervals(Mi, B, si, n) # step 4 if len(Mi) == 1 and Mi[0][0] == Mi[0][1]: plaintext = Mi[0][0] if s0 != 1: plaintext = (plaintext * invmod(s0, n)) % n log.success("Interval narrowed to one value") log.success("plaintext = {}".format(hex(plaintext))) else: i += 1 recovered[text_no] = plaintext key.texts[text_no]['plain'] = recovered[text_no] return recovered
def blinding(key, signing_oracle=None, decryption_oracle=None): """Perform signature/ciphertext blinding attack Args: key(RSAKey): with at least one plaintext(to sign) or ciphertext(to decrypt) signing_oracle(callable) decryption_oracle(callable) Returns: dict: {index: signature/plaintext, index2: signature/plaintext} update key texts """ if not signing_oracle and not decryption_oracle: log.critical_error("Give one of signing_oracle or decryption_oracle") if signing_oracle and decryption_oracle: log.critical_error( "Give only one of signing_oracle or decryption_oracle") recovered = {} if signing_oracle: log.debug("Have signing_oracle") for text_no in range(len(key.texts)): if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[ text_no]: log.info("Blinding signature of plaintext no {} ({})".format( text_no, i2h(key.texts[text_no]['plain']))) blind = random.randint(2, 100) blind_enc = key.encrypt(blind) blinded_plaintext = (key.texts[text_no]['plain'] * blind_enc) % key.n blinded_signature = signing_oracle(blinded_plaintext) if not blinded_signature: log.critical_error( "Error during call to signing_oracle({})".format( blinded_plaintext)) signature = (invmod(blind, key.n) * blinded_signature) % key.n key.texts[text_no]['cipher'] = signature recovered[text_no] = signature log.success("Signature: {}".format(signature)) if decryption_oracle: log.debug("Have decryption_oracle") for text_no in range(len(key.texts)): if 'cipher' in key.texts[text_no] and 'plain' not in key.texts[ text_no]: log.info("Blinding ciphertext no {} ({})".format( text_no, key.texts[text_no]['cipher'])) blind = random.randint(2, 100) blind_enc = key.encrypt(blind) blinded_ciphertext = (key.texts[text_no]['cipher'] * blind_enc) % key.n blinded_plaintext = decryption_oracle(blinded_ciphertext) if not blinded_plaintext: log.critical_error( "Error during call to decryption_oracle({})".format( blinded_plaintext)) plaintext = (invmod(blind, key.n) * blinded_plaintext) % key.n key.texts[text_no]['plain'] = plaintext recovered[text_no] = plaintext log.success("Plaintext: {}".format(plaintext)) return recovered
def hastad(keys, ciphertexts=None): """Hastad's broadcast attack (small public exponent) Given at least e keys with public exponent equals to e and ciphertexts of the same plaintext, plaintext can be efficiently recovered Args: keys(list): RSAKeys, all with same public exponent e, len(keys) >= e, every key with only one ciphertext ciphertexts(list/None): if not None, use this ciphertexts Returns: NoneType/int: None on failure, recovered plaintext otherwise update keys texts """ e = keys[0].e if len(keys) < e: log.critical_error("Not enough keys, e={}".format(e)) if ciphertexts is None: for key in keys: if len(key.texts) != 1: log.info( "Key have more than one ciphertext, using the first one(key=={})" .format(key.identifier)) if 'cipher' not in key.texts[0]: log.critical_error("key {} doesn't have ciphertext".format( key.identifier)) # prepare ciphertexts and correct_keys lists ciphertexts, modules, correct_keys = [], [], [] for key in keys: # get only first ciphertext (if exists) if key.n not in modules and key.texts[0][ 'cipher'] not in ciphertexts: if key.e == e: modules.append(key.n) correct_keys.append(key) ciphertexts.append(key.texts[0]['cipher']) else: log.info("Key {} have different e(={})".format( key.identifier, key.e)) else: if len(ciphertexts) != len(keys): log.critical_error("len(ciphertexts) != len(keys)") modules = [key.n for key in keys] correct_keys = keys # check if we have enough ciphertexts if len(modules) < e: log.info( "Not enough keys with unique modulus and ciphertext, e={}, len(modules)={}" .format(e, len(modules))) log.info("Checking for simple roots (small_e_msg)") for one_key in correct_keys: recovered_plaintexts = small_e_msg(one_key) if len(recovered_plaintexts) > 0: log.success("Found plaintext: {}".format( recovered_plaintexts[0])) return recovered_plaintexts[0] if len(modules) > e: log.debug("Number of modules/ciphertexts larger than e") modules = modules[:e] ciphertexts = ciphertexts[:e] # actual Hastad result = crt(ciphertexts, modules) plaintext, correct = gmpy2.iroot(result, e) if correct: plaintext = int(plaintext) log.success("Found plaintext: {}".format(plaintext)) for one_key in correct_keys: one_key.texts[0]['plain'] = plaintext return plaintext else: log.debug("Plaintext wasn't {}-th root") log.debug("result (from crt) = {}".format(e, result)) log.debug("plaintext ({}-th root of result) = {}".format(e, plaintext)) return None
def manger(oaep_padding_oracle, key, ciphertext=None, **kwargs): """Given oracle that checks if ciphertext decrypts to some valid plaintext with OAEP padding we can decrypt whole ciphertext oaep_padding_oracle function must be implemented https://iacr.org/archive/crypto2001/21390229.pdf Args: oaep_padding_oracle(callable) key(RSAKey): contains ciphertexts to decrypt Returns: dict: decrypted ciphertexts update key texts """ def ceil(a, b): return a // b + (a % b > 0) def floor(a, b): return a // b recovered = {} for text_no, cipher in _prepare_ciphertexts(key=key, ciphertext=ciphertext): log.info("Decrypting {}".format(cipher)) n = key.n e = key.e B = pow(2, key.size - 8) # step 1 log.debug('step 1') f1 = 2 cipheri = (cipher * gmpy2.powmod(f1, e, n)) % n while oaep_padding_oracle(cipheri, **kwargs): f1 *= 2 cipheri = (cipher * gmpy2.powmod(f1, e, n)) % n log.debug('step 1 done') log.debug('Found f1: {}'.format(hex(f1))) # step 2 log.debug('step 2') f1_half = f1 // 2 f2 = int(floor(n + B, B)) * f1_half cipheri = (cipher * gmpy2.powmod(f2, e, n)) % n while not oaep_padding_oracle(cipheri, **kwargs): f2 += f1_half cipheri = (cipher * gmpy2.powmod(f2, e, n)) % n log.debug('step 2 done') log.debug('Found f2: {}'.format(hex(f2))) # step 3 log.debug('step 3') m_min, m_max = ceil(n, f2), floor(n + B, f2) while m_min < m_max: log.debug(hex(m_max - m_min)) f_tmp = floor(2 * B, m_max - m_min) i = floor(f_tmp * m_min, n) f3 = ceil(i * n, m_min) cipheri = (cipher * gmpy2.powmod(f3, e, n)) % n if oaep_padding_oracle(cipheri, **kwargs): m_max = floor(i * n + B, f3) else: m_min = ceil(i * n + B, f3) log.debug('step 3 done') log.debug('m_min = {}'.format(hex(m_min))) log.debug('m_max = {}'.format(hex(m_max))) plaintext = m_min log.success("plaintext = {}".format(hex(plaintext))) recovered[text_no] = plaintext key.texts[text_no]['plain'] = recovered[text_no] return recovered
def decrypt(ciphertext, padding_oracle=None, decryption_oracle=None, iv=None, block_size=16, is_correct=True, amount=0, known_plaintext=None, async_calls=False): """Decrypt ciphertext Give padding_oracle or decryption_oracle (or both) Args: ciphertext(string): to decrypt padding_oracle(function/None) decryption_oracle(function/None) iv(string): if not specified, first block of ciphertext is treated as iv block_size(int) is_correct(bool): set if ciphertext will decrypt to something with correct padding amount(int): how much blocks decrypt (counting from last), zero (default) means all known_plaintext(string): with padding, from end (aligned to end of ciphertext) async_calls(bool): make asynchronous calls to oracle (not implemented yet) Returns: plaintext(string): with padding """ _check_oracles(padding_oracle=padding_oracle, decryption_oracle=decryption_oracle, block_size=block_size) if block_size % 8 != 0: log.critical_error("Incorrect block size: {}".format(block_size)) if len(ciphertext) % block_size != 0: log.critical_error("Incorrect ciphertext length: {}".format( len(ciphertext))) if decryption_oracle: if iv: ciphertext = iv + ciphertext blocks = chunks(ciphertext, block_size) plaintext = bytes(b'') for position in range(len(blocks) - 1, 0, -1): plaintext = xor(decryption_oracle(blocks[position]), blocks[position - 1]) + plaintext log.info("Plaintext(hex): {}".format(b2h(plaintext))) if amount != 0 and len(plaintext) == amount * block_size: break log.success("Decrypted(hex): {}".format(b2h(plaintext))) return plaintext log.info("Start cbc padding oracle") log.debug(print_chunks(chunks(ciphertext, block_size))) # prepare blocks blocks = chunks(ciphertext, block_size) if iv: if len(iv) % block_size != 0: log.critical_error("Incorrect iv length: {}".format(len(iv))) log.info("Set iv") blocks.insert(0, iv) if amount != 0: amount = len(blocks) - amount - 1 if amount < 0 or amount >= len(blocks): log.critical_error( "Incorrect amount of blocks to decrypt: {} (have to be in [0,{}]". format(amount, len(blocks) - 1)) log.info("Will decrypt {} block(s)".format(len(blocks) - 1 - amount)) # add known plaintext plaintext = bytes(b'') position_known = 0 chars_decoded = 0 if known_plaintext: is_correct = False plaintext = known_plaintext blocks_decoded = len(plaintext) // block_size chars_decoded = len(plaintext) % block_size if blocks_decoded == len(blocks) - 1: log.debug("Nothing decrypted, known plaintext long enough") return plaintext if blocks_decoded > len(blocks) - 1: log.critical_error( "Too long known plaintext ({} blocks)".format(blocks_decoded)) if blocks_decoded != 0: blocks = blocks[:-blocks_decoded] position_known = chars_decoded log.info("Have known plaintext, skip {} block(s) and {} bytes".format( blocks_decoded, chars_decoded)) # start decryption for count_block in range(len(blocks) - 1, amount, -1): """ Blocks from the last to the second (all except iv) """ log.info("Block no. {}".format(count_block)) payload_prefix = bytes(b''.join(blocks[:count_block - 1])) payload_modify = blocks[count_block - 1] payload_decrypt = blocks[count_block] if chars_decoded != 0: # we know some chars, so modify previous block payload_modify = payload_modify[:-chars_decoded] +\ xor(plaintext[:chars_decoded], payload_modify[-chars_decoded:], bytes([chars_decoded + 1])) chars_decoded = 0 position = block_size - 1 - position_known position_known = 0 while position >= 0: """ Every position in block, from the end """ log.debug("Position: {}".format(position)) found_correct_char = False for guess_char in range(256): modified = payload_modify[:position] + bytes( [guess_char]) + payload_modify[position + 1:] payload = bytes(b''.join( [payload_prefix, modified, payload_decrypt])) iv = payload[:block_size] payload = payload[block_size:] log.debug(print_chunks(chunks(iv + payload, block_size))) correct = padding_oracle(payload=payload, iv=iv) if correct: """ oracle returns True """ padding = block_size - position # sent ciphertext decoded to that padding decrypted_char = bytes( [payload_modify[position] ^ guess_char ^ padding]) if is_correct: """ If we didn't send original ciphertext, then we have found original padding value. Otherwise keep searching and if won't find any other correct char - padding is \x01 """ if guess_char == blocks[-2][-1]: log.debug( "Skip this guess char ({})".format(guess_char)) continue dc = int(decrypted_char[0]) log.info( "Found padding value for correct ciphertext: {}". format(dc)) if dc == 0 or dc > block_size: log.critical_error( "Found bad padding value (given ciphertext may not be correct)" ) plaintext = decrypted_char * dc payload_modify = payload_modify[:-dc] + xor( payload_modify[-dc:], decrypted_char, bytes([dc + 1])) position = position - dc + 1 is_correct = False else: """ abcd efgh ijkl o|guess_char|xy || 1234 5678 9tre qwer - ciphertext what ever itma ybex || xyzw rtua lopo k|\x03|\x03\x03 - plaintext abcd efgh ijkl |guess_char|wxy || 1234 5678 9tre qwer - next round ciphertext some thin gels eheh || xyzw rtua lopo guessing|\x04\x04\x04 - next round plaintext """ if position == block_size - 1: """ if we decrypt first byte, check if we didn't hit other padding than \x01 """ payload = iv + payload payload = payload[:-block_size - 2] + bytes( b'A') + payload[-block_size - 1:] iv = payload[:block_size] payload = payload[block_size:] correct = padding_oracle(payload=payload, iv=iv) if not correct: log.debug("Hit false positive, guess char({})". format(guess_char)) continue payload_modify = payload_modify[:position] + xor( bytes([guess_char]) + payload_modify[position + 1:], bytes([padding]), bytes([padding + 1])) plaintext = decrypted_char + plaintext found_correct_char = True log.debug( "Guessed char(\\x{:02x}), decrypted char(\\x{:02x})". format(guess_char, decrypted_char[0])) log.debug("Plaintext: {}".format(plaintext)) log.info("Plaintext(hex): {}".format(b2h(plaintext))) break position -= 1 if found_correct_char is False: if is_correct: padding = 0x01 payload_modify = payload_modify[:position + 1] + xor( payload_modify[position + 1:], bytes([padding]), bytes([padding + 1])) plaintext = bytes(b"\x01") is_correct = False else: log.critical_error( "Can't find correct padding (oracle function return False 256 times)" ) log.success("Decrypted(hex): {}".format(b2h(plaintext))) return plaintext