def __new__(cls, *args, **options): val_op = options.pop("validate_operands", context.Validation.current_context) evaluate = options.pop("evaluate", context.Evaluation.current_context) simplify = options.pop("simplify", context.Simplification.current_context) st = options.pop("state", context.Memoization.current_context) noteval = _tuplify( options.pop("notevaluate", context.NotEvaluation.current_context)) if val_op: args = cls._parse_args(*args) if st is not None: newargs = [] for arg in args: if isinstance(arg, Operation) and st.contain_op(arg): newargs.append(st.get_id(arg)) else: newargs.append(arg) args = newargs width = cls.output_width(*args) if noteval is None: noteval = [] with context.Memoization(None): if evaluate and not (cls in noteval): result = cls.eval(*args) else: result = None if result is not None: # result is already a Term/Operation (possibly simplified) obj = result else: obj = super().__new__(cls, *args, width=width) if isinstance(obj, Operation) and simplify and evaluate: with context.Simplification(False), context.Memoization(None): while True: obj, modified = obj._simplify() if not modified or not isinstance(obj, Operation): break if isinstance(obj, Operation) and st is not None: for arg in obj.args: if isinstance(arg, Operation): raise ValueError("arg {} of {} was not memoized".format( arg, obj)) if st.contain_op(obj): return st.get_id(obj) else: return st.add_op(obj) return obj
def __new__(cls, *args, **options): """Create the object.""" val_op = options.pop("validate_operands", context.Validation.current_context) evaluate = options.pop("evaluate", context.Evaluation.current_context) simplify = options.pop("simplify", context.Simplification.current_context) st = options.pop("state", context.StatefulExecution.current_context) if val_op: args = cls._parse_args(*args) if st is not None: newargs = [] for arg in args: if isinstance(arg, Operation) and st.contain_op(arg): newargs.append(st.get_id(arg)) else: newargs.append(arg) args = newargs width = cls.output_width(*args) with context.StatefulExecution(None): if evaluate: result = cls.eval(*args) else: result = None if result is not None: return result obj = super().__new__(cls, *args, **options, width=width) if isinstance(obj, Operation) and st is not None: if st.contain_op(obj): return st.get_id(obj) else: return st.add_op(obj) if isinstance(obj, Operation) and simplify and evaluate: with context.Simplification(False): while True: obj, modified = obj._simplify() if not modified or not isinstance(obj, Operation): break return obj
def _empirical_weight_distribution(self, cipher, input_diff, output_diff, pair_samples, key_samples, precision=1, rk_output_diff=None): # this function is not part of SingleKeyCh since it must be accessible # for the encryption characteristic of RelatedKeyCh (which is a # plain BvCharacteristic) assert isinstance(input_diff, collections.abc.Sequence) assert isinstance(output_diff, collections.abc.Sequence) assert all(isinstance(d, difference.Difference) for d in input_diff) assert all(isinstance(d, difference.Difference) for d in output_diff) assert all(isinstance(d.val, core.Constant) for d in input_diff) assert all(isinstance(d.val, core.Constant) for d in output_diff) assert len(input_diff) == len(self.input_diff) assert len(output_diff) == len(self.output_diff) assert len(self.ssa["input_vars"]) == len(input_diff) assert len(self.ssa["output_vars"]) == len(output_diff) old_round_keys = self.func.round_keys empirical_weights = collections.Counter() if rk_output_diff is not None: class RelatedFunc(self.func): pass else: RelatedFunc = self.func with context.Simplification(False): input_widths = [d.val.width for d in self.input_diff] if pair_samples >= 2**sum(input_widths): iterators = [range(2 ** w) for w in input_widths] list_pairs = [] for x in itertools.product(*iterators): pt = [core.Constant(x_i, w) for x_i, w in zip(x, input_widths)] other_pt = [diff.get_pair_element(pt[i]) for i, diff in enumerate(input_diff)] list_pairs.append([pt, other_pt]) pair_samples = len(list_pairs) assert pair_samples == 2**sum(input_widths) else: list_pairs = [] for _ in range(pair_samples): pt = [] other_pt = [] for diff in input_diff: random_int = random.randrange(2 ** diff.val.width) random_bv = core.Constant(random_int, diff.val.width) pt.append(random_bv) other_pt.append(diff.get_pair_element(random_bv)) list_pairs.append([pt, other_pt]) for _ in range(key_samples): master_key = [] for width in cipher.key_schedule.input_widths: master_key.append(core.Constant(random.randrange(2 ** width), width)) self.func.round_keys = cipher.key_schedule(*master_key) assert all(isinstance(rk, core.Constant) for rk in self.func.round_keys), str(self.func.round_keys) if rk_output_diff is not None: RelatedFunc.round_keys = [d.get_pair_element(r) for r, d in zip(self.func.round_keys, rk_output_diff)] assert all(isinstance(rk, core.Constant) for rk in RelatedFunc.round_keys), str(RelatedFunc.round_keys) correct_pairs = 0 for index_input in range(pair_samples): pt, other_pt = list_pairs[index_input] ct = self.func(*pt) other_ct = RelatedFunc(*other_pt) assert all(isinstance(x, core.Constant) for x in ct), str(ct) assert all(isinstance(x, core.Constant) for x in other_ct), str(other_ct) for i, diff in enumerate(output_diff): # noinspection PyUnresolvedReferences if self.diff_type.from_pair(ct[i], other_ct[i]) != diff: break else: correct_pairs += 1 if correct_pairs == 0: weight = math.inf else: weight = abs(-math.log(correct_pairs * 1.0 / pair_samples, 2)) # weight = float(("{0:."+str(precision)+"f}").format(weight)) weight = round(weight, precision) empirical_weights[weight] += 1 self.func.round_keys = old_round_keys return empirical_weights
def empirical_weight(self, input_diff, output_diff, pair_samples): """Return the empirical weight of a given differential. Given a differential (a pair of input and output differences), the differential probability is the fraction of input pairs with the given input difference leading to output pairs with the given output difference. This method returns an approximation of the weight of the differential probability by sampling a given number of input pairs. If no correct output pairs are found, `math.inf` is returned. >>> from arxpy.bitvector.core import Constant >>> from arxpy.differential.difference import XorDiff, RXDiff >>> from arxpy.differential.characteristic import BvCharacteristic >>> from arxpy.primitives.chaskey import ChaskeyPi >>> ChaskeyPi.set_rounds(1) >>> ch = BvCharacteristic(ChaskeyPi, XorDiff, ["dv" + str(i) for i in range(4)]) >>> zero, one = XorDiff(Constant(0, 32)), XorDiff(Constant(1, 32)) >>> ch.empirical_weight([zero, zero, zero, zero], [zero, zero, zero, zero], 100) 0.0 >>> ch.empirical_weight([zero, zero, zero, zero], [one, one, one, one], 100) inf >>> ch = BvCharacteristic(ChaskeyPi, RXDiff, ["dv" + str(i) for i in range(4)]) >>> zero, one = RXDiff(Constant(0, 32)), RXDiff(Constant(1, 32)) >>> 4 - 1 <= ch.empirical_weight([zero]*4, [zero]*4, 3 * 2**6) <= 8 True >>> ch.empirical_weight([zero]*4, [one]*4, 3 * 2**6) inf """ assert isinstance(input_diff, collections.abc.Sequence) assert isinstance(output_diff, collections.abc.Sequence) assert all(isinstance(d, difference.Difference) for d in input_diff) assert all(isinstance(d, difference.Difference) for d in output_diff) assert all(isinstance(d.val, core.Constant) for d in input_diff) assert all(isinstance(d.val, core.Constant) for d in output_diff) assert len(input_diff) == len(self.input_diff) assert len(output_diff) == len(self.output_diff) assert len(self.ssa["input_vars"]) == len(input_diff) assert len(self.ssa["output_vars"]) == len(output_diff) with context.Simplification(False): input_widths = [d.val.width for d in self.input_diff] if pair_samples >= 2**sum(input_widths): iterators = [range(2 ** w) for w in input_widths] list_pairs = [] for x in itertools.product(*iterators): pt = [core.Constant(x_i, w) for x_i, w in zip(x, input_widths)] other_pt = [diff.get_pair_element(pt[i]) for i, diff in enumerate(input_diff)] list_pairs.append([pt, other_pt]) pair_samples = len(list_pairs) assert pair_samples == 2**sum(input_widths) else: list_pairs = [] for _ in range(pair_samples): pt = [] other_pt = [] for diff in input_diff: random_int = random.randrange(2 ** diff.val.width) random_bv = core.Constant(random_int, diff.val.width) pt.append(random_bv) other_pt.append(diff.get_pair_element(random_bv)) list_pairs.append([pt, other_pt]) correct_pairs = 0 for index_input in range(pair_samples): pt, other_pt = list_pairs[index_input] ct = self.func(*pt) other_ct = self.func(*other_pt) assert all(isinstance(x, core.Constant) for x in ct), str(ct) assert all(isinstance(x, core.Constant) for x in other_ct), str(other_ct) for i, diff in enumerate(output_diff): # noinspection PyUnresolvedReferences if self.diff_type.from_pair(ct[i], other_ct[i]) != diff: break else: correct_pairs += 1 if correct_pairs == 0: weight = math.inf else: weight = abs(-math.log(correct_pairs * 1.0 / pair_samples, 2)) return weight
def empirical_weight(self, differences, constant_conversion=True, theoretical_weight=None): """Return the empirical weight for a particular sequence of differences. The probability of a characteristic *p* is the probability that a random input pair with given input difference follows the characteristic. The weight is the -log(p) and the empirical weight is the weight computed by sampling random pairs and counting the correct ones. >>> from arxpy.bitvector.core import Constant >>> from arxpy.bitvector.function import Function >>> from arxpy.diffcrypt.difference import XorDiff, DiffVar >>> from arxpy.diffcrypt.characteristic import Characteristic >>> class MyFunction(Function): ... input_widths = [8, 8, 8] ... output_widths = [8, 8] ... @classmethod ... def eval(cls, x, y, k): ... return (y + k, (y + k) ^ x) >>> x, y, k = DiffVar("x", 8), DiffVar("y", 8), DiffVar("k", 8) >>> ch = Characteristic(MyFunction, XorDiff, [x, y, k]) >>> zero = Constant(0, 8) >>> ch.empirical_weight([zero for d in ch.sequence]) # doctest:+SKIP 0 The theoretical weight can be specified to adjust the number of pairs will be sampled. Note that math.inf is returned if no correct pairs are found. """ assert len(differences) == len(self.sequence) assert all(isinstance(d, core.Constant) for d in differences) min_correct_pairs = 3 if theoretical_weight is not None and theoretical_weight < 32: min_pairs = 2**theoretical_weight max_pairs = 6 * min_pairs else: min_pairs = 2**10 max_pairs = 2**20 input_vars = self.func._symbolic_input() sym_exec = self.func.symbolic_execution(*input_vars) correct_pairs = 0 total_pairs = 0 with contextlib.ExitStack() as stack: stack.enter_context(context.Simplification(False)) if not constant_conversion: stack.enter_context(context.Validation(False)) while correct_pairs < min_correct_pairs or total_pairs < min_pairs: if total_pairs > max_pairs: break total_pairs += 1 exec_state1 = {} exec_state2 = {} for v, diff in zip(input_vars, differences): random_int = random.randrange(2**v.width) exec_state1[v] = core.Constant(random_int, v.width) exec_state2[v] = self.diff_type.get_pair_element( exec_state1[v], diff) for (identifier, op), diff in zip(sym_exec[1].items(), differences[len(input_vars):]): x = op.xreplace(exec_state1) y = op.xreplace(exec_state2) exec_state1[identifier] = x exec_state2[identifier] = y if self.diff_type.get_difference(x, y) != diff: break else: correct_pairs += 1 if correct_pairs == 0: return math.inf elif total_pairs == correct_pairs: return 0 else: return -math.log(correct_pairs / total_pairs, 2)