def execute(self, ciphertexts: list, iterations: int=3) -> list:
        """
        Executes the attack.
        
        Parameters:
            ciphertexts (list): List of bytes-like ciphertexts using the same keystream.
            iterations   (int): Number of iterations of the full-text analysis phase. Accuracy-time trade-off.

        Returns:
            list: List of recovered plaintexts.
        """
        min_size = min([len(ciphertext) for ciphertext in ciphertexts])

        same_size_ciphers = [ciphertext[:min_size] for ciphertext in ciphertexts]
        transposed_ciphers = [bytearray(transposed) for transposed in zip(*same_size_ciphers)]
        assert [bytearray(transposed) for transposed in zip(*transposed_ciphers)] == same_size_ciphers

        log.debug("Starting initial transposition analysis")

        # Transposition analysis first (transposition)
        transposed_plaintexts = []
        for cipher in RUNTIME.report_progress(transposed_ciphers, desc='Transposition analysis', unit='ciphers'):
            all_chars = {}
            for char in range(256):
                plaintext = Bytes(struct.pack('B', char)).stretch(len(cipher)) ^ cipher

                all_chars[char] = (self.analyzer.analyze(plaintext), plaintext)

            transposed_plaintexts.append(sorted(all_chars.items(), key=lambda kv: kv[1][0], reverse=True)[0][1][1])


        retransposed_plaintexts = [bytearray(transposed) for transposed in zip(*transposed_plaintexts)]

        log.debug("Starting full-text analysis on retransposed text")

        # Clean up with a character-by-character, higher-context analysis (retransposed)
        for j in RUNTIME.report_progress(range(iterations), desc='Higher-context analysis'):
            log.debug("Starting iteration {}/{}".format(j + 1, iterations))
            differential_mask = bytearray()

            for i in RUNTIME.report_progress(range(min_size), desc='Building differential mask', unit='bytes'):
                all_chars = {}

                for char in range(256):
                    full_text_analyses = []

                    for curr_cipher in retransposed_plaintexts:
                        cipher_copy = bytearray([_ for _ in curr_cipher])
                        cipher_copy[i] = ord(Bytes(struct.pack('B', char)) ^ struct.pack('B', curr_cipher[i]))

                        full_text_analyses.append(self.analyzer.analyze(cipher_copy))

                    all_chars[char] = (sum(full_text_analyses), char)

                best_char = sorted(all_chars.items(), key=lambda kv: kv[1][0], reverse=True)[0][1][1]
                differential_mask += struct.pack('B', best_char)

            retransposed_plaintexts = [Bytes.wrap(cipher) ^ differential_mask for cipher in retransposed_plaintexts]

        return retransposed_plaintexts
Beispiel #2
0
    def decrypt(self,
                nonce: bytes,
                authed_ciphertext: bytes,
                data: bytes = b'') -> Bytes:
        """
        Decrypts `ciphertext`.

        Parameters:
            nonce             (bytes): Bytes-like nonce.
            authed_ciphertext (bytes): Bytes-like object to be decrypted.
            data              (bytes): Bytes-like additional data to be authenticated.
        
        Returns:
            Bytes: Resulting plaintext.
        """
        from samson.utilities.runtime import RUNTIME

        authed_ciphertext = Bytes.wrap(authed_ciphertext)
        ciphertext, orig_tag = authed_ciphertext[:-16], authed_ciphertext[-16:]

        tag_mask = self.clock_ctr(nonce)
        data = Bytes.wrap(data)
        tag = self.auth(ciphertext, data, tag_mask)

        if not RUNTIME.compare_bytes(tag, orig_tag):
            raise Exception('Tag mismatch: authentication failed!')

        return self.ctr.decrypt(ciphertext)
    def execute(self,
                ciphertexts: list,
                word_ranges: list = [2, 3],
                delimiter: str = ' ') -> list:
        """
        Executes the attack.
        
        Parameters:
            ciphertexts (list): List of bytes-like ciphertexts using the same keystream.
            word_ranges (list): List of numbers of words to try. E.G. [2, 3, 4] means try the Cartesian product of 2, 3, and 4-tuple word combinations.
            delimiter    (str): Delimiter to use between word combinations.
        
        Returns:
            list: Top 10 possible plaintexts.
        """
        if len(ciphertexts) != 2:
            raise ValueError('`ciphertexts` MUST contain two samples.')

        two_time = xor_buffs(*ciphertexts)

        cipher_len = len(two_time)
        trimmed_list = [
            word for word in self.wordlist if len(word) <= cipher_len
        ]
        prepend_list = ['']

        last_num_processed = 0
        results = []

        for j in RUNTIME.report_progress(word_ranges):
            log.debug(f"Starting word range {j}")

            for i in range(j - last_num_processed):
                word_scores = []
                for prepend in prepend_list:
                    for word in trimmed_list:
                        mod_word = (prepend + delimiter + word).strip()
                        xor_result = xor_buffs(
                            (bytes(mod_word, 'utf-8') +
                             b'\x00' * cipher_len)[:cipher_len],
                            two_time)[:len(two_time)]
                        analysis = self.analyzer.analyze(xor_result)
                        word_scores.append(
                            (mod_word, analysis / (len(word)**2)))

                prepend_list = [
                    word for word, _ in sorted(
                        word_scores, key=lambda score: score[1], reverse=True)
                    [:10**(i + 1 + last_num_processed)]
                ]
            last_num_processed = j

            results.append(
                sorted(prepend_list,
                       key=lambda word: self.analyzer.analyze(
                           xor_buffs((bytes(word, 'utf-8') + b'\x00' *
                                      cipher_len)[:cipher_len], two_time)),
                       reverse=True)[:10])

        return results
Beispiel #4
0
    def verify(self,
               plaintext: bytes,
               signature: bytes,
               strict_type_match: bool = True) -> bool:
        """
        Verifies the `plaintext` against the `signature`.

        Parameters:
            plaintext        (bytes): Plaintext to verify.
            signature        (bytes): Signature to verify plaintext against.
            strict_type_match (bool): Whether or not to force use of `hash_obj` vs using the OID provided in the signature.
        
        Returns:
            bool: Whether or not the signature passed verification.
        """
        from samson.utilities.runtime import RUNTIME

        try:
            padded = Bytes(self.rsa.encrypt(signature))
            der_encoded = self.padder.unpad(padded)

            items = bytes_to_der_sequence(der_encoded)
            hash_obj = self.hash_obj

            if not strict_type_match:
                hash_obj = INVERSE_HASH_OID_LOOKUP[items[0][0]]()

            hashed_value = Bytes(items[1])

            return RUNTIME.compare_bytes(hashed_value,
                                         hash_obj.hash(plaintext))
        except Exception as _:
            return False
Beispiel #5
0
    def execute(self, public_key: int, max_factor_size: int = 2**16) -> int:
        """
        Executes the attack.

        Parameters:
            public_key      (int): Diffie-Hellman public key to crack.
            max_factor_size (int): Max factor size to prevent attempting to factor forever.
        
        Returns:
            int: Private key.
        """
        # Factor as much as we can
        factors = [
            r for r in factorint((self.p - 1) // self.order,
                                 use_rho=False,
                                 limit=max_factor_size) if r < max_factor_size
        ]
        log.debug(f'Found factors: {factors}')

        residues = []

        # Request residues from crafted public keys
        for factor in RUNTIME.report_progress(
                factors, desc='Sending malicious public keys', unit='factor'):
            h = 1
            while h == 1:
                h = pow(random_int_between(1, self.p), (self.p - 1) // factor,
                        self.p)

            residue = self.oracle.request(h, factor)
            residues.append((residue, factor))

        # Build partials using CRT
        n, r = crt(residues)

        # Oh, I guess we already found it...
        if r > self.order:
            return n

        g_prime = pow(self.g, r, self.p)
        y_prime = (public_key *
                   mod_inv(pow(self.g, n, self.p), self.p)) % self.p

        log.info(
            f'Recovered {"%.2f"%math.log(reduce(int.__mul__, factors, 1), 2)}/{"%.2f"%math.log(self.order, 2)} bits'
        )
        log.info(f'Found relation: x = {n} + m*{r}')
        log.debug(f"g' = {g_prime}")
        log.debug(f"y' = {y_prime}")
        log.info('Attempting to catch a kangaroo...')

        # Probabilistically solve DLP
        R = (ZZ / ZZ(self.p)).mul_group()
        m = pollards_kangaroo(R(g_prime),
                              R(y_prime),
                              a=0,
                              b=(self.order - 1) // r)
        return n + m * r
Beispiel #6
0
    def verify(self, plaintext: bytes, signature: bytes) -> bool:
        """
        Verifies the `plaintext` against the `signature`.

        Parameters:
            plaintext (bytes): Plaintext to verify.
            signature (bytes): Signature to verify against plaintext.
        
        Returns:
            bool: Whether or not the plaintext is verified.
        """
        from samson.utilities.runtime import RUNTIME

        plaintext = Bytes.wrap(plaintext)
        signature = Bytes.wrap(signature).zfill((self.modulus_len + 7) // 8)
        mHash     = self.hash_obj.hash(plaintext)

        em_bits = self.modulus_len - 1
        em_len  = (em_bits + 7) // 8

        if em_len < (self.hash_obj.digest_size + self.salt_len + 2):
            return False

        if bytes([signature[-1]]) != b'\xbc':
            return False


        # Let maskedDB be the leftmost emLen - hLen - 1 octets of EM,
        #    and let H be the next hLen octets.
        mask_len  = em_len - self.hash_obj.digest_size - 1
        masked_db = signature[:mask_len]
        H         = signature[mask_len:mask_len + self.hash_obj.digest_size]

        # If the leftmost 8emLen - emBits bits of the leftmost octet in
        #    maskedDB are not all equal to zero, output "inconsistent" and
        #    stop.
        left_mask = (2**(len(masked_db) * 8) - 1) >> ((8 * em_len) - em_bits)
        if masked_db & left_mask != masked_db:
            return False

        db_mask = self.mgf(H, mask_len)
        DB      = masked_db ^ db_mask
        DB     &= left_mask

        # If the emLen - hLen - sLen - 2 leftmost octets of DB are not
        #    zero or if the octet at position emLen - hLen - sLen - 1 (the
        #    leftmost position is "position 1") does not have hexadecimal
        #    value 0x01, output "inconsistent" and stop.
        if DB[:em_len - self.hash_obj.digest_size - self.salt_len - 1].int() != 1:
            return False

        salt    = DB[-self.salt_len:] if self.salt_len else b''
        m_prime = b'\x00' * 8 + mHash + salt
        h_prime = self.hash_obj.hash(m_prime)

        return RUNTIME.compare_bytes(h_prime, H)
Beispiel #7
0
    def run(self, generations: int) -> OptimizationResult:
        min_conv_counter = 0
        current_best = (not self.maximize) * 2**8192
        termination_reason = TerminationReason.MAX_ITERATION_REACHED
        granularized_minimum_convergence = self.convergence_granularity * self.minimum_convergence

        for iteration in RUNTIME.report_progress(range(generations), desc='Generations', unit='gens'):
            # 1) Measure
            self.obj_func(self.population)

            # 2) Select
            parent_pool = sorted(self.population, key=lambda chromo: chromo.fitness, reverse=self.maximize)[:self.parent_pool_size]

            # Test for minimum convergence heuristic
            if abs(parent_pool[0].fitness - current_best) // (current_best * self.convergence_granularity) < granularized_minimum_convergence:
                if min_conv_counter < self.min_conv_tolerance:
                    min_conv_counter += 1
                else:
                    termination_reason = TerminationReason.MINIMUM_CONVERGENCE_UNSATISFIED
                    break
            else:
                current_best = parent_pool[0].fitness
                min_conv_counter = 0


            if not self.maximize and not current_best:
                termination_reason = TerminationReason.GLOBAL_OPTIMA_REACHED
                break

            next_population = []

            # Elitism
            if self.elitism:
                next_population.extend(parent_pool)

            # Immigration
            parent_pool.extend([Chromosome(individual) for individual in self.initialization_func(self.num_immigrants)])

            # Breeding
            while len(next_population) < self.population_size:
                parents = random.sample(parent_pool, k=self.num_parents)

                # 3) Crossover
                individual = Chromosome(self.crossover_func(parents))

                # 4) Mutate
                individual = self.mutation_func(individual)

                next_population.append(individual)


            self.population = next_population

        self.obj_func(self.population)

        return OptimizationResult(sorted(self.population, key=lambda chromo: chromo.fitness, reverse=self.maximize)[0], iteration, termination_reason)
Beispiel #8
0
    def execute(self) -> Bytes:
        """
        Executes the attack.

        Parameters:
            unpad (bool): Whether or not to PKCS7 unpad the result.
        
        Returns:
            Bytes: The recovered plaintext.
        """
        baseline = len(self.oracle.encrypt(b''))
        block_size = self.oracle.test_io_relation()['block_size']

        plaintexts = []
        for curr_block in RUNTIME.report_progress(range(baseline //
                                                        block_size),
                                                  unit='blocks'):
            log.debug("Starting iteration {}".format(curr_block))

            plaintext = b''
            for curr_byte in RUNTIME.report_progress(range(block_size),
                                                     unit='bytes'):
                if curr_block == 0:
                    payload = ('A' * (block_size - (curr_byte + 1))).encode()
                else:
                    payload = plaintexts[-1][curr_byte + 1:]

                one_byte_short = get_blocks(self.oracle.encrypt(payload),
                                            block_size=block_size)[curr_block]

                for i in range(256):
                    curr_byte = struct.pack('B', i)
                    ciphertext = self.oracle.encrypt(payload + plaintext +
                                                     curr_byte)

                    # We're always editing the first block to look like block 'curr_block'
                    if get_blocks(ciphertext,
                                  block_size=block_size)[0] == one_byte_short:
                        plaintext += curr_byte
                        break

            plaintexts.append(plaintext)
        return Bytes(b''.join(plaintexts))
Beispiel #9
0
def color_format(color: ConsoleColors, text: str):
    from samson.utilities.runtime import RUNTIME
    if type(color) is ConsoleColors:
        color = color.value

    if RUNTIME.get_context().use_color:
        formatted = f"{PREFIX}{color}m{text}{SUFFIX}"
    else:
        formatted = text

    return formatted
Beispiel #10
0
def rand_bytes(size: int=16) -> bytes:
    """
    Reads bytes from RUNTIME.random.

    Parameters:
        size (int): Number of bytes to read.
    
    Returns:
        bytes: Random bytes.
    """
    from samson.utilities.runtime import RUNTIME
    return RUNTIME.random(size)
Beispiel #11
0
    def decrypt(self, key: bytes, iv: bytes, ciphertext: bytes,
                auth_data: bytes, auth_tag: bytes) -> Bytes:
        mac_key, enc_key = key.chunk(self.chunk_size)

        hmac = HMAC(
            mac_key,
            self.hash_obj).generate(auth_data + iv + ciphertext +
                                    Bytes(len(auth_data) *
                                          8).zfill(8))[:self.chunk_size]

        assert RUNTIME.compare_bytes(hmac, auth_tag)

        rij = Rijndael(enc_key)
        cbc = CBC(rij, iv=iv)

        return cbc.decrypt(ciphertext)
Beispiel #12
0
    def recover_plaintext(cls: object,
                          ciphertext: int,
                          pub: list,
                          alpha: int = 1) -> Bytes:
        """
        Attempts to recover the plaintext without the private key.

        Parameters:
            ciphertext (int): A ciphertext sum.
            pub        (int): The public key.
            alpha      (int): Punishment coefficient for deviation from guessed bit distribution.
        
        Returns:
            Bytes: Recovered plaintext.
        """
        from samson.math.matrix import Matrix
        from samson.math.algebra.all import QQ

        # Construct the problem matrix
        ident = Matrix.identity(len(pub), coeff_ring=QQ)
        pub_matrix = ident.col_join(Matrix([pub], coeff_ring=QQ))
        problem_matrix = pub_matrix.row_join(
            Matrix([[0] * len(pub) + [-ciphertext]], coeff_ring=QQ).T)

        # Attempt to crack the ciphertext using various punishment coefficients
        for i in RUNTIME.report_progress(range(len(pub)),
                                         desc='Alphaspace searched'):
            alpha_penalizer = Matrix([[alpha] * len(pub) + [-alpha * i]],
                                     coeff_ring=QQ)
            problem_matrix_prime = problem_matrix.col_join(alpha_penalizer).T
            solution_matrix = problem_matrix_prime.LLL(0.99)

            for row in solution_matrix.rows:
                relevant = row[:-2]
                new_row = [
                    item for item in relevant
                    if item >= QQ.zero() and item <= QQ.one()
                ]

                if len(new_row) == len(relevant):
                    return Bytes(
                        int(
                            ''.join([str(int(float(val)))
                                     for val in relevant]), 2))
        return solution_matrix
Beispiel #13
0
    def crack(self, lm_hash: bytes, charset: bytes=None) -> Bytes:
        """
        Cracks both halves simultaneously.

        Parameters:
            lm_hash (bytes): Hash to crack.
            charset (bytes): Character set to use.
        
        Returns:
            Bytes: Cracked LM hash.
        """
        h1, h2 = lm_hash.zfill(16).chunk(8)

        h1_pt, h2_pt = None, None

        h1_null, h2_null = self.check_halves_null(lm_hash)

        if h1_null:
            h1_pt = Bytes(b'').zfill(8)

        if h2_null:
            h2_pt = Bytes(b'').zfill(8)

        if not charset:
            charset = bytes(string.ascii_uppercase + string.digits + string.punctuation, 'utf-8')

        try:
            for i in RUNTIME.report_progress(range(1, 8), unit='length'):
                for attempt in itertools.product(charset, repeat=i):
                    b_attempt = bytes(attempt)
                    hashed    = self.hash(b_attempt)[:8]
                    if hashed == h1:
                        h1_pt = b_attempt

                    if hashed == h2:
                        h2_pt = b_attempt

                    if h1_pt and h2_pt:
                        raise KeyboardInterrupt()

        except KeyboardInterrupt:
            return Bytes(h1_pt or b'\x00').pad_congruent_right(7) + Bytes(h2_pt or b'')
Beispiel #14
0
    def derive_from_backdoor(cls: object, P: WeierstrassPoint,
                             Q: WeierstrassPoint, d: int,
                             observed_out: bytes) -> list:
        """
        Recovers the internal state of a Dual EC generator and builds a replica.

        Parameters:
            P (WeierstrassPoint): Elliptical curve point `P`.
            Q (WeierstrassPoint): Elliptical curve point `Q`.
            d              (int): Backdoor that relates Q to P.
            observed_out (bytes): Observed output from the compromised Dual EC generator.
        
        Returns:
            list: List of possible internal states.
        """
        assert len(observed_out) >= 30

        curve = P.curve

        r1 = observed_out[:30]
        r2 = observed_out[30:]

        possible_states = []
        Q_cache = Q.cache_mul(curve.cardinality().bit_length())

        for i in RUNTIME.report_progress(range(2**16),
                                         desc='Statespace searched',
                                         unit='states'):
            test_r1 = int.to_bytes(i, 2, 'big') + r1
            test_x = int.from_bytes(test_r1, 'big')
            try:
                R = curve(test_x)
                dR = d * R
                test_r2 = Q_cache * int(dR.x)

                if int.to_bytes(int(test_r2.x), 32,
                                'big')[2:2 + len(r2)] == r2:
                    possible_states.append(DualEC(P, Q, int(dR.x)))
            except AssertionError as _:
                pass

        return possible_states
Beispiel #15
0
    def crack_truncated(cls: object, outputs: list, outputs_to_predict: list,
                        multiplier: int, increment: int, modulus: int,
                        trunc_amount: int) -> object:
        """
        Given a decent number of truncated states (about 200 when there's only 3-bit outputs), returns a replica LCG.

        Parameters:
            outputs            (list): List of truncated-state outputs (in order).
            outputs_to_predict (list): Next few outputs to compare against. Accuracy/number of samples trade-off.
            multiplier          (int): The LCG's multiplier.
            increment           (int): The LCG's increment.
            modulus             (int): The LCG's modulus.

        Returns:
            LCG: Replica LCG that predicts all future outputs of the original.

        References:
            https://github.com/mariuslp/PCG_attack
            "Reconstructing Truncated Integer Variables Satisfying Linear Congruences" (https://www.math.cmu.edu/~af1p/Texfiles/RECONTRUNC.pdf)
        """
        if not outputs_to_predict:
            outputs_to_predict = outputs[-2:]
            outputs = outputs[:-2]

        # Trivial case
        if increment == 0:
            computed_seed = LCG.solve_tlcg(outputs + outputs_to_predict,
                                           multiplier, modulus, trunc_amount)

            # Here we take the second to last seed since our implementation edits the state BEFORE it returns
            return LCG((multiplier * computed_seed[-2]) % modulus,
                       multiplier,
                       increment,
                       modulus,
                       trunc=trunc_amount)

        else:
            diffs = [o2 - o1 for o1, o2 in zip(outputs, outputs[1:])]
            seed_diffs = LCG.solve_tlcg(diffs, multiplier, modulus,
                                        trunc_amount)
            seed_diffs = [
                int(seed_diff) % modulus for row in seed_diffs
                for seed_diff in row
            ]

            # Bruteforce low bits
            for z in RUNTIME.report_progress(range(2**trunc_amount),
                                             desc='Seedspace searched',
                                             unit='seeds'):
                x_0 = (outputs[0] << trunc_amount) + z
                x_1 = (seed_diffs[0] + x_0) % modulus
                computed_c = (x_1 - multiplier * x_0) % modulus

                computed_x_2 = (multiplier * x_1 + computed_c) % modulus
                actual_x_2 = (seed_diffs[1] + x_1) % modulus

                if computed_x_2 == actual_x_2:
                    computed_seeds = [x_0]

                    for diff in seed_diffs:
                        computed_seeds.append(
                            (diff + computed_seeds[-1]) % modulus)

                    # It's possible to find a spectrum of nearly-equivalent LCGs.
                    # The accuracy of `predicted_lcg` is dependent on the size of `outputs_to_predict` and the
                    # parameters of the LCG.
                    predicted_seed = (multiplier * computed_seeds[-2] +
                                      computed_c) % modulus
                    predicted_lcg = LCG(X=int(predicted_seed),
                                        a=multiplier,
                                        c=int(computed_c),
                                        m=modulus,
                                        trunc=trunc_amount)

                    if [
                            predicted_lcg.generate()
                            for _ in range(len(outputs_to_predict))
                    ] == outputs_to_predict:
                        return predicted_lcg

            raise SearchspaceExhaustedException('Seedspace exhausted')
Beispiel #16
0
 def real_decorator(cls):
     RUNTIME.register_exploit_mapping(cls, attack)
     return cls
Beispiel #17
0
 def real_decorator(cls):
     RUNTIME.register_primitive(cls)
     return cls
Beispiel #18
0
 def real_decorator(cls):
     RUNTIME.register_exploit(cls, consequence, requirements)
     return cls
    def _generate_tree(self):
        """
        Builds a binary hash tree of colliding, intermediary Merkle-Damgard construction states.
        """
        log.debug('Generating hash tree')
        tree = []
        for i in range(self.k):
            tree.append([])

        promoted_prefix = None

        # Determine if we need to "promote" a prefix. Basically, if we have an odd number of
        # prefixes on this layer, we put the last prefix in waiting as there will eventually be
        # another layer which has one less node than a power of two.
        curr_prefix_list = self.prefixes
        if len(self.prefixes) % 2 == 1:
            promoted_prefix = self.prefixes[-1]
            curr_prefix_list = self.prefixes[:-1]

        tree[0] = [(curr_prefix_list[i], curr_prefix_list[i + 1])
                   for i in range(0, len(curr_prefix_list), 2)]

        solution_tree = []
        for i in range(self.k):
            solution_tree.append([])

        for i in RUNTIME.report_progress(range(self.k)):
            for (p1, p2) in tree[i]:
                p1_suffix, p2_suffix, intermediary_collision = self.collision_func(
                    p1, p2)
                solution_tree[i].append(
                    (p1, p2, p1_suffix, p2_suffix, intermediary_collision))

            # Add solutions
            if i < (self.k - 1):

                # If there's an odd number of solutions at this level, then there's either a prefix
                # in waiting, or we need to promote one.
                next_level_states = solution_tree[i]
                if len(solution_tree[i]) % 2 == 1:
                    next_level_states = deepcopy(solution_tree[i])

                    # Last level is a multiple of 2 but not a power of 2 (e.g. 6).
                    # Promote our last prefix.
                    if promoted_prefix == None:
                        promoted_prefix = next_level_states[-1][-1]
                        next_level_states = next_level_states[:-1]

                    # We have a prefix in waiting. Use it immediately.
                    else:
                        next_level_states.append((promoted_prefix, ))

                tree[i + 1] = [(next_level_states[sol][-1],
                                next_level_states[sol + 1][-1])
                               for sol in range(0, len(next_level_states), 2)]

        # We're done generating the tree; time to set the output fields
        for layer in solution_tree:
            for p1_init, p2_init, p1_msg, p2_msg, result in layer:
                self.hash_tree[p1_init] = (p1_init, p2_init, p1_msg, p2_msg,
                                           result)
                self.hash_tree[p2_init] = (p1_init, p2_init, p1_msg, p2_msg,
                                           result)

        self.crafted_hash = solution_tree[-1][0][-1]
Beispiel #20
0
    def execute(self, ciphertext: bytes) -> Bytes:
        """
        Executes Manger's attack.

        Parameters:
            ciphertext (bytes): The ciphertext to decrypt.
        
        Returns:
            Bytes: The ciphertext's corresponding plaintext.
        """
        ciphertext = Bytes.wrap(ciphertext)
        ct_int = ciphertext.int()

        k = math.ceil(math.log(self.rsa.n, 256))
        B = 2**(8 * (k - 1))
        n = self.rsa.n
        e = self.rsa.e

        log.debug(f"k: {k}, B: {B}, n: {n}, e: {e}")

        # Step 1
        f1 = 2
        log.info("Starting step 1")

        while not self._greater_equal_B(f1, ct_int, e, n):
            f1 *= 2

        f1 //= 2
        log.debug(f"Found f1: {f1}")

        # Step 2
        nB = n + B
        nB_B = nB // B
        f2 = nB_B * f1

        log.info("Starting step 2")
        while self._greater_equal_B(f2, ct_int, e, n):
            f2 += f1

        log.debug(f"Found f2: {f2}")

        # Step 3
        div_mod = 1 if n % f2 else 0
        m_min = n // f2 + div_mod
        m_max = nB // f2
        BB = 2 * B
        diff = m_max - m_min
        ctr = 0

        log.info("Starting step 3")
        log.debug(f"B-(diff * f2) = {B - (diff * f2)}")

        # Reporting
        last_log_diff = math.log(diff, 2)
        progress = RUNTIME.report_progress(None, total=last_log_diff)

        while diff > 0:
            if ctr % 100 == 0:
                log.debug(f"Iteration {ctr} difference: {diff}")

            f = BB // diff
            f_min = f * m_min
            i = f_min // n
            iN = i * n

            div_mod = 1 if iN % m_min else 0
            f3 = iN // m_min + div_mod
            iNB = iN + B

            if self._greater_equal_B(f3, ct_int, e, n):
                div_mod = 1 if iNB % f3 else 0
                m_min = iNB // f3 + div_mod
            else:
                m_max = iNB // f3

            diff = m_max - m_min

            # Update progress
            log_diff = math.log(diff + 1, 2)
            progress.update(last_log_diff - log_diff)
            last_log_diff = log_diff

            ctr += 1

        return Bytes(m_min)
Beispiel #21
0
    def execute(self,
                initial_filter=BASIC_FILTER,
                min_input_len: int = 1) -> Fingerprint:
        sample = self.oracle.encrypt(b'a' * min_input_len)
        base_len = len(sample)
        filtered = RUNTIME.search_primitives(initial_filter)

        io_rel_analysis = self.oracle.test_io_relation(min_input_len)
        io_relation = io_rel_analysis['io_relation']
        block_size = io_rel_analysis['block_size']

        max_val_analysis = IntegerAnalysis.analyze(
            self.oracle.test_max_input())

        modifiers = {}
        if max_val_analysis.n != -1:
            if max_val_analysis.prime_name:
                log.debug(
                    f'Max input size is a well-known modulus: {max_val_analysis.prime_name}'
                )

            # Add modifiers for matching primitives
            if not max_val_analysis.is_prime and not max_val_analysis.byte_aligned and max_val_analysis.is_uniform:
                from samson.public_key.rsa import RSA
                modifiers[RSA] = 1

                log.debug(f'Max input size looks like RSA modulus')

            elif max_val_analysis.is_prime and max_val_analysis.is_uniform:
                from samson.protocols.dragonfly import Dragonfly
                from samson.public_key.elgamal import ElGamal

                # Process Diffie-Hellman-like primitives
                dh_modifier = 1 + max_val_analysis.is_safe_prime + bool(
                    max_val_analysis.prime_name)

                modifiers[DiffieHellman] = dh_modifier
                modifiers[Dragonfly] = dh_modifier
                modifiers[ElGamal] = dh_modifier

                log.debug(f'Max input size looks like Diffie-Hellman modulus')

        matching = [
            match for match in filtered if block_size *
            8 in match.BLOCK_SIZE and match.IO_RELATION_TYPE == io_relation
        ]
        bc_modes = []

        # Punish IV/nonce/AEAD primitives if we can prove the output doesn't contain their ephemeral/tag
        # This is only really possible if the output is smaller than their ephemeral/tag

        def calculate_min_size(size):
            min_size = 0
            typical_size = 0

            sizes = size.sizes
            if type(sizes) is int:
                min_size += sizes
                typical_size += sizes

            else:
                if size.size_type not in [
                        SizeType.ARBITRARY, SizeType.DEPENDENT
                ]:
                    min_size += sizes[0]

                if size.typical:
                    typical_size += size.typical[0]

            return min_size, typical_size

        # If the primitive is a block cipher mode and its ephemeral/tag is DEPENDENT, we'll want to check
        # against known block ciphers.
        block_ciphers = [
            prim for prim in filtered
            if issubclass(prim, BlockCipher) and block_size *
            8 in prim.BLOCK_SIZE
        ]
        minimum_bc = min([
            calculate_min_size(block_cipher.BLOCK_SIZE)[0]
            for block_cipher in block_ciphers
        ]) if block_ciphers else 0

        for match in matching:
            min_size = 0
            typical_size = 0
            all_sizes = []

            if hasattr(
                    match, 'EPHEMERAL'
            ) and not match.EPHEMERAL.ephemeral_type == EphemeralType.KEY:
                all_sizes.append(match.EPHEMERAL.size)

            if hasattr(match, 'AUTH_TAG_SIZE'):
                all_sizes.append(match.AUTH_TAG_SIZE)

            for size in all_sizes:
                component_min, component_typical = calculate_min_size(size)
                min_size += component_min
                typical_size += component_typical

                if issubclass(match, BlockCipherMode
                              ) and size.size_type == SizeType.DEPENDENT:
                    min_size += minimum_bc

            for size in [min_size, typical_size]:
                if base_len * 8 < size:
                    if not match in modifiers:
                        modifiers[match] = 0

                    modifiers[match] -= 1

        # Find possible block cipher modes. We exclude StreamingBlockCipherModes because they're handled
        # by their block size above.
        if any([issubclass(match, BlockCipher) for match in matching]):
            from samson.block_ciphers.modes.ecb import ECB

            log.debug(
                f'Block ciphers in candidates. Attempting to find possible block cipher modes'
            )
            bc_modes = [
                prim for prim in filtered if issubclass(prim, BlockCipherMode)
                and not issubclass(prim, StreamingBlockCipherMode)
            ]

            # Check for ECB
            if self.oracle.test_stateless(block_size):
                log.info(f'Stateless blocks detected')
                bc_modes = [ECB]

            else:
                if ECB in bc_modes:
                    bc_modes.remove(ECB)

        # Score matches higher if they're more explicit
        scored_matches = {}

        for match in matching:
            bitsize = block_size * 8
            base_freq = match.USAGE_FREQUENCY.value + (
                modifiers[match] if match in modifiers else 0)
            scored_matches[match] = base_freq

            # If it matches a SINGLE value, that's significant
            if bitsize == match.BLOCK_SIZE.sizes:
                scored_matches[match] += FrequencyType.PROLIFIC.value

            # If it's in a RANGE, it's a bit less significant
            elif bitsize in match.BLOCK_SIZE.sizes:
                scored_matches[match] += FrequencyType.NORMAL.value

            # Add a modifier for being in 'typical'
            if bitsize in match.BLOCK_SIZE.typical:
                scored_matches[match] += 1

        return Fingerprint(candidates=scored_matches,
                           modes=bc_modes,
                           max_input_analysis=max_val_analysis,
                           io_relation=io_relation,
                           block_size=block_size)
    def execute(self, ciphertext: bytes) -> Bytes:
        """
        Executes the attack.

        Parameters:
            ciphertext (bytes): Bytes-like ciphertext to be decrypted.

        Returns:
            Bytes: Plaintext corresponding to the inputted ciphertext.
        """
        blocks = Bytes.wrap(ciphertext).chunk(self.block_size)
        reversed_blocks = blocks[::-1]

        plaintexts = []

        for i, block in enumerate(RUNTIME.report_progress(reversed_blocks, desc='Blocks cracked', unit='blocks')):
            log.debug("Starting iteration {}".format(i))
            plaintext = Bytes(b'')

            if i == len(reversed_blocks) - 1:
                preceding_block = self.iv
            else:
                preceding_block = reversed_blocks[i + 1]

            for _ in RUNTIME.report_progress(range(len(block)), desc='Bytes cracked', unit='bytes'):
                last_working_char = None
                exploit_blocks    = {}

                # Generate candidate blocks
                for possible_char in self.alphabet:
                    test_byte = struct.pack('B', possible_char)
                    payload   = test_byte + plaintext
                    prefix    = b'\x00' * (self.block_size - len(payload))

                    padding = (struct.pack('B', len(payload)) * (len(payload))) ^ payload

                    fake_block    = prefix + padding
                    exploit_block = fake_block ^ preceding_block
                    new_cipher    = bytes(exploit_block + block)

                    exploit_blocks[new_cipher] = test_byte


                if self.batch_requests:
                    best_block = self.oracle.check_padding([k for k,v in exploit_blocks.items()])
                    last_working_char = exploit_blocks[best_block]
                    log.debug("Found working byte: {}".format(last_working_char))
                else:
                    # Oracle can't handle batch requests. Feed blocks into it.
                    for exploit_block, byte in exploit_blocks.items():
                        if self.oracle.check_padding(exploit_block):
                            log.debug("Found working byte: {}".format(byte))
                            last_working_char = byte

                        # Early out optimization. Note, we're being careful about PKCS7 padding here.
                        if last_working_char and ord(byte) >= self.block_size:
                            break

                plaintext = last_working_char + plaintext

            plaintexts.append(plaintext)
        return Bytes(b''.join(plaintexts[::-1]))
    def execute(self, ciphertext: int, n: int, e: int,
                key_length: int) -> Bytes:
        """
        Executes the attack.

        Parameters:
            ciphertext (int): The ciphertext represented as an integer.
                     n (int): The RSA instance's modulus.
                     e (int): The RSA instance's public exponent.
            key_length (int): The the bit length of the RSA instance (2048, 4096, etc).
        
        Returns:
            Bytes: The ciphertext's corresponding plaintext.
        """
        key_byte_len = key_length // 8

        # Convenience variables
        B = 2**(8 * (key_byte_len - 2))

        # Initial values
        c = ciphertext
        c_0 = ciphertext
        M = [(2 * B, 3 * B - 1)]
        i = 1

        if not self.oracle.check_padding(c):
            log.debug("Initial padding not correct; attempting blinding")

            # Step 1: Blinding
            while True:
                s = randint(0, n - 1)
                c_0 = (c * pow(s, e, n)) % n

                if self.oracle.check_padding(c_0):
                    log.debug("Padding is now correct; blinding complete")
                    break

        # Setup reporting
        last_log_diff = math.log(M[0][1] - M[0][0], 2)
        progress = RUNTIME.report_progress(None, total=last_log_diff)

        # Step 2
        while True:
            log.debug("Starting iteration {}".format(i))
            log.debug("Current intervals: {}".format(M))

            diff = math.log(
                sum([interval[1] - interval[0] for interval in M]) + 1, 2)
            progress.update(last_log_diff - diff)
            last_log_diff = diff

            # Step 2.a
            if i == 1:
                s = _ceil(n, 3 * B)

                log.debug("Starting search at {}".format(s))

                while True:
                    c = c_0 * pow(s, e, n) % n
                    if self.oracle.check_padding(c):
                        break

                    s += 1
            # Step 2.b
            elif len(M) >= 2:
                log.debug("Intervals left: {}".format(M))
                while True:
                    s += 1
                    c = c_0 * pow(s, e, n) % n

                    if self.oracle.check_padding(c):
                        break

            # Step 2.c
            elif len(M) == 1:
                log.debug("Only one interval")

                a, b = M[0]

                if a == b:
                    return Bytes(b'\x00' + int_to_bytes(a, 'big'))

                r = _ceil(2 * (b * s - 2 * B), n)
                s = _ceil(2 * B + r * n, b)

                while True:
                    c = c_0 * pow(s, e, n) % n
                    if self.oracle.check_padding(c):
                        break

                    s += 1
                    if s > (3 * B + r * n) // a:
                        r += 1
                        s = _ceil(2 * B + r * n, b)

            M_new = []

            for a, b in M:
                min_r = _ceil(a * s - 3 * B + 1, n)
                max_r = (b * s - 2 * B) // n

                for r in range(min_r, max_r + 1):
                    new_a = max(a, _ceil(2 * B + r * n, s))
                    new_b = min(b, (3 * B - 1 + r * n) // s)

                    if new_a > new_b:
                        raise Exception(
                            "Step 3: new_a > new_b? new_a: {} new_b: {}".
                            format(new_a, new_b))

                    # Now we need to check for overlap between ranges and merge them
                    _append_and_merge(new_a, new_b, M_new)

            if len(M_new) == 0:
                raise Exception("There are zero intervals in 'M_new'")

            M = M_new
            i += 1
Beispiel #24
0
    def execute(self,
                secret_length: int,
                sample_size: int = 2**23,
                chunk_size: int = 2**19) -> Bytes:
        """
        Executes the attack.

        Parameters:
            secret_length (int): The length of the secret you're trying to recover.
            sample_size   (int): The amount of samples to collect per byte of the secret. Higher numbers are slower but more accurate.
            chunk_size    (int): The size of sample chunks per CPU before a forceful garbage collection saves the day.
        
        Returns:
            Bytes: The recovered plaintext.
        """
        cracked_indices = [set() for i in range(secret_length)]
        cpu_count = multiprocessing.cpu_count()
        pool = multiprocessing.Pool(processes=cpu_count)

        log.info(f"Running with {cpu_count} cores")

        for i in RUNTIME.report_progress(range(secret_length), unit='bytes'):
            log.debug(f"Starting iteration {i + 1}/{secret_length}")

            if len(cracked_indices[i]) > 0:
                continue

            applicable_biases = [
                bias for bias in self.strongest_biases
                if i <= bias or (secret_length + i >= bias and i < bias)
            ]
            padding_len = max(applicable_biases[0] - i, 0)
            active_biases = [
                bias for bias in applicable_biases
                if padding_len + secret_length > bias
            ]

            payload = b'\x00' * padding_len
            num_chunks = math.ceil(sample_size / chunk_size)

            log.debug(f"Sampling {sample_size} ciphertexts")
            flattened_list = []
            for i in range(math.ceil(num_chunks / cpu_count)):
                random_ciphertexts = [
                    pool.apply_async(self._encrypt_chunk,
                                     (payload, chunk_size))
                    for i in range(min(num_chunks -
                                       (i * cpu_count), cpu_count))
                ]
                flattened_list.extend([
                    result for result_list in random_ciphertexts
                    for result in result_list.get()
                ])
                gc.collect()

            log.debug("Generating bias map")
            bias_map = generate_rc4_bias_map(flattened_list)

            for bias_idx in active_biases:
                cracked_indices[bias_idx -
                                padding_len].add(RC4_BIAS_MAP[bias_idx]
                                                 ^ bias_map[bias_idx][0][0])

        all_branches = itertools.product(
            *[list(results) for results in cracked_indices])
        return [
            Bytes(struct.pack('B' * len(branch), *branch))
            for branch in all_branches
        ]
    def execute(self,
                public_key: WeierstrassPoint,
                invalid_curves: List[WeierstrassCurve] = None,
                max_factor_size: int = 2**16) -> int:
        """
        Executes the attack.

        Parameters:
            public_key      (int): ECDH public key to crack.
            invalid_curves (list): List of invalid curves to use in the attack.
            max_factor_size (int): Max factor size to prevent attempting to factor forever.
        
        Returns:
            int: Private key.
        """
        residues = []
        factors_seen = set()
        total = 1

        reached_card = False
        cardinality = self.curve.cardinality()

        if not invalid_curves:
            invalid_curves = []

        # Generate invalid curves if the user doesn't specify them or have enough factors
        def curve_gen():
            orig = self.curve
            while True:
                b = orig.b

                while b == orig.b:
                    b = orig.ring.random()

                curve = WeierstrassCurve(a=orig.a, b=b, ring=orig.ring)
                curve.cardinality()

                curve.G_cache = orig.G_cache
                yield curve

        for inv_curve in itertools.chain(invalid_curves, curve_gen()):
            # Factor as much as we can
            factors = [
                r for r, _ in factorint(inv_curve.cardinality(),
                                        use_rho=False,
                                        limit=max_factor_size).items()
                if r > 2 and r < max_factor_size
            ]
            log.debug(f'Found factors: {factors}')

            # Request residues from crafted public keys
            for factor in RUNTIME.report_progress(
                    set(factors) - factors_seen,
                    desc='Sending malicious public keys',
                    unit='factor'):
                if total > cardinality:
                    reached_card = True
                    break

                if factor in factors_seen:
                    continue

                total *= factor

                # Generate a low-order point on the invalid curve
                bad_pub = inv_curve.POINT_AT_INFINITY

                while bad_pub == inv_curve.POINT_AT_INFINITY or bad_pub == inv_curve.G:
                    point = inv_curve.random()
                    bad_pub = point * (inv_curve.cardinality() // factor)

                residue = self.oracle.request(bad_pub, factor)
                residues.append((ZZ / ZZ(factor))(residue))
                factors_seen.add(factor)

            if reached_card:
                break

        # We have to take into account the fact we can end up on the "negative" side of the field
        negations = [(residue, -residue) for residue in residues]
        G_cache = self.curve.G.cache_mul(cardinality.bit_length())

        # Just bruteforce the correct configuration based off of the public key
        for residue_subset in RUNTIME.report_progress(
                itertools.product(*negations),
                desc='Bruteforcing residue configuration',
                unit='residue set',
                total=2**len(residues)):
            n, _ = crt(residue_subset)
            if G_cache * (int(n) % cardinality) == public_key:
                break

        return n.val