Exemple #1
0
    def eval(cls, bv):
        # Source: Hacker's Delight

        if bv.width == 1:
            return bv
        elif bv.width == 2:
            return operation.RotateLeft(bv, 1)
        elif bv.width == 3:
            return operation.Concat(operation.Concat(bv[0], bv[1]), bv[2])

        original_width = bv.width
        while (bv.width & (bv.width - 1)) != 0:
            bv = operation.ZeroExtend(bv, 1)
        width_log2 = bv.width.bit_length() - 1

        m_ctes = []
        for i in range(width_log2):
            m_ctes.append(repeat_pattern(pattern01(2 ** i), bv.width))

        if bv.width > 32:
            for i, m in list(enumerate(m_ctes)):
                bv = ((bv & m) << core.Constant(2 ** i, bv.width)) | ((bv >> core.Constant(2 ** i, bv.width)) & m)
            return bv[:bv.width - original_width]

        for i, m in list(enumerate(m_ctes))[:3]:
            bv = ((bv & m) << core.Constant(2 ** i, bv.width)) | ((bv >> core.Constant(2 ** i, bv.width)) & m)  # generic case

        if len(m_ctes) == 4:
            bv = ((bv & m_ctes[3]) << core.Constant(8, bv.width)) | ((bv >> core.Constant(8, bv.width)) & m_ctes[3])
        elif len(m_ctes) == 5:
            rol = operation.RotateLeft
            ror = operation.RotateRight
            bv = ror(bv & m_ctes[3], 8) | (rol(bv, 8) & m_ctes[3])

        return bv[:bv.width - original_width]
Exemple #2
0
def pattern01(width):
    """Obtain the pattern 0...01...1 with given 0-width."""
    zeroes = core.Constant(0, width)
    return operation.Concat(zeroes, ~zeroes)
Exemple #3
0
def bv2pysmt(bv):
    """Convert a bit-vector type to a pySMT type.

        >>> from arxpy.bitvector.core import Constant, Variable
        >>> from arxpy.diffcrypt.smt import bv2pysmt
        >>> bv2pysmt(Constant(0b00000001, 8))
        1_8
        >>> x, y = Variable("x", 8), Variable("y", 8)
        >>> bv2pysmt(x)
        x
        >>> bv2pysmt(x +  y)
        (x + y)
        >>> bv2pysmt(x <=  y)
        (x u<= y)
        >>> bv2pysmt(x[4: 2])
        x[2:4]

    """
    msg = "unknown conversion of {} to a pySMT type".format(type(bv).__name__)

    if isinstance(bv, int):
        return bv

    if isinstance(bv, core.Variable):
        return sc.Symbol(bv.name, typing.BVType(bv.width))

    if isinstance(bv, core.Constant):
        return sc.BV(bv.val, bv.width)

    if isinstance(bv, operation.Operation):
        args = [bv2pysmt(a) for a in bv.args]

        if type(bv) == operation.BvNot:
            if args[0].is_equals():
                return sc.Not(*args)
            else:
                return sc.BVNot(*args)

        if type(bv) == operation.BvAnd:
            return sc.BVAnd(*args)

        if type(bv) == operation.BvOr:
            return sc.BVOr(*args)

        if type(bv) == operation.BvXor:
            return sc.BVXor(*args)

        if type(bv) == operation.BvComp:
            # return sc.BVComp(*args)
            return sc.Equals(*args)

        if type(bv) == operation.BvUlt:
            return sc.BVULT(*args)

        if type(bv) == operation.BvUle:
            return sc.BVULE(*args)

        if type(bv) == operation.BvUgt:
            return sc.BVUGT(*args)

        if type(bv) == operation.BvUge:
            return sc.BVUGE(*args)

        if type(bv) == operation.BvShl:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVLShl(*args)
            else:
                x, r = bv.args
                offset = 0
                while (x.width & (x.width - 1)) != 0:
                    x = operation.ZeroExtend(x, 1)
                    r = operation.ZeroExtend(r, 1)
                    offset += 1

                shift = bv2pysmt(x << r)
                return sc.BVExtract(shift, end=shift.bv_width() - offset - 1)
            # width = args[0].bv_width()
            # assert (width & (width - 1)) == 0  # power of 2
            # return sc.BVLShl(*args)

        if type(bv) == operation.BvLshr:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVLShr(*args)
            else:
                x, r = bv.args
                offset = 0
                while (x.width & (x.width - 1)) != 0:
                    x = operation.ZeroExtend(x, 1)
                    r = operation.ZeroExtend(r, 1)
                    offset += 1

                shift = bv2pysmt(x >> r)
                return sc.BVExtract(shift, end=shift.bv_width() - offset - 1)
            # width = args[1].bv_width()
            # assert (width & (width - 1)) == 0  # power of 2
            # return sc.BVLShr(*args)

        if type(bv) == operation.RotateLeft:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVRol(*args)
            else:
                x, r = bv.args
                n = x.width
                return bv2pysmt(operation.Concat(x[n - r - 1:],
                                                 x[n - 1:n - r]))

        if type(bv) == operation.RotateRight:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVRor(*args)
            else:
                x, r = bv.args
                n = x.width
                return bv2pysmt(operation.Concat(x[r - 1:], x[n - 1:r]))

        if type(bv) == operation.Ite:
            if args[0].is_equals():
                a0 = args[0]
            else:
                a0 = sc.Equals(args[0], bv2pysmt(core.Constant(1, 1)))

            return sc.Ite(a0, *args[1:])

        if type(bv) == operation.Extract:
            return sc.BVExtract(args[0], args[2], args[1])

        if type(bv) == operation.Concat:
            return sc.BVConcat(*args)

        if type(bv) == operation.ZeroExtend:
            return sc.BVZExt(*args)

        if type(bv) == operation.Repeat:
            return args[0].BVRepeat(args[1])

        if type(bv) == operation.BvNeg:
            return sc.BVNeg(*args)

        if type(bv) == operation.BvAdd:
            return sc.BVAdd(*args)

        if type(bv) == operation.BvSub:
            return sc.BVSub(*args)

        if type(bv) == operation.BvMul:
            return sc.BVMul(*args)

        if type(bv) == operation.BvMul:
            return sc.BVMul(*args)

        if type(bv) == operation.BvUdiv:
            return sc.BVUDiv(*args)

        if type(bv) == operation.BvUrem:
            return sc.BVURem(*args)

        raise NotImplementedError(msg)
Exemple #4
0
def bv2pysmt(bv, boolean=False, strict_shift=False, env=None):
    """Convert a bit-vector type to a pySMT type.

    Args:
        bv: the bit-vector `Term` to convert
        boolean: if True, boolean pySMT types (e.g., `pysmt.shortcuts.Bool`) are used instead of
            bit-vector pySMT types (e.g., `pysmt.shortcuts.BV`).
        strict_shift: if `True`, shifts and rotation by non-power-of-two offsets
            are power of two are translated to pySMT's shifts and
            rotation directly.
        env: a `pysmt.environment.Environment`; if not specified, a new pySMT environment is created.
    ::

        >>> from arxpy.bitvector.core import Constant, Variable
        >>> from arxpy.smt.types import bv2pysmt
        >>> s = bv2pysmt(Constant(0b00000001, 8), boolean=False)
        >>> s, s.get_type()
        (1_8, BV{8})
        >>> x, y = Variable("x", 8), Variable("y", 8)
        >>> s = bv2pysmt(x)
        >>> s, s.get_type()
        (x, BV{8})
        >>> s = bv2pysmt(x +  y)
        >>> s, s.get_type()
        ((x + y), BV{8})
        >>> s = bv2pysmt(x <=  y)
        >>> s, s.get_type()
        ((x u<= y), Bool)
        >>> s = bv2pysmt(x[4: 2])
        >>> s, s.get_type()
        (x[2:4], BV{3})

    """
    msg = "unknown conversion of {} to a pySMT type".format(type(bv).__name__)

    if env is None:
        env = environment.reset_env()
    fm = env.formula_manager

    if isinstance(bv, int):
        return bv

    pysmt_bv = None

    if isinstance(bv, core.Variable):
        if boolean:
            assert bv.width == 1
            pysmt_bv = fm.Symbol(bv.name, env.type_manager.BOOL())
        else:
            pysmt_bv = fm.Symbol(bv.name, env.type_manager.BVType(bv.width))

    elif isinstance(bv, core.Constant):
        if boolean:
            assert bv.width == 1
            pysmt_bv = fm.Bool(bool(bv))
        else:
            pysmt_bv = fm.BV(bv.val, bv.width)

    elif isinstance(bv, operation.Operation):
        # only 1st layer can return a boolean
        # Equals and Ite work well with BV, the rest don't

        if issubclass(type(bv), extraop.PartialOperation):
            raise NotImplementedError("PartialOperation is not yet supported")

        if type(bv) == operation.BvNot:
            if boolean:
                assert bv.width == 1
                args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.Not(*args)
            else:
                args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.BVNot(*args)

        elif type(bv) == operation.BvAnd:
            if boolean:
                assert bv.width == 1
                args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.And(*args)
            else:
                args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.BVAnd(*args)

        elif type(bv) == operation.BvOr:
            if boolean:
                assert bv.width == 1
                args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.Or(*args)
            else:
                args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.BVOr(*args)
        elif type(bv) == operation.BvXor:
            if boolean:
                assert bv.width == 1
                args = [bv2pysmt(a, True, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.Xor(*args)
            else:
                args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args]
                pysmt_bv = fm.BVXor(*args)
        elif type(bv) == operation.Ite:
            args = [None for _ in range(len(bv.args))]
            # fm.Ite requires a Boolean type for args[0] but
            # bv2pysmt(bv.args[0], True, ...)  caused an error
            # (if args[0] is BvComp, it can be further optimized)
            args[0] = bv2pysmt(bv.args[0], False, strict_shift, env)
            if args[0].get_type().is_bv_type():
                args[0] = fm.Equals(args[0], fm.BV(1, 1))
            if boolean:
                assert bv.width == 1
                args[1:] = [
                    bv2pysmt(a, True, strict_shift, env) for a in bv.args[1:]
                ]
            else:
                args[1:] = [
                    bv2pysmt(a, False, strict_shift, env) for a in bv.args[1:]
                ]
            pysmt_bv = fm.Ite(*args)
        else:
            args = [bv2pysmt(a, False, strict_shift, env) for a in bv.args]

            if type(bv) == operation.BvComp:
                if boolean:
                    pysmt_bv = fm.Equals(*args)
                else:
                    pysmt_bv = fm.BVComp(*args)

            elif type(bv) == operation.BvUlt:
                pysmt_bv = fm.BVULT(*args)

            elif type(bv) == operation.BvUle:
                pysmt_bv = fm.BVULE(*args)

            elif type(bv) == operation.BvUgt:
                pysmt_bv = fm.BVUGT(*args)

            elif type(bv) == operation.BvUge:
                pysmt_bv = fm.BVUGE(*args)

            elif boolean:
                raise ValueError("{} cannot return a boolean type".format(
                    type(bv).__name__))

            elif type(bv) in [operation.BvShl, operation.BvLshr]:
                if not strict_shift or _is_power_of_2(args[0].bv_width()):
                    if type(bv) == operation.BvShl:
                        pysmt_bv = fm.BVLShl(*args)
                    elif type(bv) == operation.BvLshr:
                        pysmt_bv = fm.BVLShr(*args)
                else:
                    x, r = bv.args
                    offset = 0
                    while not _is_power_of_2(x.width):
                        x = operation.ZeroExtend(x, 1)
                        r = operation.ZeroExtend(r, 1)
                        offset += 1

                    shift = bv2pysmt(type(bv)(x, r), False, strict_shift, env)
                    pysmt_bv = fm.BVExtract(shift,
                                            end=shift.bv_width() - offset - 1)

            elif type(bv) == operation.RotateLeft:
                if not strict_shift or _is_power_of_2(args[0].bv_width()):
                    pysmt_bv = fm.BVRol(*args)
                else:
                    # Left hand side width must be a power of 2
                    x, r = bv.args
                    n = x.width
                    pysmt_bv = bv2pysmt(
                        operation.Concat(x[n - r - 1:], x[n - 1:n - r]), False,
                        strict_shift, env)

            elif type(bv) == operation.RotateRight:
                if not strict_shift or _is_power_of_2(args[0].bv_width()):
                    pysmt_bv = fm.BVRor(*args)
                else:
                    # Left hand side width must be a power of 2
                    x, r = bv.args
                    n = x.width
                    pysmt_bv = bv2pysmt(
                        operation.Concat(x[r - 1:], x[n - 1:r]), False,
                        strict_shift, env)

            elif type(bv) == operation.Extract:
                # pySMT Extract(bv, start, end)
                pysmt_bv = fm.BVExtract(args[0], args[2], args[1])

            elif type(bv) == operation.Concat:
                pysmt_bv = fm.BVConcat(*args)

            elif type(bv) == operation.ZeroExtend:
                pysmt_bv = fm.BVZExt(*args)

            elif type(bv) == operation.Repeat:
                pysmt_bv = args[0].BVRepeat(args[1])

            elif type(bv) == operation.BvNeg:
                pysmt_bv = fm.BVNeg(*args)

            elif type(bv) == operation.BvAdd:
                pysmt_bv = fm.BVAdd(*args)

            elif type(bv) == operation.BvSub:
                pysmt_bv = fm.BVSub(*args)

            elif type(bv) == operation.BvMul:
                pysmt_bv = fm.BVMul(*args)

            elif type(bv) == operation.BvUdiv:
                pysmt_bv = fm.BVUDiv(*args)

            elif type(bv) == operation.BvUrem:
                pysmt_bv = fm.BVURem(*args)

            else:
                bv2 = bv.doit()
                assert bv.width == bv2.width, "{} == {}\n{}\n{}".format(
                    bv.width, bv2.width, bv.vrepr(), bv2.vrepr())
                if bv != bv2:  # avoid cyclic loop
                    pysmt_bv = bv2pysmt(bv2,
                                        boolean=boolean,
                                        strict_shift=strict_shift,
                                        env=env)
                else:
                    raise NotImplementedError("(doit) " + msg)

    elif isinstance(bv, difference.Difference) or isinstance(bv, mask.Mask):
        pysmt_bv = bv2pysmt(bv.val, boolean, strict_shift, env)

    if pysmt_bv is not None:
        try:
            pysmt_bv_width = pysmt_bv.bv_width()
        except (AssertionError, TypeError):
            pysmt_bv_width = 1  # boolean type

        assert bv.width == pysmt_bv_width
        return pysmt_bv
    else:
        raise NotImplementedError(msg)
Exemple #5
0
    def _weight(self, output_diff, prefix=None, debug=False, version=2):
        u = self.input_diff[0].val
        v = output_diff.val
        a = self.op.constant
        n = a.width
        one = core.Constant(1, n)

        assert self._effective_width == n - 1

        assert version in [0, 1, 2]  # 0-reference, 1-w/o extra reverse, 2-s_000 and no HW2 in fr

        if prefix is None:
            prefix = "tmp" + str(abs(hash(u) + hash(v) + hash(a)))

        if isinstance(u, core.Constant) and isinstance(v, core.Constant):
            are_cte_differences = True
        else:
            self._i_auxvar = 0
            assertions = []
            are_cte_differences = False

        def rev(x):
            if are_cte_differences:
                return extraop.Reverse(x)
            else:
                aux = core.Variable("{}_{}rev".format(prefix, self._i_auxvar), x.width)
                self._i_auxvar += 1
                assertions.append(operation.BvComp(aux, extraop.Reverse(x)))
                return aux

        def lz(x):
            if are_cte_differences:
                return extraop.LeadingZeros(x)
            else:
                aux = core.Variable("{}_{}lz".format(prefix, self._i_auxvar), x.width)
                self._i_auxvar += 1
                assertions.append(operation.BvComp(aux, extraop.LeadingZeros(x)))
                return aux

        def carry(x, y):
            return (x + y) ^ x ^ y

        def rev_carry(x, y):
            return rev(carry(rev(x), rev(y)))

        if version in [0, 1]:
            s00_old = (~(u << one)) & (~(v << one))  # i-bit is True if S_{i} = 00*
        else:
            s00_old = ((~u) & (~v)) << one
        s00_ = s00_old & (~lz(~s00_old))  # if x is 001*...*, then lz(x) = 1100...0

        if version == 0:
            e_i1 = s00_ & (~ (s00_ >> one))  # e_{i-1}
            e_ili = ~s00_ & (s00_ >> one)  # e_{i-l_i}
        else:
            e_i1 = s00_old & (~ (s00_old >> one))  # e_{i-1}
            e_ili = ~s00_old & (s00_old >> one)  # e_{i-l_i}

        q = ~( (a << one) ^ (u ^ v) )  # q[i] = ~(a[i-1]^u[i]^v[i])
        q = ((q >> one ) & e_i1)  # q[i-1, i-3] = (a[i-1]^u[i]^v[i], 0, 0)

        if version == 0:
            s = ((a << one) & e_ili) + (a & (s00_ >> one))
        else:
            s = ((a << one) & e_ili) + (a & (s00_old >> one))

        if version == 0:
            d = rev_carry(s00_, q) | q
        else:
            rev_s00_old = rev(s00_old)
            d = rev(carry(rev_s00_old, rev(q))) | q

        w = (q - (s & d)) | (s & (~d))

        if version == 0:
            w = w << one
            h = rev_carry(s00_ << one, w & (s00_ << one))
        elif version == 1:
            rev_w = rev(w) >> one
            rev_h = carry( (rev_s00_old + one) >> one, rev_w & (rev(s00_)) >> one)
        else:
            rev_w = rev(w)
            rev_h = carry(rev_s00_old + one, rev_w & rev_s00_old)

        sbnegb = (u ^ v) << one  # i-bit is True if S_{i} = (b, \neg b, *)

        if version == 0:
            int = extraop.PopCountDiff(sbnegb | s00_, h)   # or hw(sbminb_) + (hw(s00_) - hw(h))
        else:
            int = extraop.PopCountDiff(sbnegb | s00_, rev_h)

        def smart_add(x, y):
            if x.width == y.width:
                return x + y
            elif x.width < y.width:
                return operation.ZeroExtend(x, y.width - x.width) + y
            else:
                return x + operation.ZeroExtend(y, x.width - y.width)

        def smart_sub(x, y):
            # cannot be replaced by smart_add(x, -y)
            if x.width == y.width:
                return x - y
            elif x.width < y.width:
                return operation.ZeroExtend(x, y.width - x.width) - y
            else:
                return x - operation.ZeroExtend(y, x.width - y.width)

        k = self._effective_precision

        if k == 0:
            int_frac = int
        elif k == 1:
            int = operation.Concat(int, core.Constant(0, 1))
            if version == 0:
                f1 = extraop.PopCount(w & h & (~(h >> one)))  # each one adds 2^(-1)
            else:
                f1 = extraop.PopCount(rev_w & rev_h & (~(rev_h << one)))
            int_frac = smart_sub(int, f1)
        else:
            two = core.Constant(2, n)
            three = core.Constant(3, n)
            four = core.Constant(4, n)

            if version == 0:
                f12 = extraop.PopCountSum2(
                    w & h & (~(h >> one)),
                    w & h & ((~(h >> one)) | (~(h >> two)) & (h >> one))
                )  # each one adds 2^(-2), that's why ~(h >> one) need to be counted twice
            elif version == 1:
                f12 = extraop.PopCountSum2(
                    rev_w & rev_h & (~(rev_h << one)),
                    rev_w & rev_h & ((~(rev_h << one)) | (~(rev_h << two)) & (rev_h << one))
                )
            else:
                f12 = extraop.PopCount(
                    # ( ( rev_w & rev_h & (~(rev_h << one)) ) >> one ) |
                    ( ( (rev_w & rev_h) >> one) & (~rev_h)  ) |
                    (rev_w & rev_h & ((~(rev_h << one)) | (~(rev_h << two)) & (rev_h << one)))
                )

            if k == 2:
                int = operation.Concat(int, core.Constant(0, 2))
                int_frac = smart_sub(int, f12)
            elif k == 3:
                # f3 cannot be included in f12, since ~(h >> one) would need to be counted 4 times
                if version == 0:
                    f3 = extraop.PopCount(w & h & (h >> one) & (h >> two) & (~(h >> three)))
                else:
                    f3 = extraop.PopCount(rev_w & rev_h & (rev_h << one) & (rev_h << two) & (~(rev_h << three)))
                int = operation.Concat(int, core.Constant(0, 3))
                f12 = operation.Concat(f12, core.Constant(0, 1))
                int_frac = smart_sub(int, smart_add(f12, f3))
            elif k == 4:
                if version == 0:
                    f34 = extraop.PopCountSum2(
                        w & h & (h >> one) & (h >> two) & (~(h >> three)),
                        w & h & (h >> one) & (h >> two) & ((~(h >> three)) | (~(h >> four) & (h >> three)))
                    )
                elif version == 1:
                    f34 = extraop.PopCountSum2(
                        rev_w & rev_h & (rev_h << one) & (rev_h << two) & (~(rev_h << three)),
                        rev_w & rev_h & (rev_h << one) & (rev_h << two) & ((~(rev_h << three)) | (~(rev_h << four) & (rev_h << three)))
                    )
                else:
                    f34 = extraop.PopCount(
                        # ( (rev_w & rev_h & (rev_h << one) & (rev_h << two) & (~(rev_h << three))) >> one ) |
                        ( ((rev_w & rev_h) >> one) & rev_h & (rev_h << one) & (~(rev_h << two))) |
                        (rev_w & rev_h & (rev_h << one) & (rev_h << two) & ((~(rev_h << three)) | (~(rev_h << four) & (rev_h << three))))
                    )
                int = operation.Concat(int, core.Constant(0, 4))
                f12 = operation.Concat(f12, core.Constant(0, 2))
                int_frac = smart_sub(int, smart_add(f12, f34))
            else:
                raise ValueError("precision must be between 0 and 4")

        if debug:
            print("\n\n ~~ ")
            print("u:            ", u.bin())
            print("v:            ", v.bin())
            print("a:            ", a.bin())
            print("s00_:         ", s00_.bin())
            print("e_i1:         ", e_i1.bin())
            print("e_ili1:       ", e_ili.bin())
            print("q:            ", q.bin())
            print("s:            ", s.bin())
            print("d:            ", d.bin())
            print("w:            ", w.bin())
            if version == 0:
                print("h:            ", h.bin())
            else:
                print("rev_w:        ", rev_w.bin())
                print("rev_h:        ", rev_h.bin())
            print("sbnegb:       ", sbnegb.bin())
            print("int:          ", int.bin())
            if k == 1:
                print("f1:           ", f1.bin())
            elif k > 1:
                print("f12:          ", f12.bin())
                if k == 3:
                    print("f3:           ", f3.bin())
                elif k == 4:
                    print("f34:          ", f34.bin())
            print("int_frac:     ", int_frac.bin())

        if are_cte_differences:
            return int_frac
        else:
            return int_frac, assertions