Exemplo n.º 1
0
    def eval(cls, bv):
        # Source: Hacker's Delight

        if bv.width < 4:
            w = cls.output_width(bv)
            return sum([operation.ZeroExtend(bv[i], w-1) for i in range(bv.width)])

        # extend the bv until power of 2 length
        original_width = bv.width
        while (bv.width & (bv.width - 1)) != 0:
            bv = operation.ZeroExtend(bv, 1)
        width_log2 = bv.width.bit_length() - 1

        m_ctes = []
        for i in range(width_log2):
            m_ctes.append(repeat_pattern(pattern01(2 ** i), bv.width))

        if bv.width > 32:
            for i, m in enumerate(m_ctes):
                bv = (bv & m) + ((bv >> core.Constant(2 ** i, bv.width)) & m)
            return bv[original_width.bit_length() - 1:]

        for i, m in enumerate(m_ctes):
            if i == 0:
                bv = bv - ((bv >> core.Constant(1, bv.width)) & m)
            elif i == 1:
                bv = (bv & m) + ((bv >> core.Constant(2 ** i, bv.width)) & m)  # generic case
            elif i == 2:
                bv = (bv + (bv >> core.Constant(4, bv.width))) & m
            elif i == 3:
                bv = bv + (bv >> core.Constant(8, bv.width))
            elif i == 4:
                bv = bv + (bv >> core.Constant(16, bv.width))

        return bv[original_width.bit_length() - 1:]
Exemplo n.º 2
0
 def smart_add(x, y):
     if x.width == y.width:
         return x + y
     elif x.width < y.width:
         return operation.ZeroExtend(x, y.width - x.width) + y
     else:
         return x + operation.ZeroExtend(y, x.width - y.width)
Exemplo n.º 3
0
 def smart_sub(x, y):
     # cannot be replaced by smart_add(x, -y)
     if x.width == y.width:
         return x - y
     elif x.width < y.width:
         return operation.ZeroExtend(x, y.width - x.width) - y
     else:
         return x - operation.ZeroExtend(y, x.width - y.width)
Exemplo n.º 4
0
    def eval(cls, x, y, z):
        # Source: Hacker's Delight

        if x.width < 4:
            # the HW of a 1-bit/2-bit vector requires 1-bit/2-bit (HW(0b1)=0b1, HW(0b11)=0b10)
            # thus, the sum of three HW of these sizes require an extra bit (3*0b1=0b11, 3*0b10=0b110)
            # the HW of a 3-bit vector requires 2-bit (Hw(0b111)=0b11)
            # but the sum of three HW of 3-bit require two extra bit (3*0b11 = 0b1001)
            offset = 1
            if x.width == 3:
                offset = 2
            x = operation.ZeroExtend(PopCount(x), offset)
            y = operation.ZeroExtend(PopCount(y), offset)
            z = operation.ZeroExtend(PopCount(z), offset)
            return x + y + z
        elif x.width > 32:
            width = cls.output_width(x, y, z)
            x = PopCount(x)
            x = operation.ZeroExtend(x, width - x.width)
            y = PopCount(y)
            y = operation.ZeroExtend(y, width - y.width)
            z = PopCount(z)
            z = operation.ZeroExtend(z, width - z.width)
            return x + y + z

        orig_x, orig_y, orig_z = x, y, z
        while (x.width & (x.width - 1)) != 0:
            x = operation.ZeroExtend(x, 1)
        while (y.width & (y.width - 1)) != 0:
            y = operation.ZeroExtend(y, 1)
        while (z.width & (z.width - 1)) != 0:
            z = operation.ZeroExtend(z, 1)
        width_log2 = x.width.bit_length() - 1

        m_ctes = []
        for i in range(width_log2):
            m_ctes.append(repeat_pattern(pattern01(2 ** i), x.width))

        bv = core.Constant(0, x.width)
        for i, m in enumerate(m_ctes):
            if i == 0:
                x = x - ((x >> core.Constant(1, bv.width)) & m)
                y = y - ((y >> core.Constant(1, bv.width)) & m)
                z = z - ((z >> core.Constant(1, bv.width)) & m)
                bv = x + y + z
            elif i == 1:
                x = (x & m) + ((x >> core.Constant(2 ** i, bv.width)) & m)  # generic case
                y = (y & m) + ((y >> core.Constant(2 ** i, bv.width)) & m)
                z = (z & m) + ((z >> core.Constant(2 ** i, bv.width)) & m)
                bv = x + y + z
            elif i == 2:
                bv = (bv & m) + ((bv >> core.Constant(4, bv.width)) & m)
            elif i == 3:
                bv = bv + (bv >> core.Constant(8, bv.width))
            elif i == 4:
                bv = bv + (bv >> core.Constant(16, bv.width))

        return bv[cls.output_width(orig_x, orig_y, orig_z) - 1:]
Exemplo n.º 5
0
    def eval(cls, bv):
        # Source: Hacker's Delight

        if bv.width == 1:
            return bv
        elif bv.width == 2:
            return operation.RotateLeft(bv, 1)
        elif bv.width == 3:
            return operation.Concat(operation.Concat(bv[0], bv[1]), bv[2])

        original_width = bv.width
        while (bv.width & (bv.width - 1)) != 0:
            bv = operation.ZeroExtend(bv, 1)
        width_log2 = bv.width.bit_length() - 1

        m_ctes = []
        for i in range(width_log2):
            m_ctes.append(repeat_pattern(pattern01(2 ** i), bv.width))

        if bv.width > 32:
            for i, m in list(enumerate(m_ctes)):
                bv = ((bv & m) << core.Constant(2 ** i, bv.width)) | ((bv >> core.Constant(2 ** i, bv.width)) & m)
            return bv[:bv.width - original_width]

        for i, m in list(enumerate(m_ctes))[:3]:
            bv = ((bv & m) << core.Constant(2 ** i, bv.width)) | ((bv >> core.Constant(2 ** i, bv.width)) & m)  # generic case

        if len(m_ctes) == 4:
            bv = ((bv & m_ctes[3]) << core.Constant(8, bv.width)) | ((bv >> core.Constant(8, bv.width)) & m_ctes[3])
        elif len(m_ctes) == 5:
            rol = operation.RotateLeft
            ror = operation.RotateRight
            bv = ror(bv & m_ctes[3], 8) | (rol(bv, 8) & m_ctes[3])

        return bv[:bv.width - original_width]
Exemplo n.º 6
0
    def eval(cls, bv):
        # Source: Hacker's Delight

        if bv.width == 1:
            return ~bv

        original_width = bv.width
        while (bv.width & (bv.width - 1)) != 0:
            bv = operation.ZeroExtend(bv, 1)
        width_log2 = bv.width.bit_length() - 1

        for i in range(width_log2):
            bv = bv | (bv >> core.Constant(2 ** i, bv.width))
        return ~bv[original_width - 1:]
Exemplo n.º 7
0
    def eval(cls, x, y):
        # Source: Hacker's Delight

        if x.width < 4:
            return PopCount(x) - PopCount(y)
        elif x.width > 32:
            return PopCount(x) - PopCount(y)

        orig_x, orig_y = x, y
        while (x.width & (x.width - 1)) != 0:
            x = operation.ZeroExtend(x, 1)
        while (y.width & (y.width - 1)) != 0:
            y = operation.ZeroExtend(y, 1)
        width_log2 = x.width.bit_length() - 1

        m_ctes = []
        for i in range(width_log2):
            m_ctes.append(repeat_pattern(pattern01(2 ** i), x.width))

        bv = core.Constant(0, x.width)
        for i, m in enumerate(m_ctes):
            if i == 0:
                x = x - ((x >> core.Constant(1, bv.width)) & m)
                y = (~y) - (((~y) >> core.Constant(1, bv.width)) & m)
                bv = x + y
            elif i == 1:
                x = (x & m) + ((x >> core.Constant(2 ** i, bv.width)) & m)  # generic case
                y = (y & m) + ((y >> core.Constant(2 ** i, bv.width)) & m)
                bv = x + y
            elif i == 2:
                bv = (bv & m) + ((bv >> core.Constant(4, bv.width)) & m)
            elif i == 3:
                bv = bv + (bv >> core.Constant(8, bv.width))
            elif i == 4:
                bv = bv + (bv >> core.Constant(16, bv.width))

        return (bv - y.width)[cls.output_width(orig_x, orig_y) - 1:]
Exemplo n.º 8
0
    def weight(self):
        """Return the weight of the differential.

            >>> from arxpy.bitvector.core import Constant
            >>> from arxpy.diffcrypt.difference import DiffVar
            >>> from arxpy.diffcrypt.differential import RXDBvAdd
            >>> a, b, c = DiffVar("a", 8), DiffVar("b", 8), DiffVar("c", 8)
            >>> rxda = RXDBvAdd([a, b], c)
            >>> rxda.weight()  # doctest: +ELLIPSIS
            (0b00010 * (0b00 ∘ (((0x0f & ((0x33 & ((0x55 & ((0b00 ...
            >>> zero = Constant(0, 8)
            >>> rxda.weight().xreplace({a: zero, b: zero, c: zero})
            0b00011

        """
        alpha, beta = self.input_diff
        gamma = self.output_diff
        da, db, dc = alpha[:1], beta[:1], gamma[:1]
        # da, db, dc = alpha >> 1, beta >> 1, gamma >> 1

        rhs = ((da ^ dc) | (db ^ dc)) << 1
        hw = _HammingWeight(rhs[:1])  # ignore LSB

        max_hw = rhs.width - 1
        weight_width = max((2 * max_hw + 6).bit_length(), hw.width, 3)  # 0b110

        # let lhs = LSB(lhs) = da ^ db ^ dc
        #     rhs = LSB(rhs) = 0
        # case A (2 * 1.415): lhs => rhs
        # case B (2 * 3):     lhs ^ 1 => rhs

        def bitwise_implication(x, y):
            return (~x) | y

        cte_part = operation.Ite(
            bitwise_implication(da[0] ^ db[0] ^ dc[0], core.Constant(0, 1)),
            core.Constant(3, weight_width), core.Constant(6, weight_width))

        hw_extend = operation.ZeroExtend(hw, weight_width - hw.width)

        return 2 * hw_extend + cte_part
Exemplo n.º 9
0
    def eval(cls, x, y):
        # Source: Hacker's Delight

        if x.width < 4:
            # the HW of a 1-bit/2-bit vector requires 1-bit/2-bit (HW(0b1)=0b1, HW(0b11)=0b10)
            # thus, the sum of two HW of these sizes require an extra bit
            # the HW of a 3-bit vector requires 2-bit (Hw(0b111)=0b11)
            # and the sum of two HW of 3-bit also require an extra bit
            return operation.ZeroExtend(PopCount(x), 1) + operation.ZeroExtend(PopCount(y), 1)
        elif x.width > 32:
            width = cls.output_width(x, y)
            x = PopCount(x)
            x = operation.ZeroExtend(x, width - x.width)
            y = PopCount(y)
            y = operation.ZeroExtend(y, width - y.width)
            return x + y

        orig_x, orig_y = x, y
        while (x.width & (x.width - 1)) != 0:
            x = operation.ZeroExtend(x, 1)
        while (y.width & (y.width - 1)) != 0:
            y = operation.ZeroExtend(y, 1)
        width_log2 = x.width.bit_length() - 1

        m_ctes = []
        for i in range(width_log2):
            m_ctes.append(repeat_pattern(pattern01(2 ** i), x.width))

        bv = core.Constant(0, x.width)
        for i, m in enumerate(m_ctes):
            if i == 0:
                x = x - ((x >> core.Constant(1, bv.width)) & m)
                y = y - ((y >> core.Constant(1, bv.width)) & m)
                bv = x + y
            elif i == 1:
                x = (x & m) + ((x >> core.Constant(2 ** i, bv.width)) & m)  # generic case
                y = (y & m) + ((y >> core.Constant(2 ** i, bv.width)) & m)
                bv = x + y
            elif i == 2:
                bv = (bv & m) + ((bv >> core.Constant(4, bv.width)) & m)
            elif i == 3:
                bv = bv + (bv >> core.Constant(8, bv.width))
            elif i == 4:
                bv = bv + (bv >> core.Constant(16, bv.width))

        return bv[cls.output_width(orig_x, orig_y) - 1:]
Exemplo n.º 10
0
    def eval(cls, bv):
        def bv_pattern(pattern, width):
            """Repeat the pattern until obtain a bv of given width."""
            assert width % pattern.width == 0
            return operation.Repeat(pattern, width // pattern.width)

        def simple_pattern(width):
            """Obtain the pattern 0...01...1 with given 0-width."""
            zeroes = core.Constant(0, width)
            return operation.Concat(zeroes, ~zeroes)

        original_width = bv.width
        while (bv.width & (bv.width - 1)) != 0:
            bv = operation.ZeroExtend(bv, 1)
        width_log2 = bv.width.bit_length() - 1

        m_ctes = []
        for i in range(width_log2):
            m_ctes.append(bv_pattern(simple_pattern(2**i), bv.width))

        for i, m in enumerate(m_ctes):
            bv = (bv & m) + ((bv >> 2**i) & m)

        return bv[original_width.bit_length() - 1:]
Exemplo n.º 11
0
Arquivo: types.py Projeto: ranea/ArxPy
def bv2pysmt(bv, boolean=False, strict_shift=False, env=None):
    """Convert a bit-vector type to a pySMT type.

    Args:
        bv: the bit-vector `Term` to convert
        boolean: if True, boolean pySMT types (e.g., `pysmt.shortcuts.Bool`) are used instead of
            bit-vector pySMT types (e.g., `pysmt.shortcuts.BV`).
        strict_shift: if `True`, shifts and rotation by non-power-of-two offsets
            are power of two are translated to pySMT's shifts and
            rotation directly.
        env: a `pysmt.environment.Environment`; if not specified, a new pySMT environment is created.
    ::

        >>> from arxpy.bitvector.core import Constant, Variable
        >>> from arxpy.smt.types import bv2pysmt
        >>> s = bv2pysmt(Constant(0b00000001, 8), boolean=False)
        >>> s, s.get_type()
        (1_8, BV{8})
        >>> x, y = Variable("x", 8), Variable("y", 8)
        >>> s = bv2pysmt(x)
        >>> s, s.get_type()
        (x, BV{8})
        >>> s = bv2pysmt(x +  y)
        >>> s, s.get_type()
        ((x + y), BV{8})
        >>> s = bv2pysmt(x <=  y)
        >>> s, s.get_type()
        ((x u<= y), Bool)
        >>> s = bv2pysmt(x[4: 2])
        >>> s, s.get_type()
        (x[2:4], BV{3})

    """
    msg = "unknown conversion of {} to a pySMT type".format(type(bv).__name__)

    if env is None:
        env = environment.reset_env()
    fm = env.formula_manager

    if isinstance(bv, int):
        return bv

    pysmt_bv = None

    if isinstance(bv, core.Variable):
        if boolean:
            assert bv.width == 1
            pysmt_bv = fm.Symbol(bv.name, env.type_manager.BOOL())
        else:
            pysmt_bv = fm.Symbol(bv.name, env.type_manager.BVType(bv.width))

    elif isinstance(bv, core.Constant):
        if boolean:
            assert bv.width == 1
            pysmt_bv = fm.Bool(bool(bv))
        else:
            pysmt_bv = fm.BV(bv.val, bv.width)

    elif isinstance(bv, operation.Operation):
        # only 1st layer can return a boolean
        # Equals and Ite work well with BV, the rest don't

        if issubclass(type(bv), extraop.PartialOperation):
            raise NotImplementedError("PartialOperation is not yet supported")

        if type(bv) == operation.BvNot:
            if boolean:
                assert bv.width == 1
                args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.Not(*args)
            else:
                args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.BVNot(*args)

        elif type(bv) == operation.BvAnd:
            if boolean:
                assert bv.width == 1
                args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.And(*args)
            else:
                args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.BVAnd(*args)

        elif type(bv) == operation.BvOr:
            if boolean:
                assert bv.width == 1
                args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.Or(*args)
            else:
                args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.BVOr(*args)
        elif type(bv) == operation.BvXor:
            if boolean:
                assert bv.width == 1
                args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.Xor(*args)
            else:
                args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.BVXor(*args)
        elif type(bv) == operation.Ite:
            args = [None for _ in range(len(bv.args))]
            # fm.Ite requires a Boolean type for args[0] but
            # bv2pysmt(bv.args[0], True, ...)  caused an error
            # (if args[0] is BvComp, it can be further optimized)
            args[0] = bv2pysmt(bv.args[0], False, strict_shift, env)
            if args[0].get_type().is_bv_type():
                args[0] = fm.Equals(args[0], fm.BV(1, 1))
            if boolean:
                assert bv.width == 1
                args[1:] = [
                    bv2pysmt(a, True, strict_shift, env) for a in bv.args[1:]
                ]
            else:
                args[1:] = [
                    bv2pysmt(a, False, strict_shift, env) for a in bv.args[1:]
                ]
            pysmt_bv = fm.Ite(*args)
        else:
            args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args]

            if type(bv) == operation.BvComp:
                if boolean:
                    pysmt_bv = fm.Equals(*args)
                else:
                    pysmt_bv = fm.BVComp(*args)

            elif type(bv) == operation.BvUlt:
                pysmt_bv = fm.BVULT(*args)

            elif type(bv) == operation.BvUle:
                pysmt_bv = fm.BVULE(*args)

            elif type(bv) == operation.BvUgt:
                pysmt_bv = fm.BVUGT(*args)

            elif type(bv) == operation.BvUge:
                pysmt_bv = fm.BVUGE(*args)

            elif boolean:
                raise ValueError("{} cannot return a boolean type".format(
                    type(bv).__name__))

            elif type(bv) in [operation.BvShl, operation.BvLshr]:
                if not strict_shift or _is_power_of_2(args[0].bv_width()):
                    if type(bv) == operation.BvShl:
                        pysmt_bv = fm.BVLShl(*args)
                    elif type(bv) == operation.BvLshr:
                        pysmt_bv = fm.BVLShr(*args)
                else:
                    x, r = bv.args
                    offset = 0
                    while not _is_power_of_2(x.width):
                        x = operation.ZeroExtend(x, 1)
                        r = operation.ZeroExtend(r, 1)
                        offset += 1

                    shift = bv2pysmt(type(bv)(x, r), False, strict_shift, env)
                    pysmt_bv = fm.BVExtract(shift,
                                            end=shift.bv_width() - offset - 1)

            elif type(bv) == operation.RotateLeft:
                if not strict_shift or _is_power_of_2(args[0].bv_width()):
                    pysmt_bv = fm.BVRol(*args)
                else:
                    # Left hand side width must be a power of 2
                    x, r = bv.args
                    n = x.width
                    pysmt_bv = bv2pysmt(
                        operation.Concat(x[n - r - 1:], x[n - 1:n - r]), False,
                        strict_shift, env)

            elif type(bv) == operation.RotateRight:
                if not strict_shift or _is_power_of_2(args[0].bv_width()):
                    pysmt_bv = fm.BVRor(*args)
                else:
                    # Left hand side width must be a power of 2
                    x, r = bv.args
                    n = x.width
                    pysmt_bv = bv2pysmt(
                        operation.Concat(x[r - 1:], x[n - 1:r]), False,
                        strict_shift, env)

            elif type(bv) == operation.Extract:
                # pySMT Extract(bv, start, end)
                pysmt_bv = fm.BVExtract(args[0], args[2], args[1])

            elif type(bv) == operation.Concat:
                pysmt_bv = fm.BVConcat(*args)

            elif type(bv) == operation.ZeroExtend:
                pysmt_bv = fm.BVZExt(*args)

            elif type(bv) == operation.Repeat:
                pysmt_bv = args[0].BVRepeat(args[1])

            elif type(bv) == operation.BvNeg:
                pysmt_bv = fm.BVNeg(*args)

            elif type(bv) == operation.BvAdd:
                pysmt_bv = fm.BVAdd(*args)

            elif type(bv) == operation.BvSub:
                pysmt_bv = fm.BVSub(*args)

            elif type(bv) == operation.BvMul:
                pysmt_bv = fm.BVMul(*args)

            elif type(bv) == operation.BvUdiv:
                pysmt_bv = fm.BVUDiv(*args)

            elif type(bv) == operation.BvUrem:
                pysmt_bv = fm.BVURem(*args)

            else:
                bv2 = bv.doit()
                assert bv.width == bv2.width, "{} == {}\n{}\n{}".format(
                    bv.width, bv2.width, bv.vrepr(), bv2.vrepr())
                if bv != bv2:  # avoid cyclic loop
                    pysmt_bv = bv2pysmt(bv2,
                                        boolean=boolean,
                                        strict_shift=strict_shift,
                                        env=env)
                else:
                    raise NotImplementedError("(doit) " + msg)

    elif isinstance(bv, difference.Difference) or isinstance(bv, mask.Mask):
        pysmt_bv = bv2pysmt(bv.val, boolean, strict_shift, env)

    if pysmt_bv is not None:
        try:
            pysmt_bv_width = pysmt_bv.bv_width()
        except (AssertionError, TypeError):
            pysmt_bv_width = 1  # boolean type

        assert bv.width == pysmt_bv_width
        return pysmt_bv
    else:
        raise NotImplementedError(msg)
Exemplo n.º 12
0
    def _generate(self):
        """Generate the SMT problem."""
        self.assertions = []

        # Forbid zero input difference with XOR difference

        if self.ch.diff_type == difference.XorDiff:
            if self.parent_ch is not None and self.parent_ch.outer_ch == self.ch:
                inner_noutputs = len(self.parent_ch.inner_ch.output_diff)
                non_zero_input_diff = self.ch.input_diff[:-inner_noutputs]
            else:
                non_zero_input_diff = self.ch.input_diff
            non_zero_input_diff = functools.reduce(operation.Concat,
                                                   non_zero_input_diff)
            zero = core.Constant(0, non_zero_input_diff.width)
            self.assertions.append(
                operation.BvNot(operation.BvComp(non_zero_input_diff, zero)))

        # Assertions of the weights of the non-deterministic steps

        self.op_weights = []
        for var, propagation in self.ch.items():
            if isinstance(propagation, differential.Differential):
                self.assertions.append(propagation.is_valid())
                weight_value = propagation.weight()
                weight_var = core.Variable(propagation._weight_var_name(),
                                           weight_value.width)
                self.assertions.append(
                    operation.BvComp(weight_var, weight_value))
                self.op_weights.append(weight_var)
            else:
                self.assertions.append(operation.BvComp(var, propagation))

        # Characteristic weight assignment

        max_value = 0
        for ow in self.op_weights:
            max_value += (2**ow.width) - 1
        width = max(max_value.bit_length(), 1)  # for trivial characteristic
        ext_op_weights = []
        for ow in self.op_weights:
            ext_op_weights.append(operation.ZeroExtend(ow, width - ow.width))

        name_ch_weight = "w_{}_{}".format(
            ''.join([str(i) for i in self.ch.input_diff]),
            ''.join([str(i) for i in self.ch.output_diff]))
        ch_weight = core.Variable(name_ch_weight, width)

        self.assertions.append(operation.BvComp(ch_weight,
                                                sum(ext_op_weights)))

        # Condition between the weight and the target weight

        weight_function = self.ch.get_weight_function()
        target_weight = int(weight_function(self.target_weight))

        width = max(ch_weight.width, target_weight.bit_length())
        self.ch_weight = operation.ZeroExtend(ch_weight,
                                              width - ch_weight.width)

        if self.equality:
            self.assertions.append(
                operation.BvComp(self.ch_weight, target_weight))
        else:
            self.assertions.append(
                operation.BvUlt(self.ch_weight, target_weight))

        self.assertions = tuple(self.assertions)
Exemplo n.º 13
0
def bv2pysmt(bv):
    """Convert a bit-vector type to a pySMT type.

        >>> from arxpy.bitvector.core import Constant, Variable
        >>> from arxpy.diffcrypt.smt import bv2pysmt
        >>> bv2pysmt(Constant(0b00000001, 8))
        1_8
        >>> x, y = Variable("x", 8), Variable("y", 8)
        >>> bv2pysmt(x)
        x
        >>> bv2pysmt(x +  y)
        (x + y)
        >>> bv2pysmt(x <=  y)
        (x u<= y)
        >>> bv2pysmt(x[4: 2])
        x[2:4]

    """
    msg = "unknown conversion of {} to a pySMT type".format(type(bv).__name__)

    if isinstance(bv, int):
        return bv

    if isinstance(bv, core.Variable):
        return sc.Symbol(bv.name, typing.BVType(bv.width))

    if isinstance(bv, core.Constant):
        return sc.BV(bv.val, bv.width)

    if isinstance(bv, operation.Operation):
        args = [bv2pysmt(a) for a in bv.args]

        if type(bv) == operation.BvNot:
            if args[0].is_equals():
                return sc.Not(*args)
            else:
                return sc.BVNot(*args)

        if type(bv) == operation.BvAnd:
            return sc.BVAnd(*args)

        if type(bv) == operation.BvOr:
            return sc.BVOr(*args)

        if type(bv) == operation.BvXor:
            return sc.BVXor(*args)

        if type(bv) == operation.BvComp:
            # return sc.BVComp(*args)
            return sc.Equals(*args)

        if type(bv) == operation.BvUlt:
            return sc.BVULT(*args)

        if type(bv) == operation.BvUle:
            return sc.BVULE(*args)

        if type(bv) == operation.BvUgt:
            return sc.BVUGT(*args)

        if type(bv) == operation.BvUge:
            return sc.BVUGE(*args)

        if type(bv) == operation.BvShl:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVLShl(*args)
            else:
                x, r = bv.args
                offset = 0
                while (x.width & (x.width - 1)) != 0:
                    x = operation.ZeroExtend(x, 1)
                    r = operation.ZeroExtend(r, 1)
                    offset += 1

                shift = bv2pysmt(x << r)
                return sc.BVExtract(shift, end=shift.bv_width() - offset - 1)
            # width = args[0].bv_width()
            # assert (width & (width - 1)) == 0  # power of 2
            # return sc.BVLShl(*args)

        if type(bv) == operation.BvLshr:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVLShr(*args)
            else:
                x, r = bv.args
                offset = 0
                while (x.width & (x.width - 1)) != 0:
                    x = operation.ZeroExtend(x, 1)
                    r = operation.ZeroExtend(r, 1)
                    offset += 1

                shift = bv2pysmt(x >> r)
                return sc.BVExtract(shift, end=shift.bv_width() - offset - 1)
            # width = args[1].bv_width()
            # assert (width & (width - 1)) == 0  # power of 2
            # return sc.BVLShr(*args)

        if type(bv) == operation.RotateLeft:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVRol(*args)
            else:
                x, r = bv.args
                n = x.width
                return bv2pysmt(operation.Concat(x[n - r - 1:],
                                                 x[n - 1:n - r]))

        if type(bv) == operation.RotateRight:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVRor(*args)
            else:
                x, r = bv.args
                n = x.width
                return bv2pysmt(operation.Concat(x[r - 1:], x[n - 1:r]))

        if type(bv) == operation.Ite:
            if args[0].is_equals():
                a0 = args[0]
            else:
                a0 = sc.Equals(args[0], bv2pysmt(core.Constant(1, 1)))

            return sc.Ite(a0, *args[1:])

        if type(bv) == operation.Extract:
            return sc.BVExtract(args[0], args[2], args[1])

        if type(bv) == operation.Concat:
            return sc.BVConcat(*args)

        if type(bv) == operation.ZeroExtend:
            return sc.BVZExt(*args)

        if type(bv) == operation.Repeat:
            return args[0].BVRepeat(args[1])

        if type(bv) == operation.BvNeg:
            return sc.BVNeg(*args)

        if type(bv) == operation.BvAdd:
            return sc.BVAdd(*args)

        if type(bv) == operation.BvSub:
            return sc.BVSub(*args)

        if type(bv) == operation.BvMul:
            return sc.BVMul(*args)

        if type(bv) == operation.BvMul:
            return sc.BVMul(*args)

        if type(bv) == operation.BvUdiv:
            return sc.BVUDiv(*args)

        if type(bv) == operation.BvUrem:
            return sc.BVURem(*args)

        raise NotImplementedError(msg)
Exemplo n.º 14
0
    def weight(self, output_diff):
        """Return the weight of a possible output `RXDiff`.

            >>> from arxpy.bitvector.core import Constant, Variable
            >>> from arxpy.differential.difference import RXDiff
            >>> from arxpy.differential.derivative import RXDA
            >>> n = 4
            >>> alpha = RXDiff(Constant(0, n)), RXDiff(Constant(0, n))
            >>> f = RXDA(alpha)
            >>> f.weight(RXDiff(Constant(0, n)))
            0b001001
            >>> a0, a1, b = Variable("a0", n), Variable("a1", n), Variable("b", n)
            >>> alpha = RXDiff(a0), RXDiff(a1)
            >>> f = RXDA(alpha)
            >>> result = f.weight(RXDiff(b))
            >>> result  # doctest:+NORMALIZE_WHITESPACE
            ((0x0 :: ((0b0 :: ((((a0[:1]) ^ (b[:1])) | ((a1[:1]) ^ (b[:1])))[0])) + (0b0 :: ((((a0[:1]) ^ (b[:1])) |
            ((a1[:1]) ^ (b[:1])))[1])))) << 0b000011) + (Ite(~((a0[1]) ^ (a1[1]) ^ (b[1])), 0b001001, 0b011000))
            >>> result.xreplace({a0: Constant(0, n), a1: Constant(0, n), b: Constant(0, n)})
            0b001001

        See `Derivative.weight` for more information.
        """
        # one = core.Constant(1, self.input_diff[0].val.width)  # alt v1
        # two = core.Constant(2, self.input_diff[0].val.width)

        alpha, beta = [d.val for d in self.input_diff]
        gamma = output_diff.val
        da, db, dc = alpha[:1], beta[:1], gamma[:1]  # alt v1
        # da, db, dc = alpha >> one, beta >> one, gamma >> one  # alt v1

        rhs = ((da ^ dc) | (db ^ dc))[da.width-2:]  # equiv to shift left
        # rhs = ((da ^ dc) | (db ^ dc)) << two  # alt v1
        hw = extraop.PopCount(rhs)

        max_hw = rhs.width - 1

        # (max_hw + 3) = maximum integer part
        k = self.__class__.precision  # num fraction bits
        weight_width = (max_hw + 3).bit_length() + k

        # let lhs = LSB(lhs) = da ^ db ^ dc
        #     rhs = LSB(rhs) = 0
        # case A (w=1.415): lhs => rhs
        # case B (w=3):     lhs ^ 1 => rhs
        # 1.415 = -log2(pr propagation of a rotational pair with offset 1)
        # bin(1.415) = 1.01101010001111010111

        n = alpha.width
        w_rotational_pair = -(math.log2((1 + 2**(1 - n) + 0.5 + 2 ** (-n))) - 2)
        w_rotational_pair = int(self.__class__.decimal2bin(w_rotational_pair, k), base=2)

        def bitwise_implication(x, y):
            return (~x) | y

        cte_part = operation.Ite(
            bitwise_implication(da[0] ^ db[0] ^ dc[0], core.Constant(0, 1)),
            core.Constant(w_rotational_pair, weight_width),
            core.Constant(3, weight_width) << k
        )

        hw_extend = operation.ZeroExtend(hw, weight_width - hw.width)

        return (hw_extend << k) + cte_part