def execute(self, ciphertexts: list, iterations: int=3) -> list: """ Executes the attack. Parameters: ciphertexts (list): List of bytes-like ciphertexts using the same keystream. iterations (int): Number of iterations of the full-text analysis phase. Accuracy-time trade-off. Returns: list: List of recovered plaintexts. """ min_size = min([len(ciphertext) for ciphertext in ciphertexts]) same_size_ciphers = [ciphertext[:min_size] for ciphertext in ciphertexts] transposed_ciphers = [bytearray(transposed) for transposed in zip(*same_size_ciphers)] assert [bytearray(transposed) for transposed in zip(*transposed_ciphers)] == same_size_ciphers log.debug("Starting initial transposition analysis") # Transposition analysis first (transposition) transposed_plaintexts = [] for cipher in RUNTIME.report_progress(transposed_ciphers, desc='Transposition analysis', unit='ciphers'): all_chars = {} for char in range(256): plaintext = Bytes(struct.pack('B', char)).stretch(len(cipher)) ^ cipher all_chars[char] = (self.analyzer.analyze(plaintext), plaintext) transposed_plaintexts.append(sorted(all_chars.items(), key=lambda kv: kv[1][0], reverse=True)[0][1][1]) retransposed_plaintexts = [bytearray(transposed) for transposed in zip(*transposed_plaintexts)] log.debug("Starting full-text analysis on retransposed text") # Clean up with a character-by-character, higher-context analysis (retransposed) for j in RUNTIME.report_progress(range(iterations), desc='Higher-context analysis'): log.debug("Starting iteration {}/{}".format(j + 1, iterations)) differential_mask = bytearray() for i in RUNTIME.report_progress(range(min_size), desc='Building differential mask', unit='bytes'): all_chars = {} for char in range(256): full_text_analyses = [] for curr_cipher in retransposed_plaintexts: cipher_copy = bytearray([_ for _ in curr_cipher]) cipher_copy[i] = ord(Bytes(struct.pack('B', char)) ^ struct.pack('B', curr_cipher[i])) full_text_analyses.append(self.analyzer.analyze(cipher_copy)) all_chars[char] = (sum(full_text_analyses), char) best_char = sorted(all_chars.items(), key=lambda kv: kv[1][0], reverse=True)[0][1][1] differential_mask += struct.pack('B', best_char) retransposed_plaintexts = [Bytes.wrap(cipher) ^ differential_mask for cipher in retransposed_plaintexts] return retransposed_plaintexts
def decrypt(self, nonce: bytes, authed_ciphertext: bytes, data: bytes = b'') -> Bytes: """ Decrypts `ciphertext`. Parameters: nonce (bytes): Bytes-like nonce. authed_ciphertext (bytes): Bytes-like object to be decrypted. data (bytes): Bytes-like additional data to be authenticated. Returns: Bytes: Resulting plaintext. """ from samson.utilities.runtime import RUNTIME authed_ciphertext = Bytes.wrap(authed_ciphertext) ciphertext, orig_tag = authed_ciphertext[:-16], authed_ciphertext[-16:] tag_mask = self.clock_ctr(nonce) data = Bytes.wrap(data) tag = self.auth(ciphertext, data, tag_mask) if not RUNTIME.compare_bytes(tag, orig_tag): raise Exception('Tag mismatch: authentication failed!') return self.ctr.decrypt(ciphertext)
def execute(self, ciphertexts: list, word_ranges: list = [2, 3], delimiter: str = ' ') -> list: """ Executes the attack. Parameters: ciphertexts (list): List of bytes-like ciphertexts using the same keystream. word_ranges (list): List of numbers of words to try. E.G. [2, 3, 4] means try the Cartesian product of 2, 3, and 4-tuple word combinations. delimiter (str): Delimiter to use between word combinations. Returns: list: Top 10 possible plaintexts. """ if len(ciphertexts) != 2: raise ValueError('`ciphertexts` MUST contain two samples.') two_time = xor_buffs(*ciphertexts) cipher_len = len(two_time) trimmed_list = [ word for word in self.wordlist if len(word) <= cipher_len ] prepend_list = [''] last_num_processed = 0 results = [] for j in RUNTIME.report_progress(word_ranges): log.debug(f"Starting word range {j}") for i in range(j - last_num_processed): word_scores = [] for prepend in prepend_list: for word in trimmed_list: mod_word = (prepend + delimiter + word).strip() xor_result = xor_buffs( (bytes(mod_word, 'utf-8') + b'\x00' * cipher_len)[:cipher_len], two_time)[:len(two_time)] analysis = self.analyzer.analyze(xor_result) word_scores.append( (mod_word, analysis / (len(word)**2))) prepend_list = [ word for word, _ in sorted( word_scores, key=lambda score: score[1], reverse=True) [:10**(i + 1 + last_num_processed)] ] last_num_processed = j results.append( sorted(prepend_list, key=lambda word: self.analyzer.analyze( xor_buffs((bytes(word, 'utf-8') + b'\x00' * cipher_len)[:cipher_len], two_time)), reverse=True)[:10]) return results
def verify(self, plaintext: bytes, signature: bytes, strict_type_match: bool = True) -> bool: """ Verifies the `plaintext` against the `signature`. Parameters: plaintext (bytes): Plaintext to verify. signature (bytes): Signature to verify plaintext against. strict_type_match (bool): Whether or not to force use of `hash_obj` vs using the OID provided in the signature. Returns: bool: Whether or not the signature passed verification. """ from samson.utilities.runtime import RUNTIME try: padded = Bytes(self.rsa.encrypt(signature)) der_encoded = self.padder.unpad(padded) items = bytes_to_der_sequence(der_encoded) hash_obj = self.hash_obj if not strict_type_match: hash_obj = INVERSE_HASH_OID_LOOKUP[items[0][0]]() hashed_value = Bytes(items[1]) return RUNTIME.compare_bytes(hashed_value, hash_obj.hash(plaintext)) except Exception as _: return False
def execute(self, public_key: int, max_factor_size: int = 2**16) -> int: """ Executes the attack. Parameters: public_key (int): Diffie-Hellman public key to crack. max_factor_size (int): Max factor size to prevent attempting to factor forever. Returns: int: Private key. """ # Factor as much as we can factors = [ r for r in factorint((self.p - 1) // self.order, use_rho=False, limit=max_factor_size) if r < max_factor_size ] log.debug(f'Found factors: {factors}') residues = [] # Request residues from crafted public keys for factor in RUNTIME.report_progress( factors, desc='Sending malicious public keys', unit='factor'): h = 1 while h == 1: h = pow(random_int_between(1, self.p), (self.p - 1) // factor, self.p) residue = self.oracle.request(h, factor) residues.append((residue, factor)) # Build partials using CRT n, r = crt(residues) # Oh, I guess we already found it... if r > self.order: return n g_prime = pow(self.g, r, self.p) y_prime = (public_key * mod_inv(pow(self.g, n, self.p), self.p)) % self.p log.info( f'Recovered {"%.2f"%math.log(reduce(int.__mul__, factors, 1), 2)}/{"%.2f"%math.log(self.order, 2)} bits' ) log.info(f'Found relation: x = {n} + m*{r}') log.debug(f"g' = {g_prime}") log.debug(f"y' = {y_prime}") log.info('Attempting to catch a kangaroo...') # Probabilistically solve DLP R = (ZZ / ZZ(self.p)).mul_group() m = pollards_kangaroo(R(g_prime), R(y_prime), a=0, b=(self.order - 1) // r) return n + m * r
def verify(self, plaintext: bytes, signature: bytes) -> bool: """ Verifies the `plaintext` against the `signature`. Parameters: plaintext (bytes): Plaintext to verify. signature (bytes): Signature to verify against plaintext. Returns: bool: Whether or not the plaintext is verified. """ from samson.utilities.runtime import RUNTIME plaintext = Bytes.wrap(plaintext) signature = Bytes.wrap(signature).zfill((self.modulus_len + 7) // 8) mHash = self.hash_obj.hash(plaintext) em_bits = self.modulus_len - 1 em_len = (em_bits + 7) // 8 if em_len < (self.hash_obj.digest_size + self.salt_len + 2): return False if bytes([signature[-1]]) != b'\xbc': return False # Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, # and let H be the next hLen octets. mask_len = em_len - self.hash_obj.digest_size - 1 masked_db = signature[:mask_len] H = signature[mask_len:mask_len + self.hash_obj.digest_size] # If the leftmost 8emLen - emBits bits of the leftmost octet in # maskedDB are not all equal to zero, output "inconsistent" and # stop. left_mask = (2**(len(masked_db) * 8) - 1) >> ((8 * em_len) - em_bits) if masked_db & left_mask != masked_db: return False db_mask = self.mgf(H, mask_len) DB = masked_db ^ db_mask DB &= left_mask # If the emLen - hLen - sLen - 2 leftmost octets of DB are not # zero or if the octet at position emLen - hLen - sLen - 1 (the # leftmost position is "position 1") does not have hexadecimal # value 0x01, output "inconsistent" and stop. if DB[:em_len - self.hash_obj.digest_size - self.salt_len - 1].int() != 1: return False salt = DB[-self.salt_len:] if self.salt_len else b'' m_prime = b'\x00' * 8 + mHash + salt h_prime = self.hash_obj.hash(m_prime) return RUNTIME.compare_bytes(h_prime, H)
def run(self, generations: int) -> OptimizationResult: min_conv_counter = 0 current_best = (not self.maximize) * 2**8192 termination_reason = TerminationReason.MAX_ITERATION_REACHED granularized_minimum_convergence = self.convergence_granularity * self.minimum_convergence for iteration in RUNTIME.report_progress(range(generations), desc='Generations', unit='gens'): # 1) Measure self.obj_func(self.population) # 2) Select parent_pool = sorted(self.population, key=lambda chromo: chromo.fitness, reverse=self.maximize)[:self.parent_pool_size] # Test for minimum convergence heuristic if abs(parent_pool[0].fitness - current_best) // (current_best * self.convergence_granularity) < granularized_minimum_convergence: if min_conv_counter < self.min_conv_tolerance: min_conv_counter += 1 else: termination_reason = TerminationReason.MINIMUM_CONVERGENCE_UNSATISFIED break else: current_best = parent_pool[0].fitness min_conv_counter = 0 if not self.maximize and not current_best: termination_reason = TerminationReason.GLOBAL_OPTIMA_REACHED break next_population = [] # Elitism if self.elitism: next_population.extend(parent_pool) # Immigration parent_pool.extend([Chromosome(individual) for individual in self.initialization_func(self.num_immigrants)]) # Breeding while len(next_population) < self.population_size: parents = random.sample(parent_pool, k=self.num_parents) # 3) Crossover individual = Chromosome(self.crossover_func(parents)) # 4) Mutate individual = self.mutation_func(individual) next_population.append(individual) self.population = next_population self.obj_func(self.population) return OptimizationResult(sorted(self.population, key=lambda chromo: chromo.fitness, reverse=self.maximize)[0], iteration, termination_reason)
def execute(self) -> Bytes: """ Executes the attack. Parameters: unpad (bool): Whether or not to PKCS7 unpad the result. Returns: Bytes: The recovered plaintext. """ baseline = len(self.oracle.encrypt(b'')) block_size = self.oracle.test_io_relation()['block_size'] plaintexts = [] for curr_block in RUNTIME.report_progress(range(baseline // block_size), unit='blocks'): log.debug("Starting iteration {}".format(curr_block)) plaintext = b'' for curr_byte in RUNTIME.report_progress(range(block_size), unit='bytes'): if curr_block == 0: payload = ('A' * (block_size - (curr_byte + 1))).encode() else: payload = plaintexts[-1][curr_byte + 1:] one_byte_short = get_blocks(self.oracle.encrypt(payload), block_size=block_size)[curr_block] for i in range(256): curr_byte = struct.pack('B', i) ciphertext = self.oracle.encrypt(payload + plaintext + curr_byte) # We're always editing the first block to look like block 'curr_block' if get_blocks(ciphertext, block_size=block_size)[0] == one_byte_short: plaintext += curr_byte break plaintexts.append(plaintext) return Bytes(b''.join(plaintexts))
def color_format(color: ConsoleColors, text: str): from samson.utilities.runtime import RUNTIME if type(color) is ConsoleColors: color = color.value if RUNTIME.get_context().use_color: formatted = f"{PREFIX}{color}m{text}{SUFFIX}" else: formatted = text return formatted
def rand_bytes(size: int=16) -> bytes: """ Reads bytes from RUNTIME.random. Parameters: size (int): Number of bytes to read. Returns: bytes: Random bytes. """ from samson.utilities.runtime import RUNTIME return RUNTIME.random(size)
def decrypt(self, key: bytes, iv: bytes, ciphertext: bytes, auth_data: bytes, auth_tag: bytes) -> Bytes: mac_key, enc_key = key.chunk(self.chunk_size) hmac = HMAC( mac_key, self.hash_obj).generate(auth_data + iv + ciphertext + Bytes(len(auth_data) * 8).zfill(8))[:self.chunk_size] assert RUNTIME.compare_bytes(hmac, auth_tag) rij = Rijndael(enc_key) cbc = CBC(rij, iv=iv) return cbc.decrypt(ciphertext)
def recover_plaintext(cls: object, ciphertext: int, pub: list, alpha: int = 1) -> Bytes: """ Attempts to recover the plaintext without the private key. Parameters: ciphertext (int): A ciphertext sum. pub (int): The public key. alpha (int): Punishment coefficient for deviation from guessed bit distribution. Returns: Bytes: Recovered plaintext. """ from samson.math.matrix import Matrix from samson.math.algebra.all import QQ # Construct the problem matrix ident = Matrix.identity(len(pub), coeff_ring=QQ) pub_matrix = ident.col_join(Matrix([pub], coeff_ring=QQ)) problem_matrix = pub_matrix.row_join( Matrix([[0] * len(pub) + [-ciphertext]], coeff_ring=QQ).T) # Attempt to crack the ciphertext using various punishment coefficients for i in RUNTIME.report_progress(range(len(pub)), desc='Alphaspace searched'): alpha_penalizer = Matrix([[alpha] * len(pub) + [-alpha * i]], coeff_ring=QQ) problem_matrix_prime = problem_matrix.col_join(alpha_penalizer).T solution_matrix = problem_matrix_prime.LLL(0.99) for row in solution_matrix.rows: relevant = row[:-2] new_row = [ item for item in relevant if item >= QQ.zero() and item <= QQ.one() ] if len(new_row) == len(relevant): return Bytes( int( ''.join([str(int(float(val))) for val in relevant]), 2)) return solution_matrix
def crack(self, lm_hash: bytes, charset: bytes=None) -> Bytes: """ Cracks both halves simultaneously. Parameters: lm_hash (bytes): Hash to crack. charset (bytes): Character set to use. Returns: Bytes: Cracked LM hash. """ h1, h2 = lm_hash.zfill(16).chunk(8) h1_pt, h2_pt = None, None h1_null, h2_null = self.check_halves_null(lm_hash) if h1_null: h1_pt = Bytes(b'').zfill(8) if h2_null: h2_pt = Bytes(b'').zfill(8) if not charset: charset = bytes(string.ascii_uppercase + string.digits + string.punctuation, 'utf-8') try: for i in RUNTIME.report_progress(range(1, 8), unit='length'): for attempt in itertools.product(charset, repeat=i): b_attempt = bytes(attempt) hashed = self.hash(b_attempt)[:8] if hashed == h1: h1_pt = b_attempt if hashed == h2: h2_pt = b_attempt if h1_pt and h2_pt: raise KeyboardInterrupt() except KeyboardInterrupt: return Bytes(h1_pt or b'\x00').pad_congruent_right(7) + Bytes(h2_pt or b'')
def derive_from_backdoor(cls: object, P: WeierstrassPoint, Q: WeierstrassPoint, d: int, observed_out: bytes) -> list: """ Recovers the internal state of a Dual EC generator and builds a replica. Parameters: P (WeierstrassPoint): Elliptical curve point `P`. Q (WeierstrassPoint): Elliptical curve point `Q`. d (int): Backdoor that relates Q to P. observed_out (bytes): Observed output from the compromised Dual EC generator. Returns: list: List of possible internal states. """ assert len(observed_out) >= 30 curve = P.curve r1 = observed_out[:30] r2 = observed_out[30:] possible_states = [] Q_cache = Q.cache_mul(curve.cardinality().bit_length()) for i in RUNTIME.report_progress(range(2**16), desc='Statespace searched', unit='states'): test_r1 = int.to_bytes(i, 2, 'big') + r1 test_x = int.from_bytes(test_r1, 'big') try: R = curve(test_x) dR = d * R test_r2 = Q_cache * int(dR.x) if int.to_bytes(int(test_r2.x), 32, 'big')[2:2 + len(r2)] == r2: possible_states.append(DualEC(P, Q, int(dR.x))) except AssertionError as _: pass return possible_states
def crack_truncated(cls: object, outputs: list, outputs_to_predict: list, multiplier: int, increment: int, modulus: int, trunc_amount: int) -> object: """ Given a decent number of truncated states (about 200 when there's only 3-bit outputs), returns a replica LCG. Parameters: outputs (list): List of truncated-state outputs (in order). outputs_to_predict (list): Next few outputs to compare against. Accuracy/number of samples trade-off. multiplier (int): The LCG's multiplier. increment (int): The LCG's increment. modulus (int): The LCG's modulus. Returns: LCG: Replica LCG that predicts all future outputs of the original. References: https://github.com/mariuslp/PCG_attack "Reconstructing Truncated Integer Variables Satisfying Linear Congruences" (https://www.math.cmu.edu/~af1p/Texfiles/RECONTRUNC.pdf) """ if not outputs_to_predict: outputs_to_predict = outputs[-2:] outputs = outputs[:-2] # Trivial case if increment == 0: computed_seed = LCG.solve_tlcg(outputs + outputs_to_predict, multiplier, modulus, trunc_amount) # Here we take the second to last seed since our implementation edits the state BEFORE it returns return LCG((multiplier * computed_seed[-2]) % modulus, multiplier, increment, modulus, trunc=trunc_amount) else: diffs = [o2 - o1 for o1, o2 in zip(outputs, outputs[1:])] seed_diffs = LCG.solve_tlcg(diffs, multiplier, modulus, trunc_amount) seed_diffs = [ int(seed_diff) % modulus for row in seed_diffs for seed_diff in row ] # Bruteforce low bits for z in RUNTIME.report_progress(range(2**trunc_amount), desc='Seedspace searched', unit='seeds'): x_0 = (outputs[0] << trunc_amount) + z x_1 = (seed_diffs[0] + x_0) % modulus computed_c = (x_1 - multiplier * x_0) % modulus computed_x_2 = (multiplier * x_1 + computed_c) % modulus actual_x_2 = (seed_diffs[1] + x_1) % modulus if computed_x_2 == actual_x_2: computed_seeds = [x_0] for diff in seed_diffs: computed_seeds.append( (diff + computed_seeds[-1]) % modulus) # It's possible to find a spectrum of nearly-equivalent LCGs. # The accuracy of `predicted_lcg` is dependent on the size of `outputs_to_predict` and the # parameters of the LCG. predicted_seed = (multiplier * computed_seeds[-2] + computed_c) % modulus predicted_lcg = LCG(X=int(predicted_seed), a=multiplier, c=int(computed_c), m=modulus, trunc=trunc_amount) if [ predicted_lcg.generate() for _ in range(len(outputs_to_predict)) ] == outputs_to_predict: return predicted_lcg raise SearchspaceExhaustedException('Seedspace exhausted')
def real_decorator(cls): RUNTIME.register_exploit_mapping(cls, attack) return cls
def real_decorator(cls): RUNTIME.register_primitive(cls) return cls
def real_decorator(cls): RUNTIME.register_exploit(cls, consequence, requirements) return cls
def _generate_tree(self): """ Builds a binary hash tree of colliding, intermediary Merkle-Damgard construction states. """ log.debug('Generating hash tree') tree = [] for i in range(self.k): tree.append([]) promoted_prefix = None # Determine if we need to "promote" a prefix. Basically, if we have an odd number of # prefixes on this layer, we put the last prefix in waiting as there will eventually be # another layer which has one less node than a power of two. curr_prefix_list = self.prefixes if len(self.prefixes) % 2 == 1: promoted_prefix = self.prefixes[-1] curr_prefix_list = self.prefixes[:-1] tree[0] = [(curr_prefix_list[i], curr_prefix_list[i + 1]) for i in range(0, len(curr_prefix_list), 2)] solution_tree = [] for i in range(self.k): solution_tree.append([]) for i in RUNTIME.report_progress(range(self.k)): for (p1, p2) in tree[i]: p1_suffix, p2_suffix, intermediary_collision = self.collision_func( p1, p2) solution_tree[i].append( (p1, p2, p1_suffix, p2_suffix, intermediary_collision)) # Add solutions if i < (self.k - 1): # If there's an odd number of solutions at this level, then there's either a prefix # in waiting, or we need to promote one. next_level_states = solution_tree[i] if len(solution_tree[i]) % 2 == 1: next_level_states = deepcopy(solution_tree[i]) # Last level is a multiple of 2 but not a power of 2 (e.g. 6). # Promote our last prefix. if promoted_prefix == None: promoted_prefix = next_level_states[-1][-1] next_level_states = next_level_states[:-1] # We have a prefix in waiting. Use it immediately. else: next_level_states.append((promoted_prefix, )) tree[i + 1] = [(next_level_states[sol][-1], next_level_states[sol + 1][-1]) for sol in range(0, len(next_level_states), 2)] # We're done generating the tree; time to set the output fields for layer in solution_tree: for p1_init, p2_init, p1_msg, p2_msg, result in layer: self.hash_tree[p1_init] = (p1_init, p2_init, p1_msg, p2_msg, result) self.hash_tree[p2_init] = (p1_init, p2_init, p1_msg, p2_msg, result) self.crafted_hash = solution_tree[-1][0][-1]
def execute(self, ciphertext: bytes) -> Bytes: """ Executes Manger's attack. Parameters: ciphertext (bytes): The ciphertext to decrypt. Returns: Bytes: The ciphertext's corresponding plaintext. """ ciphertext = Bytes.wrap(ciphertext) ct_int = ciphertext.int() k = math.ceil(math.log(self.rsa.n, 256)) B = 2**(8 * (k - 1)) n = self.rsa.n e = self.rsa.e log.debug(f"k: {k}, B: {B}, n: {n}, e: {e}") # Step 1 f1 = 2 log.info("Starting step 1") while not self._greater_equal_B(f1, ct_int, e, n): f1 *= 2 f1 //= 2 log.debug(f"Found f1: {f1}") # Step 2 nB = n + B nB_B = nB // B f2 = nB_B * f1 log.info("Starting step 2") while self._greater_equal_B(f2, ct_int, e, n): f2 += f1 log.debug(f"Found f2: {f2}") # Step 3 div_mod = 1 if n % f2 else 0 m_min = n // f2 + div_mod m_max = nB // f2 BB = 2 * B diff = m_max - m_min ctr = 0 log.info("Starting step 3") log.debug(f"B-(diff * f2) = {B - (diff * f2)}") # Reporting last_log_diff = math.log(diff, 2) progress = RUNTIME.report_progress(None, total=last_log_diff) while diff > 0: if ctr % 100 == 0: log.debug(f"Iteration {ctr} difference: {diff}") f = BB // diff f_min = f * m_min i = f_min // n iN = i * n div_mod = 1 if iN % m_min else 0 f3 = iN // m_min + div_mod iNB = iN + B if self._greater_equal_B(f3, ct_int, e, n): div_mod = 1 if iNB % f3 else 0 m_min = iNB // f3 + div_mod else: m_max = iNB // f3 diff = m_max - m_min # Update progress log_diff = math.log(diff + 1, 2) progress.update(last_log_diff - log_diff) last_log_diff = log_diff ctr += 1 return Bytes(m_min)
def execute(self, initial_filter=BASIC_FILTER, min_input_len: int = 1) -> Fingerprint: sample = self.oracle.encrypt(b'a' * min_input_len) base_len = len(sample) filtered = RUNTIME.search_primitives(initial_filter) io_rel_analysis = self.oracle.test_io_relation(min_input_len) io_relation = io_rel_analysis['io_relation'] block_size = io_rel_analysis['block_size'] max_val_analysis = IntegerAnalysis.analyze( self.oracle.test_max_input()) modifiers = {} if max_val_analysis.n != -1: if max_val_analysis.prime_name: log.debug( f'Max input size is a well-known modulus: {max_val_analysis.prime_name}' ) # Add modifiers for matching primitives if not max_val_analysis.is_prime and not max_val_analysis.byte_aligned and max_val_analysis.is_uniform: from samson.public_key.rsa import RSA modifiers[RSA] = 1 log.debug(f'Max input size looks like RSA modulus') elif max_val_analysis.is_prime and max_val_analysis.is_uniform: from samson.protocols.dragonfly import Dragonfly from samson.public_key.elgamal import ElGamal # Process Diffie-Hellman-like primitives dh_modifier = 1 + max_val_analysis.is_safe_prime + bool( max_val_analysis.prime_name) modifiers[DiffieHellman] = dh_modifier modifiers[Dragonfly] = dh_modifier modifiers[ElGamal] = dh_modifier log.debug(f'Max input size looks like Diffie-Hellman modulus') matching = [ match for match in filtered if block_size * 8 in match.BLOCK_SIZE and match.IO_RELATION_TYPE == io_relation ] bc_modes = [] # Punish IV/nonce/AEAD primitives if we can prove the output doesn't contain their ephemeral/tag # This is only really possible if the output is smaller than their ephemeral/tag def calculate_min_size(size): min_size = 0 typical_size = 0 sizes = size.sizes if type(sizes) is int: min_size += sizes typical_size += sizes else: if size.size_type not in [ SizeType.ARBITRARY, SizeType.DEPENDENT ]: min_size += sizes[0] if size.typical: typical_size += size.typical[0] return min_size, typical_size # If the primitive is a block cipher mode and its ephemeral/tag is DEPENDENT, we'll want to check # against known block ciphers. block_ciphers = [ prim for prim in filtered if issubclass(prim, BlockCipher) and block_size * 8 in prim.BLOCK_SIZE ] minimum_bc = min([ calculate_min_size(block_cipher.BLOCK_SIZE)[0] for block_cipher in block_ciphers ]) if block_ciphers else 0 for match in matching: min_size = 0 typical_size = 0 all_sizes = [] if hasattr( match, 'EPHEMERAL' ) and not match.EPHEMERAL.ephemeral_type == EphemeralType.KEY: all_sizes.append(match.EPHEMERAL.size) if hasattr(match, 'AUTH_TAG_SIZE'): all_sizes.append(match.AUTH_TAG_SIZE) for size in all_sizes: component_min, component_typical = calculate_min_size(size) min_size += component_min typical_size += component_typical if issubclass(match, BlockCipherMode ) and size.size_type == SizeType.DEPENDENT: min_size += minimum_bc for size in [min_size, typical_size]: if base_len * 8 < size: if not match in modifiers: modifiers[match] = 0 modifiers[match] -= 1 # Find possible block cipher modes. We exclude StreamingBlockCipherModes because they're handled # by their block size above. if any([issubclass(match, BlockCipher) for match in matching]): from samson.block_ciphers.modes.ecb import ECB log.debug( f'Block ciphers in candidates. Attempting to find possible block cipher modes' ) bc_modes = [ prim for prim in filtered if issubclass(prim, BlockCipherMode) and not issubclass(prim, StreamingBlockCipherMode) ] # Check for ECB if self.oracle.test_stateless(block_size): log.info(f'Stateless blocks detected') bc_modes = [ECB] else: if ECB in bc_modes: bc_modes.remove(ECB) # Score matches higher if they're more explicit scored_matches = {} for match in matching: bitsize = block_size * 8 base_freq = match.USAGE_FREQUENCY.value + ( modifiers[match] if match in modifiers else 0) scored_matches[match] = base_freq # If it matches a SINGLE value, that's significant if bitsize == match.BLOCK_SIZE.sizes: scored_matches[match] += FrequencyType.PROLIFIC.value # If it's in a RANGE, it's a bit less significant elif bitsize in match.BLOCK_SIZE.sizes: scored_matches[match] += FrequencyType.NORMAL.value # Add a modifier for being in 'typical' if bitsize in match.BLOCK_SIZE.typical: scored_matches[match] += 1 return Fingerprint(candidates=scored_matches, modes=bc_modes, max_input_analysis=max_val_analysis, io_relation=io_relation, block_size=block_size)
def execute(self, ciphertext: bytes) -> Bytes: """ Executes the attack. Parameters: ciphertext (bytes): Bytes-like ciphertext to be decrypted. Returns: Bytes: Plaintext corresponding to the inputted ciphertext. """ blocks = Bytes.wrap(ciphertext).chunk(self.block_size) reversed_blocks = blocks[::-1] plaintexts = [] for i, block in enumerate(RUNTIME.report_progress(reversed_blocks, desc='Blocks cracked', unit='blocks')): log.debug("Starting iteration {}".format(i)) plaintext = Bytes(b'') if i == len(reversed_blocks) - 1: preceding_block = self.iv else: preceding_block = reversed_blocks[i + 1] for _ in RUNTIME.report_progress(range(len(block)), desc='Bytes cracked', unit='bytes'): last_working_char = None exploit_blocks = {} # Generate candidate blocks for possible_char in self.alphabet: test_byte = struct.pack('B', possible_char) payload = test_byte + plaintext prefix = b'\x00' * (self.block_size - len(payload)) padding = (struct.pack('B', len(payload)) * (len(payload))) ^ payload fake_block = prefix + padding exploit_block = fake_block ^ preceding_block new_cipher = bytes(exploit_block + block) exploit_blocks[new_cipher] = test_byte if self.batch_requests: best_block = self.oracle.check_padding([k for k,v in exploit_blocks.items()]) last_working_char = exploit_blocks[best_block] log.debug("Found working byte: {}".format(last_working_char)) else: # Oracle can't handle batch requests. Feed blocks into it. for exploit_block, byte in exploit_blocks.items(): if self.oracle.check_padding(exploit_block): log.debug("Found working byte: {}".format(byte)) last_working_char = byte # Early out optimization. Note, we're being careful about PKCS7 padding here. if last_working_char and ord(byte) >= self.block_size: break plaintext = last_working_char + plaintext plaintexts.append(plaintext) return Bytes(b''.join(plaintexts[::-1]))
def execute(self, ciphertext: int, n: int, e: int, key_length: int) -> Bytes: """ Executes the attack. Parameters: ciphertext (int): The ciphertext represented as an integer. n (int): The RSA instance's modulus. e (int): The RSA instance's public exponent. key_length (int): The the bit length of the RSA instance (2048, 4096, etc). Returns: Bytes: The ciphertext's corresponding plaintext. """ key_byte_len = key_length // 8 # Convenience variables B = 2**(8 * (key_byte_len - 2)) # Initial values c = ciphertext c_0 = ciphertext M = [(2 * B, 3 * B - 1)] i = 1 if not self.oracle.check_padding(c): log.debug("Initial padding not correct; attempting blinding") # Step 1: Blinding while True: s = randint(0, n - 1) c_0 = (c * pow(s, e, n)) % n if self.oracle.check_padding(c_0): log.debug("Padding is now correct; blinding complete") break # Setup reporting last_log_diff = math.log(M[0][1] - M[0][0], 2) progress = RUNTIME.report_progress(None, total=last_log_diff) # Step 2 while True: log.debug("Starting iteration {}".format(i)) log.debug("Current intervals: {}".format(M)) diff = math.log( sum([interval[1] - interval[0] for interval in M]) + 1, 2) progress.update(last_log_diff - diff) last_log_diff = diff # Step 2.a if i == 1: s = _ceil(n, 3 * B) log.debug("Starting search at {}".format(s)) while True: c = c_0 * pow(s, e, n) % n if self.oracle.check_padding(c): break s += 1 # Step 2.b elif len(M) >= 2: log.debug("Intervals left: {}".format(M)) while True: s += 1 c = c_0 * pow(s, e, n) % n if self.oracle.check_padding(c): break # Step 2.c elif len(M) == 1: log.debug("Only one interval") a, b = M[0] if a == b: return Bytes(b'\x00' + int_to_bytes(a, 'big')) r = _ceil(2 * (b * s - 2 * B), n) s = _ceil(2 * B + r * n, b) while True: c = c_0 * pow(s, e, n) % n if self.oracle.check_padding(c): break s += 1 if s > (3 * B + r * n) // a: r += 1 s = _ceil(2 * B + r * n, b) M_new = [] for a, b in M: min_r = _ceil(a * s - 3 * B + 1, n) max_r = (b * s - 2 * B) // n for r in range(min_r, max_r + 1): new_a = max(a, _ceil(2 * B + r * n, s)) new_b = min(b, (3 * B - 1 + r * n) // s) if new_a > new_b: raise Exception( "Step 3: new_a > new_b? new_a: {} new_b: {}". format(new_a, new_b)) # Now we need to check for overlap between ranges and merge them _append_and_merge(new_a, new_b, M_new) if len(M_new) == 0: raise Exception("There are zero intervals in 'M_new'") M = M_new i += 1
def execute(self, secret_length: int, sample_size: int = 2**23, chunk_size: int = 2**19) -> Bytes: """ Executes the attack. Parameters: secret_length (int): The length of the secret you're trying to recover. sample_size (int): The amount of samples to collect per byte of the secret. Higher numbers are slower but more accurate. chunk_size (int): The size of sample chunks per CPU before a forceful garbage collection saves the day. Returns: Bytes: The recovered plaintext. """ cracked_indices = [set() for i in range(secret_length)] cpu_count = multiprocessing.cpu_count() pool = multiprocessing.Pool(processes=cpu_count) log.info(f"Running with {cpu_count} cores") for i in RUNTIME.report_progress(range(secret_length), unit='bytes'): log.debug(f"Starting iteration {i + 1}/{secret_length}") if len(cracked_indices[i]) > 0: continue applicable_biases = [ bias for bias in self.strongest_biases if i <= bias or (secret_length + i >= bias and i < bias) ] padding_len = max(applicable_biases[0] - i, 0) active_biases = [ bias for bias in applicable_biases if padding_len + secret_length > bias ] payload = b'\x00' * padding_len num_chunks = math.ceil(sample_size / chunk_size) log.debug(f"Sampling {sample_size} ciphertexts") flattened_list = [] for i in range(math.ceil(num_chunks / cpu_count)): random_ciphertexts = [ pool.apply_async(self._encrypt_chunk, (payload, chunk_size)) for i in range(min(num_chunks - (i * cpu_count), cpu_count)) ] flattened_list.extend([ result for result_list in random_ciphertexts for result in result_list.get() ]) gc.collect() log.debug("Generating bias map") bias_map = generate_rc4_bias_map(flattened_list) for bias_idx in active_biases: cracked_indices[bias_idx - padding_len].add(RC4_BIAS_MAP[bias_idx] ^ bias_map[bias_idx][0][0]) all_branches = itertools.product( *[list(results) for results in cracked_indices]) return [ Bytes(struct.pack('B' * len(branch), *branch)) for branch in all_branches ]
def execute(self, public_key: WeierstrassPoint, invalid_curves: List[WeierstrassCurve] = None, max_factor_size: int = 2**16) -> int: """ Executes the attack. Parameters: public_key (int): ECDH public key to crack. invalid_curves (list): List of invalid curves to use in the attack. max_factor_size (int): Max factor size to prevent attempting to factor forever. Returns: int: Private key. """ residues = [] factors_seen = set() total = 1 reached_card = False cardinality = self.curve.cardinality() if not invalid_curves: invalid_curves = [] # Generate invalid curves if the user doesn't specify them or have enough factors def curve_gen(): orig = self.curve while True: b = orig.b while b == orig.b: b = orig.ring.random() curve = WeierstrassCurve(a=orig.a, b=b, ring=orig.ring) curve.cardinality() curve.G_cache = orig.G_cache yield curve for inv_curve in itertools.chain(invalid_curves, curve_gen()): # Factor as much as we can factors = [ r for r, _ in factorint(inv_curve.cardinality(), use_rho=False, limit=max_factor_size).items() if r > 2 and r < max_factor_size ] log.debug(f'Found factors: {factors}') # Request residues from crafted public keys for factor in RUNTIME.report_progress( set(factors) - factors_seen, desc='Sending malicious public keys', unit='factor'): if total > cardinality: reached_card = True break if factor in factors_seen: continue total *= factor # Generate a low-order point on the invalid curve bad_pub = inv_curve.POINT_AT_INFINITY while bad_pub == inv_curve.POINT_AT_INFINITY or bad_pub == inv_curve.G: point = inv_curve.random() bad_pub = point * (inv_curve.cardinality() // factor) residue = self.oracle.request(bad_pub, factor) residues.append((ZZ / ZZ(factor))(residue)) factors_seen.add(factor) if reached_card: break # We have to take into account the fact we can end up on the "negative" side of the field negations = [(residue, -residue) for residue in residues] G_cache = self.curve.G.cache_mul(cardinality.bit_length()) # Just bruteforce the correct configuration based off of the public key for residue_subset in RUNTIME.report_progress( itertools.product(*negations), desc='Bruteforcing residue configuration', unit='residue set', total=2**len(residues)): n, _ = crt(residue_subset) if G_cache * (int(n) % cardinality) == public_key: break return n.val