Пример #1
0
def break_reuse_key(ciphertexts, lang='English', no_of_comparisons=5, alphabet=None,
                    key_space=None, reliability=100.0):
    """Sentences xored with the same key

    Args:
        ciphertexts(list): texts xored with the same key
        lang(string): key in frequencies dict
        no_of_comparisons(int): used during comparing by frequencies
        alphabet(string/None): plaintext space
        key_space(string/None): key space
        reliability(float): between 0 and 100, used during comparing by frequencies

    Returns:
        list: sorted (by frequencies) list of tuples (key, list(plaintexts))
    """
    if len(ciphertexts) < 2:
        log.critical_error("Too less ciphertexts")

    min_size = min(list(map(len, ciphertexts)))
    ciphertexts = [one[:min_size] for one in ciphertexts]
    log.info("Ciphertexts shrinked to {} bytes".format(min_size))

    pairs = break_repeated_key(''.join(ciphertexts), lang=lang, no_of_comparisons=no_of_comparisons, key_size=min_size,
                               alphabet=alphabet, key_space=key_space, reliability=reliability)

    res = [(pair[0], chunks(pair[1], min_size)) for pair in pairs]
    return res
Пример #2
0
def test_bleichenbacher_pkcs15():
    print("\nTest: Bleichenbacher's PKCS 1.5 Padding Oracle")

    keys = [key_64, key_256, key_1024]
    for key in keys:
        pkcs15_padding_oracle_calls = [0]  # must be mutable
        incremental_blinding = False
        if key.size < 512:
            incremental_blinding = True

        if key.size > 512:
            plaintext = randint(2, key.n) >> 16
            plaintext |= 0x0002 << (key.size - 16)
        else:
            plaintext = randint(2, key.n)
        ciphertext = h2b(
            subprocess.check_output([
                "python", rsa_oracles_path, "encrypt", key.identifier,
                i2h(plaintext)
            ]).strip().decode())

        msgs_recovered = bleichenbacher_pkcs15(
            pkcs15_padding_oracle,
            key.publickey(),
            ciphertext,
            incremental_blinding=incremental_blinding,
            oracle_key=key,
            pkcs15_padding_oracle_calls=pkcs15_padding_oracle_calls)
        log.info('For keysize {}: pkcs15_padding_oracle_calls = {}'.format(
            key.size, pkcs15_padding_oracle_calls[0]))
        assert msgs_recovered[0] == plaintext
        key.clear_texts()
Пример #3
0
def find_prefix_suffix_size(encryption_oracle, block_size=16):
    """Determine prefix and suffix sizes if ecb mode, sizes must be constant
    Rarely may fail (if random data that are send unhappily matches prefix/suffix)

    Args:
        encryption_oracle(callable)
        block_size(int)

    Returns:
        tuple(int,int): prefix_size, suffix_size
    """
    blocks_to_send = 5
    payload = random_bytes(1) * (blocks_to_send * block_size)
    enc_chunks = chunks(encryption_oracle(payload), block_size)
    log.debug("Encryption of length {}".format(blocks_to_send * block_size))
    log.debug(print_chunks(enc_chunks))

    for position_start in range(len(enc_chunks) - 1):
        if enc_chunks[position_start] == enc_chunks[position_start + 1]:
            for y in range(2, blocks_to_send - 1):
                if enc_chunks[position_start] != enc_chunks[position_start + y]:
                    break
            else:
                log.success("Controlled payload start at chunk {}".format(position_start))
                break
    else:
        log.critical_error("Position of controlled chunks not found")

    log.info('Finding prefix')
    changed_char = bytes([(payload[0] - 1)%256])
    for aligned_bytes in range(block_size):
        payload_new = payload[:aligned_bytes] + changed_char + payload[aligned_bytes+1:]
        enc_chunks_new = chunks(encryption_oracle(payload_new), block_size)
        log.debug(print_chunks(chunks(payload_new, block_size)))
        log.debug(print_chunks(enc_chunks_new))
        if enc_chunks_new[position_start] != enc_chunks[position_start]:
            prefix_size = position_start*block_size - aligned_bytes
            log.success("Prefix size: {}".format(prefix_size))
            break
    else:
        log.critical_error("Size of prefix not found")

    log.info('Finding suffix')
    payload = random_bytes(1) * (block_size - (prefix_size % block_size))  # align to block_size
    encrypted = encryption_oracle(payload)
    suffix_size = len(encrypted) - len(payload) - prefix_size
    while True:
        payload += random_bytes(1)
        suffix_size -= 1
        if len(encryption_oracle(payload)) > len(encrypted):
            log.success("Suffix size: {}".format(suffix_size))
            break
    else:
        log.critical_error("Size of suffix not found")

    return prefix_size, suffix_size
Пример #4
0
def parity(parity_oracle, key, min_lower_bound=None, max_upper_bound=None):
    """Given oracle that returns LSB of decrypted ciphertext we can decrypt whole ciphertext
    parity_oracle function must be implemented

    Args:
        parity_oracle(callable)
        key(RSAKey): contains ciphertexts to decrypt
        min_lower_bound(None/int)
        max_upper_bound(None/int)

    Returns:
        dict: decrypted ciphertexts
        update key texts
    """

    recovered = {}
    for text_no in range(len(key.texts)):
        if 'cipher' in key.texts[text_no] and 'plain' not in key.texts[text_no]:
            cipher = key.texts[text_no]['cipher']
            log.info("Decrypting {}".format(cipher))
            two_encrypted = key.encrypt(2)

            counter = lower_bound = numerator = 0
            upper_bound = key.n
            denominator = 1
            while lower_bound + 1 < upper_bound:
                cipher = (two_encrypted * cipher) % key.n
                denominator *= 2
                numerator *= 2
                counter += 1

                if max_upper_bound is not None and upper_bound > max_upper_bound:
                    is_odd = 0
                else:
                    # todo: check below
                    if min_lower_bound is not None and lower_bound < min_lower_bound:
                        is_odd = 1
                    else:
                        is_odd = parity_oracle(cipher)

                if is_odd:  # plaintext > n/(2**counter)
                    numerator += 1
                lower_bound = (key.n * numerator) // denominator
                upper_bound = (key.n * (numerator + 1)) // denominator

                log.debug("{} {} [{}, {}]".format(counter, is_odd,
                                                  int(lower_bound),
                                                  int(upper_bound)))
                log.debug("{}/{}  -  {}/{}\n".format(numerator, denominator,
                                                     numerator + 1,
                                                     denominator))
            log.success("Decrypted: {}".format(i2h(upper_bound)))
            key.texts[text_no]['plain'] = upper_bound
            recovered[text_no] = upper_bound
    return recovered
Пример #5
0
def faulty(key, padding=None):
    """Faulty attack against crt-rsa, Boneh-DeMillo-Lipton
    sp = padding(m)**(d % p-1) % p
    sq' = padding(m)**(d % q-1) % q <--any error during computation
    s' = crt(sp, sq') % n <-- broken signature
    s = crt(sp, sq) % n <-- correct signature
    p = gcd(s'**e - padding(m), n)
    p = gcd(s - s', n)

    Args:
        key(RSAKey): with at least one broken signature (key.texts[no]['cipher']) and corresponding
                     plaintext (key.texts[no]['plain']), or valid and broken signature
        padding(None/function): function used before signing message

    Returns:
        NoneType/RSAKey: False on failure, recovered private key otherwise
    """
    log.debug("Check signature-message pairs")
    for pair in key.texts:
        if 'plain' in pair and 'cipher' in pair:
            signature = gmpy2.mpz(pair['cipher'])
            message = pair['plain']
            if padding:
                message = padding(message)
            p = gmpy2.gcd(gmpy2.pow(signature, key.e) - message, key.n)
            if p != 1 and p != key.n:
                log.info("Found p={}".format(p))
                new_key = RSAKey.construct(key.n,
                                           key.e,
                                           p=p,
                                           identifier=key.identifier +
                                           '-private')
                new_key.texts = key.texts[:]
                return new_key

    log.debug("Check for valid-invalid signatures")
    signatures = [tmp['cipher'] for tmp in key.texts if 'cipher' in tmp]
    for pair in itertools.combinations(signatures, 2):
        p = gmpy2.gcd(pair[0] - pair[1], key.n)
        if p != 1 and p != key.n:
            log.info("Found p={}".format(p))
            new_key = RSAKey.construct(key.n,
                                       key.e,
                                       p=p,
                                       identifier=key.identifier + '-private')
            new_key.texts = key.texts[:]
            return new_key
    return None
Пример #6
0
def find_block_size(encryption_oracle, constant=True):
    """Determine block size if ecb mode

    Args:
        encryption_oracle(callable)
        constant(bool): True if prefix and suffix have constant length

    Returns:
        int
    """
    if constant:
        log.debug("constant == True")
        payload = bytes(b'A')
        size = len(encryption_oracle(payload))
        while True:
            payload += bytes(b'A')
            new_size = len(encryption_oracle(payload))
            if new_size > size:
                log.info("block_size={}".format(new_size - size))
                return new_size - size
    else:
        log.debug("constant == False")
        payload = bytes(b'A')
        max_size = len(encryption_oracle(payload))
        possible_sizes = factors(max_size)
        possible_sizes.add(max_size)
        blocks_to_send = 5

        for block_size in sorted(possible_sizes):
            """send payload of length x, so at least x-1 blocks should be identical"""
            payload = random_bytes(1) * (blocks_to_send*block_size)
            enc_chunks = chunks(encryption_oracle(payload), block_size)
            for x in range(len(enc_chunks)-1):
                if enc_chunks[x] == enc_chunks[x+1]:
                    log.debug("Found two identical blocks at {}: {}".format(x, print_chunks(enc_chunks)))
                    for y in range(2, blocks_to_send-1):
                        if enc_chunks[x] != enc_chunks[x+y]:
                            break
                    else:
                        log.info("block_size={}".format(block_size))
                        return block_size
Пример #7
0
def fake_ciphertext(new_plaintext,
                    padding_oracle=None,
                    decryption_oracle=None,
                    block_size=16):
    """Make ciphertext that will decrypt to given plaintext
    Give padding_oracle or decryption_oracle (or both)

    Args:
        new_plaintext(string): with padding
        padding_oracle(function/None)
        decryption_oracle(function/None): maximum one block to decrypt
        block_size(int)

    Returns:
        fake_ciphertext(string): fake ciphertext that will decrypt to new_plaintext
    """
    _check_oracles(padding_oracle=padding_oracle,
                   decryption_oracle=decryption_oracle,
                   block_size=block_size)

    if block_size % 8 != 0:
        log.critical_error("Incorrect block size: {}".format(block_size))

    log.info("Start fake ciphertext")
    ciphertext = bytes(b'A' * (len(new_plaintext) + block_size))

    # prepare blocks
    blocks = chunks(ciphertext, block_size)
    new_pl_blocks = chunks(new_plaintext, block_size)
    if len(new_pl_blocks) != len(blocks) - 1:
        log.critical_error(
            "Wrong new plaintext length({}), should be {}".format(
                len(new_plaintext), block_size * (len(blocks) - 1)))
    new_ct_blocks = list(blocks)

    # add known plaintext

    for count_block in range(len(blocks) - 1, 0, -1):
        """ Every block, modify block[count_block-1] to set block[count_block] """
        log.info("Block no. {}".format(count_block))

        ciphertext_to_decrypt = bytes(b''.join(new_ct_blocks[:count_block +
                                                             1]))
        original_plaintext = decrypt(ciphertext_to_decrypt,
                                     padding_oracle=padding_oracle,
                                     decryption_oracle=decryption_oracle,
                                     block_size=block_size,
                                     amount=1,
                                     is_correct=False)
        log.info("Set block no. {}".format(count_block))
        new_ct_blocks[count_block - 1] = xor(blocks[count_block - 1],
                                             original_plaintext,
                                             new_pl_blocks[count_block - 1])

    fake_ciphertext_res = bytes(b''.join(new_ct_blocks))
    log.success("Fake ciphertext(hex): {}".format(b2h(fake_ciphertext_res)))
    return fake_ciphertext_res
Пример #8
0
def bleichenbacher_pkcs15(pkcs15_padding_oracle,
                          key,
                          ciphertext=None,
                          incremental_blinding=False,
                          **kwargs):
    """Given oracle that checks if ciphertext decrypts to some valid plaintext with PKCS1.5 padding
    we can decrypt whole ciphertext
    pkcs15_padding_oracle function must be implemented

    http://archiv.infsec.ethz.ch/education/fs08/secsem/bleichenbacher98.pdf
    https://www.dsi.unive.it/~focardi/RSA-padding-oracle/#eq5

    Note that this attack is very slow. Approximate number of main loop iterations == key's bit length

    Args:
        pkcs15_padding_oracle(callable)
        key(RSAKey): contains ciphertexts to decrypt
        incremental_blinding(bool): if ciphertext is not pkcs confirming we need to blind it.
                                    this may be done using random or incremental values

    Returns:
        dict: decrypted ciphertexts
        update key texts
    """
    def ceil(a, b):
        return a // b + (a % b > 0)

    def floor(a, b):
        return a // b

    def insert_interval(M, lb, ub):
        lo, hi = 0, len(M)
        while lo < hi:
            mid = (lo + hi) // 2
            if M[mid][0] < lb:
                lo = mid + 1
            else:
                hi = mid
        # insert it
        M.insert(lo, (lb, ub))

        # lb inside previous interval
        if lo > 0 and M[lo - 1][1] >= lb:
            lb = min(lb, M[lo - 1][0])
            M[lo] = (lb, M[lo][1])
            del M[lo - 1]
            lo -= 1

        # remove covered intervals
        i = lo + 1
        to_remove_first = i
        to_remove_last = lo
        while i < len(M) and M[i][0] <= ub:
            to_remove_last += 1
            i += 1
        if to_remove_last > lo:
            new_ub = max(ub, M[to_remove_last][1])
            M[lo] = (M[lo][0], new_ub)
            del M[to_remove_first:to_remove_last + 1]

    def update_intervals(M, B, s, n):
        # step 3
        M2 = []
        for a, b in M:
            r_min = ceil(a * s - 3 * B + 1, n)
            r_max = floor(b * s - 2 * B, n)
            for r in range(r_min, r_max + 1):
                lb = max(a, ceil(2 * B + r * n, s))
                ub = min(b, floor(3 * B - 1 + r * n, s))
                insert_interval(M2, lb, ub)
        del M[:]
        return M2

    def find_si(si_start, si_max=None):
        si_new = si_start
        while si_max is None or si_new < si_max:
            cipheri = (cipher_blinded * gmpy2.powmod(si_new, e, n)) % n
            if pkcs15_padding_oracle(cipheri, **kwargs):
                return si_new
            si_new += 1
        return None

    recovered = {}
    for text_no, cipher in _prepare_ciphertexts(key=key,
                                                ciphertext=ciphertext):
        log.info("Decrypting {}".format(cipher))

        n = key.n
        e = key.e
        B = gmpy2.pow(2, key.size - 16)

        # step 1
        log.debug('Blinding the ciphertext (to make it PKCS1.5 confirming)')
        i = 0
        si = 1
        cipher_blinded = cipher
        while not pkcs15_padding_oracle(cipher_blinded, **kwargs):
            # the paper says to draw it, but seems like incrementation sometimes run faster
            if incremental_blinding:
                si += 1
            else:
                si = random.randint(2, 1 << (key.size - 16))
            cipher_blinded = (cipher * gmpy2.powmod(si, e, n)) % n
        Mi = [(2 * B, 3 * B - 1)]
        s0 = si
        log.debug('Found s{}: {}'.format(i, hex(si)))
        i = 1

        plaintext = None
        while plaintext is None:
            log.debug('len(M{}): {}'.format(i - 1, len(Mi)))

            if i == 1:
                # step 2.a
                si = find_si(si_start=ceil(n, (3 * B)))
                log.debug('Found s{}: {}'.format(i, hex(si)))

            elif len(Mi) > 1:
                # step 2.b
                si = find_si(si_start=si + 1)

            elif len(Mi) == 1 and Mi[0][0] != Mi[0][1]:
                # step 2.c
                a, b = Mi[0]
                ri = ceil(2 * (b * si - 2 * B), n)
                si = None
                while si is None:
                    si_min = ceil(2 * B + ri * n, b)
                    si_max = ceil(3 * B + ri * n, a)
                    si = find_si(si_start=si_min, si_max=si_max)
                    ri += 1

            else:
                log.error(
                    "Hm, something strange happend. Len(M{}) = {}".format(
                        i, len(Mi)))
                return None

            # step 3
            Mi = update_intervals(Mi, B, si, n)

            # step 4
            if len(Mi) == 1 and Mi[0][0] == Mi[0][1]:
                plaintext = Mi[0][0]
                if s0 != 1:
                    plaintext = (plaintext * invmod(s0, n)) % n
                log.success("Interval narrowed to one value")
                log.success("plaintext = {}".format(hex(plaintext)))
            else:
                i += 1

        recovered[text_no] = plaintext
        key.texts[text_no]['plain'] = recovered[text_no]

    return recovered
Пример #9
0
def bleichenbacher_signature_forgery(key,
                                     garbage='suffix',
                                     hash_function='sha1'):
    """Bleichenbacher's signature forgery based on bug in verify implementation

    Args:
        key(RSAKey): with small e and at least one plaintext
        garbage(string): middle: 00 01 ff garbage 00 ASN.1 HASH
                         suffix: 00 01 ff 00 ASN.1 HASH garbage
        hash_function(string)

    Returns:
        dict: forged signatures, signatures[no] == signature(key.texts[no]['plain'])
        update key texts
    """
    hash_asn1 = {
        'md5':
        bytes(
            b'\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10'
        ),
        'sha1':
        bytes(b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'),
        'sha256':
        bytes(
            b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'
        ),
        'sha384':
        bytes(
            b'\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30'
        ),
        'sha512':
        bytes(
            b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'
        )
    }
    if garbage not in ['suffix', 'middle']:
        log.critical_error("Bad garbage position, must be suffix or middle")
    if hash_function not in list(hash_asn1.keys()):
        log.critical_error(
            "Hash function {} not implemented".format(hash_function))

    if key.e > 3:
        log.debug("May not work, because e > 3")

    signatures = {}
    if garbage == 'suffix':
        for text_no in range(len(key.texts)):
            if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[
                    text_no]:
                log.info("Forge for plaintext no {} ({})".format(
                    text_no, key.texts[text_no]['plain']))

                hash_callable = getattr(hashlib, hash_function)(i2b(
                    key.texts[text_no]
                    ['plain'])).digest()  # hack to call hashlib.hash_function
                plaintext_prefix = bytes(b'\x00\x01\xff\x00') + hash_asn1[
                    hash_function] + hash_callable

                plaintext = plaintext_prefix + bytes(
                    b'\x00' * (key.size // 8 - len(plaintext_prefix)))
                plaintext = b2i(plaintext)
                for round_error in range(-5, 5):
                    signature, _ = gmpy2.iroot(plaintext, key.e)
                    signature = int(signature + round_error)
                    test_prefix = i2b(gmpy2.powmod(signature, key.e, key.n),
                                      size=key.size)[:len(plaintext_prefix)]
                    if test_prefix == plaintext_prefix:
                        log.info("Got signature: {}".format(signature))
                        log.debug("signature**e % n == {}".format(
                            i2h(gmpy2.powmod(signature, key.e, key.n),
                                size=key.size)))
                        key.texts[text_no]['cipher'] = signature
                        signatures[text_no] = signature
                        break
                else:
                    log.error(
                        "Something wrong, can't compute correct signature")
        return signatures

    elif garbage == 'middle':
        for text_no in range(len(key.texts)):
            if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[
                    text_no]:
                log.info("Forge for plaintext no {} ({})".format(
                    text_no, key.texts[text_no]['plain']))
                hash_callable = getattr(hashlib, hash_function)(i2b(
                    key.texts[text_no]
                    ['plain'])).digest()  # hack to call hashlib.hash_function
                plaintext_suffix = bytes(
                    b'\x00') + hash_asn1[hash_function] + hash_callable
                if b2i(plaintext_suffix) & 1 != 1:
                    log.error(
                        "Plaintext suffix is even, can't compute signature")
                    continue

                # compute suffix
                signature_suffix = 0b1
                for b in range(len(plaintext_suffix) * 8):
                    if (signature_suffix**
                            3) & (1 << b) != b2i(plaintext_suffix) & (1 << b):
                        signature_suffix |= 1 << b
                signature_suffix = i2b(
                    signature_suffix)[-len(plaintext_suffix):]

                # compute prefix
                while True:
                    plaintext_prefix = bytes(b'\x00\x01\xff') + random_bytes(
                        key.size // 8 - 3)
                    signature_prefix, _ = gmpy2.iroot(b2i(plaintext_prefix),
                                                      key.e)
                    signature_prefix = i2b(
                        int(signature_prefix),
                        size=key.size)[:-len(signature_suffix)]

                    signature = b2i(signature_prefix + signature_suffix)
                    test_plaintext = i2b(gmpy2.powmod(signature, key.e, key.n),
                                         size=key.size)
                    if bytes(b'\x00'
                             ) not in test_plaintext[2:-len(plaintext_suffix)]:
                        if test_plaintext[:3] == plaintext_prefix[:3] and test_plaintext[
                                -len(plaintext_suffix):] == plaintext_suffix:
                            log.info("Got signature: {}".format(signature))
                            key.texts[text_no]['cipher'] = signature
                            signatures[text_no] = signature
                            break
                        else:
                            log.error("Something wrong, signature={},"
                                      " signature**{}%{} is {}".format(
                                          signature, key.e, key.n,
                                          [(test_plaintext)]))
                            break
        return signatures
Пример #10
0
def blinding(key, signing_oracle=None, decryption_oracle=None):
    """Perform signature/ciphertext blinding attack

    Args:
        key(RSAKey): with at least one plaintext(to sign) or ciphertext(to decrypt)
        signing_oracle(callable)
        decryption_oracle(callable)

    Returns:
        dict: {index: signature/plaintext, index2: signature/plaintext}
        update key texts
    """
    if not signing_oracle and not decryption_oracle:
        log.critical_error("Give one of signing_oracle or decryption_oracle")
    if signing_oracle and decryption_oracle:
        log.critical_error(
            "Give only one of signing_oracle or decryption_oracle")

    recovered = {}
    if signing_oracle:
        log.debug("Have signing_oracle")
        for text_no in range(len(key.texts)):
            if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[
                    text_no]:
                log.info("Blinding signature of plaintext no {} ({})".format(
                    text_no, i2h(key.texts[text_no]['plain'])))

                blind = random.randint(2, 100)
                blind_enc = key.encrypt(blind)
                blinded_plaintext = (key.texts[text_no]['plain'] *
                                     blind_enc) % key.n
                blinded_signature = signing_oracle(blinded_plaintext)
                if not blinded_signature:
                    log.critical_error(
                        "Error during call to signing_oracle({})".format(
                            blinded_plaintext))
                signature = (invmod(blind, key.n) * blinded_signature) % key.n
                key.texts[text_no]['cipher'] = signature
                recovered[text_no] = signature
                log.success("Signature: {}".format(signature))

    if decryption_oracle:
        log.debug("Have decryption_oracle")
        for text_no in range(len(key.texts)):
            if 'cipher' in key.texts[text_no] and 'plain' not in key.texts[
                    text_no]:
                log.info("Blinding ciphertext no {} ({})".format(
                    text_no, key.texts[text_no]['cipher']))
                blind = random.randint(2, 100)
                blind_enc = key.encrypt(blind)
                blinded_ciphertext = (key.texts[text_no]['cipher'] *
                                      blind_enc) % key.n
                blinded_plaintext = decryption_oracle(blinded_ciphertext)
                if not blinded_plaintext:
                    log.critical_error(
                        "Error during call to decryption_oracle({})".format(
                            blinded_plaintext))
                plaintext = (invmod(blind, key.n) * blinded_plaintext) % key.n
                key.texts[text_no]['plain'] = plaintext
                recovered[text_no] = plaintext
                log.success("Plaintext: {}".format(plaintext))

    return recovered
Пример #11
0
def hastad(keys, ciphertexts=None):
    """Hastad's broadcast attack (small public exponent)
    Given at least e keys with public exponent equals to e and ciphertexts of the same plaintext,
    plaintext can be efficiently recovered

    Args:
        keys(list): RSAKeys, all with same public exponent e, len(keys) >= e,
                    every key with only one ciphertext
        ciphertexts(list/None): if not None, use this ciphertexts

    Returns:
        NoneType/int: None on failure, recovered plaintext otherwise
        update keys texts
    """
    e = keys[0].e
    if len(keys) < e:
        log.critical_error("Not enough keys, e={}".format(e))

    if ciphertexts is None:
        for key in keys:
            if len(key.texts) != 1:
                log.info(
                    "Key have more than one ciphertext, using the first one(key=={})"
                    .format(key.identifier))
            if 'cipher' not in key.texts[0]:
                log.critical_error("key {} doesn't have ciphertext".format(
                    key.identifier))

        # prepare ciphertexts and correct_keys lists
        ciphertexts, modules, correct_keys = [], [], []
        for key in keys:
            # get only first ciphertext (if exists)
            if key.n not in modules and key.texts[0][
                    'cipher'] not in ciphertexts:
                if key.e == e:
                    modules.append(key.n)
                    correct_keys.append(key)
                    ciphertexts.append(key.texts[0]['cipher'])
                else:
                    log.info("Key {} have different e(={})".format(
                        key.identifier, key.e))
    else:
        if len(ciphertexts) != len(keys):
            log.critical_error("len(ciphertexts) != len(keys)")
        modules = [key.n for key in keys]
        correct_keys = keys

    # check if we have enough ciphertexts
    if len(modules) < e:
        log.info(
            "Not enough keys with unique modulus and ciphertext, e={}, len(modules)={}"
            .format(e, len(modules)))
        log.info("Checking for simple roots (small_e_msg)")
        for one_key in correct_keys:
            recovered_plaintexts = small_e_msg(one_key)
            if len(recovered_plaintexts) > 0:
                log.success("Found plaintext: {}".format(
                    recovered_plaintexts[0]))
                return recovered_plaintexts[0]

    if len(modules) > e:
        log.debug("Number of modules/ciphertexts larger than e")
        modules = modules[:e]
        ciphertexts = ciphertexts[:e]

    # actual Hastad
    result = crt(ciphertexts, modules)
    plaintext, correct = gmpy2.iroot(result, e)
    if correct:
        plaintext = int(plaintext)
        log.success("Found plaintext: {}".format(plaintext))
        for one_key in correct_keys:
            one_key.texts[0]['plain'] = plaintext
        return plaintext
    else:
        log.debug("Plaintext wasn't {}-th root")
        log.debug("result (from crt) = {}".format(e, result))
        log.debug("plaintext ({}-th root of result) = {}".format(e, plaintext))
        return None
Пример #12
0
def manger(oaep_padding_oracle, key, ciphertext=None, **kwargs):
    """Given oracle that checks if ciphertext decrypts to some valid plaintext with OAEP padding
    we can decrypt whole ciphertext
    oaep_padding_oracle function must be implemented

    https://iacr.org/archive/crypto2001/21390229.pdf

    Args:
        oaep_padding_oracle(callable)
        key(RSAKey): contains ciphertexts to decrypt

    Returns:
        dict: decrypted ciphertexts
        update key texts
    """
    def ceil(a, b):
        return a // b + (a % b > 0)

    def floor(a, b):
        return a // b

    recovered = {}
    for text_no, cipher in _prepare_ciphertexts(key=key,
                                                ciphertext=ciphertext):
        log.info("Decrypting {}".format(cipher))

        n = key.n
        e = key.e
        B = pow(2, key.size - 8)

        # step 1
        log.debug('step 1')
        f1 = 2
        cipheri = (cipher * gmpy2.powmod(f1, e, n)) % n
        while oaep_padding_oracle(cipheri, **kwargs):
            f1 *= 2
            cipheri = (cipher * gmpy2.powmod(f1, e, n)) % n

        log.debug('step 1 done')
        log.debug('Found f1: {}'.format(hex(f1)))

        # step 2
        log.debug('step 2')
        f1_half = f1 // 2
        f2 = int(floor(n + B, B)) * f1_half
        cipheri = (cipher * gmpy2.powmod(f2, e, n)) % n
        while not oaep_padding_oracle(cipheri, **kwargs):
            f2 += f1_half
            cipheri = (cipher * gmpy2.powmod(f2, e, n)) % n

        log.debug('step 2 done')
        log.debug('Found f2: {}'.format(hex(f2)))

        # step 3
        log.debug('step 3')
        m_min, m_max = ceil(n, f2), floor(n + B, f2)
        while m_min < m_max:
            log.debug(hex(m_max - m_min))
            f_tmp = floor(2 * B, m_max - m_min)
            i = floor(f_tmp * m_min, n)
            f3 = ceil(i * n, m_min)

            cipheri = (cipher * gmpy2.powmod(f3, e, n)) % n
            if oaep_padding_oracle(cipheri, **kwargs):
                m_max = floor(i * n + B, f3)
            else:
                m_min = ceil(i * n + B, f3)

        log.debug('step 3 done')
        log.debug('m_min = {}'.format(hex(m_min)))
        log.debug('m_max = {}'.format(hex(m_max)))
        plaintext = m_min

        log.success("plaintext = {}".format(hex(plaintext)))
        recovered[text_no] = plaintext
        key.texts[text_no]['plain'] = recovered[text_no]

    return recovered
Пример #13
0
def guess_key_size(ciphertext, max_key_size=40):
    """Given sentence xored with short key, guess key size
    From: http://trustedsignal.blogspot.com/2015/06/xord-play-normalized-hamming-distance.html

    Args:
         ciphertext(string)
         max_key_size(int)

    Returns:
        list: sorted list of tuples (key_size, probability),
        note that most probable key size not necessary have the largest probability
    """
    if not max_key_size:
        max_key_size = len(ciphertext)/4

    result = {}
    for key_size in range(1, max_key_size):
        blocks = re.findall('.' * key_size, ciphertext, re.DOTALL)
        if len(blocks) < 2:
            break

        diff = i = 0
        while i < len(blocks) - 1:
            if len(blocks[i]) != len(blocks[i + 1]):  # not full-length block
                break
            diff += hamming_distance(blocks[i], blocks[i + 1])
            i += 1
        result[key_size] = diff / float(i)  # average
        result[key_size] /= float(key_size)  # normalize
    result = sorted(list(result.items()), key=operator.itemgetter(1))

    # now part from given link, case one
    # gcd12 = gcd(result[0][0], result[1][0])
    # gcd13 = gcd(result[0][0], result[2][0])
    # gcd23 = gcd(result[1][0], result[2][0])
    # print gcd12, gcd13, gcd23
    # if (gcd12 != 1) and (gcd12 in [x[0] for x in result[:5]]):
    #     if (gcd12 == gcd13 and gcd12 == gcd23) or (gcd12 == result[0][0] or gcd12 == result[1][0]):
    #         #remove key_size == gcd12 from result list and add it to the beginning
    #         for x in result:
    #             if x[0] == gcd12:
    #                 result.remove(x)
    #                 break
    #         result[0] == (gcd12, 1.0)

    # from link, case two; yep, black magic it is
    gcd_frequencies = defaultdict(lambda: 0)
    for gcd_pairs in itertools.combinations(result[:10], 2):
        gcd_tmp = gcd(gcd_pairs[0][0], gcd_pairs[1][0])
        gcd_frequencies[gcd_tmp] += 1
    gcd_frequencies = sorted(list(gcd_frequencies.items()), key=operator.itemgetter(1), reverse=True)

    key_sizes = [x[0] for x in result[:10]]
    distances = [x[1] for x in result[:10]]
    for guessed_most_probable_key_size in gcd_frequencies[:5]:
        if guessed_most_probable_key_size[0] != 1 and guessed_most_probable_key_size[1] != 0 and \
                        guessed_most_probable_key_size[0] in key_sizes:
            gmks_position = result[key_sizes.index(guessed_most_probable_key_size[0])]
            if gmks_position[1] < max(distances):
                result.remove(gmks_position)
                result = [gmks_position] + result
    log.info("Guessed key size: {}".format(result))
    return result
Пример #14
0
def decrypt(ciphertext,
            padding_oracle=None,
            decryption_oracle=None,
            iv=None,
            block_size=16,
            is_correct=True,
            amount=0,
            known_plaintext=None,
            async_calls=False):
    """Decrypt ciphertext
    Give padding_oracle or decryption_oracle (or both)

    Args:
        ciphertext(string): to decrypt
        padding_oracle(function/None)
        decryption_oracle(function/None)
        iv(string): if not specified, first block of ciphertext is treated as iv
        block_size(int)
        is_correct(bool): set if ciphertext will decrypt to something with correct padding
        amount(int): how much blocks decrypt (counting from last), zero (default) means all
        known_plaintext(string): with padding, from end (aligned to end of ciphertext)
        async_calls(bool): make asynchronous calls to oracle (not implemented yet)

    Returns:
        plaintext(string): with padding
    """
    _check_oracles(padding_oracle=padding_oracle,
                   decryption_oracle=decryption_oracle,
                   block_size=block_size)

    if block_size % 8 != 0:
        log.critical_error("Incorrect block size: {}".format(block_size))

    if len(ciphertext) % block_size != 0:
        log.critical_error("Incorrect ciphertext length: {}".format(
            len(ciphertext)))

    if decryption_oracle:
        if iv:
            ciphertext = iv + ciphertext
        blocks = chunks(ciphertext, block_size)
        plaintext = bytes(b'')
        for position in range(len(blocks) - 1, 0, -1):
            plaintext = xor(decryption_oracle(blocks[position]),
                            blocks[position - 1]) + plaintext
            log.info("Plaintext(hex): {}".format(b2h(plaintext)))
            if amount != 0 and len(plaintext) == amount * block_size:
                break
        log.success("Decrypted(hex): {}".format(b2h(plaintext)))
        return plaintext

    log.info("Start cbc padding oracle")
    log.debug(print_chunks(chunks(ciphertext, block_size)))

    # prepare blocks
    blocks = chunks(ciphertext, block_size)
    if iv:
        if len(iv) % block_size != 0:
            log.critical_error("Incorrect iv length: {}".format(len(iv)))
        log.info("Set iv")
        blocks.insert(0, iv)

    if amount != 0:
        amount = len(blocks) - amount - 1
    if amount < 0 or amount >= len(blocks):
        log.critical_error(
            "Incorrect amount of blocks to decrypt: {} (have to be in [0,{}]".
            format(amount,
                   len(blocks) - 1))
    log.info("Will decrypt {} block(s)".format(len(blocks) - 1 - amount))

    # add known plaintext
    plaintext = bytes(b'')
    position_known = 0
    chars_decoded = 0
    if known_plaintext:
        is_correct = False
        plaintext = known_plaintext
        blocks_decoded = len(plaintext) // block_size
        chars_decoded = len(plaintext) % block_size

        if blocks_decoded == len(blocks) - 1:
            log.debug("Nothing decrypted, known plaintext long enough")
            return plaintext
        if blocks_decoded > len(blocks) - 1:
            log.critical_error(
                "Too long known plaintext ({} blocks)".format(blocks_decoded))

        if blocks_decoded != 0:
            blocks = blocks[:-blocks_decoded]

        position_known = chars_decoded
        log.info("Have known plaintext, skip {} block(s) and {} bytes".format(
            blocks_decoded, chars_decoded))

    # start decryption
    for count_block in range(len(blocks) - 1, amount, -1):
        """ Blocks from the last to the second (all except iv) """
        log.info("Block no. {}".format(count_block))

        payload_prefix = bytes(b''.join(blocks[:count_block - 1]))
        payload_modify = blocks[count_block - 1]
        payload_decrypt = blocks[count_block]

        if chars_decoded != 0:
            # we know some chars, so modify previous block
            payload_modify = payload_modify[:-chars_decoded] +\
                             xor(plaintext[:chars_decoded], payload_modify[-chars_decoded:], bytes([chars_decoded + 1]))
            chars_decoded = 0

        position = block_size - 1 - position_known
        position_known = 0
        while position >= 0:
            """ Every position in block, from the end """
            log.debug("Position: {}".format(position))

            found_correct_char = False
            for guess_char in range(256):
                modified = payload_modify[:position] + bytes(
                    [guess_char]) + payload_modify[position + 1:]
                payload = bytes(b''.join(
                    [payload_prefix, modified, payload_decrypt]))

                iv = payload[:block_size]
                payload = payload[block_size:]
                log.debug(print_chunks(chunks(iv + payload, block_size)))

                correct = padding_oracle(payload=payload, iv=iv)
                if correct:
                    """ oracle returns True """
                    padding = block_size - position  # sent ciphertext decoded to that padding
                    decrypted_char = bytes(
                        [payload_modify[position] ^ guess_char ^ padding])

                    if is_correct:
                        """ If we didn't send original ciphertext, then we have found original padding value.
                            Otherwise keep searching and if won't find any other correct char - padding is \x01
                        """
                        if guess_char == blocks[-2][-1]:
                            log.debug(
                                "Skip this guess char ({})".format(guess_char))
                            continue

                        dc = int(decrypted_char[0])
                        log.info(
                            "Found padding value for correct ciphertext: {}".
                            format(dc))
                        if dc == 0 or dc > block_size:
                            log.critical_error(
                                "Found bad padding value (given ciphertext may not be correct)"
                            )

                        plaintext = decrypted_char * dc
                        payload_modify = payload_modify[:-dc] + xor(
                            payload_modify[-dc:], decrypted_char,
                            bytes([dc + 1]))
                        position = position - dc + 1
                        is_correct = False
                    else:
                        """ abcd efgh ijkl o|guess_char|xy || 1234 5678 9tre qwer - ciphertext
                            what ever itma ybex            || xyzw rtua lopo k|\x03|\x03\x03 - plaintext
                            abcd efgh ijkl |guess_char|wxy || 1234 5678 9tre qwer - next round ciphertext
                            some thin gels eheh            || xyzw rtua lopo guessing|\x04\x04\x04 - next round plaintext
                        """
                        if position == block_size - 1:
                            """ if we decrypt first byte, check if we didn't hit other padding than \x01 """
                            payload = iv + payload
                            payload = payload[:-block_size - 2] + bytes(
                                b'A') + payload[-block_size - 1:]
                            iv = payload[:block_size]
                            payload = payload[block_size:]
                            correct = padding_oracle(payload=payload, iv=iv)
                            if not correct:
                                log.debug("Hit false positive, guess char({})".
                                          format(guess_char))
                                continue

                        payload_modify = payload_modify[:position] + xor(
                            bytes([guess_char]) +
                            payload_modify[position + 1:], bytes([padding]),
                            bytes([padding + 1]))
                        plaintext = decrypted_char + plaintext

                    found_correct_char = True
                    log.debug(
                        "Guessed char(\\x{:02x}), decrypted char(\\x{:02x})".
                        format(guess_char, decrypted_char[0]))
                    log.debug("Plaintext: {}".format(plaintext))
                    log.info("Plaintext(hex): {}".format(b2h(plaintext)))
                    break
            position -= 1
            if found_correct_char is False:
                if is_correct:
                    padding = 0x01
                    payload_modify = payload_modify[:position + 1] + xor(
                        payload_modify[position + 1:], bytes([padding]),
                        bytes([padding + 1]))
                    plaintext = bytes(b"\x01")
                    is_correct = False
                else:
                    log.critical_error(
                        "Can't find correct padding (oracle function return False 256 times)"
                    )
    log.success("Decrypted(hex): {}".format(b2h(plaintext)))
    return plaintext