def test_pysmt_operations(self, width, x, y): try: from pysmt import shortcuts as sc except ImportError: return modulus = 2 ** width x = x % modulus y = y % modulus bvx = Constant(x, width) bvy = Constant(y, width) psx = sc.BV(x, width) psy = sc.BV(y, width) def eval_pysmt(pysmt_var): return pysmt_var.simplify().constant_value() self.assertEqual(~bvx, eval_pysmt(sc.BVNot(psx))) self.assertEqual(bvx & bvy, eval_pysmt(sc.BVAnd(psx, psy))) self.assertEqual(bvx | bvy, eval_pysmt(sc.BVOr(psx, psy))) self.assertEqual(bvx ^ bvy, eval_pysmt(sc.BVXor(psx, psy))) self.assertEqual(BvComp(bvx, bvy), eval_pysmt(sc.BVComp(psx, psy))) self.assertEqual((bvx < bvy), eval_pysmt(sc.BVULT(psx, psy))) self.assertEqual((bvx <= bvy), eval_pysmt(sc.BVULE(psx, psy))) self.assertEqual((bvx > bvy), eval_pysmt(sc.BVUGT(psx, psy))) self.assertEqual((bvx >= bvy), eval_pysmt(sc.BVUGE(psx, psy))) r = y % bvx.width self.assertEqual(bvx << bvy, eval_pysmt(sc.BVLShl(psx, psy))) self.assertEqual(bvx >> bvy, eval_pysmt(sc.BVLShr(psx, psy))) self.assertEqual(RotateLeft(bvx, r), eval_pysmt(sc.BVRol(psx, r))) self.assertEqual(RotateRight(bvx, r), eval_pysmt(sc.BVRor(psx, r))) bvb = Constant(y % 2, 1) psb = sc.Bool(bool(bvb)) self.assertEqual(Ite(bvb, bvx, bvy), eval_pysmt(sc.Ite(psb, psx, psy))) j = y % bvx.width self.assertEqual(bvx[:j], eval_pysmt(sc.BVExtract(psx, start=j))) self.assertEqual(bvx[j:], eval_pysmt(sc.BVExtract(psx, end=j))) self.assertEqual(Concat(bvx, bvy), eval_pysmt(sc.BVConcat(psx, psy))) self.assertEqual(ZeroExtend(bvx, j), eval_pysmt(sc.BVZExt(psx, j))) self.assertEqual(Repeat(bvx, 1 + j), eval_pysmt(psx.BVRepeat(1 + j))) self.assertEqual(-bvx, eval_pysmt(sc.BVNeg(psx))) self.assertEqual(bvx + bvy, eval_pysmt(sc.BVAdd(psx, psy))) self.assertEqual(bvx - bvy, eval_pysmt(sc.BVSub(psx, psy))) self.assertEqual(bvx * bvy, eval_pysmt(sc.BVMul(psx, psy))) if bvy > 0: self.assertEqual(bvx / bvy, eval_pysmt(sc.BVUDiv(psx, psy))) self.assertEqual(bvx % bvy, eval_pysmt(sc.BVURem(psx, psy)))
def test_bv2pysmt(self): bvx, bvy = Variable("x", 8), Variable("y", 8) psx, psy = bv2pysmt(bvx), bv2pysmt(bvy) self.assertEqual(bv2pysmt(Constant(0, 8)), sc.BV(0, 8)) self.assertEqual(psx, sc.Symbol("x", typing.BVType(8))) self.assertEqual(bv2pysmt(~bvx), sc.BVNot(psx)) self.assertEqual(bv2pysmt(bvx & bvy), sc.BVAnd(psx, psy)) self.assertEqual(bv2pysmt(bvx | bvy), sc.BVOr(psx, psy)) self.assertEqual(bv2pysmt(bvx ^ bvy), sc.BVXor(psx, psy)) self.assertEqual(bv2pysmt(BvComp(bvx, bvy)), sc.Equals(psx, psy)) self.assertEqual(bv2pysmt(BvNot(BvComp(bvx, bvy))), sc.Not(sc.Equals(psx, psy))) self.assertEqual(bv2pysmt(bvx < bvy), sc.BVULT(psx, psy)) self.assertEqual(bv2pysmt(bvx <= bvy), sc.BVULE(psx, psy)) self.assertEqual(bv2pysmt(bvx > bvy), sc.BVUGT(psx, psy)) self.assertEqual(bv2pysmt(bvx >= bvy), sc.BVUGE(psx, psy)) self.assertEqual(bv2pysmt(bvx << bvy), sc.BVLShl(psx, psy)) self.assertEqual(bv2pysmt(bvx >> bvy), sc.BVLShr(psx, psy)) self.assertEqual(bv2pysmt(RotateLeft(bvx, 1)), sc.BVRol(psx, 1)) self.assertEqual(bv2pysmt(RotateRight(bvx, 1)), sc.BVRor(psx, 1)) self.assertEqual(bv2pysmt(bvx[4:2]), sc.BVExtract(psx, 2, 4)) self.assertEqual(bv2pysmt(Concat(bvx, bvy)), sc.BVConcat(psx, psy)) # zeroextend reduces to Concat # self.assertEqual(bv2pysmt(ZeroExtend(bvx, 2)), sc.BVZExt(psx, 2)) self.assertEqual(bv2pysmt(Repeat(bvx, 2)), psx.BVRepeat(2)) self.assertEqual(bv2pysmt(-bvx), sc.BVNeg(psx)) self.assertEqual(bv2pysmt(bvx + bvy), sc.BVAdd(psx, psy)) # bvsum reduces to add # self.assertEqual(bv2pysmt(bvx - bvy), sc.BVSub(psx, psy)) self.assertEqual(bv2pysmt(bvx * bvy), sc.BVMul(psx, psy)) self.assertEqual(bv2pysmt(bvx / bvy), sc.BVUDiv(psx, psy)) self.assertEqual(bv2pysmt(bvx % bvy), sc.BVURem(psx, psy))
def bvxnor(self, other): return type(self)(smt.BVNot(smt.BVXor(self.value, other.value)))
def bvnand(self, other): return type(self)(smt.BVNot(smt.BVAnd(self.value, other.value)))
def bvnot(self): return type(self)(smt.BVNot(self.value))
def __getExpressionTree(symbolicExpression): # TODO LATER: take into account bitwise shift operations args = [] castType = None if len(symbolicExpression.args) > 0: for symbolicArg in symbolicExpression.args: arg, type = Solver.__getExpressionTree(symbolicArg) args.append(arg) if castType is None: castType = type else: if castType.literal == 'Integer': if type.literal == 'Real': castType = type # TODO LATER: consider other possible castings if castType.literal == 'Real': for i in range(len(args)): args[i] = pysmt.ToReal(args[i]) if isinstance(symbolicExpression, sympy.Not): if castType.literal == 'Integer': return pysmt.Equals(args[0], pysmt.Int(0)), Type('Bool') elif castType.literal == 'Real': return pysmt.Equals(args[0], pysmt.Real(0)), Type('Bool') elif castType.literal == 'Bool': return pysmt.Not(args[0]), Type('Bool') else: # castType.literal == 'BitVector' return pysmt.BVNot(args[0]), Type('BitVector') elif isinstance(symbolicExpression, sympy.Lt): return pysmt.LT(args[0], args[1]), Type('Bool') elif isinstance(symbolicExpression, sympy.Gt): return pysmt.GT(args[0], args[1]), Type('Bool') elif isinstance(symbolicExpression, sympy.Ge): return pysmt.GE(args[0], args[1]), Type('Bool') elif isinstance(symbolicExpression, sympy.Le): return pysmt.LE(args[0], args[1]), Type('Bool') elif isinstance(symbolicExpression, sympy.Eq): return pysmt.Equals(args[0], args[1]), Type('Bool') elif isinstance(symbolicExpression, sympy.Ne): return pysmt.NotEquals(args[0], args[1]), Type('Bool') elif isinstance(symbolicExpression, sympy.And): if castType.literal == 'Bool': return pysmt.And(args[0], args[1]), Type('Bool') else: # type.literal == 'BitVector' return pysmt.BVAnd(args[0], args[1]), castType elif isinstance(symbolicExpression, sympy.Or): if castType.literal == 'Bool': return pysmt.Or(args[0], args[1]), Type('Bool') else: # type.literal == 'BitVector' return pysmt.BVOr(args[0], args[1]), castType elif isinstance(symbolicExpression, sympy.Xor): return pysmt.BVXor(args[0], args[1]), castType elif isinstance(symbolicExpression, sympy.Add): return pysmt.Plus(args), castType elif isinstance(symbolicExpression, sympy.Mul): return pysmt.Times(args), castType elif isinstance(symbolicExpression, sympy.Pow): return pysmt.Pow(args[0], args[1]), castType # TODO LATER: deal with missing modulo operator from pysmt else: if isinstance(symbolicExpression, sympy.Symbol): symbolType = Variable.symbolTypes[symbolicExpression.name] literal = symbolType.getTypeForSolver() designator = symbolType.designatorExpr1 type = Type(literal, designator) return Solver.__encodeTerminal(symbolicExpression, type), type elif isinstance(symbolicExpression, sympy.Integer): type = Type('Integer') return Solver.__encodeTerminal(symbolicExpression, type), type elif isinstance(symbolicExpression, sympy.Rational): type = Type('Real') return Solver.__encodeTerminal(symbolicExpression, type), type elif isinstance(symbolicExpression, sympy.Float): type = Type('Real') return Solver.__encodeTerminal(symbolicExpression, type), type else: type = Type('Real') return Solver.__encodeTerminal(symbolicExpression, type), type
def bv2pysmt(bv): """Convert a bit-vector type to a pySMT type. >>> from arxpy.bitvector.core import Constant, Variable >>> from arxpy.diffcrypt.smt import bv2pysmt >>> bv2pysmt(Constant(0b00000001, 8)) 1_8 >>> x, y = Variable("x", 8), Variable("y", 8) >>> bv2pysmt(x) x >>> bv2pysmt(x + y) (x + y) >>> bv2pysmt(x <= y) (x u<= y) >>> bv2pysmt(x[4: 2]) x[2:4] """ msg = "unknown conversion of {} to a pySMT type".format(type(bv).__name__) if isinstance(bv, int): return bv if isinstance(bv, core.Variable): return sc.Symbol(bv.name, typing.BVType(bv.width)) if isinstance(bv, core.Constant): return sc.BV(bv.val, bv.width) if isinstance(bv, operation.Operation): args = [bv2pysmt(a) for a in bv.args] if type(bv) == operation.BvNot: if args[0].is_equals(): return sc.Not(*args) else: return sc.BVNot(*args) if type(bv) == operation.BvAnd: return sc.BVAnd(*args) if type(bv) == operation.BvOr: return sc.BVOr(*args) if type(bv) == operation.BvXor: return sc.BVXor(*args) if type(bv) == operation.BvComp: # return sc.BVComp(*args) return sc.Equals(*args) if type(bv) == operation.BvUlt: return sc.BVULT(*args) if type(bv) == operation.BvUle: return sc.BVULE(*args) if type(bv) == operation.BvUgt: return sc.BVUGT(*args) if type(bv) == operation.BvUge: return sc.BVUGE(*args) if type(bv) == operation.BvShl: # Left hand side width must be a power of 2 if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0: return sc.BVLShl(*args) else: x, r = bv.args offset = 0 while (x.width & (x.width - 1)) != 0: x = operation.ZeroExtend(x, 1) r = operation.ZeroExtend(r, 1) offset += 1 shift = bv2pysmt(x << r) return sc.BVExtract(shift, end=shift.bv_width() - offset - 1) # width = args[0].bv_width() # assert (width & (width - 1)) == 0 # power of 2 # return sc.BVLShl(*args) if type(bv) == operation.BvLshr: # Left hand side width must be a power of 2 if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0: return sc.BVLShr(*args) else: x, r = bv.args offset = 0 while (x.width & (x.width - 1)) != 0: x = operation.ZeroExtend(x, 1) r = operation.ZeroExtend(r, 1) offset += 1 shift = bv2pysmt(x >> r) return sc.BVExtract(shift, end=shift.bv_width() - offset - 1) # width = args[1].bv_width() # assert (width & (width - 1)) == 0 # power of 2 # return sc.BVLShr(*args) if type(bv) == operation.RotateLeft: # Left hand side width must be a power of 2 if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0: return sc.BVRol(*args) else: x, r = bv.args n = x.width return bv2pysmt(operation.Concat(x[n - r - 1:], x[n - 1:n - r])) if type(bv) == operation.RotateRight: # Left hand side width must be a power of 2 if (args[0].bv_width() & (args[0].bv_width() - 1)) == 0: return sc.BVRor(*args) else: x, r = bv.args n = x.width return bv2pysmt(operation.Concat(x[r - 1:], x[n - 1:r])) if type(bv) == operation.Ite: if args[0].is_equals(): a0 = args[0] else: a0 = sc.Equals(args[0], bv2pysmt(core.Constant(1, 1))) return sc.Ite(a0, *args[1:]) if type(bv) == operation.Extract: return sc.BVExtract(args[0], args[2], args[1]) if type(bv) == operation.Concat: return sc.BVConcat(*args) if type(bv) == operation.ZeroExtend: return sc.BVZExt(*args) if type(bv) == operation.Repeat: return args[0].BVRepeat(args[1]) if type(bv) == operation.BvNeg: return sc.BVNeg(*args) if type(bv) == operation.BvAdd: return sc.BVAdd(*args) if type(bv) == operation.BvSub: return sc.BVSub(*args) if type(bv) == operation.BvMul: return sc.BVMul(*args) if type(bv) == operation.BvMul: return sc.BVMul(*args) if type(bv) == operation.BvUdiv: return sc.BVUDiv(*args) if type(bv) == operation.BvUrem: return sc.BVURem(*args) raise NotImplementedError(msg)