예제 #1
0
    def __new__(cls, *args, **options):
        val_op = options.pop("validate_operands",
                             context.Validation.current_context)
        evaluate = options.pop("evaluate", context.Evaluation.current_context)
        simplify = options.pop("simplify",
                               context.Simplification.current_context)
        st = options.pop("state", context.Memoization.current_context)
        noteval = _tuplify(
            options.pop("notevaluate", context.NotEvaluation.current_context))

        if val_op:
            args = cls._parse_args(*args)

        if st is not None:
            newargs = []
            for arg in args:
                if isinstance(arg, Operation) and st.contain_op(arg):
                    newargs.append(st.get_id(arg))
                else:
                    newargs.append(arg)
            args = newargs

        width = cls.output_width(*args)

        if noteval is None:
            noteval = []

        with context.Memoization(None):
            if evaluate and not (cls in noteval):
                result = cls.eval(*args)
            else:
                result = None

        if result is not None:
            # result is already a Term/Operation (possibly simplified)
            obj = result
        else:
            obj = super().__new__(cls, *args, width=width)

            if isinstance(obj, Operation) and simplify and evaluate:
                with context.Simplification(False), context.Memoization(None):
                    while True:
                        obj, modified = obj._simplify()
                        if not modified or not isinstance(obj, Operation):
                            break

        if isinstance(obj, Operation) and st is not None:
            for arg in obj.args:
                if isinstance(arg, Operation):
                    raise ValueError("arg {} of {} was not memoized".format(
                        arg, obj))
            if st.contain_op(obj):
                return st.get_id(obj)
            else:
                return st.add_op(obj)

        return obj
예제 #2
0
    def __new__(cls, *args, **options):
        """Create the object."""
        val_op = options.pop("validate_operands",
                             context.Validation.current_context)
        evaluate = options.pop("evaluate", context.Evaluation.current_context)
        simplify = options.pop("simplify", context.Simplification.current_context)
        st = options.pop("state", context.StatefulExecution.current_context)

        if val_op:
            args = cls._parse_args(*args)

        if st is not None:
            newargs = []
            for arg in args:
                if isinstance(arg, Operation) and st.contain_op(arg):
                    newargs.append(st.get_id(arg))
                else:
                    newargs.append(arg)
            args = newargs

        width = cls.output_width(*args)

        with context.StatefulExecution(None):
            if evaluate:
                result = cls.eval(*args)
            else:
                result = None

        if result is not None:
            return result

        obj = super().__new__(cls, *args, **options, width=width)

        if isinstance(obj, Operation) and st is not None:
            if st.contain_op(obj):
                return st.get_id(obj)
            else:
                return st.add_op(obj)

        if isinstance(obj, Operation) and simplify and evaluate:
            with context.Simplification(False):
                while True:
                    obj, modified = obj._simplify()
                    if not modified or not isinstance(obj, Operation):
                        break

        return obj
예제 #3
0
    def _empirical_weight_distribution(self, cipher, input_diff, output_diff, pair_samples, key_samples,
                                       precision=1, rk_output_diff=None):
        # this function is not part of SingleKeyCh since it must be accessible
        # for the encryption characteristic of RelatedKeyCh (which is a
        # plain BvCharacteristic)
        assert isinstance(input_diff, collections.abc.Sequence)
        assert isinstance(output_diff, collections.abc.Sequence)
        assert all(isinstance(d, difference.Difference) for d in input_diff)
        assert all(isinstance(d, difference.Difference) for d in output_diff)
        assert all(isinstance(d.val, core.Constant) for d in input_diff)
        assert all(isinstance(d.val, core.Constant) for d in output_diff)

        assert len(input_diff) == len(self.input_diff)
        assert len(output_diff) == len(self.output_diff)
        assert len(self.ssa["input_vars"]) == len(input_diff)
        assert len(self.ssa["output_vars"]) == len(output_diff)

        old_round_keys = self.func.round_keys

        empirical_weights = collections.Counter()

        if rk_output_diff is not None:
            class RelatedFunc(self.func):
                pass
        else:
            RelatedFunc = self.func

        with context.Simplification(False):
            input_widths = [d.val.width for d in self.input_diff]
            if pair_samples >= 2**sum(input_widths):
                iterators = [range(2 ** w) for w in input_widths]
                list_pairs = []
                for x in itertools.product(*iterators):
                    pt = [core.Constant(x_i, w) for x_i, w in zip(x, input_widths)]
                    other_pt = [diff.get_pair_element(pt[i]) for i, diff in enumerate(input_diff)]
                    list_pairs.append([pt, other_pt])
                pair_samples = len(list_pairs)
                assert pair_samples == 2**sum(input_widths)
            else:
                list_pairs = []
                for _ in range(pair_samples):
                    pt = []
                    other_pt = []
                    for diff in input_diff:
                        random_int = random.randrange(2 ** diff.val.width)
                        random_bv = core.Constant(random_int, diff.val.width)
                        pt.append(random_bv)
                        other_pt.append(diff.get_pair_element(random_bv))
                    list_pairs.append([pt, other_pt])

            for _ in range(key_samples):
                master_key = []
                for width in cipher.key_schedule.input_widths:
                    master_key.append(core.Constant(random.randrange(2 ** width), width))
                self.func.round_keys = cipher.key_schedule(*master_key)
                assert all(isinstance(rk, core.Constant) for rk in self.func.round_keys), str(self.func.round_keys)

                if rk_output_diff is not None:
                    RelatedFunc.round_keys = [d.get_pair_element(r) for r, d in zip(self.func.round_keys, rk_output_diff)]
                    assert all(isinstance(rk, core.Constant) for rk in RelatedFunc.round_keys), str(RelatedFunc.round_keys)

                correct_pairs = 0

                for index_input in range(pair_samples):
                    pt, other_pt = list_pairs[index_input]
                    ct = self.func(*pt)
                    other_ct = RelatedFunc(*other_pt)

                    assert all(isinstance(x, core.Constant) for x in ct), str(ct)
                    assert all(isinstance(x, core.Constant) for x in other_ct), str(other_ct)

                    for i, diff in enumerate(output_diff):
                        # noinspection PyUnresolvedReferences
                        if self.diff_type.from_pair(ct[i], other_ct[i]) != diff:
                            break
                    else:
                        correct_pairs += 1

                if correct_pairs == 0:
                    weight = math.inf
                else:
                    weight = abs(-math.log(correct_pairs * 1.0 / pair_samples, 2))
                # weight = float(("{0:."+str(precision)+"f}").format(weight))
                weight = round(weight, precision)
                empirical_weights[weight] += 1

        self.func.round_keys = old_round_keys

        return empirical_weights
예제 #4
0
    def empirical_weight(self, input_diff, output_diff, pair_samples):
        """Return the empirical weight of a given differential.

        Given a differential (a pair of input and output differences),
        the differential probability is the fraction of input pairs
        with the given input difference leading to output pairs
        with the given output difference.

        This method returns an approximation of the weight of the
        differential probability by sampling a given number
        of input pairs.

        If no correct output pairs are found, `math.inf` is returned.

            >>> from arxpy.bitvector.core import Constant
            >>> from arxpy.differential.difference import XorDiff, RXDiff
            >>> from arxpy.differential.characteristic import BvCharacteristic
            >>> from arxpy.primitives.chaskey import ChaskeyPi
            >>> ChaskeyPi.set_rounds(1)
            >>> ch = BvCharacteristic(ChaskeyPi, XorDiff, ["dv" + str(i) for i in range(4)])
            >>> zero, one = XorDiff(Constant(0, 32)), XorDiff(Constant(1, 32))
            >>> ch.empirical_weight([zero, zero, zero, zero], [zero, zero, zero, zero], 100)
            0.0
            >>> ch.empirical_weight([zero, zero, zero, zero], [one, one, one, one], 100)
            inf
            >>> ch = BvCharacteristic(ChaskeyPi, RXDiff, ["dv" + str(i) for i in range(4)])
            >>> zero, one = RXDiff(Constant(0, 32)), RXDiff(Constant(1, 32))
            >>> 4 - 1 <= ch.empirical_weight([zero]*4, [zero]*4, 3 * 2**6) <= 8
            True
            >>> ch.empirical_weight([zero]*4, [one]*4, 3 * 2**6)
            inf

        """
        assert isinstance(input_diff, collections.abc.Sequence)
        assert isinstance(output_diff, collections.abc.Sequence)
        assert all(isinstance(d, difference.Difference) for d in input_diff)
        assert all(isinstance(d, difference.Difference) for d in output_diff)
        assert all(isinstance(d.val, core.Constant) for d in input_diff)
        assert all(isinstance(d.val, core.Constant) for d in output_diff)

        assert len(input_diff) == len(self.input_diff)
        assert len(output_diff) == len(self.output_diff)
        assert len(self.ssa["input_vars"]) == len(input_diff)
        assert len(self.ssa["output_vars"]) == len(output_diff)

        with context.Simplification(False):
            input_widths = [d.val.width for d in self.input_diff]
            if pair_samples >= 2**sum(input_widths):
                iterators = [range(2 ** w) for w in input_widths]
                list_pairs = []
                for x in itertools.product(*iterators):
                    pt = [core.Constant(x_i, w) for x_i, w in zip(x, input_widths)]
                    other_pt = [diff.get_pair_element(pt[i]) for i, diff in enumerate(input_diff)]
                    list_pairs.append([pt, other_pt])
                pair_samples = len(list_pairs)
                assert pair_samples == 2**sum(input_widths)
            else:
                list_pairs = []
                for _ in range(pair_samples):
                    pt = []
                    other_pt = []
                    for diff in input_diff:
                        random_int = random.randrange(2 ** diff.val.width)
                        random_bv = core.Constant(random_int, diff.val.width)
                        pt.append(random_bv)
                        other_pt.append(diff.get_pair_element(random_bv))
                    list_pairs.append([pt, other_pt])

            correct_pairs = 0

            for index_input in range(pair_samples):
                pt, other_pt = list_pairs[index_input]
                ct = self.func(*pt)
                other_ct = self.func(*other_pt)

                assert all(isinstance(x, core.Constant) for x in ct), str(ct)
                assert all(isinstance(x, core.Constant) for x in other_ct), str(other_ct)

                for i, diff in enumerate(output_diff):
                    # noinspection PyUnresolvedReferences
                    if self.diff_type.from_pair(ct[i], other_ct[i]) != diff:
                        break
                else:
                    correct_pairs += 1

            if correct_pairs == 0:
                weight = math.inf
            else:
                weight = abs(-math.log(correct_pairs * 1.0 / pair_samples, 2))

        return weight
예제 #5
0
    def empirical_weight(self,
                         differences,
                         constant_conversion=True,
                         theoretical_weight=None):
        """Return the empirical weight for a particular sequence of differences.

        The probability of a characteristic *p* is the probability that
        a random input pair with given input difference follows
        the characteristic.  The weight is the -log(p) and the
        empirical weight is the weight computed by sampling random
        pairs and counting the correct ones.

            >>> from arxpy.bitvector.core import Constant
            >>> from arxpy.bitvector.function import Function
            >>> from arxpy.diffcrypt.difference import XorDiff, DiffVar
            >>> from arxpy.diffcrypt.characteristic import Characteristic
            >>> class MyFunction(Function):
            ...     input_widths = [8, 8, 8]
            ...     output_widths = [8, 8]
            ...     @classmethod
            ...     def eval(cls, x, y, k):
            ...         return (y + k, (y + k) ^ x)
            >>> x, y, k = DiffVar("x", 8), DiffVar("y", 8), DiffVar("k", 8)
            >>> ch = Characteristic(MyFunction, XorDiff, [x, y, k])
            >>> zero = Constant(0, 8)
            >>> ch.empirical_weight([zero for d in ch.sequence])  # doctest:+SKIP
            0

        The theoretical weight can be specified to adjust the
        number of pairs will be sampled.

        Note that math.inf is returned if no correct pairs are found.

        """
        assert len(differences) == len(self.sequence)
        assert all(isinstance(d, core.Constant) for d in differences)

        min_correct_pairs = 3
        if theoretical_weight is not None and theoretical_weight < 32:
            min_pairs = 2**theoretical_weight
            max_pairs = 6 * min_pairs
        else:
            min_pairs = 2**10
            max_pairs = 2**20

        input_vars = self.func._symbolic_input()
        sym_exec = self.func.symbolic_execution(*input_vars)

        correct_pairs = 0
        total_pairs = 0

        with contextlib.ExitStack() as stack:
            stack.enter_context(context.Simplification(False))
            if not constant_conversion:
                stack.enter_context(context.Validation(False))

            while correct_pairs < min_correct_pairs or total_pairs < min_pairs:
                if total_pairs > max_pairs:
                    break

                total_pairs += 1

                exec_state1 = {}
                exec_state2 = {}

                for v, diff in zip(input_vars, differences):
                    random_int = random.randrange(2**v.width)
                    exec_state1[v] = core.Constant(random_int, v.width)
                    exec_state2[v] = self.diff_type.get_pair_element(
                        exec_state1[v], diff)

                for (identifier,
                     op), diff in zip(sym_exec[1].items(),
                                      differences[len(input_vars):]):
                    x = op.xreplace(exec_state1)
                    y = op.xreplace(exec_state2)
                    exec_state1[identifier] = x
                    exec_state2[identifier] = y

                    if self.diff_type.get_difference(x, y) != diff:
                        break
                else:
                    correct_pairs += 1

        if correct_pairs == 0:
            return math.inf
        elif total_pairs == correct_pairs:
            return 0
        else:
            return -math.log(correct_pairs / total_pairs, 2)