def find_plaintext(config, c, rsa, optimized_si=False, use_negative=False, use_mod_trimmer=False): # use_negative can be used only in TTT oracle (because in weaker oracles we cannot # be sure that an answer about a message not being a conformant really # means that the message is not conformant (it could just be that the # padding is not valid) if use_negative: if config["noterm"] == False or config["shortpad"] == False: print("use_negative cannot be used for this kind of an oracle") sys.exit() modBits = Crypto.Util.number.size(rsa.n) k = ceil_div(modBits, 8) # Convert from bits to bytes (length of n in bytes) B = pow(2, 8 * (k - 2)) oracle = Oracle(rsa, B, config["noterm"], config["shortpad"]) oracle_queries = 0 if config["use_trimmer"]: _n_div_9B = 2 * n // (9 * B) if use_mod_trimmer: trimmer = ModTrimmer(config, oracle, c) else: trimmer = Trimmer(config, oracle, c) best_fractions, lcm, oracle_calls = trimmer.get_best_fractions( _n_div_9B, oracle_queries) oracle_queries += oracle_calls lower_fraction, upper_fraction = best_fractions new_a = oracle.mmin * lower_fraction[1] // lower_fraction[0] new_a = divide_ceil(new_a, lcm) * lcm new_b = divide_ceil(oracle.mmax * upper_fraction[1], upper_fraction[0]) new_b = (new_b // lcm) * lcm else: new_a = oracle.mmin new_b = oracle.mmax start = (n + new_a) // new_b #print "Searching s1 from " + str(start) si = start conformant = False while True: rs = get_rs(si, new_a, new_b, n, oracle.mmin, oracle.mmax) if len(rs) != 0: if config["with_encryption"]: new_c = (c * pow(si, e, n)) % n else: new_c = (c * si) % n conformant = oracle.call(new_c) oracle_queries += 1 if conformant: break si += 1 # debugging: for i in range(si + 1): break bla = (c * i) % n if bla > oracle.mmin and bla < oracle.mmax: print(i) print("Found s1: %s; oracle queries: %s" % (si, oracle_queries)) good_s = [si] s1_oracle_queries = oracle_queries candidate_bounds = get_candidate_bounds(si, new_a, new_b, oracle.mmin, oracle.mmax, n) if len(candidate_bounds) == 1: new_a, new_b = candidate_bounds[0] r = new_a * si // n print("---") print(r) print(get_intersect_rs(oracle, n, B, c, 0, 10)) else: if config["use_parallel_threads_method"]: new_a, new_b, r, si, qs = determine_proper_bounds_parallel_threads( oracle, rsa, n, B, e, new_a, new_b, candidate_bounds, si) oracle_queries += qs print("parallel threads determine proper bounds queries: %s" % qs) else: # find another si for which the message is conformant new_a, new_b, r, si, qs = determine_proper_bounds( oracle, rsa, n, B, e, new_a, new_b, candidate_bounds, si) oracle_queries += qs print("determine proper bounds queries: %s" % qs) if c < new_a or c > new_b: print("implementation error 1") good_r = [r] s_count = 1 if use_negative: # When a multiplied message is reported to be non conformant in TTT oracle, we know where # outside a conformant interval the message lies. # This significantly improve finding s2, s3, ... (for s1 it cannot be used) while True: r = 2 * r min_si = 1 + (r * n + oracle.mmin) // new_b max_si = (r * n + oracle.mmax) // new_a #for si in range(min_si, max_si+1): for si in range(min_si, min_si + 1): # todo: check all si if config["with_encryption"]: new_c = (c * pow(si, e, n)) % n else: new_c = (c * si) % n conformant = oracle.call(new_c) oracle_queries += 1 if conformant: na, nb = get_new_bounds(si, r, n, oracle.mmin, oracle.mmax) if na > new_a: new_a = na new_a = divide_ceil(new_a, lcm) * lcm if nb < new_b: new_b = nb new_b = (new_b // lcm) * lcm if c < new_a or c > new_b: print("implementation error 3") sys.exit(1) if new_b == new_a: print("oracle queries: %s" % (oracle_queries)) return oracle_queries, s1_oracle_queries else: done = False old_r = r old_a = new_a old_b = new_b if config["noterm"] == False or config["shortpad"] == False: pass else: if si * new_a > r * n + oracle.mmin: new_a = (r * n + oracle.mmax) // si new_a = divide_ceil(new_a, lcm) * lcm #todo ?? done = True if si * new_b < r * n + oracle.mmax: new_b = (r * n + oracle.mmin) // si #if new_b * si < r*n + oracle.mmin: # new_b += 1 new_b = (new_b // lcm) * lcm # todo ?? done = True if not done: #print(new_b - new_a) if c * si % n > oracle.mmin and c * si % n < oracle.mmax: print("this is conformant") qs, new_a, new_b, r = get_conformant_from_negative( oracle, new_a, new_b, si, r, lcm, c, B) oracle_queries += qs if new_b == new_a: print("oracle queries: %s" % (oracle_queries)) return oracle_queries, s1_oracle_queries else: unsuccessful_rs = [] while True: if optimized_si: tmp_s = 2 * B // (new_b - new_a) r = tmp_s * new_a // n si, r, queries = get_conformant_optimized_si(oracle, r, c, new_a, new_b, rsa, n, B, e, \ unsuccessful_rs) else: r = 2 * r si, r, queries = get_conformant(oracle, r, c, new_a, new_b, rsa, n, B, e) if len(good_s) < 500: good_s.append(si) good_r.append(r) oracle_queries += queries # r mora deliti (s-1) s_count += 1 #rs = get_rs(si, new_a, new_b, n, oracle.mmin, oracle.mmax) #r = rs[0] candidate_bounds = get_candidate_bounds(si, new_a, new_b, oracle.mmin, oracle.mmax, n) #if len(rs) > 1: if len(candidate_bounds) > 1: #print(rs) print( "this does not happen in TTT, but in other oracles it can") if config["use_parallel_threads_method"]: new_a, new_b, r, si, qs = determine_proper_bounds_parallel_threads( oracle, rsa, n, B, e, new_a, new_b, candidate_bounds, si) oracle_queries += qs print("calls to determine bounds (parallel): %s" % qs) else: new_a, new_b, r, si, qs = determine_proper_bounds( oracle, rsa, n, B, e, new_a, new_b, candidate_bounds, si) oracle_queries += qs print("calls to determine bounds: %s" % qs) else: new_a, new_b = candidate_bounds[0] r = new_a * si // n new_a = divide_ceil(new_a, lcm) * lcm new_b = (new_b // lcm) * lcm nn_a = new_a nn_b = new_b for ind1, si1 in enumerate(good_s): break for ind2, si2 in enumerate(good_s): if ind2 <= ind1: continue ri1 = good_r[ind1] ri2 = good_r[ind2] na = ((ri2 - ri1) * n - B) // (si2 - si1) if na * (si2 - si1) < (ri2 - ri1) * n - B: na += 1 nb = ((ri2 - ri1) * n + B) // (si2 - si1) if na > c or nb < c: sys.exit("go somewhere") if na > new_a: #print("tra1") #print((na-nn_a)/B) print("%s, %s" % (ind1, ind2)) new_a = na new_a = divide_ceil(new_a, lcm) * lcm if nb < new_b: #print("tra2") #print((nn_b-nb)/B) print("%s, %s" % (ind1, ind2)) new_b = nb new_b = (new_b // lcm) * lcm #print("=======================") if c < new_a or c > new_b: print("implementation error 2") sys.exit(1) if new_b == new_a: print("oracle queries: %s" % (oracle_queries)) print("--------------------------------------------------") print(new_a) return oracle_queries, s1_oracle_queries