Beispiel #1
0
    def test_verbose_search_RkCh(self):
        smart_print = _get_smart_print(self.filename)

        diff_type = XorDiff

        for verbose_level in range(1, 3):
            smart_print("VERBOSE LEVEL:", verbose_level)

            bc = BlockCipher(LeaCipher, None, 1)
            for option in [
                    RkChSearchMode.FirstMinSum,
                    RkChSearchMode.OptimalMinSumDifferential
            ]:
                test_search_related_key_ch(
                    cipher=bc.cipher,
                    diff_type=diff_type,
                    initial_ew=3,
                    initial_kw=0,
                    solver_name="btor",
                    rounds=bc.rk_rounds,
                    key_der_mode=DerMode.XDCA_Approx,
                    enc_der_mode=DerMode.Default,
                    allow_zero_enc_input_diff=True,
                    search_mode=option,
                    check=False if option in NoCheckModes else True,
                    verbose_level=verbose_level,
                    filename=self.filename)
                smart_print("\n~~\n")

            smart_print("\n\n-----\n\n")
Beispiel #2
0
    def test_verbose_search_Ch(self):
        smart_print = _get_smart_print(self.filename)

        diff_type = XorDiff

        for verbose_level in range(1, 4):
            smart_print("VERBOSE LEVEL:", verbose_level)

            bvf = BvFunction(LeaCipher.key_schedule, 1)
            for option in [
                    ChSearchMode.FirstCh, ChSearchMode.OptimalDifferential,
                    ChSearchMode.TopDifferentials
            ]:
                test_search_ch_skch(
                    bvf_cipher=bvf.function,
                    diff_type=diff_type,
                    initial_weight=3,
                    solver_name="btor",
                    rounds=bvf.rounds,
                    der_mode=DerMode.XDCA_Approx,
                    search_mode=option,
                    check=False if option in NoCheckModes else True,
                    verbose_level=verbose_level,
                    filename=self.filename)

                smart_print("\n~~\n")

            smart_print("\n\n-----\n\n")
Beispiel #3
0
def test_search_ch_skch(bvf_cipher, diff_type, initial_weight, solver_name,
                        rounds, der_mode, search_mode, check, verbose_level,
                        filename):
    smart_print = _get_smart_print(filename)

    if rounds is not None:
        bvf_cipher.set_rounds(rounds)

    if issubclass(bvf_cipher, BvFunction):
        num_inputs = len(bvf_cipher.input_widths)
        input_diff_names = ["dp" + str(i) for i in range(num_inputs)]
        ch = BvCharacteristic(bvf_cipher, diff_type, input_diff_names)
    else:
        assert issubclass(bvf_cipher, Cipher)
        ch = SingleKeyCh(bvf_cipher, diff_type)

    if verbose_level >= 1:
        str_rounds = "" if rounds is None else "{} rounds".format(rounds)
        smart_print(str_rounds, bvf_cipher.__name__, diff_type.__name__,
                    type(ch).__name__)
        if verbose_level >= 2:
            smart_print("Characteristic:")
            smart_print(ch)

    if issubclass(bvf_cipher, BvFunction):
        problem = SearchCh(ch, der_mode=der_mode)
    else:
        problem = SearchSkCh(ch, der_mode=der_mode)

    if verbose_level >= 1:
        smart_print(
            type(problem).__name__, der_mode, search_mode, solver_name,
            "size:", problem.formula_size())
        if verbose_level >= 2:
            smart_print(problem.hrepr(verbose_level >= 3))

    sol = problem.solve(initial_weight,
                        solver_name=solver_name,
                        search_mode=search_mode,
                        check=check,
                        verbose_level=verbose_level,
                        filename=filename)

    if verbose_level >= 1:
        if sol is None:
            smart_print("\nUnsatisfiable")
        else:
            smart_print("\nSolution:")
            smart_print(sol)
            if verbose_level >= 2:
                if isinstance(sol, collections.abc.Sequence):
                    # for search_mode TopDifferentials
                    smart_print(sol[0].vrepr())
                else:
                    smart_print(sol.vrepr())
        smart_print()

    return sol
Beispiel #4
0
def test_search_related_key_ch(cipher, diff_type, initial_ew, initial_kw,
                               solver_name, rounds, key_der_mode, enc_der_mode,
                               allow_zero_enc_input_diff, search_mode, check,
                               verbose_level, filename):
    # assert search_mode != RkChSearchMode.AllValid

    smart_print = _get_smart_print(filename)

    if rounds is not None:
        cipher.set_rounds(rounds)

    ch = RelatedKeyCh(cipher, diff_type)

    if verbose_level >= 1:
        smart_print(rounds, "round(s)", cipher.__name__, diff_type.__name__,
                    type(ch).__name__)
        if verbose_level >= 2:
            smart_print("Characteristic:")
            smart_print(ch)

    problem = SearchRkCh(rkch=ch,
                         key_der_mode=key_der_mode,
                         enc_der_mode=enc_der_mode,
                         allow_zero_enc_input_diff=allow_zero_enc_input_diff)

    if verbose_level >= 1:
        smart_print(
            type(problem).__name__, "key/enc mode:", key_der_mode,
            enc_der_mode, search_mode, solver_name, "size:",
            problem.formula_size())
        if verbose_level >= 2:
            smart_print(problem.hrepr(verbose_level >= 3))

    sol = problem.solve(initial_ew=initial_ew,
                        initial_kw=initial_kw,
                        solver_name=solver_name,
                        search_mode=search_mode,
                        check=check,
                        verbose_level=verbose_level,
                        filename=filename)

    if verbose_level >= 1:
        if sol is None:
            smart_print("\nUnsatisfiable")
        else:
            smart_print("\nSolution:")
            smart_print(sol)
            if verbose_level >= 2:
                smart_print(sol.vrepr())
        smart_print()

    return sol
def _fast_empirical_weight_distribution(ch_found,
                                        cipher,
                                        rk_dict_diffs=None,
                                        verbose_lvl=0,
                                        debug=False,
                                        filename=None,
                                        precision=0):
    """

        >>> from arxpy.differential.difference import XorDiff
        >>> from arxpy.differential.characteristic import SingleKeyCh
        >>> from arxpy.smt.search_differential import SearchSkCh
        >>> from arxpy.primitives import speck
        >>> from arxpy.smt.verification_differential import _fast_empirical_weight_distribution
        >>> Speck32 = speck.get_Speck_instance(speck.SpeckInstance.speck_32_64)
        >>> Speck32.set_rounds(1)
        >>> ch = SingleKeyCh(Speck32, XorDiff)
        >>> search_problem = SearchSkCh(ch)
        >>> ch_found = search_problem.solve(0)
        >>> _fast_empirical_weight_distribution(ch_found, Speck32)
        Counter({0: 256})

    """
    # similar to _empirical_distribution_weight of characteristic module

    from arxpy.smt.search_differential import _get_smart_print  # avoid cyclic imports

    smart_print = _get_smart_print(filename)

    exact_weight = ch_found.get_exact_weight()

    if rk_dict_diffs is not None:
        assert "nonlinear_diffs" in rk_dict_diffs and "output_diff" in rk_dict_diffs

    if debug:
        smart_print("Symbolic characteristic:")
        smart_print(ch_found.ch)
        smart_print("Characteristic found:")
        smart_print(ch_found)
        if rk_dict_diffs is not None:
            smart_print("rk_dict_diffs:", rk_dict_diffs)
        smart_print()

    der_weights = []
    for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()):
        actual_diff = ch_found.nonlinear_diffs[i][1]
        new_input_diff = [(d.xreplace(ch_found._diff_model))
                          for d in der.input_diff]
        der_weights.append(
            der._replace_input_diff(new_input_diff).exact_weight(actual_diff))

    max_subch_weight = exact_weight if exact_weight < MAX_WEIGHT else exact_weight / (
        exact_weight / MAX_WEIGHT)
    max_subch_weight = max(1, max_subch_weight, *der_weights)
    if debug:
        smart_print("max_subch_weight:", max_subch_weight)
        smart_print()

    subch_listdiffder = [[]]  # for each subch, a list of [diff, der] pairs
    subch_index = 0
    current_subch_weight = 0  # exact_weight
    subch_weight = []  # the weight of each subch
    assert len(ch_found.ch.nonlinear_diffs.items()) > 0
    for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()):
        der_weight = der_weights[i]
        if current_subch_weight + der_weight > max_subch_weight:
            subch_weight.append(current_subch_weight)
            current_subch_weight = 0
            subch_index += 1
            subch_listdiffder.append([])
        current_subch_weight += der_weight
        subch_listdiffder[subch_index].append([diff, der])

    subch_weight.append(current_subch_weight)
    assert len(subch_weight) == len(subch_listdiffder)

    num_subch = len(subch_listdiffder)

    if verbose_lvl >= 3:
        smart_print(
            "- characteristic decomposed into {} subcharacteristics with exact weights {}"
            .format(num_subch, subch_weight))

    if rk_dict_diffs is not None:
        rk_var = [var.val for var, _ in rk_dict_diffs["output_diff"]]
    else:
        rk_var = []
        for i, width in enumerate(cipher.key_schedule.output_widths):
            rk_var.append(core.Variable("k" + str(i), width))

    var2diffval = {}
    for diff_var, diff_value in itertools.chain(ch_found.input_diff,
                                                ch_found.nonlinear_diffs,
                                                ch_found.output_diff):
        var2diffval[diff_var.val] = diff_value.val
    if rk_dict_diffs is not None:
        for var, diff in rk_dict_diffs["output_diff"]:
            var2diffval[var.val] = diff.val
        for var, diff in rk_dict_diffs["nonlinear_diffs"]:
            var2diffval[var.val] = diff.val
    for var, diff in ch_found.ch._var2diff.items():
        if var not in var2diffval:
            if isinstance(diff, core.Term):
                # e.g., symbolic computations with the key
                var2diffval[var] = diff.xreplace(var2diffval)
            else:
                var2diffval[var] = diff.val.xreplace(var2diffval)

    # for each related-key pair, we associated a pair of subch_ssa
    rkey2pair_subchssa = [None for _ in range(KEY_SAMPLES)]
    for key_index in range(KEY_SAMPLES):
        master_key = []
        for width in cipher.key_schedule.input_widths:
            master_key.append(core.Constant(random.randrange(2**width), width))
        rk_val = cipher.key_schedule(*master_key)
        if rk_dict_diffs is not None:
            rk_other_val = tuple([
                d.get_pair_element(r)
                for r, (_, d) in zip(rk_val, rk_dict_diffs["output_diff"])
            ])
        else:
            rk_other_val = rk_val
        assert len(rk_var) == len(rk_other_val)
        assert all(isinstance(rk, core.Constant) for rk in rk_val)
        assert all(isinstance(rk, core.Constant) for rk in rk_other_val)

        def subch_listdiffder2subch_ssa(listdiffder,
                                        first_der_var_next_subch,
                                        var2val,
                                        first_subch=False):
            first_der_var = listdiffder[0][0].val

            input_vars = []
            inter_vars = set()
            assignments = []
            add_assignment = first_subch
            for var, expr in ch_found.ch.ssa["assignments"]:
                if var == first_der_var:
                    add_assignment = True
                elif var == first_der_var_next_subch:
                    break

                expr = expr.xreplace(var2val)

                if add_assignment:
                    input_vars.extend([
                        atom for atom in expr.atoms(core.Variable)
                        if atom not in input_vars
                    ])
                    inter_vars.add(var)
                    assignments.append([var, expr])

            subch_ssa = {}
            subch_ssa["input_vars"] = [
                var for var in input_vars if var not in inter_vars
            ]
            subch_ssa["output_vars"] = []
            subch_ssa["inter_vars"] = inter_vars
            subch_ssa["assignments"] = assignments
            return subch_ssa

        pair_subchssa = []
        for index_pair in range(2):
            current_rk_val = rk_val if index_pair == 0 else rk_other_val
            rkvar2rkval = {
                var: val
                for var, val in zip(rk_var, current_rk_val)
            }
            subch_ssa = [None for _ in range(num_subch)]
            for i in reversed(range(num_subch)):
                if i == num_subch - 1:
                    subch_ssa[i] = subch_listdiffder2subch_ssa(
                        subch_listdiffder[i], None, rkvar2rkval, i == 0)
                    subch_ssa[i]["output_vars"] = list(
                        ch_found.ch.ssa["output_vars"])
                else:
                    first_var_next_ssa = subch_listdiffder[i + 1][0][0].val
                    subch_ssa[i] = subch_listdiffder2subch_ssa(
                        subch_listdiffder[i], first_var_next_ssa, rkvar2rkval,
                        i == 0)
                    subch_ssa[i]["output_vars"] = subch_ssa[i +
                                                            1]["input_vars"][:]

                for var in subch_ssa[i]["output_vars"]:
                    if var not in subch_ssa[i][
                            "inter_vars"] and var not in subch_ssa[i][
                                "input_vars"]:
                        subch_ssa[i]["input_vars"].append(var)

                del subch_ssa[i]["inter_vars"]
                subch_ssa[i]["weight"] = subch_weight[i]

            for _, ssa in enumerate(subch_ssa):
                for j in range(len(ssa["output_vars"])):
                    var_j = ssa["output_vars"][j]
                    index_out = 0
                    if var_j in ssa["input_vars"]:
                        new_var = type(var_j)(var_j.name + "_o" +
                                              str(index_out), var_j.width)
                        index_out += 1
                        ssa["assignments"].append([new_var, var_j])
                        ssa["output_vars"][j] = new_var
                        var2diffval[new_var] = var2diffval[var_j]

                    for k in range(j + 1, len(ssa["output_vars"])):
                        if var_j == ssa["output_vars"][k]:
                            new_var = type(var_j)(var_j.name + "_o" +
                                                  str(index_out), var_j.width)
                            index_out += 1
                            ssa["assignments"].append([new_var, var_j])
                            ssa["output_vars"][k] = new_var
                            var2diffval[new_var] = var2diffval[var_j]

            pair_subchssa.append(subch_ssa)

        assert len(pair_subchssa[0]) == len(pair_subchssa[1]) == num_subch
        rkey2pair_subchssa[key_index] = pair_subchssa

    # for each related-key pair, we associated the list of the weight of each subch
    rkey2subch_ew = [[0 for _ in range(num_subch)] for _ in range(KEY_SAMPLES)]

    # start multiprocessing
    with multiprocessing.Pool() as pool:
        for i in range(num_subch):
            for key_index in range(KEY_SAMPLES):
                ssa1 = rkey2pair_subchssa[key_index][0][i]
                ssa2 = rkey2pair_subchssa[key_index][1][i]

                if key_index <= 1:
                    if verbose_lvl >= 2 and key_index == 0:
                        smart_print("- sub-characteristic {}".format(i))
                    if verbose_lvl >= 3 and key_index == 0:
                        smart_print("  - listdiffder:", subch_listdiffder[i])
                    if verbose_lvl >= 3:
                        smart_print("  - related-key pair index", key_index)
                        smart_print("  - ssa1:", ssa1)
                        if ssa1 == ssa2:
                            smart_print("  - ssa2: (same as ssa1)")
                        else:
                            smart_print("  - ssa2:", ssa2)

                if i > 0 and rkey2subch_ew[key_index][i - 1] == math.inf:
                    rkey2subch_ew[key_index][i] = math.inf
                    if key_index <= 1 and verbose_lvl >= 2:
                        smart_print(
                            "  - rk{} | skipping since invalid sub-ch[{}]".
                            format(key_index, i - 1))
                    continue

                if ssa1 == ssa2:
                    ccode = ssa2ccode(ssa1, ch_found.ch.diff_type)
                else:
                    ccode = relatedssa2ccode(ssa1, ssa2, ch_found.ch.diff_type)

                if key_index <= 1 and debug:
                    smart_print(ccode[0])
                    smart_print(ccode[1])
                    smart_print()

                input_diff_c = [
                    v.xreplace(var2diffval) for v in ssa1["input_vars"]
                ]
                output_diff_c = [
                    v.xreplace(var2diffval) for v in ssa1["output_vars"]
                ]

                if key_index <= 1 and verbose_lvl >= 2:
                    smart_print(
                        "  - rk{} | checking {} -> {} with weight {}".format(
                            key_index, '|'.join([str(d)
                                                 for d in input_diff_c]),
                            '|'.join([str(d) for d in output_diff_c]),
                            ssa1["weight"]))

                assert ssa1["weight"] == ssa2["weight"]

                assert all(
                    isinstance(d, (int, core.Constant))
                    for d in input_diff_c), "{}".format(input_diff_c)
                assert all(
                    isinstance(d, (int, core.Constant))
                    for d in output_diff_c), "{}".format(output_diff_c)

                input_diff_c = [int(d) for d in input_diff_c]
                output_diff_c = [int(d) for d in output_diff_c]

                rkey2subch_ew[key_index][i] = pool.apply_async(
                    compile_run_empirical_weight,
                    (ccode, "_libver" + ch_found.ch.func.__name__ + str(i),
                     input_diff_c, output_diff_c, ssa1["weight"], False))

            # wait until all i-subch have been compiled and run
            # and replace the Async object by the result
            for key_index in range(KEY_SAMPLES):
                if isinstance(rkey2subch_ew[key_index][i],
                              multiprocessing.pool.AsyncResult):
                    rkey2subch_ew[key_index][i] = rkey2subch_ew[key_index][
                        i].get()

                if key_index <= 1 and verbose_lvl >= 2:
                    smart_print(
                        "  - rk{} | exact/empirical weight: {}, {}".format(
                            key_index, subch_weight[i],
                            rkey2subch_ew[key_index][i]))

    # end multiprocessing

    empirical_weight_distribution = collections.Counter()
    all_rkey_weights = []
    for key_index in range(KEY_SAMPLES):
        rkey_weight = sum(rkey2subch_ew[key_index])

        if precision == 0:
            weight = int(rkey_weight) if rkey_weight != math.inf else math.inf
        else:
            weight = round(rkey_weight, precision)

        all_rkey_weights.append(rkey_weight)
        empirical_weight_distribution[weight] += 1

    if verbose_lvl >= 2:
        smart_print("- distribution empirical weights: {}".format(
            empirical_weight_distribution))

    if verbose_lvl >= 3:
        smart_print("- list empirical weights:",
                    [round(x, 8) for x in all_rkey_weights if x != math.inf])

    return empirical_weight_distribution
def fast_empirical_weight(ch_found, verbose_lvl=0, debug=False, filename=None):
    """Computes the empirical weight of the model using C code.

    If ``filename`` is not ``None``, the output will be printed
    to the given file rather than the to stdout.

    The argument ``verbose_lvl`` can take an integer between
    ``0`` (no verbose) and ``3`` (full verbose).

        >>> from arxpy.differential.difference import XorDiff, RXDiff
        >>> from arxpy.differential.characteristic import BvCharacteristic
        >>> from arxpy.primitives.chaskey import ChaskeyPi
        >>> from arxpy.smt.search_differential import SearchCh
        >>> from arxpy.smt.verification_differential import fast_empirical_weight
        >>> ChaskeyPi.set_rounds(2)
        >>> ch = BvCharacteristic(ChaskeyPi, XorDiff, ["dv0", "dv1", "dv2", "dv3"])
        >>> search_problem = SearchCh(ch)
        >>> ch_found = search_problem.solve(0)
        >>> ch_found.ch_weight
        0x04
        >>> 3 <= fast_empirical_weight(ch_found) <= 5
        True
        >>> ChaskeyPi.set_rounds(1)
        >>> ch = BvCharacteristic(ChaskeyPi, RXDiff, ["dv0", "dv1", "dv2", "dv3"])
        >>> ic = [operation.BvComp(0, d.val) for d in ch.input_diff]
        >>> ic += [operation.BvComp(0, d[1].val) for d in ch.output_diff]
        >>> ch_found = SearchCh(ch, allow_zero_input_diff=True, initial_constraints=ic).solve(5)
        >>> ch_found.ch_weight
        0x05
        >>> 4 - 1 <= fast_empirical_weight(ch_found) <= 8
        True

    """
    from arxpy.smt.search_differential import _get_smart_print  # avoid cyclic imports

    smart_print = _get_smart_print(filename)

    exact_weight = ch_found.get_exact_weight()

    if debug:
        smart_print("Symbolic characteristic:")
        smart_print(ch_found.ch)
        smart_print("Characteristic found:")
        smart_print(ch_found)
        smart_print()

    der_weights = []
    for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()):
        actual_diff = ch_found.nonlinear_diffs[i][1]
        new_input_diff = [(d.xreplace(ch_found._diff_model))
                          for d in der.input_diff]
        der_weights.append(
            der._replace_input_diff(new_input_diff).exact_weight(actual_diff))

    max_subch_weight = exact_weight if exact_weight < MAX_WEIGHT else exact_weight / (
        exact_weight / MAX_WEIGHT)
    max_subch_weight = max(1, max_subch_weight, *der_weights)
    if debug:
        smart_print("max_subch_weight:", max_subch_weight)
        smart_print()

    subch_listdiffder = [[]]  # for each subch, a list of [diff, der] pairs
    subch_index = 0
    current_subch_weight = 0  # exact_weight
    subch_weight = []  # the weight of each subch
    assert len(ch_found.ch.nonlinear_diffs.items()) > 0
    for i, (diff, der) in enumerate(ch_found.ch.nonlinear_diffs.items()):
        der_weight = der_weights[i]
        if current_subch_weight + der_weight > max_subch_weight:
            subch_weight.append(current_subch_weight)
            current_subch_weight = 0
            subch_index += 1
            subch_listdiffder.append([])
        current_subch_weight += der_weight
        subch_listdiffder[subch_index].append([diff, der])

    subch_weight.append(current_subch_weight)
    assert len(subch_weight) == len(subch_listdiffder)

    num_subch = len(subch_listdiffder)

    if verbose_lvl >= 3:
        smart_print(
            "- characteristic decomposed into {} subcharacteristics with exact weights {}"
            .format(num_subch, subch_weight))

    def subch_listdiffder2subch_ssa(listdiffder,
                                    first_var_next_subch,
                                    first_subch=False):
        first_var = listdiffder[0][0].val

        input_vars = []
        inter_vars = set()
        assignments = []
        add_assignment = first_subch
        for var, expr in ch_found.ch.ssa["assignments"]:
            if var == first_var:
                add_assignment = True
            elif var == first_var_next_subch:
                break

            if add_assignment:
                input_vars.extend([
                    atom for atom in expr.atoms(core.Variable)
                    if atom not in input_vars
                ])
                inter_vars.add(var)
                assignments.append([var, expr])

        subch_ssa = {}
        subch_ssa["input_vars"] = [
            var for var in input_vars if var not in inter_vars
        ]
        subch_ssa["output_vars"] = []
        subch_ssa["inter_vars"] = inter_vars
        subch_ssa["assignments"] = assignments
        return subch_ssa

    subch_ssa = [None for _ in range(num_subch)]
    for i in reversed(range(num_subch)):
        if i == num_subch - 1:
            subch_ssa[i] = subch_listdiffder2subch_ssa(subch_listdiffder[i],
                                                       None, i == 0)
            subch_ssa[i]["output_vars"] = list(ch_found.ch.ssa["output_vars"])
        else:
            first_var_next_ssa = subch_listdiffder[i + 1][0][0].val
            subch_ssa[i] = subch_listdiffder2subch_ssa(subch_listdiffder[i],
                                                       first_var_next_ssa,
                                                       i == 0)
            subch_ssa[i]["output_vars"] = subch_ssa[i + 1]["input_vars"][:]

        for diff_var in subch_ssa[i]["output_vars"]:
            if diff_var not in subch_ssa[i][
                    "inter_vars"] and diff_var not in subch_ssa[i][
                        "input_vars"]:
                subch_ssa[i]["input_vars"].append(diff_var)

        del subch_ssa[i]["inter_vars"]
        subch_ssa[i]["weight"] = subch_weight[i]

    var2diffval = {}
    for diff_var, diff_value in itertools.chain(ch_found.input_diff,
                                                ch_found.nonlinear_diffs,
                                                ch_found.output_diff):
        var2diffval[diff_var.val] = diff_value.val
    for var, diff in ch_found.ch._var2diff.items():
        if var not in var2diffval:
            var2diffval[var] = diff.val.xreplace(var2diffval)

    # fixing duplicate var problem
    for ssa in subch_ssa:
        for j in range(len(ssa["output_vars"])):
            var_j = ssa["output_vars"][j]
            index_out = 0
            if var_j in ssa["input_vars"]:
                new_var = type(var_j)(var_j.name + "_o" + str(index_out),
                                      var_j.width)
                index_out += 1
                ssa["assignments"].append([new_var, var_j])
                ssa["output_vars"][j] = new_var
                var2diffval[new_var] = var2diffval[var_j]

            for k in range(j + 1, len(ssa["output_vars"])):
                if var_j == ssa["output_vars"][k]:
                    new_var = type(var_j)(var_j.name + "_o" + str(index_out),
                                          var_j.width)
                    index_out += 1
                    ssa["assignments"].append([new_var, var_j])
                    ssa["output_vars"][k] = new_var
                    var2diffval[new_var] = var2diffval[var_j]

    total_empirical_weight = 0
    for i, ssa in enumerate(subch_ssa):
        ccode = ssa2ccode(ssa, ch_found.ch.diff_type)

        if verbose_lvl >= 2:
            smart_print("- sub-characteristic {}".format(i))
        if verbose_lvl >= 3:
            smart_print(
                "  - ssa:",
                subch_ssa[i])  # pprint.pformat(list_ssa[i], width=100))
            smart_print("  - listdiffder:", subch_listdiffder[i]
                        )  # pprint.pformat(ssa_ders[i], width=100))
        if debug:
            smart_print(ccode[0])
            smart_print(ccode[1])
            smart_print()

        input_diff_c = [v.xreplace(var2diffval) for v in ssa["input_vars"]]
        output_diff_c = [v.xreplace(var2diffval) for v in ssa["output_vars"]]

        if verbose_lvl >= 2:
            smart_print("  - checking {} -> {} with weight {}".format(
                '|'.join([str(d) for d in input_diff_c]),
                '|'.join([str(d) for d in output_diff_c]), ssa["weight"]))

        input_diff_c = [int(d.val) for d in input_diff_c]
        output_diff_c = [int(d.val) for d in output_diff_c]

        assert all(isinstance(d, (int, core.Constant))
                   for d in input_diff_c), "{}".format(input_diff_c)
        assert all(isinstance(d, (int, core.Constant))
                   for d in output_diff_c), "{}".format(output_diff_c)

        current_empirical_weight = compile_run_empirical_weight(
            ccode,
            "_libver" + ch_found.ch.func.__name__ + str(i),
            input_diff_c,
            output_diff_c,
            ssa["weight"],
            verbose=verbose_lvl >= 4)

        if verbose_lvl >= 2:
            smart_print("  - exact/empirical weight: {}, {}".format(
                ssa["weight"], current_empirical_weight))

        if current_empirical_weight == math.inf:
            return math.inf

        total_empirical_weight += current_empirical_weight

    return total_empirical_weight
Beispiel #7
0
def fast_empirical_weight(id_found, verbose_lvl=0, debug=False, filename=None):
    """Computes the empirical weight of the model using C code.

    If ``filename`` is not ``None``, the output will be printed
    to the given file rather than the to stdout.

    The argument ``verbose_lvl`` can take an integer between
    ``0`` (no verbose) and ``3`` (full verbose).

        >>> from arxpy.differential.difference import XorDiff, RXDiff
        >>> from arxpy.differential.characteristic import BvCharacteristic
        >>> from arxpy.primitives.chaskey import ChaskeyPi
        >>> from arxpy.smt.search_impossible import SearchID
        >>> from arxpy.smt.verification_impossible import fast_empirical_weight
        >>> ChaskeyPi.set_rounds(2)
        >>> ch = BvCharacteristic(ChaskeyPi, XorDiff, ["dv0", "dv1", "dv2", "dv3"])
        >>> search_problem = SearchID(ch)
        >>> id_found = search_problem.solve(2)
        >>> fast_empirical_weight(id_found)
        inf
        >>> ch = BvCharacteristic(ChaskeyPi, RXDiff, ["dv0", "dv1", "dv2", "dv3"])
        >>> search_problem = SearchID(ch)
        >>> id_found = search_problem.solve(2)
        >>> fast_empirical_weight(id_found)
        inf

    """
    from arxpy.smt.search_differential import _get_smart_print  # avoid cyclic imports

    smart_print = _get_smart_print(filename)

    if debug:
        smart_print("Symbolic characteristic:")
        smart_print(id_found.ch)
        smart_print("ID found:")
        smart_print(id_found)
        smart_print()

    assert len(id_found.ch.nonlinear_diffs.items()) > 0

    ssa = id_found.ch.ssa.copy()
    ssa["assignments"] = list(ssa["assignments"])
    ssa["output_vars"] = list(ssa["output_vars"])

    # fixing duplicate var problem
    var2diffval = {}
    for diff_var, diff_value in itertools.chain(id_found.input_diff, id_found.output_diff):
        var2diffval[diff_var.val] = diff_value.val

    for j in range(len(ssa["output_vars"])):
        var_j = ssa["output_vars"][j]
        index_out = 0
        if var_j in ssa["input_vars"]:
            new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
            index_out += 1
            ssa["assignments"].append([new_var, var_j])
            ssa["output_vars"][j] = new_var
            var2diffval[new_var] = var2diffval[var_j]

        for k in range(j + 1, len(ssa["output_vars"])):
            if var_j == ssa["output_vars"][k]:
                new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
                index_out += 1
                ssa["assignments"].append([new_var, var_j])
                ssa["output_vars"][k] = new_var
                var2diffval[new_var] = var2diffval[var_j]

    ccode = ssa2ccode(ssa, id_found.ch.diff_type)

    if verbose_lvl >= 3:
        smart_print("  - ssa:", ssa)  # pprint.pformat(ssa, width=100))
    if debug:
        smart_print(ccode[0])
        smart_print(ccode[1])
        smart_print()

    input_diff_c = [v.xreplace(var2diffval) for v in ssa["input_vars"]]
    output_diff_c = [v.xreplace(var2diffval) for v in ssa["output_vars"]]

    if verbose_lvl >= 2:
        smart_print("  - checking {} -> {} pairs 2**{}".format(
            '|'.join([str(d) for d in input_diff_c]), '|'.join([str(d) for d in output_diff_c]),
            MAX_WEIGHT))

    input_diff_c = [int(d.val) for d in input_diff_c]
    output_diff_c = [int(d.val) for d in output_diff_c]

    assert all(isinstance(d, (int, core.Constant)) for d in input_diff_c), "{}".format(input_diff_c)
    assert all(isinstance(d, (int, core.Constant)) for d in output_diff_c), "{}".format(output_diff_c)

    current_empirical_weight = compile_run_empirical_weight(
        ccode,
        "_libver" + id_found.ch.func.__name__,
        input_diff_c,
        output_diff_c,
        MAX_WEIGHT,
        verbose=verbose_lvl >= 4)

    if verbose_lvl >= 2:
        smart_print("  - empirical weight: {}".format(current_empirical_weight))

    if current_empirical_weight == math.inf:
        return math.inf
    else:
        return current_empirical_weight
Beispiel #8
0
def _fast_empirical_weight_distribution(ch_found, cipher, rk_dict_diffs=None,
                                        verbose_lvl=0, debug=False, filename=None, precision=0):
    """

        >>> from arxpy.differential.difference import XorDiff
        >>> from arxpy.differential.characteristic import SingleKeyCh
        >>> from arxpy.smt.search_impossible import SearchSkID
        >>> from arxpy.primitives import speck
        >>> from arxpy.smt.verification_impossible import _fast_empirical_weight_distribution
        >>> Speck32 = speck.get_Speck_instance(speck.SpeckInstance.speck_32_64)
        >>> Speck32.set_rounds(1)
        >>> ch = SingleKeyCh(Speck32, XorDiff)
        >>> search_problem = SearchSkID(ch)
        >>> id_found = search_problem.solve(2)
        >>> _fast_empirical_weight_distribution(id_found, Speck32)
        Counter({inf: 256})

    """
    if rk_dict_diffs is not None:
        raise ValueError("rk_dict_diffs must be None")

    from arxpy.smt.search_differential import _get_smart_print  # avoid cyclic imports

    smart_print = _get_smart_print(filename)

    # if rk_dict_diffs is not None:
    #     assert "nonlinear_diffs" in rk_dict_diffs and "output_diff" in rk_dict_diffs

    if debug:
        smart_print("Symbolic characteristic:")
        smart_print(ch_found.ch)
        smart_print("ID found:")
        smart_print(ch_found)
        # if rk_dict_diffs is not None:
        #     smart_print("rk_dict_diffs:", rk_dict_diffs)
        smart_print()

    # if rk_dict_diffs is not None:
    #     rk_var = [var.val for var, _ in rk_dict_diffs["output_diff"]]
    # else:
    rk_var = []
    for i, width in enumerate(cipher.key_schedule.output_widths):
        rk_var.append(core.Variable("k" + str(i), width))

    var2diffval = {}
    for diff_var, diff_value in itertools.chain(ch_found.input_diff, ch_found.output_diff):
        var2diffval[diff_var.val] = diff_value.val
    # if rk_dict_diffs is not None:
    #     for var, diff in rk_dict_diffs["output_diff"]:
    #         var2diffval[var.val] = diff.val

    # for each related-key pair, we associated a pair of ssa
    rkey2pair_ssa = [None for _ in range(KEY_SAMPLES)]
    for key_index in range(KEY_SAMPLES):
        master_key = []
        for width in cipher.key_schedule.input_widths:
            master_key.append(core.Constant(random.randrange(2 ** width), width))
        rk_val = cipher.key_schedule(*master_key)
        # if rk_dict_diffs is not None:
        #     rk_other_val = tuple([d.get_pair_element(r) for r, (_, d) in zip(rk_val, rk_dict_diffs["output_diff"])])
        # else:
        rk_other_val = rk_val
        assert len(rk_var) == len(rk_other_val)
        assert all(isinstance(rk, core.Constant) for rk in rk_val)
        assert all(isinstance(rk, core.Constant) for rk in rk_other_val)

        def replace_roundkeys(var2val):
            new_ssa = ch_found.ch.ssa.copy()
            new_ssa["assignments"] = list(new_ssa["assignments"])
            new_ssa["output_vars"] = list(new_ssa["output_vars"])

            for i, (var, expr) in enumerate(ch_found.ch.ssa["assignments"]):
                new_ssa["assignments"][i] = (var, expr.xreplace(var2val))

            return new_ssa

        pair_ssa = []
        for index_pair in range(2):
            current_rk_val = rk_val if index_pair == 0 else rk_other_val
            rkvar2rkval = {var: val for var, val in zip(rk_var, current_rk_val)}
            ssa = replace_roundkeys(rkvar2rkval)

            for j in range(len(ssa["output_vars"])):
                var_j = ssa["output_vars"][j]
                index_out = 0
                if var_j in ssa["input_vars"]:
                    new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
                    index_out += 1
                    ssa["assignments"].append([new_var, var_j])
                    ssa["output_vars"][j] = new_var
                    var2diffval[new_var] = var2diffval[var_j]

                for k in range(j + 1, len(ssa["output_vars"])):
                    if var_j == ssa["output_vars"][k]:
                        new_var = type(var_j)(var_j.name + "_o" + str(index_out), var_j.width)
                        index_out += 1
                        ssa["assignments"].append([new_var, var_j])
                        ssa["output_vars"][k] = new_var
                        var2diffval[new_var] = var2diffval[var_j]

            pair_ssa.append(ssa)

        rkey2pair_ssa[key_index] = pair_ssa

    # for each related-key pair, we associated their weight
    rkey2subch_ew = [0 for _ in range(KEY_SAMPLES)]

    # start multiprocessing
    with multiprocessing.Pool() as pool:
        for key_index in range(KEY_SAMPLES):
            ssa1 = rkey2pair_ssa[key_index][0]
            ssa2 = rkey2pair_ssa[key_index][1]

            if key_index <= 1:
                if verbose_lvl >= 3:
                    smart_print("  - related-key pair index", key_index)
                    smart_print("  - ssa1:", ssa1)
                    if ssa1 == ssa2:
                        smart_print("  - ssa2: (same as ssa1)")
                    else:
                        smart_print("  - ssa2:", ssa2)

            if ssa1 == ssa2:
                ccode = ssa2ccode(ssa1, ch_found.ch.diff_type)
            else:
                ccode = relatedssa2ccode(ssa1, ssa2, ch_found.ch.diff_type)

            if key_index <= 1 and debug:
                smart_print(ccode[0])
                smart_print(ccode[1])
                smart_print()

            input_diff_c = [v.xreplace(var2diffval) for v in ssa1["input_vars"]]
            output_diff_c = [v.xreplace(var2diffval) for v in ssa1["output_vars"]]

            if key_index <= 1 and verbose_lvl >= 2:
                smart_print("  - rk{} | checking {} -> {} with pairs 2**{}".format(
                    key_index,
                    '|'.join([str(d) for d in input_diff_c]), '|'.join([str(d) for d in output_diff_c]),
                    MAX_WEIGHT))

            assert all(isinstance(d, (int, core.Constant)) for d in input_diff_c), "{}".format(input_diff_c)
            assert all(isinstance(d, (int, core.Constant)) for d in output_diff_c), "{}".format(output_diff_c)

            input_diff_c = [int(d) for d in input_diff_c]
            output_diff_c = [int(d) for d in output_diff_c]

            rkey2subch_ew[key_index] = pool.apply_async(
                compile_run_empirical_weight,
                (
                    ccode,
                    "_libver" + ch_found.ch.func.__name__,
                    input_diff_c,
                    output_diff_c,
                    MAX_WEIGHT,
                    False
                )
            )

        # wait until all have been compiled and run
        # and replace the Async object by the result
        for key_index in range(KEY_SAMPLES):
            if isinstance(rkey2subch_ew[key_index], multiprocessing.pool.AsyncResult):
                rkey2subch_ew[key_index] = rkey2subch_ew[key_index].get()

            if key_index <= 1 and verbose_lvl >= 2:
                smart_print("  - rk{} | empirical weight: {}".format(
                    key_index, rkey2subch_ew[key_index]))

    # end multiprocessing

    empirical_weight_distribution = collections.Counter()
    all_rkey_weights = []
    for key_index in range(KEY_SAMPLES):
        rkey_weight = rkey2subch_ew[key_index]

        if precision == 0:
            weight = int(rkey_weight) if rkey_weight != math.inf else math.inf
        else:
            weight = round(rkey_weight, precision)

        all_rkey_weights.append(rkey_weight)
        empirical_weight_distribution[weight] += 1

    if verbose_lvl >= 2:
        smart_print("- distribution empirical weights: {}".format(empirical_weight_distribution))

    if verbose_lvl >= 3:
        smart_print("- list empirical weights:", [round(x, 8) for x in all_rkey_weights if x != math.inf])

    return empirical_weight_distribution