Example #1
0
File: types.py Project: ranea/ArxPy
def pysmt2bv(ps):
    """Convert a pySMT type to a bit-vector type.

    Currently, only conversion from `pysmt.shortcuts.BV`
    and `pysmt.shortcuts.Symbol` is supported.

        >>> from pysmt import shortcuts, typing  # pySMT shortcuts and typing modules
        >>> from arxpy.smt.types import pysmt2bv
        >>> env = shortcuts.reset_env()
        >>> pysmt2bv(env.formula_manager.Symbol("x", env.type_manager.BVType(8))).vrepr()
        "Variable('x', width=8)"
        >>> pysmt2bv(env.formula_manager.BV(1, 8)).vrepr()
        'Constant(0b00000001, width=8)'

    """
    class_name = type(ps).__name__
    msg = "unknown conversion of {} ({} {}) to a bit-vector type".format(
        ps, ps.get_type(), class_name)

    if ps.is_symbol():
        if str(ps.get_type()) == "Bool":
            return core.Variable(ps.symbol_name(), 1)
        else:
            return core.Variable(ps.symbol_name(), ps.bv_width())
    elif ps.is_bv_constant():
        return core.Constant(int(ps.constant_value()), ps.bv_width())
    elif ps.is_false():
        return core.Constant(0, 1)
    elif ps.is_true():
        return core.Constant(1, 1)
    else:
        raise NotImplementedError(msg)
Example #2
0
 def lz(x):
     if are_cte_differences:
         return extraop.LeadingZeros(x)
     else:
         aux = core.Variable("{}_{}lz".format(prefix, self._i_auxvar), x.width)
         self._i_auxvar += 1
         assertions.append(operation.BvComp(aux, extraop.LeadingZeros(x)))
         return aux
Example #3
0
 def rev(x):
     if are_cte_differences:
         return extraop.Reverse(x)
     else:
         aux = core.Variable("{}_{}rev".format(prefix, self._i_auxvar), x.width)
         self._i_auxvar += 1
         assertions.append(operation.BvComp(aux, extraop.Reverse(x)))
         return aux
Example #4
0
    def _symbolic_input(cls, symbol_prefix="i"):
        """Return a tuple of variables with proper input widths."""
        in_vars = []

        for i, width in enumerate(cls.input_widths):
            name = "{}{}".format(symbol_prefix, i)
            in_vars.append(core.Variable(name, width))

        return tuple(in_vars)
Example #5
0
    def add_op(self, expr):
        """Add an bit-vector expression and return its identifier."""
        from arxpy.bitvector import operation
        assert isinstance(expr, operation.Operation)
        assert not self.contain_op(expr)
        name = "{}{}".format(self.id_prefix, self.counter)
        self.counter += 1
        identifier = core.Variable(name, expr.width)
        self.table[identifier] = expr

        return identifier
Example #6
0
    def add_op(self, op):
        """Add an operation to the state and return its identifier."""
        from arxpy.bitvector import operation
        assert isinstance(op, operation.Operation)
        assert not self.contain_op(op)
        name = "{}{}".format(self.id_prefix, self.counter)
        self.counter += 1
        identifier = core.Variable(name, op.width)
        self.table[identifier] = op

        return identifier
Example #7
0
    def __init__(self, bv_cipher, diff_type):
        assert issubclass(bv_cipher, primitives.Cipher)
        assert issubclass(diff_type, difference.Difference)

        rk = []
        for i, width in enumerate(bv_cipher.key_schedule.output_widths):
            rk.append(core.Variable("k" + str(i), width))

        class Encryption(bv_cipher.encryption):
            round_keys = tuple(rk)

        func = Encryption
        num_inputs = len(func.input_widths)
        input_diff_names = ["dp" + str(i) for i in range(num_inputs)]
        prefix = "dx"
        super().__init__(func, diff_type, input_diff_names, prefix)
        self._cipher = bv_cipher
Example #8
0
    def __init__(self, func, diff_type, input_diff_names, prefix="d", initial_var2diff=None):
        assert issubclass(func, primitives.BvFunction)
        assert issubclass(diff_type, difference.Difference)

        assert len(input_diff_names) == len(func.input_widths)
        input_diff = []
        for name, width in zip(input_diff_names, func.input_widths):
            input_diff.append(diff_type(core.Variable(name, width)))
        input_diff = tuple(input_diff)

        self.func = func
        self.diff_type = diff_type
        self.input_diff = input_diff

        # Propagate the input difference through the function

        names = [d.val.name for d in self.input_diff]
        ssa = self.func.ssa(names, id_prefix=prefix)
        self.ssa = ssa
        self._prefix = prefix
        self._input_diff_names = input_diff_names

        for var in ssa["output_vars"]:
            if isinstance(var, core.Constant):
                raise ValueError("constant outputs (independent of the inputs) are not supported")

        var2diff = {}  # Variable to Difference
        for var, diff in zip(ssa["input_vars"], self.input_diff):
            var2diff[var] = diff

        if initial_var2diff is not None:
            for var in initial_var2diff:
                if str(var) in names:
                    raise ValueError("the input differences cannot be replaced by initial_var2diff")
            var2diff.update(initial_var2diff)

        self.nonlinear_diffs = collections.OrderedDict()
        for var, expr in ssa["assignments"]:
            expr_args = []
            for arg in expr.args:
                if isinstance(arg, int):
                    expr_args.append(arg)  # 'int' object has no attribute 'xreplace'
                else:
                    expr_args.append(arg.xreplace(var2diff))

            if all(not isinstance(arg, diff_type) for arg in expr_args):
                # symbolic computations with the key
                var2diff[var] = expr
                continue

            if all(isinstance(arg, diff_type) for arg in expr_args):
                der = self.diff_type.derivative(type(expr), expr_args)
            else:
                def contains_key_var(term):
                    from sympy.core import basic
                    for sub in basic.preorder_traversal(term):
                        if sub in func.round_keys:
                            return True
                    else:
                        return False

                if type(expr) == operation.BvAdd and hasattr(func, 'round_keys') and \
                        all(isinstance(r, core.Variable) for r in func.round_keys) and \
                        any(contains_key_var(a) for a in expr_args):
                    # temporary solution to Derivative(BvAddCte_k(x)) != Derivative(x + k)
                    # with x a Diff and k a key variable
                    keyed_indices = []
                    for i, a in enumerate(expr_args):
                        if contains_key_var(a):
                            keyed_indices.append(i)
                    if len(keyed_indices) != 1 or expr_args[keyed_indices[0]] not in func.round_keys:
                        raise NotImplementedError("invalid expression: op={}, args={}".format(
                            type(expr).__name__, expr_args))
                    # expr_args[keyed_indices[0]] replaced to the zero diff
                    zero_diff = diff_type(core.Constant(0, expr_args[keyed_indices[0]].width))
                    der = self.diff_type.derivative(type(expr), [expr_args[(keyed_indices[0] + 1) % 2], zero_diff])
                elif hasattr(expr, "xor_derivative"):
                    # temporary solution to operations containing a custom derivative
                    input_diff_expr = []
                    for i, arg in enumerate(expr_args):
                        if isinstance(arg, diff_type):
                            input_diff_expr.append(arg)
                        else:
                            assert isinstance(arg, core.Term)  # int arguments currently not supported
                            input_diff_expr.append(diff_type.from_pair(arg, arg))
                    der = self.diff_type.derivative(type(expr), input_diff_expr)
                else:
                    fixed_args = []
                    for i, arg in enumerate(expr_args):
                        if not isinstance(arg, diff_type):
                            fixed_args.append(arg)
                        else:
                            fixed_args.append(None)
                    new_op = extraop.make_partial_operation(type(expr), tuple(fixed_args))
                    der = self.diff_type.derivative(new_op, [arg for arg in expr_args if isinstance(arg, diff_type)])

            if isinstance(der, derivative.Derivative):
                diff = self.diff_type(var)
                var2diff[var] = diff
                self.nonlinear_diffs[diff] = der
            else:
                var2diff[var] = der

        self._var2diff = var2diff

        self.output_diff = []
        for var in ssa["output_vars"]:
            self.output_diff.append([self.diff_type(var), var2diff[var]])
Example #9
0
    def ssa(cls, input_names, id_prefix):
        """Return a static single assignment program representing the function.

        Args:
            input_names: the names  for the input variables
            id_prefix: the prefix to denote the intermediate variables

        Return:
            : a dictionary with three keys

            - *input_vars*: a list of `Variable` representing the inputs
            - *output_vars*: a list of `Variable` representing the outputs
            - *assignments*: an ordered sequence of pairs
              (`Variable`, `Operation`) representing each assignment
              of the SSA program.

        ::

                >>> from arxpy.primitives.chaskey import ChaskeyPi
                >>> ChaskeyPi.set_rounds(1)
                >>> ChaskeyPi.ssa(["v0", "v1", "v2", "v3"], "x")  # doctest: +NORMALIZE_WHITESPACE
                {'input_vars': (v0, v1, v2, v3),
                'output_vars': (x7, x12, x13, x9),
                'assignments': ((x0, v0 + v1), (x1, v1 <<< 5), (x2, x0 ^ x1), (x3, x0 <<< 16), (x4, v2 + v3),
                (x5, v3 <<< 8), (x6, x4 ^ x5), (x7, x3 + x6), (x8, x6 <<< 13), (x9, x7 ^ x8), (x10, x2 + x4),
                (x11, x2 <<< 7), (x12, x10 ^ x11), (x13, x10 <<< 16))}

        """
        input_vars = []
        for name, width in zip(input_names, cls.input_widths):
            input_vars.append(core.Variable(name, width))
        input_vars = tuple(input_vars)

        table = context.MemoizationTable(id_prefix=id_prefix)

        with context.Memoization(table):
            # noinspection PyArgumentList
            output_vars = cls(*input_vars, symbolic_inputs=True)

        ssa_dict = {
            "input_vars": input_vars,
            "output_vars": output_vars,
            "assignments": tuple(table.items())
        }

        for var, expr in ssa_dict["assignments"]:
            for arg in expr.args:
                if isinstance(arg, operation.Operation):
                    raise ValueError(
                        "assignment {} <- {} was not decomposed".format(
                            var, expr))

        to_delete = []
        vars_needed = set()
        for var in output_vars:
            vars_needed.add(var)
        for var, expr in reversed(ssa_dict["assignments"]):
            if var in vars_needed:
                for arg in expr.atoms(core.Variable):
                    vars_needed.add(arg)
            else:
                to_delete.append((var, expr))
                # raise ValueError("assignment {} <- {} is redundant in \n{}".format(var, expr, ssa_dict))

        if len(to_delete) > 0:
            import warnings
            warnings.warn("removing redundant assignments {} in \n{}".format(
                to_delete, ssa_dict))
            ssa_dict["assignments"] = list(ssa_dict["assignments"])
            for assignment in to_delete:
                ssa_dict["assignments"].remove(assignment)
            ssa_dict["assignments"] = tuple(ssa_dict["assignments"])

        return ssa_dict
Example #10
0
def _fast_empirical_weight_distribution(ch_found,
                                        cipher,
                                        rk_dict_diffs=None,
                                        verbose_lvl=0,
                                        debug=False,
                                        filename=None,
                                        precision=0):
    """

        >>> from arxpy.differential.difference import XorDiff
        >>> from arxpy.differential.characteristic import SingleKeyCh
        >>> from arxpy.smt.search_differential import SearchSkCh
        >>> from arxpy.primitives import speck
        >>> from arxpy.smt.verification_differential import _fast_empirical_weight_distribution
        >>> Speck32 = speck.get_Speck_instance(speck.SpeckInstance.speck_32_64)
        >>> Speck32.set_rounds(1)
        >>> ch = SingleKeyCh(Speck32, XorDiff)
        >>> search_problem = SearchSkCh(ch)
        >>> ch_found = search_problem.solve(0)
        >>> _fast_empirical_weight_distribution(ch_found, Speck32)
        Counter({0: 256})

    """
    # similar to _empirical_distribution_weight of characteristic module

    from arxpy.smt.search_differential import _get_smart_print  # avoid cyclic imports

    smart_print = _get_smart_print(filename)

    exact_weight = ch_found.get_exact_weight()

    if rk_dict_diffs is not None:
        assert "nonlinear_diffs" in rk_dict_diffs and "output_diff" in rk_dict_diffs

    if debug:
        smart_print("Symbolic characteristic:")
        smart_print(ch_found.ch)
        smart_print("Characteristic found:")
        smart_print(ch_found)
        if rk_dict_diffs is not None:
            smart_print("rk_dict_diffs:", rk_dict_diffs)
        smart_print()

    der_weights = []
    for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()):
        actual_diff = ch_found.nonlinear_diffs[i][1]
        new_input_diff = [(d.xreplace(ch_found._diff_model))
                          for d in der.input_diff]
        der_weights.append(
            der._replace_input_diff(new_input_diff).exact_weight(actual_diff))

    max_subch_weight = exact_weight if exact_weight < MAX_WEIGHT else exact_weight / (
        exact_weight / MAX_WEIGHT)
    max_subch_weight = max(1, max_subch_weight, *der_weights)
    if debug:
        smart_print("max_subch_weight:", max_subch_weight)
        smart_print()

    subch_listdiffder = [[]]  # for each subch, a list of [diff, der] pairs
    subch_index = 0
    current_subch_weight = 0  # exact_weight
    subch_weight = []  # the weight of each subch
    assert len(ch_found.ch.nonlinear_diffs.items()) > 0
    for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()):
        der_weight = der_weights[i]
        if current_subch_weight + der_weight > max_subch_weight:
            subch_weight.append(current_subch_weight)
            current_subch_weight = 0
            subch_index += 1
            subch_listdiffder.append([])
        current_subch_weight += der_weight
        subch_listdiffder[subch_index].append([diff, der])

    subch_weight.append(current_subch_weight)
    assert len(subch_weight) == len(subch_listdiffder)

    num_subch = len(subch_listdiffder)

    if verbose_lvl >= 3:
        smart_print(
            "- characteristic decomposed into {} subcharacteristics with exact weights {}"
            .format(num_subch, subch_weight))

    if rk_dict_diffs is not None:
        rk_var = [var.val for var, _ in rk_dict_diffs["output_diff"]]
    else:
        rk_var = []
        for i, width in enumerate(cipher.key_schedule.output_widths):
            rk_var.append(core.Variable("k" + str(i), width))

    var2diffval = {}
    for diff_var, diff_value in itertools.chain(ch_found.input_diff,
                                                ch_found.nonlinear_diffs,
                                                ch_found.output_diff):
        var2diffval[diff_var.val] = diff_value.val
    if rk_dict_diffs is not None:
        for var, diff in rk_dict_diffs["output_diff"]:
            var2diffval[var.val] = diff.val
        for var, diff in rk_dict_diffs["nonlinear_diffs"]:
            var2diffval[var.val] = diff.val
    for var, diff in ch_found.ch._var2diff.items():
        if var not in var2diffval:
            if isinstance(diff, core.Term):
                # e.g., symbolic computations with the key
                var2diffval[var] = diff.xreplace(var2diffval)
            else:
                var2diffval[var] = diff.val.xreplace(var2diffval)

    # for each related-key pair, we associated a pair of subch_ssa
    rkey2pair_subchssa = [None for _ in range(KEY_SAMPLES)]
    for key_index in range(KEY_SAMPLES):
        master_key = []
        for width in cipher.key_schedule.input_widths:
            master_key.append(core.Constant(random.randrange(2**width), width))
        rk_val = cipher.key_schedule(*master_key)
        if rk_dict_diffs is not None:
            rk_other_val = tuple([
                d.get_pair_element(r)
                for r, (_, d) in zip(rk_val, rk_dict_diffs["output_diff"])
            ])
        else:
            rk_other_val = rk_val
        assert len(rk_var) == len(rk_other_val)
        assert all(isinstance(rk, core.Constant) for rk in rk_val)
        assert all(isinstance(rk, core.Constant) for rk in rk_other_val)

        def subch_listdiffder2subch_ssa(listdiffder,
                                        first_der_var_next_subch,
                                        var2val,
                                        first_subch=False):
            first_der_var = listdiffder[0][0].val

            input_vars = []
            inter_vars = set()
            assignments = []
            add_assignment = first_subch
            for var, expr in ch_found.ch.ssa["assignments"]:
                if var == first_der_var:
                    add_assignment = True
                elif var == first_der_var_next_subch:
                    break

                expr = expr.xreplace(var2val)

                if add_assignment:
                    input_vars.extend([
                        atom for atom in expr.atoms(core.Variable)
                        if atom not in input_vars
                    ])
                    inter_vars.add(var)
                    assignments.append([var, expr])

            subch_ssa = {}
            subch_ssa["input_vars"] = [
                var for var in input_vars if var not in inter_vars
            ]
            subch_ssa["output_vars"] = []
            subch_ssa["inter_vars"] = inter_vars
            subch_ssa["assignments"] = assignments
            return subch_ssa

        pair_subchssa = []
        for index_pair in range(2):
            current_rk_val = rk_val if index_pair == 0 else rk_other_val
            rkvar2rkval = {
                var: val
                for var, val in zip(rk_var, current_rk_val)
            }
            subch_ssa = [None for _ in range(num_subch)]
            for i in reversed(range(num_subch)):
                if i == num_subch - 1:
                    subch_ssa[i] = subch_listdiffder2subch_ssa(
                        subch_listdiffder[i], None, rkvar2rkval, i == 0)
                    subch_ssa[i]["output_vars"] = list(
                        ch_found.ch.ssa["output_vars"])
                else:
                    first_var_next_ssa = subch_listdiffder[i + 1][0][0].val
                    subch_ssa[i] = subch_listdiffder2subch_ssa(
                        subch_listdiffder[i], first_var_next_ssa, rkvar2rkval,
                        i == 0)
                    subch_ssa[i]["output_vars"] = subch_ssa[i +
                                                            1]["input_vars"][:]

                for var in subch_ssa[i]["output_vars"]:
                    if var not in subch_ssa[i][
                            "inter_vars"] and var not in subch_ssa[i][
                                "input_vars"]:
                        subch_ssa[i]["input_vars"].append(var)

                del subch_ssa[i]["inter_vars"]
                subch_ssa[i]["weight"] = subch_weight[i]

            for _, ssa in enumerate(subch_ssa):
                for j in range(len(ssa["output_vars"])):
                    var_j = ssa["output_vars"][j]
                    index_out = 0
                    if var_j in ssa["input_vars"]:
                        new_var = type(var_j)(var_j.name + "_o" +
                                              str(index_out), var_j.width)
                        index_out += 1
                        ssa["assignments"].append([new_var, var_j])
                        ssa["output_vars"][j] = new_var
                        var2diffval[new_var] = var2diffval[var_j]

                    for k in range(j + 1, len(ssa["output_vars"])):
                        if var_j == ssa["output_vars"][k]:
                            new_var = type(var_j)(var_j.name + "_o" +
                                                  str(index_out), var_j.width)
                            index_out += 1
                            ssa["assignments"].append([new_var, var_j])
                            ssa["output_vars"][k] = new_var
                            var2diffval[new_var] = var2diffval[var_j]

            pair_subchssa.append(subch_ssa)

        assert len(pair_subchssa[0]) == len(pair_subchssa[1]) == num_subch
        rkey2pair_subchssa[key_index] = pair_subchssa

    # for each related-key pair, we associated the list of the weight of each subch
    rkey2subch_ew = [[0 for _ in range(num_subch)] for _ in range(KEY_SAMPLES)]

    # start multiprocessing
    with multiprocessing.Pool() as pool:
        for i in range(num_subch):
            for key_index in range(KEY_SAMPLES):
                ssa1 = rkey2pair_subchssa[key_index][0][i]
                ssa2 = rkey2pair_subchssa[key_index][1][i]

                if key_index <= 1:
                    if verbose_lvl >= 2 and key_index == 0:
                        smart_print("- sub-characteristic {}".format(i))
                    if verbose_lvl >= 3 and key_index == 0:
                        smart_print("  - listdiffder:", subch_listdiffder[i])
                    if verbose_lvl >= 3:
                        smart_print("  - related-key pair index", key_index)
                        smart_print("  - ssa1:", ssa1)
                        if ssa1 == ssa2:
                            smart_print("  - ssa2: (same as ssa1)")
                        else:
                            smart_print("  - ssa2:", ssa2)

                if i > 0 and rkey2subch_ew[key_index][i - 1] == math.inf:
                    rkey2subch_ew[key_index][i] = math.inf
                    if key_index <= 1 and verbose_lvl >= 2:
                        smart_print(
                            "  - rk{} | skipping since invalid sub-ch[{}]".
                            format(key_index, i - 1))
                    continue

                if ssa1 == ssa2:
                    ccode = ssa2ccode(ssa1, ch_found.ch.diff_type)
                else:
                    ccode = relatedssa2ccode(ssa1, ssa2, ch_found.ch.diff_type)

                if key_index <= 1 and debug:
                    smart_print(ccode[0])
                    smart_print(ccode[1])
                    smart_print()

                input_diff_c = [
                    v.xreplace(var2diffval) for v in ssa1["input_vars"]
                ]
                output_diff_c = [
                    v.xreplace(var2diffval) for v in ssa1["output_vars"]
                ]

                if key_index <= 1 and verbose_lvl >= 2:
                    smart_print(
                        "  - rk{} | checking {} -> {} with weight {}".format(
                            key_index, '|'.join([str(d)
                                                 for d in input_diff_c]),
                            '|'.join([str(d) for d in output_diff_c]),
                            ssa1["weight"]))

                assert ssa1["weight"] == ssa2["weight"]

                assert all(
                    isinstance(d, (int, core.Constant))
                    for d in input_diff_c), "{}".format(input_diff_c)
                assert all(
                    isinstance(d, (int, core.Constant))
                    for d in output_diff_c), "{}".format(output_diff_c)

                input_diff_c = [int(d) for d in input_diff_c]
                output_diff_c = [int(d) for d in output_diff_c]

                rkey2subch_ew[key_index][i] = pool.apply_async(
                    compile_run_empirical_weight,
                    (ccode, "_libver" + ch_found.ch.func.__name__ + str(i),
                     input_diff_c, output_diff_c, ssa1["weight"], False))

            # wait until all i-subch have been compiled and run
            # and replace the Async object by the result
            for key_index in range(KEY_SAMPLES):
                if isinstance(rkey2subch_ew[key_index][i],
                              multiprocessing.pool.AsyncResult):
                    rkey2subch_ew[key_index][i] = rkey2subch_ew[key_index][
                        i].get()

                if key_index <= 1 and verbose_lvl >= 2:
                    smart_print(
                        "  - rk{} | exact/empirical weight: {}, {}".format(
                            key_index, subch_weight[i],
                            rkey2subch_ew[key_index][i]))

    # end multiprocessing

    empirical_weight_distribution = collections.Counter()
    all_rkey_weights = []
    for key_index in range(KEY_SAMPLES):
        rkey_weight = sum(rkey2subch_ew[key_index])

        if precision == 0:
            weight = int(rkey_weight) if rkey_weight != math.inf else math.inf
        else:
            weight = round(rkey_weight, precision)

        all_rkey_weights.append(rkey_weight)
        empirical_weight_distribution[weight] += 1

    if verbose_lvl >= 2:
        smart_print("- distribution empirical weights: {}".format(
            empirical_weight_distribution))

    if verbose_lvl >= 3:
        smart_print("- list empirical weights:",
                    [round(x, 8) for x in all_rkey_weights if x != math.inf])

    return empirical_weight_distribution
Example #11
0
    def _generate(self):
        """Generate the SMT problem."""
        self.assertions = []

        # Forbid zero input difference with XOR difference

        if self.ch.diff_type == difference.XorDiff:
            if self.parent_ch is not None and self.parent_ch.outer_ch == self.ch:
                inner_noutputs = len(self.parent_ch.inner_ch.output_diff)
                non_zero_input_diff = self.ch.input_diff[:-inner_noutputs]
            else:
                non_zero_input_diff = self.ch.input_diff
            non_zero_input_diff = functools.reduce(operation.Concat,
                                                   non_zero_input_diff)
            zero = core.Constant(0, non_zero_input_diff.width)
            self.assertions.append(
                operation.BvNot(operation.BvComp(non_zero_input_diff, zero)))

        # Assertions of the weights of the non-deterministic steps

        self.op_weights = []
        for var, propagation in self.ch.items():
            if isinstance(propagation, differential.Differential):
                self.assertions.append(propagation.is_valid())
                weight_value = propagation.weight()
                weight_var = core.Variable(propagation._weight_var_name(),
                                           weight_value.width)
                self.assertions.append(
                    operation.BvComp(weight_var, weight_value))
                self.op_weights.append(weight_var)
            else:
                self.assertions.append(operation.BvComp(var, propagation))

        # Characteristic weight assignment

        max_value = 0
        for ow in self.op_weights:
            max_value += (2**ow.width) - 1
        width = max(max_value.bit_length(), 1)  # for trivial characteristic
        ext_op_weights = []
        for ow in self.op_weights:
            ext_op_weights.append(operation.ZeroExtend(ow, width - ow.width))

        name_ch_weight = "w_{}_{}".format(
            ''.join([str(i) for i in self.ch.input_diff]),
            ''.join([str(i) for i in self.ch.output_diff]))
        ch_weight = core.Variable(name_ch_weight, width)

        self.assertions.append(operation.BvComp(ch_weight,
                                                sum(ext_op_weights)))

        # Condition between the weight and the target weight

        weight_function = self.ch.get_weight_function()
        target_weight = int(weight_function(self.target_weight))

        width = max(ch_weight.width, target_weight.bit_length())
        self.ch_weight = operation.ZeroExtend(ch_weight,
                                              width - ch_weight.width)

        if self.equality:
            self.assertions.append(
                operation.BvComp(self.ch_weight, target_weight))
        else:
            self.assertions.append(
                operation.BvUlt(self.ch_weight, target_weight))

        self.assertions = tuple(self.assertions)
Example #12
0
 def to_Variable(self):
     """Convert the DiffVar to a Variable."""
     return core.Variable(self.name, self.width)
Example #13
0
def _fast_empirical_weight_distribution(ch_found, cipher, rk_dict_diffs=None,
                                        verbose_lvl=0, debug=False, filename=None, precision=0):
    """

        >>> from arxpy.differential.difference import XorDiff
        >>> from arxpy.differential.characteristic import SingleKeyCh
        >>> from arxpy.smt.search_impossible import SearchSkID
        >>> from arxpy.primitives import speck
        >>> from arxpy.smt.verification_impossible import _fast_empirical_weight_distribution
        >>> Speck32 = speck.get_Speck_instance(speck.SpeckInstance.speck_32_64)
        >>> Speck32.set_rounds(1)
        >>> ch = SingleKeyCh(Speck32, XorDiff)
        >>> search_problem = SearchSkID(ch)
        >>> id_found = search_problem.solve(2)
        >>> _fast_empirical_weight_distribution(id_found, Speck32)
        Counter({inf: 256})

    """
    if rk_dict_diffs is not None:
        raise ValueError("rk_dict_diffs must be None")

    from arxpy.smt.search_differential import _get_smart_print  # avoid cyclic imports

    smart_print = _get_smart_print(filename)

    # if rk_dict_diffs is not None:
    #     assert "nonlinear_diffs" in rk_dict_diffs and "output_diff" in rk_dict_diffs

    if debug:
        smart_print("Symbolic characteristic:")
        smart_print(ch_found.ch)
        smart_print("ID found:")
        smart_print(ch_found)
        # if rk_dict_diffs is not None:
        #     smart_print("rk_dict_diffs:", rk_dict_diffs)
        smart_print()

    # if rk_dict_diffs is not None:
    #     rk_var = [var.val for var, _ in rk_dict_diffs["output_diff"]]
    # else:
    rk_var = []
    for i, width in enumerate(cipher.key_schedule.output_widths):
        rk_var.append(core.Variable("k" + str(i), width))

    var2diffval = {}
    for diff_var, diff_value in itertools.chain(ch_found.input_diff, ch_found.output_diff):
        var2diffval[diff_var.val] = diff_value.val
    # if rk_dict_diffs is not None:
    #     for var, diff in rk_dict_diffs["output_diff"]:
    #         var2diffval[var.val] = diff.val

    # for each related-key pair, we associated a pair of ssa
    rkey2pair_ssa = [None for _ in range(KEY_SAMPLES)]
    for key_index in range(KEY_SAMPLES):
        master_key = []
        for width in cipher.key_schedule.input_widths:
            master_key.append(core.Constant(random.randrange(2 ** width), width))
        rk_val = cipher.key_schedule(*master_key)
        # if rk_dict_diffs is not None:
        #     rk_other_val = tuple([d.get_pair_element(r) for r, (_, d) in zip(rk_val, rk_dict_diffs["output_diff"])])
        # else:
        rk_other_val = rk_val
        assert len(rk_var) == len(rk_other_val)
        assert all(isinstance(rk, core.Constant) for rk in rk_val)
        assert all(isinstance(rk, core.Constant) for rk in rk_other_val)

        def replace_roundkeys(var2val):
            new_ssa = ch_found.ch.ssa.copy()
            new_ssa["assignments"] = list(new_ssa["assignments"])
            new_ssa["output_vars"] = list(new_ssa["output_vars"])

            for i, (var, expr) in enumerate(ch_found.ch.ssa["assignments"]):
                new_ssa["assignments"][i] = (var, expr.xreplace(var2val))

            return new_ssa

        pair_ssa = []
        for index_pair in range(2):
            current_rk_val = rk_val if index_pair == 0 else rk_other_val
            rkvar2rkval = {var: val for var, val in zip(rk_var, current_rk_val)}
            ssa = replace_roundkeys(rkvar2rkval)

            for j in range(len(ssa["output_vars"])):
                var_j = ssa["output_vars"][j]
                index_out = 0
                if var_j in ssa["input_vars"]:
                    new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
                    index_out += 1
                    ssa["assignments"].append([new_var, var_j])
                    ssa["output_vars"][j] = new_var
                    var2diffval[new_var] = var2diffval[var_j]

                for k in range(j + 1, len(ssa["output_vars"])):
                    if var_j == ssa["output_vars"][k]:
                        new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
                        index_out += 1
                        ssa["assignments"].append([new_var, var_j])
                        ssa["output_vars"][k] = new_var
                        var2diffval[new_var] = var2diffval[var_j]

            pair_ssa.append(ssa)

        rkey2pair_ssa[key_index] = pair_ssa

    # for each related-key pair, we associated their weight
    rkey2subch_ew = [0 for _ in range(KEY_SAMPLES)]

    # start multiprocessing
    with multiprocessing.Pool() as pool:
        for key_index in range(KEY_SAMPLES):
            ssa1 = rkey2pair_ssa[key_index][0]
            ssa2 = rkey2pair_ssa[key_index][1]

            if key_index <= 1:
                if verbose_lvl >= 3:
                    smart_print("  - related-key pair index", key_index)
                    smart_print("  - ssa1:", ssa1)
                    if ssa1 == ssa2:
                        smart_print("  - ssa2: (same as ssa1)")
                    else:
                        smart_print("  - ssa2:", ssa2)

            if ssa1 == ssa2:
                ccode = ssa2ccode(ssa1, ch_found.ch.diff_type)
            else:
                ccode = relatedssa2ccode(ssa1, ssa2, ch_found.ch.diff_type)

            if key_index <= 1 and debug:
                smart_print(ccode[0])
                smart_print(ccode[1])
                smart_print()

            input_diff_c = [v.xreplace(var2diffval) for v in ssa1["input_vars"]]
            output_diff_c = [v.xreplace(var2diffval) for v in ssa1["output_vars"]]

            if key_index <= 1 and verbose_lvl >= 2:
                smart_print("  - rk{} | checking {} -> {} with pairs 2**{}".format(
                    key_index,
                    '|'.join([str(d) for d in input_diff_c]), '|'.join([str(d) for d in output_diff_c]),
                    MAX_WEIGHT))

            assert all(isinstance(d, (int, core.Constant)) for d in input_diff_c), "{}".format(input_diff_c)
            assert all(isinstance(d, (int, core.Constant)) for d in output_diff_c), "{}".format(output_diff_c)

            input_diff_c = [int(d) for d in input_diff_c]
            output_diff_c = [int(d) for d in output_diff_c]

            rkey2subch_ew[key_index] = pool.apply_async(
                compile_run_empirical_weight,
                (
                    ccode,
                    "_libver" + ch_found.ch.func.__name__,
                    input_diff_c,
                    output_diff_c,
                    MAX_WEIGHT,
                    False
                )
            )

        # wait until all have been compiled and run
        # and replace the Async object by the result
        for key_index in range(KEY_SAMPLES):
            if isinstance(rkey2subch_ew[key_index], multiprocessing.pool.AsyncResult):
                rkey2subch_ew[key_index] = rkey2subch_ew[key_index].get()

            if key_index <= 1 and verbose_lvl >= 2:
                smart_print("  - rk{} | empirical weight: {}".format(
                    key_index, rkey2subch_ew[key_index]))

    # end multiprocessing

    empirical_weight_distribution = collections.Counter()
    all_rkey_weights = []
    for key_index in range(KEY_SAMPLES):
        rkey_weight = rkey2subch_ew[key_index]

        if precision == 0:
            weight = int(rkey_weight) if rkey_weight != math.inf else math.inf
        else:
            weight = round(rkey_weight, precision)

        all_rkey_weights.append(rkey_weight)
        empirical_weight_distribution[weight] += 1

    if verbose_lvl >= 2:
        smart_print("- distribution empirical weights: {}".format(empirical_weight_distribution))

    if verbose_lvl >= 3:
        smart_print("- list empirical weights:", [round(x, 8) for x in all_rkey_weights if x != math.inf])

    return empirical_weight_distribution