def pysmt2bv(ps): """Convert a pySMT type to a bit-vector type. Currently, only conversion from `pysmt.shortcuts.BV` and `pysmt.shortcuts.Symbol` is supported. >>> from pysmt import shortcuts, typing # pySMT shortcuts and typing modules >>> from arxpy.smt.types import pysmt2bv >>> env = shortcuts.reset_env() >>> pysmt2bv(env.formula_manager.Symbol("x", env.type_manager.BVType(8))).vrepr() "Variable('x', width=8)" >>> pysmt2bv(env.formula_manager.BV(1, 8)).vrepr() 'Constant(0b00000001, width=8)' """ class_name = type(ps).__name__ msg = "unknown conversion of {} ({} {}) to a bit-vector type".format( ps, ps.get_type(), class_name) if ps.is_symbol(): if str(ps.get_type()) == "Bool": return core.Variable(ps.symbol_name(), 1) else: return core.Variable(ps.symbol_name(), ps.bv_width()) elif ps.is_bv_constant(): return core.Constant(int(ps.constant_value()), ps.bv_width()) elif ps.is_false(): return core.Constant(0, 1) elif ps.is_true(): return core.Constant(1, 1) else: raise NotImplementedError(msg)
def lz(x): if are_cte_differences: return extraop.LeadingZeros(x) else: aux = core.Variable("{}_{}lz".format(prefix, self._i_auxvar), x.width) self._i_auxvar += 1 assertions.append(operation.BvComp(aux, extraop.LeadingZeros(x))) return aux
def rev(x): if are_cte_differences: return extraop.Reverse(x) else: aux = core.Variable("{}_{}rev".format(prefix, self._i_auxvar), x.width) self._i_auxvar += 1 assertions.append(operation.BvComp(aux, extraop.Reverse(x))) return aux
def _symbolic_input(cls, symbol_prefix="i"): """Return a tuple of variables with proper input widths.""" in_vars = [] for i, width in enumerate(cls.input_widths): name = "{}{}".format(symbol_prefix, i) in_vars.append(core.Variable(name, width)) return tuple(in_vars)
def add_op(self, expr): """Add an bit-vector expression and return its identifier.""" from arxpy.bitvector import operation assert isinstance(expr, operation.Operation) assert not self.contain_op(expr) name = "{}{}".format(self.id_prefix, self.counter) self.counter += 1 identifier = core.Variable(name, expr.width) self.table[identifier] = expr return identifier
def add_op(self, op): """Add an operation to the state and return its identifier.""" from arxpy.bitvector import operation assert isinstance(op, operation.Operation) assert not self.contain_op(op) name = "{}{}".format(self.id_prefix, self.counter) self.counter += 1 identifier = core.Variable(name, op.width) self.table[identifier] = op return identifier
def __init__(self, bv_cipher, diff_type): assert issubclass(bv_cipher, primitives.Cipher) assert issubclass(diff_type, difference.Difference) rk = [] for i, width in enumerate(bv_cipher.key_schedule.output_widths): rk.append(core.Variable("k" + str(i), width)) class Encryption(bv_cipher.encryption): round_keys = tuple(rk) func = Encryption num_inputs = len(func.input_widths) input_diff_names = ["dp" + str(i) for i in range(num_inputs)] prefix = "dx" super().__init__(func, diff_type, input_diff_names, prefix) self._cipher = bv_cipher
def __init__(self, func, diff_type, input_diff_names, prefix="d", initial_var2diff=None): assert issubclass(func, primitives.BvFunction) assert issubclass(diff_type, difference.Difference) assert len(input_diff_names) == len(func.input_widths) input_diff = [] for name, width in zip(input_diff_names, func.input_widths): input_diff.append(diff_type(core.Variable(name, width))) input_diff = tuple(input_diff) self.func = func self.diff_type = diff_type self.input_diff = input_diff # Propagate the input difference through the function names = [d.val.name for d in self.input_diff] ssa = self.func.ssa(names, id_prefix=prefix) self.ssa = ssa self._prefix = prefix self._input_diff_names = input_diff_names for var in ssa["output_vars"]: if isinstance(var, core.Constant): raise ValueError("constant outputs (independent of the inputs) are not supported") var2diff = {} # Variable to Difference for var, diff in zip(ssa["input_vars"], self.input_diff): var2diff[var] = diff if initial_var2diff is not None: for var in initial_var2diff: if str(var) in names: raise ValueError("the input differences cannot be replaced by initial_var2diff") var2diff.update(initial_var2diff) self.nonlinear_diffs = collections.OrderedDict() for var, expr in ssa["assignments"]: expr_args = [] for arg in expr.args: if isinstance(arg, int): expr_args.append(arg) # 'int' object has no attribute 'xreplace' else: expr_args.append(arg.xreplace(var2diff)) if all(not isinstance(arg, diff_type) for arg in expr_args): # symbolic computations with the key var2diff[var] = expr continue if all(isinstance(arg, diff_type) for arg in expr_args): der = self.diff_type.derivative(type(expr), expr_args) else: def contains_key_var(term): from sympy.core import basic for sub in basic.preorder_traversal(term): if sub in func.round_keys: return True else: return False if type(expr) == operation.BvAdd and hasattr(func, 'round_keys') and \ all(isinstance(r, core.Variable) for r in func.round_keys) and \ any(contains_key_var(a) for a in expr_args): # temporary solution to Derivative(BvAddCte_k(x)) != Derivative(x + k) # with x a Diff and k a key variable keyed_indices = [] for i, a in enumerate(expr_args): if contains_key_var(a): keyed_indices.append(i) if len(keyed_indices) != 1 or expr_args[keyed_indices[0]] not in func.round_keys: raise NotImplementedError("invalid expression: op={}, args={}".format( type(expr).__name__, expr_args)) # expr_args[keyed_indices[0]] replaced to the zero diff zero_diff = diff_type(core.Constant(0, expr_args[keyed_indices[0]].width)) der = self.diff_type.derivative(type(expr), [expr_args[(keyed_indices[0] + 1) % 2], zero_diff]) elif hasattr(expr, "xor_derivative"): # temporary solution to operations containing a custom derivative input_diff_expr = [] for i, arg in enumerate(expr_args): if isinstance(arg, diff_type): input_diff_expr.append(arg) else: assert isinstance(arg, core.Term) # int arguments currently not supported input_diff_expr.append(diff_type.from_pair(arg, arg)) der = self.diff_type.derivative(type(expr), input_diff_expr) else: fixed_args = [] for i, arg in enumerate(expr_args): if not isinstance(arg, diff_type): fixed_args.append(arg) else: fixed_args.append(None) new_op = extraop.make_partial_operation(type(expr), tuple(fixed_args)) der = self.diff_type.derivative(new_op, [arg for arg in expr_args if isinstance(arg, diff_type)]) if isinstance(der, derivative.Derivative): diff = self.diff_type(var) var2diff[var] = diff self.nonlinear_diffs[diff] = der else: var2diff[var] = der self._var2diff = var2diff self.output_diff = [] for var in ssa["output_vars"]: self.output_diff.append([self.diff_type(var), var2diff[var]])
def ssa(cls, input_names, id_prefix): """Return a static single assignment program representing the function. Args: input_names: the names for the input variables id_prefix: the prefix to denote the intermediate variables Return: : a dictionary with three keys - *input_vars*: a list of `Variable` representing the inputs - *output_vars*: a list of `Variable` representing the outputs - *assignments*: an ordered sequence of pairs (`Variable`, `Operation`) representing each assignment of the SSA program. :: >>> from arxpy.primitives.chaskey import ChaskeyPi >>> ChaskeyPi.set_rounds(1) >>> ChaskeyPi.ssa(["v0", "v1", "v2", "v3"], "x") # doctest: +NORMALIZE_WHITESPACE {'input_vars': (v0, v1, v2, v3), 'output_vars': (x7, x12, x13, x9), 'assignments': ((x0, v0 + v1), (x1, v1 <<< 5), (x2, x0 ^ x1), (x3, x0 <<< 16), (x4, v2 + v3), (x5, v3 <<< 8), (x6, x4 ^ x5), (x7, x3 + x6), (x8, x6 <<< 13), (x9, x7 ^ x8), (x10, x2 + x4), (x11, x2 <<< 7), (x12, x10 ^ x11), (x13, x10 <<< 16))} """ input_vars = [] for name, width in zip(input_names, cls.input_widths): input_vars.append(core.Variable(name, width)) input_vars = tuple(input_vars) table = context.MemoizationTable(id_prefix=id_prefix) with context.Memoization(table): # noinspection PyArgumentList output_vars = cls(*input_vars, symbolic_inputs=True) ssa_dict = { "input_vars": input_vars, "output_vars": output_vars, "assignments": tuple(table.items()) } for var, expr in ssa_dict["assignments"]: for arg in expr.args: if isinstance(arg, operation.Operation): raise ValueError( "assignment {} <- {} was not decomposed".format( var, expr)) to_delete = [] vars_needed = set() for var in output_vars: vars_needed.add(var) for var, expr in reversed(ssa_dict["assignments"]): if var in vars_needed: for arg in expr.atoms(core.Variable): vars_needed.add(arg) else: to_delete.append((var, expr)) # raise ValueError("assignment {} <- {} is redundant in \n{}".format(var, expr, ssa_dict)) if len(to_delete) > 0: import warnings warnings.warn("removing redundant assignments {} in \n{}".format( to_delete, ssa_dict)) ssa_dict["assignments"] = list(ssa_dict["assignments"]) for assignment in to_delete: ssa_dict["assignments"].remove(assignment) ssa_dict["assignments"] = tuple(ssa_dict["assignments"]) return ssa_dict
def _fast_empirical_weight_distribution(ch_found, cipher, rk_dict_diffs=None, verbose_lvl=0, debug=False, filename=None, precision=0): """ >>> from arxpy.differential.difference import XorDiff >>> from arxpy.differential.characteristic import SingleKeyCh >>> from arxpy.smt.search_differential import SearchSkCh >>> from arxpy.primitives import speck >>> from arxpy.smt.verification_differential import _fast_empirical_weight_distribution >>> Speck32 = speck.get_Speck_instance(speck.SpeckInstance.speck_32_64) >>> Speck32.set_rounds(1) >>> ch = SingleKeyCh(Speck32, XorDiff) >>> search_problem = SearchSkCh(ch) >>> ch_found = search_problem.solve(0) >>> _fast_empirical_weight_distribution(ch_found, Speck32) Counter({0: 256}) """ # similar to _empirical_distribution_weight of characteristic module from arxpy.smt.search_differential import _get_smart_print # avoid cyclic imports smart_print = _get_smart_print(filename) exact_weight = ch_found.get_exact_weight() if rk_dict_diffs is not None: assert "nonlinear_diffs" in rk_dict_diffs and "output_diff" in rk_dict_diffs if debug: smart_print("Symbolic characteristic:") smart_print(ch_found.ch) smart_print("Characteristic found:") smart_print(ch_found) if rk_dict_diffs is not None: smart_print("rk_dict_diffs:", rk_dict_diffs) smart_print() der_weights = [] for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()): actual_diff = ch_found.nonlinear_diffs[i][1] new_input_diff = [(d.xreplace(ch_found._diff_model)) for d in der.input_diff] der_weights.append( der._replace_input_diff(new_input_diff).exact_weight(actual_diff)) max_subch_weight = exact_weight if exact_weight < MAX_WEIGHT else exact_weight / ( exact_weight / MAX_WEIGHT) max_subch_weight = max(1, max_subch_weight, *der_weights) if debug: smart_print("max_subch_weight:", max_subch_weight) smart_print() subch_listdiffder = [[]] # for each subch, a list of [diff, der] pairs subch_index = 0 current_subch_weight = 0 # exact_weight subch_weight = [] # the weight of each subch assert len(ch_found.ch.nonlinear_diffs.items()) > 0 for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()): der_weight = der_weights[i] if current_subch_weight + der_weight > max_subch_weight: subch_weight.append(current_subch_weight) current_subch_weight = 0 subch_index += 1 subch_listdiffder.append([]) current_subch_weight += der_weight subch_listdiffder[subch_index].append([diff, der]) subch_weight.append(current_subch_weight) assert len(subch_weight) == len(subch_listdiffder) num_subch = len(subch_listdiffder) if verbose_lvl >= 3: smart_print( "- characteristic decomposed into {} subcharacteristics with exact weights {}" .format(num_subch, subch_weight)) if rk_dict_diffs is not None: rk_var = [var.val for var, _ in rk_dict_diffs["output_diff"]] else: rk_var = [] for i, width in enumerate(cipher.key_schedule.output_widths): rk_var.append(core.Variable("k" + str(i), width)) var2diffval = {} for diff_var, diff_value in itertools.chain(ch_found.input_diff, ch_found.nonlinear_diffs, ch_found.output_diff): var2diffval[diff_var.val] = diff_value.val if rk_dict_diffs is not None: for var, diff in rk_dict_diffs["output_diff"]: var2diffval[var.val] = diff.val for var, diff in rk_dict_diffs["nonlinear_diffs"]: var2diffval[var.val] = diff.val for var, diff in ch_found.ch._var2diff.items(): if var not in var2diffval: if isinstance(diff, core.Term): # e.g., symbolic computations with the key var2diffval[var] = diff.xreplace(var2diffval) else: var2diffval[var] = diff.val.xreplace(var2diffval) # for each related-key pair, we associated a pair of subch_ssa rkey2pair_subchssa = [None for _ in range(KEY_SAMPLES)] for key_index in range(KEY_SAMPLES): master_key = [] for width in cipher.key_schedule.input_widths: master_key.append(core.Constant(random.randrange(2**width), width)) rk_val = cipher.key_schedule(*master_key) if rk_dict_diffs is not None: rk_other_val = tuple([ d.get_pair_element(r) for r, (_, d) in zip(rk_val, rk_dict_diffs["output_diff"]) ]) else: rk_other_val = rk_val assert len(rk_var) == len(rk_other_val) assert all(isinstance(rk, core.Constant) for rk in rk_val) assert all(isinstance(rk, core.Constant) for rk in rk_other_val) def subch_listdiffder2subch_ssa(listdiffder, first_der_var_next_subch, var2val, first_subch=False): first_der_var = listdiffder[0][0].val input_vars = [] inter_vars = set() assignments = [] add_assignment = first_subch for var, expr in ch_found.ch.ssa["assignments"]: if var == first_der_var: add_assignment = True elif var == first_der_var_next_subch: break expr = expr.xreplace(var2val) if add_assignment: input_vars.extend([ atom for atom in expr.atoms(core.Variable) if atom not in input_vars ]) inter_vars.add(var) assignments.append([var, expr]) subch_ssa = {} subch_ssa["input_vars"] = [ var for var in input_vars if var not in inter_vars ] subch_ssa["output_vars"] = [] subch_ssa["inter_vars"] = inter_vars subch_ssa["assignments"] = assignments return subch_ssa pair_subchssa = [] for index_pair in range(2): current_rk_val = rk_val if index_pair == 0 else rk_other_val rkvar2rkval = { var: val for var, val in zip(rk_var, current_rk_val) } subch_ssa = [None for _ in range(num_subch)] for i in reversed(range(num_subch)): if i == num_subch - 1: subch_ssa[i] = subch_listdiffder2subch_ssa( subch_listdiffder[i], None, rkvar2rkval, i == 0) subch_ssa[i]["output_vars"] = list( ch_found.ch.ssa["output_vars"]) else: first_var_next_ssa = subch_listdiffder[i + 1][0][0].val subch_ssa[i] = subch_listdiffder2subch_ssa( subch_listdiffder[i], first_var_next_ssa, rkvar2rkval, i == 0) subch_ssa[i]["output_vars"] = subch_ssa[i + 1]["input_vars"][:] for var in subch_ssa[i]["output_vars"]: if var not in subch_ssa[i][ "inter_vars"] and var not in subch_ssa[i][ "input_vars"]: subch_ssa[i]["input_vars"].append(var) del subch_ssa[i]["inter_vars"] subch_ssa[i]["weight"] = subch_weight[i] for _, ssa in enumerate(subch_ssa): for j in range(len(ssa["output_vars"])): var_j = ssa["output_vars"][j] index_out = 0 if var_j in ssa["input_vars"]: new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width) index_out += 1 ssa["assignments"].append([new_var, var_j]) ssa["output_vars"][j] = new_var var2diffval[new_var] = var2diffval[var_j] for k in range(j + 1, len(ssa["output_vars"])): if var_j == ssa["output_vars"][k]: new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width) index_out += 1 ssa["assignments"].append([new_var, var_j]) ssa["output_vars"][k] = new_var var2diffval[new_var] = var2diffval[var_j] pair_subchssa.append(subch_ssa) assert len(pair_subchssa[0]) == len(pair_subchssa[1]) == num_subch rkey2pair_subchssa[key_index] = pair_subchssa # for each related-key pair, we associated the list of the weight of each subch rkey2subch_ew = [[0 for _ in range(num_subch)] for _ in range(KEY_SAMPLES)] # start multiprocessing with multiprocessing.Pool() as pool: for i in range(num_subch): for key_index in range(KEY_SAMPLES): ssa1 = rkey2pair_subchssa[key_index][0][i] ssa2 = rkey2pair_subchssa[key_index][1][i] if key_index <= 1: if verbose_lvl >= 2 and key_index == 0: smart_print("- sub-characteristic {}".format(i)) if verbose_lvl >= 3 and key_index == 0: smart_print(" - listdiffder:", subch_listdiffder[i]) if verbose_lvl >= 3: smart_print(" - related-key pair index", key_index) smart_print(" - ssa1:", ssa1) if ssa1 == ssa2: smart_print(" - ssa2: (same as ssa1)") else: smart_print(" - ssa2:", ssa2) if i > 0 and rkey2subch_ew[key_index][i - 1] == math.inf: rkey2subch_ew[key_index][i] = math.inf if key_index <= 1 and verbose_lvl >= 2: smart_print( " - rk{} | skipping since invalid sub-ch[{}]". format(key_index, i - 1)) continue if ssa1 == ssa2: ccode = ssa2ccode(ssa1, ch_found.ch.diff_type) else: ccode = relatedssa2ccode(ssa1, ssa2, ch_found.ch.diff_type) if key_index <= 1 and debug: smart_print(ccode[0]) smart_print(ccode[1]) smart_print() input_diff_c = [ v.xreplace(var2diffval) for v in ssa1["input_vars"] ] output_diff_c = [ v.xreplace(var2diffval) for v in ssa1["output_vars"] ] if key_index <= 1 and verbose_lvl >= 2: smart_print( " - rk{} | checking {} -> {} with weight {}".format( key_index, '|'.join([str(d) for d in input_diff_c]), '|'.join([str(d) for d in output_diff_c]), ssa1["weight"])) assert ssa1["weight"] == ssa2["weight"] assert all( isinstance(d, (int, core.Constant)) for d in input_diff_c), "{}".format(input_diff_c) assert all( isinstance(d, (int, core.Constant)) for d in output_diff_c), "{}".format(output_diff_c) input_diff_c = [int(d) for d in input_diff_c] output_diff_c = [int(d) for d in output_diff_c] rkey2subch_ew[key_index][i] = pool.apply_async( compile_run_empirical_weight, (ccode, "_libver" + ch_found.ch.func.__name__ + str(i), input_diff_c, output_diff_c, ssa1["weight"], False)) # wait until all i-subch have been compiled and run # and replace the Async object by the result for key_index in range(KEY_SAMPLES): if isinstance(rkey2subch_ew[key_index][i], multiprocessing.pool.AsyncResult): rkey2subch_ew[key_index][i] = rkey2subch_ew[key_index][ i].get() if key_index <= 1 and verbose_lvl >= 2: smart_print( " - rk{} | exact/empirical weight: {}, {}".format( key_index, subch_weight[i], rkey2subch_ew[key_index][i])) # end multiprocessing empirical_weight_distribution = collections.Counter() all_rkey_weights = [] for key_index in range(KEY_SAMPLES): rkey_weight = sum(rkey2subch_ew[key_index]) if precision == 0: weight = int(rkey_weight) if rkey_weight != math.inf else math.inf else: weight = round(rkey_weight, precision) all_rkey_weights.append(rkey_weight) empirical_weight_distribution[weight] += 1 if verbose_lvl >= 2: smart_print("- distribution empirical weights: {}".format( empirical_weight_distribution)) if verbose_lvl >= 3: smart_print("- list empirical weights:", [round(x, 8) for x in all_rkey_weights if x != math.inf]) return empirical_weight_distribution
def _generate(self): """Generate the SMT problem.""" self.assertions = [] # Forbid zero input difference with XOR difference if self.ch.diff_type == difference.XorDiff: if self.parent_ch is not None and self.parent_ch.outer_ch == self.ch: inner_noutputs = len(self.parent_ch.inner_ch.output_diff) non_zero_input_diff = self.ch.input_diff[:-inner_noutputs] else: non_zero_input_diff = self.ch.input_diff non_zero_input_diff = functools.reduce(operation.Concat, non_zero_input_diff) zero = core.Constant(0, non_zero_input_diff.width) self.assertions.append( operation.BvNot(operation.BvComp(non_zero_input_diff, zero))) # Assertions of the weights of the non-deterministic steps self.op_weights = [] for var, propagation in self.ch.items(): if isinstance(propagation, differential.Differential): self.assertions.append(propagation.is_valid()) weight_value = propagation.weight() weight_var = core.Variable(propagation._weight_var_name(), weight_value.width) self.assertions.append( operation.BvComp(weight_var, weight_value)) self.op_weights.append(weight_var) else: self.assertions.append(operation.BvComp(var, propagation)) # Characteristic weight assignment max_value = 0 for ow in self.op_weights: max_value += (2**ow.width) - 1 width = max(max_value.bit_length(), 1) # for trivial characteristic ext_op_weights = [] for ow in self.op_weights: ext_op_weights.append(operation.ZeroExtend(ow, width - ow.width)) name_ch_weight = "w_{}_{}".format( ''.join([str(i) for i in self.ch.input_diff]), ''.join([str(i) for i in self.ch.output_diff])) ch_weight = core.Variable(name_ch_weight, width) self.assertions.append(operation.BvComp(ch_weight, sum(ext_op_weights))) # Condition between the weight and the target weight weight_function = self.ch.get_weight_function() target_weight = int(weight_function(self.target_weight)) width = max(ch_weight.width, target_weight.bit_length()) self.ch_weight = operation.ZeroExtend(ch_weight, width - ch_weight.width) if self.equality: self.assertions.append( operation.BvComp(self.ch_weight, target_weight)) else: self.assertions.append( operation.BvUlt(self.ch_weight, target_weight)) self.assertions = tuple(self.assertions)
def to_Variable(self): """Convert the DiffVar to a Variable.""" return core.Variable(self.name, self.width)
def _fast_empirical_weight_distribution(ch_found, cipher, rk_dict_diffs=None, verbose_lvl=0, debug=False, filename=None, precision=0): """ >>> from arxpy.differential.difference import XorDiff >>> from arxpy.differential.characteristic import SingleKeyCh >>> from arxpy.smt.search_impossible import SearchSkID >>> from arxpy.primitives import speck >>> from arxpy.smt.verification_impossible import _fast_empirical_weight_distribution >>> Speck32 = speck.get_Speck_instance(speck.SpeckInstance.speck_32_64) >>> Speck32.set_rounds(1) >>> ch = SingleKeyCh(Speck32, XorDiff) >>> search_problem = SearchSkID(ch) >>> id_found = search_problem.solve(2) >>> _fast_empirical_weight_distribution(id_found, Speck32) Counter({inf: 256}) """ if rk_dict_diffs is not None: raise ValueError("rk_dict_diffs must be None") from arxpy.smt.search_differential import _get_smart_print # avoid cyclic imports smart_print = _get_smart_print(filename) # if rk_dict_diffs is not None: # assert "nonlinear_diffs" in rk_dict_diffs and "output_diff" in rk_dict_diffs if debug: smart_print("Symbolic characteristic:") smart_print(ch_found.ch) smart_print("ID found:") smart_print(ch_found) # if rk_dict_diffs is not None: # smart_print("rk_dict_diffs:", rk_dict_diffs) smart_print() # if rk_dict_diffs is not None: # rk_var = [var.val for var, _ in rk_dict_diffs["output_diff"]] # else: rk_var = [] for i, width in enumerate(cipher.key_schedule.output_widths): rk_var.append(core.Variable("k" + str(i), width)) var2diffval = {} for diff_var, diff_value in itertools.chain(ch_found.input_diff, ch_found.output_diff): var2diffval[diff_var.val] = diff_value.val # if rk_dict_diffs is not None: # for var, diff in rk_dict_diffs["output_diff"]: # var2diffval[var.val] = diff.val # for each related-key pair, we associated a pair of ssa rkey2pair_ssa = [None for _ in range(KEY_SAMPLES)] for key_index in range(KEY_SAMPLES): master_key = [] for width in cipher.key_schedule.input_widths: master_key.append(core.Constant(random.randrange(2 ** width), width)) rk_val = cipher.key_schedule(*master_key) # if rk_dict_diffs is not None: # rk_other_val = tuple([d.get_pair_element(r) for r, (_, d) in zip(rk_val, rk_dict_diffs["output_diff"])]) # else: rk_other_val = rk_val assert len(rk_var) == len(rk_other_val) assert all(isinstance(rk, core.Constant) for rk in rk_val) assert all(isinstance(rk, core.Constant) for rk in rk_other_val) def replace_roundkeys(var2val): new_ssa = ch_found.ch.ssa.copy() new_ssa["assignments"] = list(new_ssa["assignments"]) new_ssa["output_vars"] = list(new_ssa["output_vars"]) for i, (var, expr) in enumerate(ch_found.ch.ssa["assignments"]): new_ssa["assignments"][i] = (var, expr.xreplace(var2val)) return new_ssa pair_ssa = [] for index_pair in range(2): current_rk_val = rk_val if index_pair == 0 else rk_other_val rkvar2rkval = {var: val for var, val in zip(rk_var, current_rk_val)} ssa = replace_roundkeys(rkvar2rkval) for j in range(len(ssa["output_vars"])): var_j = ssa["output_vars"][j] index_out = 0 if var_j in ssa["input_vars"]: new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width) index_out += 1 ssa["assignments"].append([new_var, var_j]) ssa["output_vars"][j] = new_var var2diffval[new_var] = var2diffval[var_j] for k in range(j + 1, len(ssa["output_vars"])): if var_j == ssa["output_vars"][k]: new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width) index_out += 1 ssa["assignments"].append([new_var, var_j]) ssa["output_vars"][k] = new_var var2diffval[new_var] = var2diffval[var_j] pair_ssa.append(ssa) rkey2pair_ssa[key_index] = pair_ssa # for each related-key pair, we associated their weight rkey2subch_ew = [0 for _ in range(KEY_SAMPLES)] # start multiprocessing with multiprocessing.Pool() as pool: for key_index in range(KEY_SAMPLES): ssa1 = rkey2pair_ssa[key_index][0] ssa2 = rkey2pair_ssa[key_index][1] if key_index <= 1: if verbose_lvl >= 3: smart_print(" - related-key pair index", key_index) smart_print(" - ssa1:", ssa1) if ssa1 == ssa2: smart_print(" - ssa2: (same as ssa1)") else: smart_print(" - ssa2:", ssa2) if ssa1 == ssa2: ccode = ssa2ccode(ssa1, ch_found.ch.diff_type) else: ccode = relatedssa2ccode(ssa1, ssa2, ch_found.ch.diff_type) if key_index <= 1 and debug: smart_print(ccode[0]) smart_print(ccode[1]) smart_print() input_diff_c = [v.xreplace(var2diffval) for v in ssa1["input_vars"]] output_diff_c = [v.xreplace(var2diffval) for v in ssa1["output_vars"]] if key_index <= 1 and verbose_lvl >= 2: smart_print(" - rk{} | checking {} -> {} with pairs 2**{}".format( key_index, '|'.join([str(d) for d in input_diff_c]), '|'.join([str(d) for d in output_diff_c]), MAX_WEIGHT)) assert all(isinstance(d, (int, core.Constant)) for d in input_diff_c), "{}".format(input_diff_c) assert all(isinstance(d, (int, core.Constant)) for d in output_diff_c), "{}".format(output_diff_c) input_diff_c = [int(d) for d in input_diff_c] output_diff_c = [int(d) for d in output_diff_c] rkey2subch_ew[key_index] = pool.apply_async( compile_run_empirical_weight, ( ccode, "_libver" + ch_found.ch.func.__name__, input_diff_c, output_diff_c, MAX_WEIGHT, False ) ) # wait until all have been compiled and run # and replace the Async object by the result for key_index in range(KEY_SAMPLES): if isinstance(rkey2subch_ew[key_index], multiprocessing.pool.AsyncResult): rkey2subch_ew[key_index] = rkey2subch_ew[key_index].get() if key_index <= 1 and verbose_lvl >= 2: smart_print(" - rk{} | empirical weight: {}".format( key_index, rkey2subch_ew[key_index])) # end multiprocessing empirical_weight_distribution = collections.Counter() all_rkey_weights = [] for key_index in range(KEY_SAMPLES): rkey_weight = rkey2subch_ew[key_index] if precision == 0: weight = int(rkey_weight) if rkey_weight != math.inf else math.inf else: weight = round(rkey_weight, precision) all_rkey_weights.append(rkey_weight) empirical_weight_distribution[weight] += 1 if verbose_lvl >= 2: smart_print("- distribution empirical weights: {}".format(empirical_weight_distribution)) if verbose_lvl >= 3: smart_print("- list empirical weights:", [round(x, 8) for x in all_rkey_weights if x != math.inf]) return empirical_weight_distribution