Exemplo n.º 1
0
def length_extension(old_hash, size, new_message, type='sha1'):
    """Length extension attack: given hash(secret) and len(secret),
    compute new_hash and old_padding so that hash(secret+old_padding+new_message) == new_hash

    Args:
        old_hash(string): hash of secret value
        size(int): length of secret (in bytes)
        new_message(string)
        type(string): sha1 or md4

    Returns:
        tuple: (new_hash, old_padding+new_message)
    """
    implemented_functions = ['sha1', 'md4']
    if type == 'sha1':
        endian = 'big'
        hash_function = sha1
    elif type == 'md4':
        endian = 'little'
        hash_function = md4
    else:
        log.critical_error("Not implemented, type must be one of {}".format(
            implemented_functions))
        return None

    old_padding = add_md_padding(bytes(b'a' * size), endian=endian)[size:]
    new_data_size = len(new_message) + len(old_padding) + size
    new_padding = add_md_padding(bytes(b'a' * new_data_size),
                                 endian=endian)[new_data_size:]
    h = [b2i(x, endian=endian) for x in chunks(old_hash, 4)]
    return hash_function(new_message, h,
                         new_padding), old_padding + new_message
Exemplo n.º 2
0
def compare_by_frequencies(a, b, lang='English', no_of_comparisons=5):
    """Check which text have more similar letter frequencies (compared to language)
    todo: add words, diagraphs etc...

    Args:
        a(string)
        b(string)
        lang(string)
        no_of_comparisons(int): how much letters compare

    Returns:
        int: -1 if a is less similar than b, 0 if equal, 1 if a is more similar
    """
    if lang not in frequencies:
        log.critical_error("[-] Can't find language {}".format(lang))
    language_frequencies = frequencies[lang][:no_of_comparisons]

    freq = {a: get_frequencies(a), b: get_frequencies(b)}
    result = {a: 0, b: 0}

    for freq_tuple in language_frequencies:
        letter = freq_tuple[0]
        letter_frequency = freq_tuple[1]
        for word in [a, b]:
            if letter not in freq[word]:
                result[word] += letter_frequency
            else:
                result[word] += abs(freq[word][letter] - letter_frequency)

    if result[a] < result[b]:  # less means that is closer to language frequencies
        return 1
    elif result[a] > result[b]:
        return -1
    else:
        return 0
Exemplo n.º 3
0
def break_reuse_key(ciphertexts, lang='English', no_of_comparisons=5, alphabet=None,
                    key_space=None, reliability=100.0):
    """Sentences xored with the same key

    Args:
        ciphertexts(list): texts xored with the same key
        lang(string): key in frequencies dict
        no_of_comparisons(int): used during comparing by frequencies
        alphabet(string/None): plaintext space
        key_space(string/None): key space
        reliability(float): between 0 and 100, used during comparing by frequencies

    Returns:
        list: sorted (by frequencies) list of tuples (key, list(plaintexts))
    """
    if len(ciphertexts) < 2:
        log.critical_error("Too less ciphertexts")

    min_size = min(list(map(len, ciphertexts)))
    ciphertexts = [one[:min_size] for one in ciphertexts]
    log.info("Ciphertexts shrinked to {} bytes".format(min_size))

    pairs = break_repeated_key(''.join(ciphertexts), lang=lang, no_of_comparisons=no_of_comparisons, key_size=min_size,
                               alphabet=alphabet, key_space=key_space, reliability=reliability)

    res = [(pair[0], chunks(pair[1], min_size)) for pair in pairs]
    return res
Exemplo n.º 4
0
    def compute_params(s, m=None, a=None, b=None):
        """Compute parameters and initial seed for LCG prng
        next_state = a*state + b mod m

        Args:
            s(list): subsequent outputs from LCG oracle starting with seed
            m(int/None)
            a(int/None)
            b(int/None)

        Returns:
            a, b, m(int)
        """
        if m is None:
            t = [s[n + 1] - s[n] for n in range(len(s) - 1)]
            u = [abs(t[n + 2] * t[n] - t[n + 1]**2) for n in range(len(t) - 2)]
            m = gcd(*u)
            log.success("m = {}".format(m))

        if a is None:
            if gcd(s[1] - s[0], m) == 1:
                a = (s[2] - s[1]) * invmod(s[1] - s[0], m)
            elif gcd(s[2] - s[0], m) == 1:
                a = (s[3] - s[1]) * invmod(s[2] - s[0], m)
            else:
                log.critical_error("a not found")
            a = a % m
            log.success("a = {}".format(a))

        if b is None:
            b = (s[1] - s[0] * a) % m
            log.success("b = {}".format(b))

        return a, b, m
Exemplo n.º 5
0
def tonelli_shanks(n, p):
    """Find r such that r^2 = n % p, r2 == p-r"""
    if legendre(n, p) != 1:
        log.critical_error("Not a square root")

    s = 0
    q = p - 1
    while q & 1 == 0:
        s += 1
        q >>= 1

    if s == 1:
        return pow(n, (p + 1) // 4, p)

    z = 1
    while legendre(z, p) != -1:
        z += 1
    c = pow(z, q, p)

    r = pow(n, (q + 1) // 2, p)
    t = pow(n, q, p)
    m = s
    while t != 1:
        i = 1
        while i < m:
            if pow(t, 2**i, p) == 1:
                break
            i += 1
        b = pow(c, 2**(m - i - 1), p)
        r = (r * b) % p
        t = (t * (b**2)) % p
        c = pow(b, 2, p)
        m = i
    return r
Exemplo n.º 6
0
def _check_oracles(padding_oracle=None, decryption_oracle=None, block_size=16):
    """Check if padding or decryption oracle works"""
    if not padding_oracle and not decryption_oracle:
        log.critical_error(
            "Give padding_oracle and/or decryption_oracle functions")

    if decryption_oracle:
        try:
            decryption_oracle(bytes(b'B' * block_size))
        except NotImplementedError:
            log.critical_error("decryption_oracle not implemented")
        except Exception as e:
            log.critical_error(
                "Error in first call to decryption_oracle: {}".format(
                    e.message))

    if padding_oracle:
        try:
            padding_oracle(payload=bytes(b'B' * block_size),
                           iv=bytes(b'A' * block_size))
        except NotImplementedError:
            log.critical_error("padding_oracle not implemented")
        except Exception as e:
            log.critical_error(
                "Error in first call to padding_oracle: {}".format(e.message))
Exemplo n.º 7
0
def fake_ciphertext(new_plaintext,
                    padding_oracle=None,
                    decryption_oracle=None,
                    block_size=16):
    """Make ciphertext that will decrypt to given plaintext
    Give padding_oracle or decryption_oracle (or both)

    Args:
        new_plaintext(string): with padding
        padding_oracle(function/None)
        decryption_oracle(function/None): maximum one block to decrypt
        block_size(int)

    Returns:
        fake_ciphertext(string): fake ciphertext that will decrypt to new_plaintext
    """
    _check_oracles(padding_oracle=padding_oracle,
                   decryption_oracle=decryption_oracle,
                   block_size=block_size)

    if block_size % 8 != 0:
        log.critical_error("Incorrect block size: {}".format(block_size))

    log.info("Start fake ciphertext")
    ciphertext = bytes(b'A' * (len(new_plaintext) + block_size))

    # prepare blocks
    blocks = chunks(ciphertext, block_size)
    new_pl_blocks = chunks(new_plaintext, block_size)
    if len(new_pl_blocks) != len(blocks) - 1:
        log.critical_error(
            "Wrong new plaintext length({}), should be {}".format(
                len(new_plaintext), block_size * (len(blocks) - 1)))
    new_ct_blocks = list(blocks)

    # add known plaintext

    for count_block in range(len(blocks) - 1, 0, -1):
        """ Every block, modify block[count_block-1] to set block[count_block] """
        log.info("Block no. {}".format(count_block))

        ciphertext_to_decrypt = bytes(b''.join(new_ct_blocks[:count_block +
                                                             1]))
        original_plaintext = decrypt(ciphertext_to_decrypt,
                                     padding_oracle=padding_oracle,
                                     decryption_oracle=decryption_oracle,
                                     block_size=block_size,
                                     amount=1,
                                     is_correct=False)
        log.info("Set block no. {}".format(count_block))
        new_ct_blocks[count_block - 1] = xor(blocks[count_block - 1],
                                             original_plaintext,
                                             new_pl_blocks[count_block - 1])

    fake_ciphertext_res = bytes(b''.join(new_ct_blocks))
    log.success("Fake ciphertext(hex): {}".format(b2h(fake_ciphertext_res)))
    return fake_ciphertext_res
Exemplo n.º 8
0
def lcm(*args):
    """Lowest common multiple"""
    if len(args) < 2:
        log.critical_error("Give at least two values")

    if len(args) == 2:
        a, b = args
        return (a * b) // gcd(a, b)
    else:
        l = 1
        for number in args:
            l = lcm(l, number)
        return l
Exemplo n.º 9
0
    def decrypt(self, ciphertext):
        """Raw decryption

        Args: ciphertext(int/string)
        Returns: pow(ciphertext, d, n)
        """
        if not isinstance(ciphertext, Number):
            try:
                ciphertext = b2i(ciphertext)
            except:
                log.critical_error(
                    "Ciphertext to decrypt must be number or be convertible to number ({})"
                    .format(ciphertext))
        return self.pyrsa_key.decrypt(gmpy2.mpz(ciphertext))
Exemplo n.º 10
0
    def encrypt(self, plaintext):
        """Raw encryption

        Args: plaintext(int/string)
        Returns: pow(plaintext,e,n)
        """
        if not isinstance(plaintext, Number):
            try:
                plaintext = b2i(plaintext)
            except:
                log.critical_error(
                    "Plaintext to decrypt must be number or be convertible to number ({})"
                    .format(plaintext))
        return self.pyrsa_key.encrypt(int(plaintext), 0)[0]
Exemplo n.º 11
0
def gcd(*args):
    """Greatest common divisor"""
    if len(args) < 2:
        log.critical_error("Give at least two values")

    if len(args) == 2:
        a, b = args
        return gmpy2.gcd(a, b)
    else:
        d = 0
        for number in args:
            d = gcd(d, number)
            if d == 1:
                break
        return d
Exemplo n.º 12
0
 def add_plaintext(self, plaintext, position=None):
     """Args:
         plaintext(int/string)
         position(int/None) - position in list where to add, None for new
     """
     if not isinstance(plaintext, Number):
         try:
             plaintext = b2i(plaintext)
         except:
             log.critical_error(
                 "Plaintext to add must be number or be convertible to number ({})"
                 .format(plaintext))
     if position is None:
         self.texts.append({'plain': plaintext})
     else:
         self.texts[position]['plain'] = plaintext
Exemplo n.º 13
0
def egcd(*args):
    """Extended Euclidean algorithm"""
    if len(args) < 2:
        log.critical_error("Give at least two values")

    if len(args) == 2:
        a, b = args
        return gmpy2.gcdext(a, b)
    else:
        d, s, t = egcd(args[0], args[1])
        coefficients = [s, t]
        for i in range(2, len(args)):
            d, s, t = egcd(d, args[i])
            for j in range(len(coefficients)):
                coefficients[j] *= s
            coefficients.append(t)
        coefficients.insert(0, d)
        return coefficients
Exemplo n.º 14
0
def crt_non_coprime(a, n):
    """Solve chinese remainder theorem with general modules
    Given:
    x = a % n
    x = b % m

    If modules n, m are not comprime, but a = b mod gcd(n, m) then solution can be found
    The solution will be modulo lcm of modules

    Args:
        a(list): remainders
        n(list): modules

    Returns:
        int: solution to crt
    """
    if len(a) != len(n):
        log.critical_error(
            "Different number of remainders({}) and modules({})".format(
                len(a), len(n)))

    if len(n) < 2:
        log.critical_error(
            "Give at least two remainders and modules (got {})".format(len(n)))

    while len(n) > 1:
        g, u, _ = egcd(n[0], n[1])

        if (a[0] - a[1]) % g != 0:
            print(
                'Not satisfied: gcd(ni, nj) | ai - aj\ngcd({}, {}) | {} - {}'.
                format(n[0], n[1], a[0], a[1]))
            return None

        w = (a[0] - a[1]) // g
        l = lcm(n[0], n[1])
        x = (a[0] - n[0] * u * w) % l

        n = n[2:]
        n.insert(0, l)
        a = a[2:]
        a.insert(0, x)

    return int(a[0] % n[0])
Exemplo n.º 15
0
def crt(a, n):
    """Solve chinese remainder theorem
    from: http://rosettacode.org/wiki/Chinese_remainder_theorem#Python
    The solution will be modulo product of modules

    Args:
        a(list): remainders
        n(list): modules

    Returns:
        int: solution to crt
    """
    if len(a) != len(n):
        log.critical_error("Different number of remainders({}) and modules({})".format(len(a), len(n)))

    prod = product(n)
    sum_crt = 0

    for n_i, a_i in zip(n, a):
        p = prod // n_i
        sum_crt += a_i * invmod(p, n_i) * p
    return int(sum_crt % prod)
Exemplo n.º 16
0
def egcd(*args):
    """Extended Euclidean algorithm"""
    if len(args) < 2:
        log.critical_error("Give at least two values")

    if len(args) == 2:
        a, b = args
        s0, t0, s1, t1 = 1, 0, 0, 1
        while b:
            q, a, b = a // b, b, a % b
            s0, s1 = s1, s0 - q * s1
            t0, t1 = t1, t0 - q * t1
        return a, s0, t0
    else:
        d, s, t = egcd(args[0], args[1])
        coefficients = [s, t]
        for i in range(2, len(args)):
            d, s, t = egcd(d, args[i])
            for j in range(len(coefficients)):
                coefficients[j] *= s
            coefficients.append(t)
        coefficients.insert(0, d)
        return coefficients
Exemplo n.º 17
0
def find_prefix_suffix_size(encryption_oracle, block_size=16):
    """Determine prefix and suffix sizes if ecb mode, sizes must be constant
    Rarely may fail (if random data that are send unhappily matches prefix/suffix)

    Args:
        encryption_oracle(callable)
        block_size(int)

    Returns:
        tuple(int,int): prefix_size, suffix_size
    """
    blocks_to_send = 5
    payload = random_bytes(1) * (blocks_to_send * block_size)
    enc_chunks = chunks(encryption_oracle(payload), block_size)
    log.debug("Encryption of length {}".format(blocks_to_send * block_size))
    log.debug(print_chunks(enc_chunks))

    for position_start in range(len(enc_chunks) - 1):
        if enc_chunks[position_start] == enc_chunks[position_start + 1]:
            for y in range(2, blocks_to_send - 1):
                if enc_chunks[position_start] != enc_chunks[position_start + y]:
                    break
            else:
                log.success("Controlled payload start at chunk {}".format(position_start))
                break
    else:
        log.critical_error("Position of controlled chunks not found")

    log.info('Finding prefix')
    changed_char = bytes([(payload[0] - 1)%256])
    for aligned_bytes in range(block_size):
        payload_new = payload[:aligned_bytes] + changed_char + payload[aligned_bytes+1:]
        enc_chunks_new = chunks(encryption_oracle(payload_new), block_size)
        log.debug(print_chunks(chunks(payload_new, block_size)))
        log.debug(print_chunks(enc_chunks_new))
        if enc_chunks_new[position_start] != enc_chunks[position_start]:
            prefix_size = position_start*block_size - aligned_bytes
            log.success("Prefix size: {}".format(prefix_size))
            break
    else:
        log.critical_error("Size of prefix not found")

    log.info('Finding suffix')
    payload = random_bytes(1) * (block_size - (prefix_size % block_size))  # align to block_size
    encrypted = encryption_oracle(payload)
    suffix_size = len(encrypted) - len(payload) - prefix_size
    while True:
        payload += random_bytes(1)
        suffix_size -= 1
        if len(encryption_oracle(payload)) > len(encrypted):
            log.success("Suffix size: {}".format(suffix_size))
            break
    else:
        log.critical_error("Size of suffix not found")

    return prefix_size, suffix_size
Exemplo n.º 18
0
def bit_flipping(ciphertext, plaintext, wanted, block_size=16):
    """Given ciphertext and corresponding plaintext (two blocks) we can set first block of ciphertext
    so that last block of plaintext will be our wanted value

    Args:
        ciphertext(string): size == 2*block_size
        plaintext(string): size == block_size, plaintext of second ciphertext block
        wanted(string): size == block_size, we want ciphertext last block decrypt to this
        block_size(int)

    Returns:
        string: ciphertext that will decrypt to garbage_block+wanted_last_block
    """
    if len(ciphertext) != 2 * block_size:
        log.critical_error("Incorrect ciphertext size")
    if len(plaintext) != block_size:
        log.critical_error("Incorrect plaintext size")
    if len(wanted) != block_size:
        log.critical_error("Incorrect wanted_last_block size")

    return xor(ciphertext[:block_size], plaintext,
               wanted) + ciphertext[block_size:]
Exemplo n.º 19
0
def decrypt(encryption_oracle, constant=True, block_size=16, prefix_size=None, secret_size=None,
            alphabet=None):
    """Given encryption oracle which produce ecb(prefix || our_input || secret), find secret
    
    Args:
        encryption_oracle(callable)
        constant(bool): True if prefix have constant length (secret must have constant length)
        block_size(int/None)
        prefix_size(int/None)
        secret_size(int/None)
        alphabet(string): plaintext space
    
    Returns:
        secret(string)
    """
    log.debug("Start decrypt function")
    if not alphabet:
        alphabet = bytes(string.printable.encode())

    if not block_size:
        block_size = find_block_size(encryption_oracle, constant)

    if constant:
        log.debug("constant == True")
        if not prefix_size or not secret_size:
            prefix_size, secret_size = find_prefix_suffix_size(encryption_oracle, block_size)

        """Start decrypt"""
        secret = bytes(b'')
        aligned_bytes = random_bytes(1) * (block_size - (prefix_size % block_size))
        if len(aligned_bytes) == block_size:
            aligned_bytes = bytes(b'')

        aligned_bytes_suffix = random_bytes(1) * (block_size - (secret_size % block_size))
        if len(aligned_bytes_suffix) == block_size:
            aligned_bytes_suffix = bytes(b'')

        block_to_find_position = -1
        controlled_block_position = (prefix_size+len(aligned_bytes)) // block_size

        while len(secret) < secret_size:
            if (len(secret)+1) % block_size == 0:
                block_to_find_position -= 1
            payload = aligned_bytes + aligned_bytes_suffix + random_bytes(1) + secret
            enc_chunks = chunks(encryption_oracle(payload), block_size)
            block_to_find = enc_chunks[block_to_find_position]

            log.debug("To guess at position {}:".format(block_to_find_position))
            log.debug("Plain: " + print_chunks(chunks(bytes(b'P'*prefix_size) + payload + bytes(b'S'*secret_size), block_size)))
            log.debug("Encry: " + print_chunks(enc_chunks)+"\n")

            for guessed_char in range(256):
                guessed_char = bytes([guessed_char])
                payload = aligned_bytes + add_padding(guessed_char + secret, block_size)
                enc_chunks = chunks(encryption_oracle(payload), block_size)

                log.debug("Plain: " + print_chunks(chunks(bytes(b'P'*prefix_size) + payload + bytes(b'S'*secret_size), block_size)))
                log.debug("Encry: " + print_chunks(enc_chunks)+"\n")
                if block_to_find == enc_chunks[controlled_block_position]:
                    secret = guessed_char + secret
                    log.debug("Found char, secret={}".format(repr(secret)))
                    break
            else:
                log.critical_error("Char not found, try change alphabet. Secret so far: {}".format(repr(secret)))
        log.success("Secret(hex): {}".format(b2h(secret)))
        return secret
    else:
        log.debug("constant == False")
Exemplo n.º 20
0
def decrypt(ciphertext,
            padding_oracle=None,
            decryption_oracle=None,
            iv=None,
            block_size=16,
            is_correct=True,
            amount=0,
            known_plaintext=None,
            async_calls=False):
    """Decrypt ciphertext
    Give padding_oracle or decryption_oracle (or both)

    Args:
        ciphertext(string): to decrypt
        padding_oracle(function/None)
        decryption_oracle(function/None)
        iv(string): if not specified, first block of ciphertext is treated as iv
        block_size(int)
        is_correct(bool): set if ciphertext will decrypt to something with correct padding
        amount(int): how much blocks decrypt (counting from last), zero (default) means all
        known_plaintext(string): with padding, from end (aligned to end of ciphertext)
        async_calls(bool): make asynchronous calls to oracle (not implemented yet)

    Returns:
        plaintext(string): with padding
    """
    _check_oracles(padding_oracle=padding_oracle,
                   decryption_oracle=decryption_oracle,
                   block_size=block_size)

    if block_size % 8 != 0:
        log.critical_error("Incorrect block size: {}".format(block_size))

    if len(ciphertext) % block_size != 0:
        log.critical_error("Incorrect ciphertext length: {}".format(
            len(ciphertext)))

    if decryption_oracle:
        if iv:
            ciphertext = iv + ciphertext
        blocks = chunks(ciphertext, block_size)
        plaintext = bytes(b'')
        for position in range(len(blocks) - 1, 0, -1):
            plaintext = xor(decryption_oracle(blocks[position]),
                            blocks[position - 1]) + plaintext
            log.info("Plaintext(hex): {}".format(b2h(plaintext)))
            if amount != 0 and len(plaintext) == amount * block_size:
                break
        log.success("Decrypted(hex): {}".format(b2h(plaintext)))
        return plaintext

    log.info("Start cbc padding oracle")
    log.debug(print_chunks(chunks(ciphertext, block_size)))

    # prepare blocks
    blocks = chunks(ciphertext, block_size)
    if iv:
        if len(iv) % block_size != 0:
            log.critical_error("Incorrect iv length: {}".format(len(iv)))
        log.info("Set iv")
        blocks.insert(0, iv)

    if amount != 0:
        amount = len(blocks) - amount - 1
    if amount < 0 or amount >= len(blocks):
        log.critical_error(
            "Incorrect amount of blocks to decrypt: {} (have to be in [0,{}]".
            format(amount,
                   len(blocks) - 1))
    log.info("Will decrypt {} block(s)".format(len(blocks) - 1 - amount))

    # add known plaintext
    plaintext = bytes(b'')
    position_known = 0
    chars_decoded = 0
    if known_plaintext:
        is_correct = False
        plaintext = known_plaintext
        blocks_decoded = len(plaintext) // block_size
        chars_decoded = len(plaintext) % block_size

        if blocks_decoded == len(blocks) - 1:
            log.debug("Nothing decrypted, known plaintext long enough")
            return plaintext
        if blocks_decoded > len(blocks) - 1:
            log.critical_error(
                "Too long known plaintext ({} blocks)".format(blocks_decoded))

        if blocks_decoded != 0:
            blocks = blocks[:-blocks_decoded]

        position_known = chars_decoded
        log.info("Have known plaintext, skip {} block(s) and {} bytes".format(
            blocks_decoded, chars_decoded))

    # start decryption
    for count_block in range(len(blocks) - 1, amount, -1):
        """ Blocks from the last to the second (all except iv) """
        log.info("Block no. {}".format(count_block))

        payload_prefix = bytes(b''.join(blocks[:count_block - 1]))
        payload_modify = blocks[count_block - 1]
        payload_decrypt = blocks[count_block]

        if chars_decoded != 0:
            # we know some chars, so modify previous block
            payload_modify = payload_modify[:-chars_decoded] +\
                             xor(plaintext[:chars_decoded], payload_modify[-chars_decoded:], bytes([chars_decoded + 1]))
            chars_decoded = 0

        position = block_size - 1 - position_known
        position_known = 0
        while position >= 0:
            """ Every position in block, from the end """
            log.debug("Position: {}".format(position))

            found_correct_char = False
            for guess_char in range(256):
                modified = payload_modify[:position] + bytes(
                    [guess_char]) + payload_modify[position + 1:]
                payload = bytes(b''.join(
                    [payload_prefix, modified, payload_decrypt]))

                iv = payload[:block_size]
                payload = payload[block_size:]
                log.debug(print_chunks(chunks(iv + payload, block_size)))

                correct = padding_oracle(payload=payload, iv=iv)
                if correct:
                    """ oracle returns True """
                    padding = block_size - position  # sent ciphertext decoded to that padding
                    decrypted_char = bytes(
                        [payload_modify[position] ^ guess_char ^ padding])

                    if is_correct:
                        """ If we didn't send original ciphertext, then we have found original padding value.
                            Otherwise keep searching and if won't find any other correct char - padding is \x01
                        """
                        if guess_char == blocks[-2][-1]:
                            log.debug(
                                "Skip this guess char ({})".format(guess_char))
                            continue

                        dc = int(decrypted_char[0])
                        log.info(
                            "Found padding value for correct ciphertext: {}".
                            format(dc))
                        if dc == 0 or dc > block_size:
                            log.critical_error(
                                "Found bad padding value (given ciphertext may not be correct)"
                            )

                        plaintext = decrypted_char * dc
                        payload_modify = payload_modify[:-dc] + xor(
                            payload_modify[-dc:], decrypted_char,
                            bytes([dc + 1]))
                        position = position - dc + 1
                        is_correct = False
                    else:
                        """ abcd efgh ijkl o|guess_char|xy || 1234 5678 9tre qwer - ciphertext
                            what ever itma ybex            || xyzw rtua lopo k|\x03|\x03\x03 - plaintext
                            abcd efgh ijkl |guess_char|wxy || 1234 5678 9tre qwer - next round ciphertext
                            some thin gels eheh            || xyzw rtua lopo guessing|\x04\x04\x04 - next round plaintext
                        """
                        if position == block_size - 1:
                            """ if we decrypt first byte, check if we didn't hit other padding than \x01 """
                            payload = iv + payload
                            payload = payload[:-block_size - 2] + bytes(
                                b'A') + payload[-block_size - 1:]
                            iv = payload[:block_size]
                            payload = payload[block_size:]
                            correct = padding_oracle(payload=payload, iv=iv)
                            if not correct:
                                log.debug("Hit false positive, guess char({})".
                                          format(guess_char))
                                continue

                        payload_modify = payload_modify[:position] + xor(
                            bytes([guess_char]) +
                            payload_modify[position + 1:], bytes([padding]),
                            bytes([padding + 1]))
                        plaintext = decrypted_char + plaintext

                    found_correct_char = True
                    log.debug(
                        "Guessed char(\\x{:02x}), decrypted char(\\x{:02x})".
                        format(guess_char, decrypted_char[0]))
                    log.debug("Plaintext: {}".format(plaintext))
                    log.info("Plaintext(hex): {}".format(b2h(plaintext)))
                    break
            position -= 1
            if found_correct_char is False:
                if is_correct:
                    padding = 0x01
                    payload_modify = payload_modify[:position + 1] + xor(
                        payload_modify[position + 1:], bytes([padding]),
                        bytes([padding + 1]))
                    plaintext = bytes(b"\x01")
                    is_correct = False
                else:
                    log.critical_error(
                        "Can't find correct padding (oracle function return False 256 times)"
                    )
    log.success("Decrypted(hex): {}".format(b2h(plaintext)))
    return plaintext
Exemplo n.º 21
0
def bleichenbacher_signature_forgery(key,
                                     garbage='suffix',
                                     hash_function='sha1'):
    """Bleichenbacher's signature forgery based on bug in verify implementation

    Args:
        key(RSAKey): with small e and at least one plaintext
        garbage(string): middle: 00 01 ff garbage 00 ASN.1 HASH
                         suffix: 00 01 ff 00 ASN.1 HASH garbage
        hash_function(string)

    Returns:
        dict: forged signatures, signatures[no] == signature(key.texts[no]['plain'])
        update key texts
    """
    hash_asn1 = {
        'md5':
        bytes(
            b'\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10'
        ),
        'sha1':
        bytes(b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'),
        'sha256':
        bytes(
            b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'
        ),
        'sha384':
        bytes(
            b'\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30'
        ),
        'sha512':
        bytes(
            b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'
        )
    }
    if garbage not in ['suffix', 'middle']:
        log.critical_error("Bad garbage position, must be suffix or middle")
    if hash_function not in list(hash_asn1.keys()):
        log.critical_error(
            "Hash function {} not implemented".format(hash_function))

    if key.e > 3:
        log.debug("May not work, because e > 3")

    signatures = {}
    if garbage == 'suffix':
        for text_no in range(len(key.texts)):
            if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[
                    text_no]:
                log.info("Forge for plaintext no {} ({})".format(
                    text_no, key.texts[text_no]['plain']))

                hash_callable = getattr(hashlib, hash_function)(i2b(
                    key.texts[text_no]
                    ['plain'])).digest()  # hack to call hashlib.hash_function
                plaintext_prefix = bytes(b'\x00\x01\xff\x00') + hash_asn1[
                    hash_function] + hash_callable

                plaintext = plaintext_prefix + bytes(
                    b'\x00' * (key.size // 8 - len(plaintext_prefix)))
                plaintext = b2i(plaintext)
                for round_error in range(-5, 5):
                    signature, _ = gmpy2.iroot(plaintext, key.e)
                    signature = int(signature + round_error)
                    test_prefix = i2b(gmpy2.powmod(signature, key.e, key.n),
                                      size=key.size)[:len(plaintext_prefix)]
                    if test_prefix == plaintext_prefix:
                        log.info("Got signature: {}".format(signature))
                        log.debug("signature**e % n == {}".format(
                            i2h(gmpy2.powmod(signature, key.e, key.n),
                                size=key.size)))
                        key.texts[text_no]['cipher'] = signature
                        signatures[text_no] = signature
                        break
                else:
                    log.error(
                        "Something wrong, can't compute correct signature")
        return signatures

    elif garbage == 'middle':
        for text_no in range(len(key.texts)):
            if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[
                    text_no]:
                log.info("Forge for plaintext no {} ({})".format(
                    text_no, key.texts[text_no]['plain']))
                hash_callable = getattr(hashlib, hash_function)(i2b(
                    key.texts[text_no]
                    ['plain'])).digest()  # hack to call hashlib.hash_function
                plaintext_suffix = bytes(
                    b'\x00') + hash_asn1[hash_function] + hash_callable
                if b2i(plaintext_suffix) & 1 != 1:
                    log.error(
                        "Plaintext suffix is even, can't compute signature")
                    continue

                # compute suffix
                signature_suffix = 0b1
                for b in range(len(plaintext_suffix) * 8):
                    if (signature_suffix**
                            3) & (1 << b) != b2i(plaintext_suffix) & (1 << b):
                        signature_suffix |= 1 << b
                signature_suffix = i2b(
                    signature_suffix)[-len(plaintext_suffix):]

                # compute prefix
                while True:
                    plaintext_prefix = bytes(b'\x00\x01\xff') + random_bytes(
                        key.size // 8 - 3)
                    signature_prefix, _ = gmpy2.iroot(b2i(plaintext_prefix),
                                                      key.e)
                    signature_prefix = i2b(
                        int(signature_prefix),
                        size=key.size)[:-len(signature_suffix)]

                    signature = b2i(signature_prefix + signature_suffix)
                    test_plaintext = i2b(gmpy2.powmod(signature, key.e, key.n),
                                         size=key.size)
                    if bytes(b'\x00'
                             ) not in test_plaintext[2:-len(plaintext_suffix)]:
                        if test_plaintext[:3] == plaintext_prefix[:3] and test_plaintext[
                                -len(plaintext_suffix):] == plaintext_suffix:
                            log.info("Got signature: {}".format(signature))
                            key.texts[text_no]['cipher'] = signature
                            signatures[text_no] = signature
                            break
                        else:
                            log.error("Something wrong, signature={},"
                                      " signature**{}%{} is {}".format(
                                          signature, key.e, key.n,
                                          [(test_plaintext)]))
                            break
        return signatures
Exemplo n.º 22
0
def blinding(key, signing_oracle=None, decryption_oracle=None):
    """Perform signature/ciphertext blinding attack

    Args:
        key(RSAKey): with at least one plaintext(to sign) or ciphertext(to decrypt)
        signing_oracle(callable)
        decryption_oracle(callable)

    Returns:
        dict: {index: signature/plaintext, index2: signature/plaintext}
        update key texts
    """
    if not signing_oracle and not decryption_oracle:
        log.critical_error("Give one of signing_oracle or decryption_oracle")
    if signing_oracle and decryption_oracle:
        log.critical_error(
            "Give only one of signing_oracle or decryption_oracle")

    recovered = {}
    if signing_oracle:
        log.debug("Have signing_oracle")
        for text_no in range(len(key.texts)):
            if 'plain' in key.texts[text_no] and 'cipher' not in key.texts[
                    text_no]:
                log.info("Blinding signature of plaintext no {} ({})".format(
                    text_no, i2h(key.texts[text_no]['plain'])))

                blind = random.randint(2, 100)
                blind_enc = key.encrypt(blind)
                blinded_plaintext = (key.texts[text_no]['plain'] *
                                     blind_enc) % key.n
                blinded_signature = signing_oracle(blinded_plaintext)
                if not blinded_signature:
                    log.critical_error(
                        "Error during call to signing_oracle({})".format(
                            blinded_plaintext))
                signature = (invmod(blind, key.n) * blinded_signature) % key.n
                key.texts[text_no]['cipher'] = signature
                recovered[text_no] = signature
                log.success("Signature: {}".format(signature))

    if decryption_oracle:
        log.debug("Have decryption_oracle")
        for text_no in range(len(key.texts)):
            if 'cipher' in key.texts[text_no] and 'plain' not in key.texts[
                    text_no]:
                log.info("Blinding ciphertext no {} ({})".format(
                    text_no, key.texts[text_no]['cipher']))
                blind = random.randint(2, 100)
                blind_enc = key.encrypt(blind)
                blinded_ciphertext = (key.texts[text_no]['cipher'] *
                                      blind_enc) % key.n
                blinded_plaintext = decryption_oracle(blinded_ciphertext)
                if not blinded_plaintext:
                    log.critical_error(
                        "Error during call to decryption_oracle({})".format(
                            blinded_plaintext))
                plaintext = (invmod(blind, key.n) * blinded_plaintext) % key.n
                key.texts[text_no]['plain'] = plaintext
                recovered[text_no] = plaintext
                log.success("Plaintext: {}".format(plaintext))

    return recovered
Exemplo n.º 23
0
def hastad(keys, ciphertexts=None):
    """Hastad's broadcast attack (small public exponent)
    Given at least e keys with public exponent equals to e and ciphertexts of the same plaintext,
    plaintext can be efficiently recovered

    Args:
        keys(list): RSAKeys, all with same public exponent e, len(keys) >= e,
                    every key with only one ciphertext
        ciphertexts(list/None): if not None, use this ciphertexts

    Returns:
        NoneType/int: None on failure, recovered plaintext otherwise
        update keys texts
    """
    e = keys[0].e
    if len(keys) < e:
        log.critical_error("Not enough keys, e={}".format(e))

    if ciphertexts is None:
        for key in keys:
            if len(key.texts) != 1:
                log.info(
                    "Key have more than one ciphertext, using the first one(key=={})"
                    .format(key.identifier))
            if 'cipher' not in key.texts[0]:
                log.critical_error("key {} doesn't have ciphertext".format(
                    key.identifier))

        # prepare ciphertexts and correct_keys lists
        ciphertexts, modules, correct_keys = [], [], []
        for key in keys:
            # get only first ciphertext (if exists)
            if key.n not in modules and key.texts[0][
                    'cipher'] not in ciphertexts:
                if key.e == e:
                    modules.append(key.n)
                    correct_keys.append(key)
                    ciphertexts.append(key.texts[0]['cipher'])
                else:
                    log.info("Key {} have different e(={})".format(
                        key.identifier, key.e))
    else:
        if len(ciphertexts) != len(keys):
            log.critical_error("len(ciphertexts) != len(keys)")
        modules = [key.n for key in keys]
        correct_keys = keys

    # check if we have enough ciphertexts
    if len(modules) < e:
        log.info(
            "Not enough keys with unique modulus and ciphertext, e={}, len(modules)={}"
            .format(e, len(modules)))
        log.info("Checking for simple roots (small_e_msg)")
        for one_key in correct_keys:
            recovered_plaintexts = small_e_msg(one_key)
            if len(recovered_plaintexts) > 0:
                log.success("Found plaintext: {}".format(
                    recovered_plaintexts[0]))
                return recovered_plaintexts[0]

    if len(modules) > e:
        log.debug("Number of modules/ciphertexts larger than e")
        modules = modules[:e]
        ciphertexts = ciphertexts[:e]

    # actual Hastad
    result = crt(ciphertexts, modules)
    plaintext, correct = gmpy2.iroot(result, e)
    if correct:
        plaintext = int(plaintext)
        log.success("Found plaintext: {}".format(plaintext))
        for one_key in correct_keys:
            one_key.texts[0]['plain'] = plaintext
        return plaintext
    else:
        log.debug("Plaintext wasn't {}-th root")
        log.debug("result (from crt) = {}".format(e, result))
        log.debug("plaintext ({}-th root of result) = {}".format(e, plaintext))
        return None