Exemplo n.º 1
0
def small_e_msg(key, ciphertexts=None, max_times=100):
    """If both e and plaintext are small, ciphertext may exceed modulus only a little

    Args:
        key(RSAKey): with small e, at least one ciphertext
        ciphertexts(list)
        max_times(int): how many times plaintext**e exceeded modulus maximally

    Returns:
        list: recovered plaintexts
    """
    ciphertexts = get_mutable_texts(key, ciphertexts)
    recovered = []
    for ciphertext in ciphertexts:
        log.debug("Find msg for ciphertext {}".format(ciphertext))
        times = 0
        for k in range(max_times):
            msg, is_correct = gmpy2.iroot(ciphertext + times, key.e)
            if is_correct and gmpy2.powmod(msg, key.e, key.n) == ciphertext:
                msg = int(msg)
                log.success("Found msg: {}, times=={}".format(
                    i2h(msg), times // key.n))
                recovered.append(msg)
                break
            times += key.n
    return recovered
Exemplo n.º 2
0
def common_primes(keys):
    """Find common prime in keys modules

    Args:
        keys(list): RSAKeys

    Returns:
        list: RSAKeys for which factorization of n was found
    """
    priv_keys = []
    for pair in itertools.combinations(keys, 2):
        prime = gmpy2.gcd(pair[0].n, pair[1].n)
        if prime != 1:
            log.success("Found common prime in: {}, {}".format(
                pair[0].identifier, pair[1].identifier))
            for key_no in range(2):
                if pair[key_no] not in priv_keys:
                    d = int(
                        invmod(pair[key_no].e,
                               (prime - 1) * (pair[key_no].n // prime - 1)))
                    new_key = RSAKey.construct(
                        int(pair[key_no].n),
                        int(pair[key_no].e),
                        int(d),
                        identifier=pair[key_no].identifier + '-private')
                    new_key.texts = pair[key_no].texts[:]
                    priv_keys.append(new_key)
                else:
                    log.debug("Key {} already in priv_keys".format(
                        pair[key_no].identifier))
    return priv_keys
Exemplo n.º 3
0
def _prepare_ciphertexts(key=None, ciphertext=None, ciphertexts=None):
    """Helper function used in various rsa functions.
        Update key (if provided) and yield ciphertexts to crack

    Yields:
        tuple: (text_no, ciphertext_to_crack)
    """
    if ciphertexts is None:
        ciphertexts = []
    if ciphertext is not None:
        ciphertexts.append(ciphertext)

    if key is None:
        for cipher in ciphertexts:
            yield None, cipher
    else:
        for cipher in ciphertexts:
            matching_texts = [(text_no, text)
                              for text_no, text in enumerate(key.texts)
                              if 'cipher' in text and text['cipher'] == cipher]
            if len(matching_texts) == 0:
                key.add_ciphertext(cipher)
                yield len(key.texts) - 1, key.texts[len(key.texts) -
                                                    1]['cipher']
            else:
                assert len(matching_texts) == 1
                text_no, text = matching_texts
                if 'plain' in text:
                    log.success(
                        "Plaintext for ciphertext {cipher} already known: {plain}!"
                        .format(**text))
                else:
                    yield text_no, text[text_no]['cipher']
Exemplo n.º 4
0
def fake_ciphertext(new_plaintext,
                    padding_oracle=None,
                    decryption_oracle=None,
                    block_size=16):
    """Make ciphertext that will decrypt to given plaintext
    Give padding_oracle or decryption_oracle (or both)

    Args:
        new_plaintext(string): with padding
        padding_oracle(function/None)
        decryption_oracle(function/None): maximum one block to decrypt
        block_size(int)

    Returns:
        fake_ciphertext(string): fake ciphertext that will decrypt to new_plaintext
    """
    _check_oracles(padding_oracle=padding_oracle,
                   decryption_oracle=decryption_oracle,
                   block_size=block_size)

    if block_size % 8 != 0:
        log.critical_error("Incorrect block size: {}".format(block_size))

    log.info("Start fake ciphertext")
    ciphertext = bytes(b'A' * (len(new_plaintext) + block_size))

    # prepare blocks
    blocks = chunks(ciphertext, block_size)
    new_pl_blocks = chunks(new_plaintext, block_size)
    if len(new_pl_blocks) != len(blocks) - 1:
        log.critical_error(
            "Wrong new plaintext length({}), should be {}".format(
                len(new_plaintext), block_size * (len(blocks) - 1)))
    new_ct_blocks = list(blocks)

    # add known plaintext

    for count_block in range(len(blocks) - 1, 0, -1):
        """ Every block, modify block[count_block-1] to set block[count_block] """
        log.info("Block no. {}".format(count_block))

        ciphertext_to_decrypt = bytes(b''.join(new_ct_blocks[:count_block +
                                                             1]))
        original_plaintext = decrypt(ciphertext_to_decrypt,
                                     padding_oracle=padding_oracle,
                                     decryption_oracle=decryption_oracle,
                                     block_size=block_size,
                                     amount=1,
                                     is_correct=False)
        log.info("Set block no. {}".format(count_block))
        new_ct_blocks[count_block - 1] = xor(blocks[count_block - 1],
                                             original_plaintext,
                                             new_pl_blocks[count_block - 1])

    fake_ciphertext_res = bytes(b''.join(new_ct_blocks))
    log.success("Fake ciphertext(hex): {}".format(b2h(fake_ciphertext_res)))
    return fake_ciphertext_res
Exemplo n.º 5
0
def parity(parity_oracle, key, min_lower_bound=None, max_upper_bound=None):
    """Given oracle that returns LSB of decrypted ciphertext we can decrypt whole ciphertext
    parity_oracle function must be implemented

    Args:
        parity_oracle(callable)
        key(RSAKey): contains ciphertexts to decrypt
        min_lower_bound(None/int)
        max_upper_bound(None/int)

    Returns:
        dict: decrypted ciphertexts
        update key texts
    """

    recovered = {}
    for text_no in range(len(key.texts)):
        if 'cipher' in key.texts[text_no] and 'plain' not in key.texts[text_no]:
            cipher = key.texts[text_no]['cipher']
            log.info("Decrypting {}".format(cipher))
            two_encrypted = key.encrypt(2)

            counter = lower_bound = numerator = 0
            upper_bound = key.n
            denominator = 1
            while lower_bound + 1 < upper_bound:
                cipher = (two_encrypted * cipher) % key.n
                denominator *= 2
                numerator *= 2
                counter += 1

                if max_upper_bound is not None and upper_bound > max_upper_bound:
                    is_odd = 0
                else:
                    # todo: check below
                    if min_lower_bound is not None and lower_bound < min_lower_bound:
                        is_odd = 1
                    else:
                        is_odd = parity_oracle(cipher)

                if is_odd:  # plaintext > n/(2**counter)
                    numerator += 1
                lower_bound = (key.n * numerator) // denominator
                upper_bound = (key.n * (numerator + 1)) // denominator

                log.debug("{} {} [{}, {}]".format(counter, is_odd,
                                                  int(lower_bound),
                                                  int(upper_bound)))
                log.debug("{}/{}  -  {}/{}\n".format(numerator, denominator,
                                                     numerator + 1,
                                                     denominator))
            log.success("Decrypted: {}".format(i2h(upper_bound)))
            key.texts[text_no]['plain'] = upper_bound
            recovered[text_no] = upper_bound
    return recovered
Exemplo n.º 6
0
def iv_as_key(ciphertext,
              plaintext,
              padding_oracle=None,
              decryption_oracle=None,
              block_size=16):
    """If iv is used as key, we can recover it using decryption oracle, padding oracle or
    known plaintext (if first ciphertext block is repeated)

    Args:
        ciphertext(string): first block must be AES.encrypt(iv xor plaintext[0])
        plaintext(string): with padding
        padding_oracle(function/None)
        decryption_oracle(function/None)
        block_size(int)

    Returns:
        string: key (== iv)
    """
    key = None
    ciphertext = chunks(ciphertext, block_size)
    plaintext = chunks(plaintext, block_size)

    try:
        position_second = ciphertext.index(ciphertext[0], 1)
        log.debug("Position of the same block as the first is {}".format(
            position_second))
        key = xor(plaintext[0], ciphertext[position_second - 1],
                  plaintext[position_second])
    except ValueError:
        log.debug(
            "first ciphertext block is not repeated, will use decryption/padding oracle"
        )

    if key is None:
        iv = bytes(b'A' * block_size)
        iv_xor_plaintext0 = decrypt(ciphertext[0],
                                    padding_oracle=padding_oracle,
                                    decryption_oracle=decryption_oracle,
                                    iv=iv,
                                    block_size=block_size,
                                    is_correct=False,
                                    amount=1)
        iv_xor_plaintext0 = xor(iv_xor_plaintext0, iv)
        key = xor(iv_xor_plaintext0, plaintext[0])
    log.success("Key(hex): {}".format(b2h(key)))
    return key
Exemplo n.º 7
0
    def compute_params(s, m=None, a=None, b=None):
        """Compute parameters and initial seed for LCG prng
        next_state = a*state + b mod m

        Args:
            s(list): subsequent outputs from LCG oracle starting with seed
            m(int/None)
            a(int/None)
            b(int/None)

        Returns:
            a, b, m(int)
        """
        if m is None:
            t = [s[n + 1] - s[n] for n in range(len(s) - 1)]
            u = [abs(t[n + 2] * t[n] - t[n + 1]**2) for n in range(len(t) - 2)]
            m = gcd(*u)
            log.success("m = {}".format(m))

        if a is None:
            if gcd(s[1] - s[0], m) == 1:
                a = (s[2] - s[1]) * invmod(s[1] - s[0], m)
            elif gcd(s[2] - s[0], m) == 1:
                a = (s[3] - s[1]) * invmod(s[2] - s[0], m)
            else:
                log.critical_error("a not found")
            a = a % m
            log.success("a = {}".format(a))

        if b is None:
            b = (s[1] - s[0] * a) % m
            log.success("b = {}".format(b))

        return a, b, m
Exemplo n.º 8
0
def test_manger():
    keys = [key_64, key_256, key_1024, key_2048]
    for key in keys:
        manger_padding_oracle_calls = [0]
        plaintext = randint(2, key.n) >> 8

        ciphertext = h2b(
            subprocess.check_output([
                "python", rsa_oracles_path, "encrypt", key.identifier,
                i2h(plaintext)
            ]).strip().decode())

        msgs_recovered = manger(
            oaep_padding_oracle,
            key.publickey(),
            ciphertext,
            oracle_key=key,
            manger_padding_oracle_calls=manger_padding_oracle_calls)
        log.success('For keysize {}: oaep_padding_oracle_calls = {}'.format(
            key.size, manger_padding_oracle_calls[0]))
        assert msgs_recovered[0] == plaintext
        key.clear_texts()
Exemplo n.º 9
0
def find_prefix_suffix_size(encryption_oracle, block_size=16):
    """Determine prefix and suffix sizes if ecb mode, sizes must be constant
    Rarely may fail (if random data that are send unhappily matches prefix/suffix)

    Args:
        encryption_oracle(callable)
        block_size(int)

    Returns:
        tuple(int,int): prefix_size, suffix_size
    """
    blocks_to_send = 5
    payload = random_bytes(1) * (blocks_to_send * block_size)
    enc_chunks = chunks(encryption_oracle(payload), block_size)
    log.debug("Encryption of length {}".format(blocks_to_send * block_size))
    log.debug(print_chunks(enc_chunks))

    for position_start in range(len(enc_chunks) - 1):
        if enc_chunks[position_start] == enc_chunks[position_start + 1]:
            for y in range(2, blocks_to_send - 1):
                if enc_chunks[position_start] != enc_chunks[position_start + y]:
                    break
            else:
                log.success("Controlled payload start at chunk {}".format(position_start))
                break
    else:
        log.critical_error("Position of controlled chunks not found")

    log.info('Finding prefix')
    changed_char = bytes([(payload[0] - 1)%256])
    for aligned_bytes in range(block_size):
        payload_new = payload[:aligned_bytes] + changed_char + payload[aligned_bytes+1:]
        enc_chunks_new = chunks(encryption_oracle(payload_new), block_size)
        log.debug(print_chunks(chunks(payload_new, block_size)))
        log.debug(print_chunks(enc_chunks_new))
        if enc_chunks_new[position_start] != enc_chunks[position_start]:
            prefix_size = position_start*block_size - aligned_bytes
            log.success("Prefix size: {}".format(prefix_size))
            break
    else:
        log.critical_error("Size of prefix not found")

    log.info('Finding suffix')
    payload = random_bytes(1) * (block_size - (prefix_size % block_size))  # align to block_size
    encrypted = encryption_oracle(payload)
    suffix_size = len(encrypted) - len(payload) - prefix_size
    while True:
        payload += random_bytes(1)
        suffix_size -= 1
        if len(encryption_oracle(payload)) > len(encrypted):
            log.success("Suffix size: {}".format(suffix_size))
            break
    else:
        log.critical_error("Size of suffix not found")

    return prefix_size, suffix_size
Exemplo n.º 10
0
def decrypt(encryption_oracle, constant=True, block_size=16, prefix_size=None, secret_size=None,
            alphabet=None):
    """Given encryption oracle which produce ecb(prefix || our_input || secret), find secret
    
    Args:
        encryption_oracle(callable)
        constant(bool): True if prefix have constant length (secret must have constant length)
        block_size(int/None)
        prefix_size(int/None)
        secret_size(int/None)
        alphabet(string): plaintext space
    
    Returns:
        secret(string)
    """
    log.debug("Start decrypt function")
    if not alphabet:
        alphabet = bytes(string.printable.encode())

    if not block_size:
        block_size = find_block_size(encryption_oracle, constant)

    if constant:
        log.debug("constant == True")
        if not prefix_size or not secret_size:
            prefix_size, secret_size = find_prefix_suffix_size(encryption_oracle, block_size)

        """Start decrypt"""
        secret = bytes(b'')
        aligned_bytes = random_bytes(1) * (block_size - (prefix_size % block_size))
        if len(aligned_bytes) == block_size:
            aligned_bytes = bytes(b'')

        aligned_bytes_suffix = random_bytes(1) * (block_size - (secret_size % block_size))
        if len(aligned_bytes_suffix) == block_size:
            aligned_bytes_suffix = bytes(b'')

        block_to_find_position = -1
        controlled_block_position = (prefix_size+len(aligned_bytes)) // block_size

        while len(secret) < secret_size:
            if (len(secret)+1) % block_size == 0:
                block_to_find_position -= 1
            payload = aligned_bytes + aligned_bytes_suffix + random_bytes(1) + secret
            enc_chunks = chunks(encryption_oracle(payload), block_size)
            block_to_find = enc_chunks[block_to_find_position]

            log.debug("To guess at position {}:".format(block_to_find_position))
            log.debug("Plain: " + print_chunks(chunks(bytes(b'P'*prefix_size) + payload + bytes(b'S'*secret_size), block_size)))
            log.debug("Encry: " + print_chunks(enc_chunks)+"\n")

            for guessed_char in range(256):
                guessed_char = bytes([guessed_char])
                payload = aligned_bytes + add_padding(guessed_char + secret, block_size)
                enc_chunks = chunks(encryption_oracle(payload), block_size)

                log.debug("Plain: " + print_chunks(chunks(bytes(b'P'*prefix_size) + payload + bytes(b'S'*secret_size), block_size)))
                log.debug("Encry: " + print_chunks(enc_chunks)+"\n")
                if block_to_find == enc_chunks[controlled_block_position]:
                    secret = guessed_char + secret
                    log.debug("Found char, secret={}".format(repr(secret)))
                    break
            else:
                log.critical_error("Char not found, try change alphabet. Secret so far: {}".format(repr(secret)))
        log.success("Secret(hex): {}".format(b2h(secret)))
        return secret
    else:
        log.debug("constant == False")
Exemplo n.º 11
0
def bleichenbacher_pkcs15(pkcs15_padding_oracle,
                          key,
                          ciphertext=None,
                          incremental_blinding=False,
                          **kwargs):
    """Given oracle that checks if ciphertext decrypts to some valid plaintext with PKCS1.5 padding
    we can decrypt whole ciphertext
    pkcs15_padding_oracle function must be implemented

    http://archiv.infsec.ethz.ch/education/fs08/secsem/bleichenbacher98.pdf
    https://www.dsi.unive.it/~focardi/RSA-padding-oracle/#eq5

    Note that this attack is very slow. Approximate number of main loop iterations == key's bit length

    Args:
        pkcs15_padding_oracle(callable)
        key(RSAKey): contains ciphertexts to decrypt
        incremental_blinding(bool): if ciphertext is not pkcs confirming we need to blind it.
                                    this may be done using random or incremental values

    Returns:
        dict: decrypted ciphertexts
        update key texts
    """
    def ceil(a, b):
        return a // b + (a % b > 0)

    def floor(a, b):
        return a // b

    def insert_interval(M, lb, ub):
        lo, hi = 0, len(M)
        while lo < hi:
            mid = (lo + hi) // 2
            if M[mid][0] < lb:
                lo = mid + 1
            else:
                hi = mid
        # insert it
        M.insert(lo, (lb, ub))

        # lb inside previous interval
        if lo > 0 and M[lo - 1][1] >= lb:
            lb = min(lb, M[lo - 1][0])
            M[lo] = (lb, M[lo][1])
            del M[lo - 1]
            lo -= 1

        # remove covered intervals
        i = lo + 1
        to_remove_first = i
        to_remove_last = lo
        while i < len(M) and M[i][0] <= ub:
            to_remove_last += 1
            i += 1
        if to_remove_last > lo:
            new_ub = max(ub, M[to_remove_last][1])
            M[lo] = (M[lo][0], new_ub)
            del M[to_remove_first:to_remove_last + 1]

    def update_intervals(M, B, s, n):
        # step 3
        M2 = []
        for a, b in M:
            r_min = ceil(a * s - 3 * B + 1, n)
            r_max = floor(b * s - 2 * B, n)
            for r in range(r_min, r_max + 1):
                lb = max(a, ceil(2 * B + r * n, s))
                ub = min(b, floor(3 * B - 1 + r * n, s))
                insert_interval(M2, lb, ub)
        del M[:]
        return M2

    def find_si(si_start, si_max=None):
        si_new = si_start
        while si_max is None or si_new < si_max:
            cipheri = (cipher_blinded * gmpy2.powmod(si_new, e, n)) % n
            if pkcs15_padding_oracle(cipheri, **kwargs):
                return si_new
            si_new += 1
        return None

    recovered = {}
    for text_no, cipher in _prepare_ciphertexts(key=key,
                                                ciphertext=ciphertext):
        log.info("Decrypting {}".format(cipher))

        n = key.n
        e = key.e
        B = gmpy2.pow(2, key.size - 16)

        # step 1
        log.debug('Blinding the ciphertext (to make it PKCS1.5 confirming)')
        i = 0
        si = 1
        cipher_blinded = cipher
        while not pkcs15_padding_oracle(cipher_blinded, **kwargs):
            # the paper says to draw it, but seems like incrementation sometimes run faster
            if incremental_blinding:
                si += 1
            else:
                si = random.randint(2, 1 << (key.size - 16))
            cipher_blinded = (cipher * gmpy2.powmod(si, e, n)) % n
        Mi = [(2 * B, 3 * B - 1)]
        s0 = si
        log.debug('Found s{}: {}'.format(i, hex(si)))
        i = 1

        plaintext = None
        while plaintext is None:
            log.debug('len(M{}): {}'.format(i - 1, len(Mi)))

            if i == 1:
                # step 2.a
                si = find_si(si_start=ceil(n, (3 * B)))
                log.debug('Found s{}: {}'.format(i, hex(si)))

            elif len(Mi) > 1:
                # step 2.b
                si = find_si(si_start=si + 1)

            elif len(Mi) == 1 and Mi[0][0] != Mi[0][1]:
                # step 2.c
                a, b = Mi[0]
                ri = ceil(2 * (b * si - 2 * B), n)
                si = None
                while si is None:
                    si_min = ceil(2 * B + ri * n, b)
                    si_max = ceil(3 * B + ri * n, a)
                    si = find_si(si_start=si_min, si_max=si_max)
                    ri += 1

            else:
                log.error(
                    "Hm, something strange happend. Len(M{}) = {}".format(
                        i, len(Mi)))
                return None

            # step 3
            Mi = update_intervals(Mi, B, si, n)

            # step 4
            if len(Mi) == 1 and Mi[0][0] == Mi[0][1]:
                plaintext = Mi[0][0]
                if s0 != 1:
                    plaintext = (plaintext * invmod(s0, n)) % n
                log.success("Interval narrowed to one value")
                log.success("plaintext = {}".format(hex(plaintext)))
            else:
                i += 1

        recovered[text_no] = plaintext
        key.texts[text_no]['plain'] = recovered[text_no]

    return recovered
Exemplo n.º 12
0
def blinding(key, signing_oracle=None, decryption_oracle=None):
    """Perform signature/ciphertext blinding attack

    Args:
        key(RSAKey): with at least one plaintext(to sign) or ciphertext(to decrypt)
        signing_oracle(callable)
        decryption_oracle(callable)

    Returns:
        dict: {index: signature/plaintext, index2: signature/plaintext}
        update key texts
    """
    if not signing_oracle and not decryption_oracle:
        log.critical_error("Give one of signing_oracle or decryption_oracle")
    if signing_oracle and decryption_oracle:
        log.critical_error(
            "Give only one of signing_oracle or decryption_oracle")

    recovered = {}
    if signing_oracle:
        log.debug("Have signing_oracle")
        for text_no in range(len(key.texts)):
            if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[
                    text_no]:
                log.info("Blinding signature of plaintext no {} ({})".format(
                    text_no, i2h(key.texts[text_no]['plain'])))

                blind = random.randint(2, 100)
                blind_enc = key.encrypt(blind)
                blinded_plaintext = (key.texts[text_no]['plain'] *
                                     blind_enc) % key.n
                blinded_signature = signing_oracle(blinded_plaintext)
                if not blinded_signature:
                    log.critical_error(
                        "Error during call to signing_oracle({})".format(
                            blinded_plaintext))
                signature = (invmod(blind, key.n) * blinded_signature) % key.n
                key.texts[text_no]['cipher'] = signature
                recovered[text_no] = signature
                log.success("Signature: {}".format(signature))

    if decryption_oracle:
        log.debug("Have decryption_oracle")
        for text_no in range(len(key.texts)):
            if 'cipher' in key.texts[text_no] and 'plain' not in key.texts[
                    text_no]:
                log.info("Blinding ciphertext no {} ({})".format(
                    text_no, key.texts[text_no]['cipher']))
                blind = random.randint(2, 100)
                blind_enc = key.encrypt(blind)
                blinded_ciphertext = (key.texts[text_no]['cipher'] *
                                      blind_enc) % key.n
                blinded_plaintext = decryption_oracle(blinded_ciphertext)
                if not blinded_plaintext:
                    log.critical_error(
                        "Error during call to decryption_oracle({})".format(
                            blinded_plaintext))
                plaintext = (invmod(blind, key.n) * blinded_plaintext) % key.n
                key.texts[text_no]['plain'] = plaintext
                recovered[text_no] = plaintext
                log.success("Plaintext: {}".format(plaintext))

    return recovered
Exemplo n.º 13
0
def hastad(keys, ciphertexts=None):
    """Hastad's broadcast attack (small public exponent)
    Given at least e keys with public exponent equals to e and ciphertexts of the same plaintext,
    plaintext can be efficiently recovered

    Args:
        keys(list): RSAKeys, all with same public exponent e, len(keys) >= e,
                    every key with only one ciphertext
        ciphertexts(list/None): if not None, use this ciphertexts

    Returns:
        NoneType/int: None on failure, recovered plaintext otherwise
        update keys texts
    """
    e = keys[0].e
    if len(keys) < e:
        log.critical_error("Not enough keys, e={}".format(e))

    if ciphertexts is None:
        for key in keys:
            if len(key.texts) != 1:
                log.info(
                    "Key have more than one ciphertext, using the first one(key=={})"
                    .format(key.identifier))
            if 'cipher' not in key.texts[0]:
                log.critical_error("key {} doesn't have ciphertext".format(
                    key.identifier))

        # prepare ciphertexts and correct_keys lists
        ciphertexts, modules, correct_keys = [], [], []
        for key in keys:
            # get only first ciphertext (if exists)
            if key.n not in modules and key.texts[0][
                    'cipher'] not in ciphertexts:
                if key.e == e:
                    modules.append(key.n)
                    correct_keys.append(key)
                    ciphertexts.append(key.texts[0]['cipher'])
                else:
                    log.info("Key {} have different e(={})".format(
                        key.identifier, key.e))
    else:
        if len(ciphertexts) != len(keys):
            log.critical_error("len(ciphertexts) != len(keys)")
        modules = [key.n for key in keys]
        correct_keys = keys

    # check if we have enough ciphertexts
    if len(modules) < e:
        log.info(
            "Not enough keys with unique modulus and ciphertext, e={}, len(modules)={}"
            .format(e, len(modules)))
        log.info("Checking for simple roots (small_e_msg)")
        for one_key in correct_keys:
            recovered_plaintexts = small_e_msg(one_key)
            if len(recovered_plaintexts) > 0:
                log.success("Found plaintext: {}".format(
                    recovered_plaintexts[0]))
                return recovered_plaintexts[0]

    if len(modules) > e:
        log.debug("Number of modules/ciphertexts larger than e")
        modules = modules[:e]
        ciphertexts = ciphertexts[:e]

    # actual Hastad
    result = crt(ciphertexts, modules)
    plaintext, correct = gmpy2.iroot(result, e)
    if correct:
        plaintext = int(plaintext)
        log.success("Found plaintext: {}".format(plaintext))
        for one_key in correct_keys:
            one_key.texts[0]['plain'] = plaintext
        return plaintext
    else:
        log.debug("Plaintext wasn't {}-th root")
        log.debug("result (from crt) = {}".format(e, result))
        log.debug("plaintext ({}-th root of result) = {}".format(e, plaintext))
        return None
Exemplo n.º 14
0
def manger(oaep_padding_oracle, key, ciphertext=None, **kwargs):
    """Given oracle that checks if ciphertext decrypts to some valid plaintext with OAEP padding
    we can decrypt whole ciphertext
    oaep_padding_oracle function must be implemented

    https://iacr.org/archive/crypto2001/21390229.pdf

    Args:
        oaep_padding_oracle(callable)
        key(RSAKey): contains ciphertexts to decrypt

    Returns:
        dict: decrypted ciphertexts
        update key texts
    """
    def ceil(a, b):
        return a // b + (a % b > 0)

    def floor(a, b):
        return a // b

    recovered = {}
    for text_no, cipher in _prepare_ciphertexts(key=key,
                                                ciphertext=ciphertext):
        log.info("Decrypting {}".format(cipher))

        n = key.n
        e = key.e
        B = pow(2, key.size - 8)

        # step 1
        log.debug('step 1')
        f1 = 2
        cipheri = (cipher * gmpy2.powmod(f1, e, n)) % n
        while oaep_padding_oracle(cipheri, **kwargs):
            f1 *= 2
            cipheri = (cipher * gmpy2.powmod(f1, e, n)) % n

        log.debug('step 1 done')
        log.debug('Found f1: {}'.format(hex(f1)))

        # step 2
        log.debug('step 2')
        f1_half = f1 // 2
        f2 = int(floor(n + B, B)) * f1_half
        cipheri = (cipher * gmpy2.powmod(f2, e, n)) % n
        while not oaep_padding_oracle(cipheri, **kwargs):
            f2 += f1_half
            cipheri = (cipher * gmpy2.powmod(f2, e, n)) % n

        log.debug('step 2 done')
        log.debug('Found f2: {}'.format(hex(f2)))

        # step 3
        log.debug('step 3')
        m_min, m_max = ceil(n, f2), floor(n + B, f2)
        while m_min < m_max:
            log.debug(hex(m_max - m_min))
            f_tmp = floor(2 * B, m_max - m_min)
            i = floor(f_tmp * m_min, n)
            f3 = ceil(i * n, m_min)

            cipheri = (cipher * gmpy2.powmod(f3, e, n)) % n
            if oaep_padding_oracle(cipheri, **kwargs):
                m_max = floor(i * n + B, f3)
            else:
                m_min = ceil(i * n + B, f3)

        log.debug('step 3 done')
        log.debug('m_min = {}'.format(hex(m_min)))
        log.debug('m_max = {}'.format(hex(m_max)))
        plaintext = m_min

        log.success("plaintext = {}".format(hex(plaintext)))
        recovered[text_no] = plaintext
        key.texts[text_no]['plain'] = recovered[text_no]

    return recovered
Exemplo n.º 15
0
def decrypt(ciphertext,
            padding_oracle=None,
            decryption_oracle=None,
            iv=None,
            block_size=16,
            is_correct=True,
            amount=0,
            known_plaintext=None,
            async_calls=False):
    """Decrypt ciphertext
    Give padding_oracle or decryption_oracle (or both)

    Args:
        ciphertext(string): to decrypt
        padding_oracle(function/None)
        decryption_oracle(function/None)
        iv(string): if not specified, first block of ciphertext is treated as iv
        block_size(int)
        is_correct(bool): set if ciphertext will decrypt to something with correct padding
        amount(int): how much blocks decrypt (counting from last), zero (default) means all
        known_plaintext(string): with padding, from end (aligned to end of ciphertext)
        async_calls(bool): make asynchronous calls to oracle (not implemented yet)

    Returns:
        plaintext(string): with padding
    """
    _check_oracles(padding_oracle=padding_oracle,
                   decryption_oracle=decryption_oracle,
                   block_size=block_size)

    if block_size % 8 != 0:
        log.critical_error("Incorrect block size: {}".format(block_size))

    if len(ciphertext) % block_size != 0:
        log.critical_error("Incorrect ciphertext length: {}".format(
            len(ciphertext)))

    if decryption_oracle:
        if iv:
            ciphertext = iv + ciphertext
        blocks = chunks(ciphertext, block_size)
        plaintext = bytes(b'')
        for position in range(len(blocks) - 1, 0, -1):
            plaintext = xor(decryption_oracle(blocks[position]),
                            blocks[position - 1]) + plaintext
            log.info("Plaintext(hex): {}".format(b2h(plaintext)))
            if amount != 0 and len(plaintext) == amount * block_size:
                break
        log.success("Decrypted(hex): {}".format(b2h(plaintext)))
        return plaintext

    log.info("Start cbc padding oracle")
    log.debug(print_chunks(chunks(ciphertext, block_size)))

    # prepare blocks
    blocks = chunks(ciphertext, block_size)
    if iv:
        if len(iv) % block_size != 0:
            log.critical_error("Incorrect iv length: {}".format(len(iv)))
        log.info("Set iv")
        blocks.insert(0, iv)

    if amount != 0:
        amount = len(blocks) - amount - 1
    if amount < 0 or amount >= len(blocks):
        log.critical_error(
            "Incorrect amount of blocks to decrypt: {} (have to be in [0,{}]".
            format(amount,
                   len(blocks) - 1))
    log.info("Will decrypt {} block(s)".format(len(blocks) - 1 - amount))

    # add known plaintext
    plaintext = bytes(b'')
    position_known = 0
    chars_decoded = 0
    if known_plaintext:
        is_correct = False
        plaintext = known_plaintext
        blocks_decoded = len(plaintext) // block_size
        chars_decoded = len(plaintext) % block_size

        if blocks_decoded == len(blocks) - 1:
            log.debug("Nothing decrypted, known plaintext long enough")
            return plaintext
        if blocks_decoded > len(blocks) - 1:
            log.critical_error(
                "Too long known plaintext ({} blocks)".format(blocks_decoded))

        if blocks_decoded != 0:
            blocks = blocks[:-blocks_decoded]

        position_known = chars_decoded
        log.info("Have known plaintext, skip {} block(s) and {} bytes".format(
            blocks_decoded, chars_decoded))

    # start decryption
    for count_block in range(len(blocks) - 1, amount, -1):
        """ Blocks from the last to the second (all except iv) """
        log.info("Block no. {}".format(count_block))

        payload_prefix = bytes(b''.join(blocks[:count_block - 1]))
        payload_modify = blocks[count_block - 1]
        payload_decrypt = blocks[count_block]

        if chars_decoded != 0:
            # we know some chars, so modify previous block
            payload_modify = payload_modify[:-chars_decoded] +\
                             xor(plaintext[:chars_decoded], payload_modify[-chars_decoded:], bytes([chars_decoded + 1]))
            chars_decoded = 0

        position = block_size - 1 - position_known
        position_known = 0
        while position >= 0:
            """ Every position in block, from the end """
            log.debug("Position: {}".format(position))

            found_correct_char = False
            for guess_char in range(256):
                modified = payload_modify[:position] + bytes(
                    [guess_char]) + payload_modify[position + 1:]
                payload = bytes(b''.join(
                    [payload_prefix, modified, payload_decrypt]))

                iv = payload[:block_size]
                payload = payload[block_size:]
                log.debug(print_chunks(chunks(iv + payload, block_size)))

                correct = padding_oracle(payload=payload, iv=iv)
                if correct:
                    """ oracle returns True """
                    padding = block_size - position  # sent ciphertext decoded to that padding
                    decrypted_char = bytes(
                        [payload_modify[position] ^ guess_char ^ padding])

                    if is_correct:
                        """ If we didn't send original ciphertext, then we have found original padding value.
                            Otherwise keep searching and if won't find any other correct char - padding is \x01
                        """
                        if guess_char == blocks[-2][-1]:
                            log.debug(
                                "Skip this guess char ({})".format(guess_char))
                            continue

                        dc = int(decrypted_char[0])
                        log.info(
                            "Found padding value for correct ciphertext: {}".
                            format(dc))
                        if dc == 0 or dc > block_size:
                            log.critical_error(
                                "Found bad padding value (given ciphertext may not be correct)"
                            )

                        plaintext = decrypted_char * dc
                        payload_modify = payload_modify[:-dc] + xor(
                            payload_modify[-dc:], decrypted_char,
                            bytes([dc + 1]))
                        position = position - dc + 1
                        is_correct = False
                    else:
                        """ abcd efgh ijkl o|guess_char|xy || 1234 5678 9tre qwer - ciphertext
                            what ever itma ybex            || xyzw rtua lopo k|\x03|\x03\x03 - plaintext
                            abcd efgh ijkl |guess_char|wxy || 1234 5678 9tre qwer - next round ciphertext
                            some thin gels eheh            || xyzw rtua lopo guessing|\x04\x04\x04 - next round plaintext
                        """
                        if position == block_size - 1:
                            """ if we decrypt first byte, check if we didn't hit other padding than \x01 """
                            payload = iv + payload
                            payload = payload[:-block_size - 2] + bytes(
                                b'A') + payload[-block_size - 1:]
                            iv = payload[:block_size]
                            payload = payload[block_size:]
                            correct = padding_oracle(payload=payload, iv=iv)
                            if not correct:
                                log.debug("Hit false positive, guess char({})".
                                          format(guess_char))
                                continue

                        payload_modify = payload_modify[:position] + xor(
                            bytes([guess_char]) +
                            payload_modify[position + 1:], bytes([padding]),
                            bytes([padding + 1]))
                        plaintext = decrypted_char + plaintext

                    found_correct_char = True
                    log.debug(
                        "Guessed char(\\x{:02x}), decrypted char(\\x{:02x})".
                        format(guess_char, decrypted_char[0]))
                    log.debug("Plaintext: {}".format(plaintext))
                    log.info("Plaintext(hex): {}".format(b2h(plaintext)))
                    break
            position -= 1
            if found_correct_char is False:
                if is_correct:
                    padding = 0x01
                    payload_modify = payload_modify[:position + 1] + xor(
                        payload_modify[position + 1:], bytes([padding]),
                        bytes([padding + 1]))
                    plaintext = bytes(b"\x01")
                    is_correct = False
                else:
                    log.critical_error(
                        "Can't find correct padding (oracle function return False 256 times)"
                    )
    log.success("Decrypted(hex): {}".format(b2h(plaintext)))
    return plaintext