Exemple #1
0
    def test_pysmt_operations(self, width, x, y):
        try:
            from pysmt import shortcuts as sc
        except ImportError:
            return

        modulus = 2 ** width
        x = x % modulus
        y = y % modulus
        bvx = Constant(x, width)
        bvy = Constant(y, width)
        psx = sc.BV(x, width)
        psy = sc.BV(y, width)

        def eval_pysmt(pysmt_var):
            return pysmt_var.simplify().constant_value()

        self.assertEqual(~bvx, eval_pysmt(sc.BVNot(psx)))
        self.assertEqual(bvx & bvy, eval_pysmt(sc.BVAnd(psx, psy)))
        self.assertEqual(bvx | bvy, eval_pysmt(sc.BVOr(psx, psy)))
        self.assertEqual(bvx ^ bvy, eval_pysmt(sc.BVXor(psx, psy)))

        self.assertEqual(BvComp(bvx, bvy), eval_pysmt(sc.BVComp(psx, psy)))
        self.assertEqual((bvx < bvy), eval_pysmt(sc.BVULT(psx, psy)))
        self.assertEqual((bvx <= bvy), eval_pysmt(sc.BVULE(psx, psy)))
        self.assertEqual((bvx > bvy), eval_pysmt(sc.BVUGT(psx, psy)))
        self.assertEqual((bvx >= bvy), eval_pysmt(sc.BVUGE(psx, psy)))

        r = y % bvx.width
        self.assertEqual(bvx << bvy, eval_pysmt(sc.BVLShl(psx, psy)))
        self.assertEqual(bvx >> bvy, eval_pysmt(sc.BVLShr(psx, psy)))
        self.assertEqual(RotateLeft(bvx, r), eval_pysmt(sc.BVRol(psx, r)))
        self.assertEqual(RotateRight(bvx, r), eval_pysmt(sc.BVRor(psx, r)))

        bvb = Constant(y % 2, 1)
        psb = sc.Bool(bool(bvb))
        self.assertEqual(Ite(bvb, bvx, bvy), eval_pysmt(sc.Ite(psb, psx, psy)))
        j = y % bvx.width
        self.assertEqual(bvx[:j], eval_pysmt(sc.BVExtract(psx, start=j)))
        self.assertEqual(bvx[j:], eval_pysmt(sc.BVExtract(psx, end=j)))
        self.assertEqual(Concat(bvx, bvy), eval_pysmt(sc.BVConcat(psx, psy)))
        self.assertEqual(ZeroExtend(bvx, j), eval_pysmt(sc.BVZExt(psx, j)))
        self.assertEqual(Repeat(bvx, 1 + j), eval_pysmt(psx.BVRepeat(1 + j)))

        self.assertEqual(-bvx, eval_pysmt(sc.BVNeg(psx)))
        self.assertEqual(bvx + bvy, eval_pysmt(sc.BVAdd(psx, psy)))
        self.assertEqual(bvx - bvy, eval_pysmt(sc.BVSub(psx, psy)))
        self.assertEqual(bvx * bvy, eval_pysmt(sc.BVMul(psx, psy)))
        if bvy > 0:
            self.assertEqual(bvx / bvy, eval_pysmt(sc.BVUDiv(psx, psy)))
            self.assertEqual(bvx % bvy, eval_pysmt(sc.BVURem(psx, psy)))
Exemple #2
0
    def __getitem__(self, index):
        size = self.size
        if isinstance(index, slice):
            start, stop, step = index.start, index.stop, index.step

            if start is None:
                start = 0
            elif start < 0:
                start = size + start

            if stop is None:
                stop = size
            elif stop < 0:
                stop = size + stop

            stop = min(stop, size)

            if step is None:
                step = 1
            elif step != 1:
                raise IndexError('SMT extract does not support step != 1')

            v = self.value[start:stop - 1]
            return type(self).unsized_t[v.get_type().width](v)
        elif isinstance(index, int):
            if index < 0:
                index = size + index

            if not (0 <= index < size):
                raise IndexError()

            v = self.value[index]
            return self.get_family().Bit(smt.Equals(v, smt.BV(1, 1)))
        else:
            raise TypeError()
Exemple #3
0
    def __init__(self, value=SMYBOLIC, *, name=AUTOMATIC):
        if name is not AUTOMATIC and value is not SMYBOLIC:
            raise TypeError('Can only name symbolic variables')
        elif name is not AUTOMATIC:
            if not isinstance(name, str):
                raise TypeError('Name must be string')
            elif name in _name_table:
                raise ValueError(f'Name {name} already in use')
            _name_table[name] = self

        T = BVType(self.size)

        if value is SMYBOLIC:
            if name is AUTOMATIC:
                value = shortcuts.FreshSymbol(T)
            else:
                value = shortcuts.Symbol(name, T)
        elif isinstance(value, pysmt.fnode.FNode):
            t = value.get_type()
            if t is not T:
                raise TypeError(f'Expected {T} not {t}')
        elif isinstance(value, type(self)):
            value = value._value
        elif isinstance(value, int):
            value = shortcuts.BV(value, self.size)
        else:
            raise TypeError(f"Can't coerce {value} to SMTFPVector")

        self._name = name
        self._value = value
Exemple #4
0
def test_poly_smt():
    S = SMTSIntVector[8]
    U = SMTUIntVector[8]

    c1 = SMTBit(name='c1')
    u1 = U(name='u1')
    u2 = U(name='u2')
    s1 = S(name='s1')
    s2 = S(name='s2')

    # NOTE: __eq__ on pysmt terms is strict structural equivalence
    # for example:
    assert u1.value == u1.value  # .value extract pysmt term
    assert u1.value != u2.value
    assert (u1 * 2).value != (u1 + u1).value
    assert (u1 + u2).value == (u1 + u2).value
    assert (u1 + u2).value != (u2 + u1).value

    # On to the real test
    expr = c1.ite(u1, s1) < 1
    # get the pysmt values
    _c1, _u1, _s1 = c1.value, u1.value, s1.value
    e1 = sc.Ite(_c1, _u1, _s1)
    one = sc.BV(1, 8)
    # Here we see that `< 1` dispatches symbolically
    f = sc.Ite(_c1, sc.BVULT(e1, one), sc.BVSLT(e1, one))
    assert expr.value == f

    expr = expr.ite(c1.ite(u1, s1), c1.ite(s2, u2)).ext(1)

    e2 = sc.Ite(_c1, s2.value, u2.value)
    e3 = sc.Ite(f, e1, e2)

    se = sc.BVSExt(e3, 1)
    ze = sc.BVZExt(e3, 1)


    g = sc.Ite(
        f,
        sc.Ite(_c1, ze, se),
        sc.Ite(_c1, se, ze)
     )
    # Here we see that ext dispatches symbolically / recursively
    assert expr.value == g


    # Here we see that polymorphic types only build muxes if they need to
    expr = c1.ite(u1, s1) + 1
    assert expr.value == sc.BVAdd(e1, one)
    # Note how it is not:
    assert expr.value != sc.Ite(_c1, sc.BVAdd(e1, one), sc.BVAdd(e1, one))
Exemple #5
0
    def test_bv2pysmt(self):
        bvx, bvy = Variable("x", 8), Variable("y", 8)
        psx, psy = bv2pysmt(bvx), bv2pysmt(bvy)

        self.assertEqual(bv2pysmt(Constant(0, 8)), sc.BV(0, 8))
        self.assertEqual(psx, sc.Symbol("x", typing.BVType(8)))

        self.assertEqual(bv2pysmt(~bvx), sc.BVNot(psx))
        self.assertEqual(bv2pysmt(bvx & bvy), sc.BVAnd(psx, psy))
        self.assertEqual(bv2pysmt(bvx | bvy), sc.BVOr(psx, psy))
        self.assertEqual(bv2pysmt(bvx ^ bvy), sc.BVXor(psx, psy))

        self.assertEqual(bv2pysmt(BvComp(bvx, bvy)), sc.Equals(psx, psy))
        self.assertEqual(bv2pysmt(BvNot(BvComp(bvx, bvy))),
                         sc.Not(sc.Equals(psx, psy)))

        self.assertEqual(bv2pysmt(bvx < bvy), sc.BVULT(psx, psy))
        self.assertEqual(bv2pysmt(bvx <= bvy), sc.BVULE(psx, psy))
        self.assertEqual(bv2pysmt(bvx > bvy), sc.BVUGT(psx, psy))
        self.assertEqual(bv2pysmt(bvx >= bvy), sc.BVUGE(psx, psy))

        self.assertEqual(bv2pysmt(bvx << bvy), sc.BVLShl(psx, psy))
        self.assertEqual(bv2pysmt(bvx >> bvy), sc.BVLShr(psx, psy))
        self.assertEqual(bv2pysmt(RotateLeft(bvx, 1)), sc.BVRol(psx, 1))
        self.assertEqual(bv2pysmt(RotateRight(bvx, 1)), sc.BVRor(psx, 1))

        self.assertEqual(bv2pysmt(bvx[4:2]), sc.BVExtract(psx, 2, 4))
        self.assertEqual(bv2pysmt(Concat(bvx, bvy)), sc.BVConcat(psx, psy))
        # zeroextend reduces to Concat
        # self.assertEqual(bv2pysmt(ZeroExtend(bvx, 2)), sc.BVZExt(psx, 2))
        self.assertEqual(bv2pysmt(Repeat(bvx, 2)), psx.BVRepeat(2))

        self.assertEqual(bv2pysmt(-bvx), sc.BVNeg(psx))
        self.assertEqual(bv2pysmt(bvx + bvy), sc.BVAdd(psx, psy))
        # bvsum reduces to add
        # self.assertEqual(bv2pysmt(bvx - bvy), sc.BVSub(psx, psy))
        self.assertEqual(bv2pysmt(bvx * bvy), sc.BVMul(psx, psy))
        self.assertEqual(bv2pysmt(bvx / bvy), sc.BVUDiv(psx, psy))
        self.assertEqual(bv2pysmt(bvx % bvy), sc.BVURem(psx, psy))
Exemple #6
0
 def __encodeTerminal(symbolicExpression, type):
     if isinstance(symbolicExpression, sympy.Symbol):
         if type.literal == 'Integer':
             return pysmt.Symbol(symbolicExpression.name, pysmt.INT)
         elif type.literal == 'Real':
             return pysmt.Symbol(symbolicExpression.name, pysmt.REAL)
         elif type.literal == 'Bool':
             return pysmt.Symbol(symbolicExpression.name, pysmt.BOOL)
         else:  # type.literal == 'BitVector'
             return pysmt.Symbol(symbolicExpression.name, pysmt.BVType(type.size))
     else:
         if type.literal == 'Integer':
             return pysmt.Int(symbolicExpression.p)
         elif type.literal == 'Real':
             if isinstance(symbolicExpression, sympy.Rational):
                 return pysmt.Real(symbolicExpression.p / symbolicExpression.q)
             else:  # isinstance(symbolicExpression, sympy.Float)
                 return pysmt.Real(symbolicExpression)
         elif type.literal == 'Bool':
             return pysmt.Bool(symbolicExpression)
         else:  # type.literal == 'BitVector'
             return pysmt.BV(symbolicExpression, type.size)
Exemple #7
0
    def __init__(self, value=SMYBOLIC, *, name=AUTOMATIC, prefix=AUTOMATIC):
        if (name is not AUTOMATIC
                or prefix is not AUTOMATIC) and value is not SMYBOLIC:
            raise TypeError('Can only name symbolic variables')
        elif name is not AUTOMATIC and prefix is not AUTOMATIC:
            raise ValueError('Can only set either name or prefix not both')
        elif name is not AUTOMATIC:
            if not isinstance(name, str):
                raise TypeError('Name must be string')
            elif name in _name_table:
                raise ValueError(f'Name {name} already in use')
            elif _name_re.fullmatch(name):
                warnings.warn(
                    'Name looks like an auto generated name, this might break things'
                )
            _name_table[name] = self
        elif prefix is not AUTOMATIC:
            name = _gen_name(prefix)
            _name_table[name] = self
        elif name is AUTOMATIC and value is SMYBOLIC:
            name = _gen_name()
            _name_table[name] = self

        self._name = name

        T = BVType(self.size)

        if value is SMYBOLIC:
            self._value = smt.Symbol(name, T)
        elif isinstance(value, pysmt.fnode.FNode):
            t = value.get_type()
            if t is T:
                self._value = value
            else:
                raise TypeError(f'Expected {T} not {t}')
        elif isinstance(value, SMTBitVector):
            if name is not AUTOMATIC and name != value.name:
                warnings.warn(
                    'Changing the name of a SMTBitVector does not cause a new underlying smt variable to be created'
                )

            ext = self.size - value.size

            if ext < 0:
                warnings.warn('Truncating value from {} to {}'.format(
                    type(value), type(self)))
                self._value = value[:self.size].value
            elif ext > 0:
                self._value = value.zext(ext).value
            else:
                self._value = value.value

        elif isinstance(value, SMTBit):
            self._value = smt.Ite(value.value, smt.BVOne(self.size),
                                  smt.BVZero(self.size))

        elif isinstance(value, tp.Sequence):
            if len(value) != self.size:
                raise ValueError('Iterable is not the correct size')
            cls = type(self)
            B1 = cls.unsized_t[1]
            self._value = ft.reduce(lambda acc, elem: acc.concat(elem),
                                    map(B1, value)).value
        elif isinstance(value, int):
            self._value = smt.BV(value % (1 << self.size), self.size)

        elif hasattr(value, '__int__'):
            value = int(value)
            self._value = smt.BV(value, self.size)
        else:
            raise TypeError("Can't coerce {} to SMTBitVector".format(
                type(value)))

        self._value = smt.simplify(self._value)
        assert self._value.get_type() is T
Exemple #8
0
def bv2pysmt(bv):
    """Convert a bit-vector type to a pySMT type.

        >>> from arxpy.bitvector.core import Constant, Variable
        >>> from arxpy.diffcrypt.smt import bv2pysmt
        >>> bv2pysmt(Constant(0b00000001, 8))
        1_8
        >>> x, y = Variable("x", 8), Variable("y", 8)
        >>> bv2pysmt(x)
        x
        >>> bv2pysmt(x +  y)
        (x + y)
        >>> bv2pysmt(x <=  y)
        (x u<= y)
        >>> bv2pysmt(x[4: 2])
        x[2:4]

    """
    msg = "unknown conversion of {} to a pySMT type".format(type(bv).__name__)

    if isinstance(bv, int):
        return bv

    if isinstance(bv, core.Variable):
        return sc.Symbol(bv.name, typing.BVType(bv.width))

    if isinstance(bv, core.Constant):
        return sc.BV(bv.val, bv.width)

    if isinstance(bv, operation.Operation):
        args = [bv2pysmt(a) for a in bv.args]

        if type(bv) == operation.BvNot:
            if args[0].is_equals():
                return sc.Not(*args)
            else:
                return sc.BVNot(*args)

        if type(bv) == operation.BvAnd:
            return sc.BVAnd(*args)

        if type(bv) == operation.BvOr:
            return sc.BVOr(*args)

        if type(bv) == operation.BvXor:
            return sc.BVXor(*args)

        if type(bv) == operation.BvComp:
            # return sc.BVComp(*args)
            return sc.Equals(*args)

        if type(bv) == operation.BvUlt:
            return sc.BVULT(*args)

        if type(bv) == operation.BvUle:
            return sc.BVULE(*args)

        if type(bv) == operation.BvUgt:
            return sc.BVUGT(*args)

        if type(bv) == operation.BvUge:
            return sc.BVUGE(*args)

        if type(bv) == operation.BvShl:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVLShl(*args)
            else:
                x, r = bv.args
                offset = 0
                while (x.width & (x.width - 1)) != 0:
                    x = operation.ZeroExtend(x, 1)
                    r = operation.ZeroExtend(r, 1)
                    offset += 1

                shift = bv2pysmt(x << r)
                return sc.BVExtract(shift, end=shift.bv_width() - offset - 1)
            # width = args[0].bv_width()
            # assert (width & (width - 1)) == 0  # power of 2
            # return sc.BVLShl(*args)

        if type(bv) == operation.BvLshr:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVLShr(*args)
            else:
                x, r = bv.args
                offset = 0
                while (x.width & (x.width - 1)) != 0:
                    x = operation.ZeroExtend(x, 1)
                    r = operation.ZeroExtend(r, 1)
                    offset += 1

                shift = bv2pysmt(x >> r)
                return sc.BVExtract(shift, end=shift.bv_width() - offset - 1)
            # width = args[1].bv_width()
            # assert (width & (width - 1)) == 0  # power of 2
            # return sc.BVLShr(*args)

        if type(bv) == operation.RotateLeft:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVRol(*args)
            else:
                x, r = bv.args
                n = x.width
                return bv2pysmt(operation.Concat(x[n - r - 1:],
                                                 x[n - 1:n - r]))

        if type(bv) == operation.RotateRight:
            # Left hand side width must be a power of 2
            if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0:
                return sc.BVRor(*args)
            else:
                x, r = bv.args
                n = x.width
                return bv2pysmt(operation.Concat(x[r - 1:], x[n - 1:r]))

        if type(bv) == operation.Ite:
            if args[0].is_equals():
                a0 = args[0]
            else:
                a0 = sc.Equals(args[0], bv2pysmt(core.Constant(1, 1)))

            return sc.Ite(a0, *args[1:])

        if type(bv) == operation.Extract:
            return sc.BVExtract(args[0], args[2], args[1])

        if type(bv) == operation.Concat:
            return sc.BVConcat(*args)

        if type(bv) == operation.ZeroExtend:
            return sc.BVZExt(*args)

        if type(bv) == operation.Repeat:
            return args[0].BVRepeat(args[1])

        if type(bv) == operation.BvNeg:
            return sc.BVNeg(*args)

        if type(bv) == operation.BvAdd:
            return sc.BVAdd(*args)

        if type(bv) == operation.BvSub:
            return sc.BVSub(*args)

        if type(bv) == operation.BvMul:
            return sc.BVMul(*args)

        if type(bv) == operation.BvMul:
            return sc.BVMul(*args)

        if type(bv) == operation.BvUdiv:
            return sc.BVUDiv(*args)

        if type(bv) == operation.BvUrem:
            return sc.BVURem(*args)

        raise NotImplementedError(msg)