def recover_key_from_nonce(msg, sig, k, q, hash=sha1): """Recovers a DSA private key (`x`) from a message-signature-pair given a known nonce Example: ```python >>> from bop.crypto_constructor import dsa >>> my_dsa = dsa() >>> msg = b"Hello kind stranger!" >>> sig = my_dsa.sign(msg) >>> leaked_nonce = sig._k >>> my_dsa.x == recover_key_from_nonce(msg, sig, leaked_nonce, my_dsa.q, hash=my_dsa.hash) True ``` Arguments: msg {bytes or int} -- The message which was signed sig {tuple or Signature} -- The signature of the given message in form `(r, s)` k {int} -- The leaked nonce q {int} -- The public parameter q """ if type(msg) != bytes: msg = i2b(msg) h = b2i(hash(msg)) r, s = sig return ((s * k - h) * invmod(r, q)) % q
def _dsa_sign(msg, p, q, g, x, k, hash=hash): if type(msg) != bytes: msg = i2b(msg) h = b2i(hash(msg)) r = pow(g, k, p) % q s = (invmod(k, q) * (h + r * x)) % q return r, s
def decrypt(self, cipher): was_bytes = False if type(cipher) != int: cipher = b2i(cipher) was_bytes = True m = pow(cipher, self.d, self.n) if was_bytes: m = i2b(m) return m
def encrypt(self, plaintext): was_bytes = False if type(plaintext) != int: plaintext = b2i(plaintext) was_bytes = True c = pow(plaintext, self.e, self.n) if was_bytes: c = i2b(c) return c
def verify(self, msg, sig): r, s = sig if r <= 0 or r >= self.q or s <= 0 or s >= self.q: return False if type(msg) != bytes: msg = i2b(msg) h = b2i(self.hash(msg)) w = invmod(s, self.q) u1 = (w * h) % self.q u2 = (w * r) % self.q v = (pow(self.g, u1, self.p) * pow(self.y, u2, self.p)) % self.q return v == r
def sign(self, msg): if type(msg) != bytes: msg = i2b(msg) h = b2i(self.hash(msg)) s = 0 while s == 0: r = 0 while r == 0: if self.generate_nonce is not None: k = self.generate_nonce() else: k = secrets.randbelow(self.q - 2) + 2 r = pow(self.g, k, self.p) % self.q s = (invmod(k, self.q) * (h + r * self.x)) % self.q return self.Signature(r, s, k)
def bleichenbacher_forge_signature(message, key_size, hash=sha1, protocol=b"DUMMY"): r"""Forge a signature RSA (e=3) signature for the given message. This method can be used to forge signatures if the checking is incorrectly implemented. For this example a dummy PKCS1.5 padding is applied (see argument `protocol`). This padding right-aligns the content (i.e. the hash of the message and some protocol information). However an incorrect implementation may not check if the signature is actually right-aligned. You may need a minimum key size for this to work. 1024 bits should be sufficient. Example: ```python3 >>> import re >>> from bop.hashing import sha1 >>> from bop.crypto_constructor import rsa >>> def check(message, sig, rsa): ... c = rsa.encrypt(sig) ... # note that we are ignoring trailing bytes ... # also note that the leading zero byte gets trimmed ... h = re.match(br"^\x01(\xff)*?\x00DUMMY(.{20})", c, re.DOTALL).group(2) ... return h == sha1(message) ... >>> some_rsa = rsa(e=3) >>> my_msg = b"My very valid message" >>> my_sig = bleichenbacher_forge_signature(my_msg, some_rsa.key_size) >>> check(my_msg, my_sig, some_rsa) True ``` Arguments: message {bytes} -- The message to sign key_size {int} -- The size of the public key parameter N in bits Keyword Arguments: hash {callable} -- The hash algorithm to use (default: {sha1}) protocol {bytes} -- The bytes to use which specify meta data. In reality this would be some kind of ASN.1 scheme but for this toy example it does not really matter. (default: {b"DUMMY"}) Returns: bytes -- The signature of the message """ # we fake this. In reality a little bit of effort has to be made in order to specify the # hash algorithm used, size of the hash etc. h = hash(message) payload = b"\x00" + protocol + h payload_len = len(payload) * 8 # this is more or less a heuristic, place digest at about 2/3 position = key_size // 8 // 3 * 16 n = (1 << payload_len) - b2i(payload) c = (1 << (key_size - 15)) - n * (1 << position) root = cubic_root2(c) if len(root) == 1: # we are lucky root = root[0] else: root = root[-1] return i2b(root)
def decrypt_pkcs_padding_leak(oracle, msg, e, n): """Perform an adaptive chosen ciphertext against a weak RSA implementation leaking PKCS1v5 padding information. This attack was discovered by Daniel Bleichenbacher and is well described in the original [paper](http://archiv.infsec.ethz.ch/education/fs08/secsem/bleichenbacher98.pdf). Side note:\ This attack is generally faster if `msg` is already properly padded. Depending on the keysize the expected running time may be somewhere between a few seconds (key size <= 512 bits) or somewhere around one minute for a 1024 public modulus. I once tested using a 2048 bit key which ran about an hour..\ I am pretty sure this implementation is far from optimal, though I am not qualified to judge where I went wrong. Example: ```python3 >>> from bop.oracles.padding import PaddingRSAOracle as Oracle >>> import bop.utils as utils >>> # We will choose smaller params to keep the running time down a little bit >>> p, q = utils.gen_rsa_key_params(256) >>> o = Oracle(p, q) >>> e, n = o.public_key() >>> plain = b"Hey! This message is save!" >>> # Optional: Pad the message: >>> # plain = utils.pad_pkcs1v5(plain, n) >>> msg = o.encrypt(plain) >>> decrypt_pkcs_padding_leak(o, msg, e, n) b'Hey! This message is save!' ``` Args: oracle (callable): An oracle which leaks whether the decrypted message has a valid PKCS1v5 padding or not (i. e. the first byte is equal to 0 and the second byte is equal to 2) msg (bytes or int): The message to decrypt e (int): The public exponent n (int): The public modulus Returns: bytes or int: The decrypted message. Depending on the input type the output type is matched. """ was_bytes = False if type(msg) != int: msg = b2i(msg) was_bytes = True k = bit_length_exp2(n) assert k > 16 B = 1 << (k - 16) s0 = 1 c0 = msg first = True M = {(2 * B, 3 * B - 1)} if not oracle(msg): # Blinding # ensure we have a valid padding to work with while True: s0 = secrets.randbelow(n - 1) + 1 c0 = (msg * pow(s0, e, n)) % n if oracle(c0): break while True: if first: # Step 2a s = n // (3 * B) while oracle((c0 * pow(s, e, n)) % n) is False: s += 1 first = False elif len(M) > 1: # Step 2b s_ = s + 1 while not oracle((c0 * pow(s_, e, n)) % n): s_ += 1 s = s_ else: # Step 2c a, b = next(iter(M)) found = False r = 2 * (b * s - 2 * B) // n while not found: s_ = (2 * B + r * n) // b s_max = (3 * B + r * n) // a while s_ <= s_max: if oracle((c0 * pow(s_, e, n)) % n): found = True break s_ += 1 r += 1 s = s_ # Step 3 M_ = set() for a, b in M: r_low = (a * s - 3 * B + 1) // n r_high = (b * s - 2 * B) // n for r in range(r_low, r_high + 1): # note the + s - 1 in order to ensure rounding to the next integer low = max(a, (2 * B + r * n + s - 1) // s) high = min(b, (3 * B - 1 + r * n) // s) if low <= high and (low, high) not in M_: M_.add((low, high)) M = M_ # Step 4 if len(M) == 1: a, b = next(iter(M)) if a == b: plain = (a * invmod(s0, n)) % n if was_bytes: return i2b(plain) return plain
def decrypt_parity_leak(oracle, msg, e, n): """Performs a RSA-parity attack given an oracle which reports whether a decrypted plaintext is even or odd Example: ```python >>> from bop.oracles.parity import RsaParityOracle as Oracle >>> o = Oracle() >>> msg = o.encrypt(b"Hello Bob!") >>> e, n = o.public_key() >>> decrypt_parity_leak(o, msg, e, n) b'Hello Bob!' ``` Args: oracle (callable): A function which decrypts ciphers which were encrypted using the given public key and reports whether the resulting plain text is even (`True`) or odd (`False`) msg (bytes or int): The encrypted message to decrypt e (int): The public key exponent e n (int): The public key modul n Raises: RuntimeError: If the plain text cannot be recovered. This is propably caused by rounding errors Returns: bytes or int: The decrypted message. Depending on the type given (`msg`) the resulting type matches. """ was_bytes = False if type(msg) != int: was_bytes = True msg = b2i(msg) i = 2 upper = n lower = 0 # we will use this to carry the rounding error, though this is VERY vague # the desired plain text is usually off by only 1 anyways rest = 1 while i <= n: # encrypt i f = pow(i, e, n) if oracle(f * msg): # even upper = (upper + lower) // 2 rest += (upper + lower) & 1 else: # odd lower = (upper + lower) // 2 i <<= 1 for plain in range(lower, upper + rest): if pow(plain, e, n) == msg: if was_bytes: return i2b(plain) else: return plain raise RuntimeError("Could not find plain text. Something went wrong :(")
def recover_key_from_duplicate_nonce(gen, public_parameters, hash=sha1): """Attempt to recover a DSA private key (`x`) from a stream of messages by searching for a duplicate usage of a nonce. Note that this is a pretty naive implementation running in O(n^2) where n is the number of messages (signatures) checked. Example: ```python >>> from bop.crypto_constructor import dsa >>> import secrets >>> >>> # Setup a weak implementation >>> my_dsa = dsa() >>> def weak_nonce_generator(): ... return secrets.choice(range(2, 11)) ... >>> my_dsa.generate_nonce = weak_nonce_generator >>> def message_src(): ... while True: ... msg = secrets.token_bytes(16) ... yield msg, my_dsa.sign(msg) ... >>> x, leaked_nonce = recover_key_from_duplicate_nonce(message_src(), my_dsa.public_parameters, hash=my_dsa.hash) >>> my_dsa.x == x True ``` Arguments: gen {generator} -- A generator yielding message-signature-pairs `(msg, signature)`. It will be consumed until a duplicate nonce is found or it is empty. public_parameters {tuple} -- The public parameters of the DSA algorithm used Keyword Arguments: hash {callable} -- The hash function to use (default: {sha1}) Returns: (int, int) -- (x, leaked_nonce), i.e. the private key and the leaked nonce. Or (None, None) if no duplicate was found. """ p, q, g, y = public_parameters first_msg, first_sig = next(gen) if type(first_msg) != bytes: first_msg = i2b(first_msg) bases = [(first_msg, first_sig)] # we do a simple exhaustive search for a duplicated nonce for msg1, sig1 in gen: if type(msg1) != bytes: msg1 = i2b(msg1) h1 = b2i(hash(msg1)) r1, s1 = sig1 for msg2, sig2 in bases: h2 = b2i(hash(msg2)) r2, s2 = sig2 # Assume k was equal k = ((h2 - h1) * invmod(s2 - s1, q)) % q # Check if our assumption holds x = recover_key_from_nonce(msg1, sig1, k, q, hash=hash) if _dsa_sign(msg1, p, q, g, x, k, hash=hash) == (r1, s1): return x, k bases.append((msg1, sig1)) return None, None