Exemplo n.º 1
0
    def get_antialias_constraint(self, address, register="sp"):
        register = self.get_reg_before(self.arch.registers[register][0])
        num_bytes = self.arch.bits / 8
        return z3.And(
            # Don't allow the address to be overlaping the register
            z3.Or(z3.ULT(address, register - num_bytes),
                  z3.UGT(address, register + num_bytes)),

            # Don't allow the address or register to wrap around
            z3.ULT(address, address + num_bytes),
            z3.UGT(address, address - num_bytes),
            z3.ULT(register, register + num_bytes),
            z3.UGT(register, register - num_bytes),
        )
Exemplo n.º 2
0
 def UGT(self, other):
     if isinstance(other, int):
         other = BVV(other, self.size)
     else:
         assert isinstance(other, BV)
         assert self.size == other.size
     return BoolExpr(z3.UGT(self.z3obj, other.z3obj))
Exemplo n.º 3
0
    def _max(self,
             expr,
             extra_constraints=(),
             solver=None,
             model_callback=None):
        global solve_count

        lo = 0
        hi = 2**expr.size() - 1
        vals = set()

        numpop = 0
        if len(extra_constraints) > 0:
            solver.push()
            numpop += 1
            solver.add(*[self.convert(e) for e in extra_constraints])

        # TODO: Can only deal with bitvectors, not floats
        while hi - lo > 1:
            middle = (lo + hi) // 2
            #l.debug("h/m/l/d: %d %d %d %d", hi, middle, lo, hi-lo)

            solver.push()
            solver.add(z3.UGT(expr, middle), z3.ULE(expr, hi))
            numpop += 1

            solve_count += 1
            l.debug("Doing a check!")
            if solver.check() == z3.sat:
                l.debug("... still sat")
                lo = middle
                vals.add(self._primitive_from_model(solver.model(), expr))
                if model_callback is not None:
                    model_callback(self._generic_model(solver.model()))
            else:
                l.debug("... now unsat")
                hi = middle
                solver.pop()
                numpop -= 1
            #l.debug("          now: %d %d %d %d", hi, middle, lo, hi-lo)

        for _ in range(numpop):
            solver.pop()

        if hi == lo:
            vals.add(hi)
        else:
            solver.push()
            solver.add(expr == hi)
            l.debug("Doing a check!")
            if solver.check() == z3.sat:
                if model_callback is not None:
                    model_callback(self._generic_model(solver.model()))
                vals.add(hi)
                solver.pop()
            else:
                vals.add(lo)
                solver.pop()

        return max(vals)
Exemplo n.º 4
0
def UGT(a: BitVec, b: BitVec) -> Bool:
    """Create an unsigned greater than expression.

    :param a:
    :param b:
    :return:
    """
    annotations = a.annotations + b.annotations
    return Bool(z3.UGT(a.raw, b.raw), annotations)
Exemplo n.º 5
0
    def _max_values(self,
                    expr,
                    extra_constraints=(),
                    result=None,
                    solver=None):
        global solve_count

        lo = 0
        hi = 2**expr.size() - 1
        vals = set()

        numpop = 0
        if len(extra_constraints) > 0:
            solver.push()
            numpop += 1
            solver.add(*[self.convert(e) for e in extra_constraints])

        while hi - lo > 1:
            middle = (lo + hi) / 2
            #l.debug("h/m/l/d: %d %d %d %d", hi, middle, lo, hi-lo)

            solver.push()
            solver.add(z3.UGT(expr, middle), z3.ULE(expr, hi))
            numpop += 1

            solve_count += 1
            l.debug("Doing a check!")
            if solver.check() == z3.sat:
                l.debug("... still sat")
                lo = middle
                vals.add(self._primitive_from_model(solver.model(), expr))
            else:
                l.debug("... now unsat")
                hi = middle
                solver.pop()
                numpop -= 1
            #l.debug("        now: %d %d %d %d", hi, middle, lo, hi-lo)

        for _ in range(numpop):
            solver.pop()

        if hi == lo:
            vals.add(hi)
        else:
            solver.push()
            solver.add(expr == hi)
            l.debug("Doing a check!")
            if solver.check() == z3.sat:
                vals.add(hi)
                solver.pop()
            else:
                vals.add(lo)
                solver.pop()

        return vals
Exemplo n.º 6
0
 def BVS(self, ast, result=None):  #pylint:disable=unused-argument
     name, mn, mx, stride, _ = ast.args
     size = ast.size()
     expr = z3.BitVec(name, size, ctx=self._context)
     if mn is not None:
         expr = z3.If(z3.ULT(expr, mn), mn, expr, ctx=self._context)
     if mx is not None:
         expr = z3.If(z3.UGT(expr, mx), mx, expr, ctx=self._context)
     if stride is not None:
         expr = (expr / stride) * stride
     return expr
Exemplo n.º 7
0
 def ComparisonToSmt(self):
     assert (VertexNode.OpCode.IsComparison(self.operator))
     lhs = self.operands[0].VertexNameToSmt()
     rhs = self.operands[1].VertexNameToSmt()
     if self.operator == VertexNode.OpCode.GT:
         return z3.UGT(lhs, rhs)
     elif self.operator == VertexNode.OpCode.GE:
         return z3.UGE(lhs, rhs)
     elif self.operator == VertexNode.OpCode.LT:
         return z3.ULT(lhs, rhs)
     elif self.operator == VertexNode.OpCode.LE:
         return z3.ULE(lhs, rhs)
     elif self.operator == VertexNode.OpCode.EQ:
         return (lhs == rhs)
     elif self.operator == VertexNode.OpCode.NE:
         return (lhs != rhs)
Exemplo n.º 8
0
 def bvugt(self, other):
     return self.get_family().Bit(z3.UGT(self.value, other.value))
def is_fn_valid(fn):
    return z3.And(z3.UGT(fn, 0), z3.ULT(fn, dt.NFILE))
Exemplo n.º 10
0
def parse_int_params(name):
    things = name.split('[')[1:]
    #    print "things:".format(things)
    if not all(t.endswith(']') for t in things):
        raise SyntaxError()
    return [int(t[:-1]) for t in things]


def is_solver_sort(name):
    return name.startswith('bv[') and name.endswith(']') or name == 'int'


relations_dict = {
    '<': (lambda x, y: z3.ULT(x, y) if z3.is_bv(x) else x < y),
    '<=': (lambda x, y: z3.ULE(x, y) if z3.is_bv(x) else x <= y),
    '>': (lambda x, y: z3.UGT(x, y) if z3.is_bv(x) else x > y),
    '>=': (lambda x, y: z3.UGE(x, y) if z3.is_bv(x) else x >= y),
}


def relations(name):
    return relations_dict.get(name)


functions_dict = {
    "+": (lambda x, y: x + y),
    "-": my_minus,
    "*": (lambda x, y: x * y),
    "concat": (lambda x, y: z3.Concat(x, y)),
}
Exemplo n.º 11
0
log.info("Min: 0x%x (%d queries, %02fs)", lb, q, end - start)

val_at_least = lb = 0
ub = 0xffffffff
q = 0

# Walk the val_at_least and ub towards each other. The val_at_least represents
# the value which we know the upper bound definitely is equal to, or greater
# than.
start = time.time()
while val_at_least != ub:
    s, v = get_model()

    log.debug("val_at_least: 0x%x, ub: 0x%x, lb: 0x%x", val_at_least, ub, lb)
    log.debug("Checking v > 0x%x ... ", lb)
    s.add(z3.UGT(v, lb))

    q += 1
    if s.check() == z3.sat:
        log.debug("... SAT")
        val_at_least = lb + 1
        lb += (ub - lb) // 2
    else:
        log.debug("... UNSAT")
        ub = lb
        if ub - val_at_least == 1:
            lb = val_at_least
        else:
            lb = ub - (ub - val_at_least) / 2

end = time.time()
Exemplo n.º 12
0
 def GT(self, gstate, a, b):
     a, b = map(svm_utils.convert_to_bitvec, (a, b))
     exp = z3.UGT(a, b)
     gstate.mstate.stack.append(exp)
Exemplo n.º 13
0
def is_s_len_valid(s_len):
    return z3.UGT(s_len, 0)
Exemplo n.º 14
0
if args.z3:
    # This check is done on 32 bits, but the size is not use by Miasm formulas, so
    # it should be OK for any size > 0
    x1 = ExprId("x1", 32)
    x2 = ExprId("x2", 32)
    i1_tmp = ExprInt(1, 1)

    x1_z3 = trans.from_expr(x1)
    x2_z3 = trans.from_expr(x2)
    i1_z3 = trans.from_expr(i1_tmp)

    # (Assumptions, function(arg1, arg2) -> True/False (= i1/i0) to check)
    tests = [
        (x1_z3 == x2_z3, expr_is_equal),
        (x1_z3 != x2_z3, expr_is_not_equal),
        (z3.UGT(x1_z3, x2_z3), expr_is_unsigned_greater),
        (z3.UGE(x1_z3, x2_z3), expr_is_unsigned_greater_or_equal),
        (z3.ULT(x1_z3, x2_z3), expr_is_unsigned_lower),
        (z3.ULE(x1_z3, x2_z3), expr_is_unsigned_lower_or_equal),
        (x1_z3 > x2_z3, expr_is_signed_greater),
        (x1_z3 >= x2_z3, expr_is_signed_greater_or_equal),
        (x1_z3 < x2_z3, expr_is_signed_lower),
        (x1_z3 <= x2_z3, expr_is_signed_lower_or_equal),
    ]

    for assumption, func in tests:
        solver = z3.Solver()
        solver.add(assumption)
        solver.add(trans.from_expr(func(x1, x2)) != i1_z3)
        assert solver.check() == z3.unsat
Exemplo n.º 15
0
 def ugt(self, ctx, return_type, a, atype, b, btype, **kwargs):
     assert atype == btype
     return z3.UGT(a, b)
Exemplo n.º 16
0
 def semantics(self, a, b):
     return z3.If(z3.UGT(a, b), z3.BitVecVal(-1, 32, self.ctx),
                  z3.BitVecVal(0, 32, self.ctx))
def opcode_gt(opcode_data1, opcode_data2):

    express = z3.If(z3.UGT(opcode_data1, opcode_data2), z3.BitVecVal(1, 256),
                    z3.BitVecVal(0, 256))
    return express
Exemplo n.º 18
0
eval.register(
    TrailingZerosCnxp, BaseSMTTranslator,
    lambda term, smt: cttz(smt.type(term).width, smt.eval(term._args[0])))

eval.register(
    Log2Cnxp, BaseSMTTranslator,
    lambda term, smt: bv_log2(smt.type(term).width, smt.eval(term._args[0])))

eval.register(LShrFunCnxp, BaseSMTTranslator, _handler(z3.LShR))
eval.register(SMaxCnxp, BaseSMTTranslator,
              _handler(lambda x, y: z3.If(x > y, x, y)))
eval.register(SMinCnxp, BaseSMTTranslator,
              _handler(lambda x, y: z3.If(x > y, y, x)))
eval.register(UMaxCnxp, BaseSMTTranslator,
              _handler(lambda x, y: z3.If(z3.UGT(x, y), x, y)))
eval.register(UMinCnxp, BaseSMTTranslator,
              _handler(lambda x, y: z3.If(z3.UGT(x, y), y, x)))


@eval.register(SExtCnxp, BaseSMTTranslator)
def _(term, smt):
    v = smt.eval(term._args[0])
    src = smt.type(term._args[0])
    tgt = smt.type(term)

    return z3.SignExt(tgt.width - src.width, v)


@eval.register(ZExtCnxp, BaseSMTTranslator)
def _(term, smt):
Exemplo n.º 19
0
def dec(old):
    cond = z3.UGT(old.counter, 0)
    new = old.copy()
    new.counter -= 1
    return cond, util.If(cond, new, old)
Exemplo n.º 20
0
    def __init__(self, ns):
        super(Z3python, self).__init__(False, ns, False, True, True)
        self.ctx = z3.Context()
        self.ast_list = []
        self.sort_list = []
        self.bool_sort = self.mk_sort((esbmc.solve.smt_sort_kind.bool, ))
        self.solver = z3.Solver(solver=None, ctx=self.ctx)
        self.fresh_arr_idx = 0

        # No 'int' support at this time
        self.func_map = {
          #SMT_FUNC_ADD
            esbmc.solve.smt_func_kind.bvadd :
                lambda ctx, args, asts: asts[0] + asts[1],
          #SMT_FUNC_SUB
            esbmc.solve.smt_func_kind.bvsub :
                lambda ctx, args, asts: asts[0] - asts[1],
          #SMT_FUNC_MUL,
            esbmc.solve.smt_func_kind.bvmul :
                lambda ctx, args, asts: asts[0] * asts[1],
          #SMT_FUNC_DIV,
            esbmc.solve.smt_func_kind.bvudiv :
                lambda ctx, args, asts: z3.UDiv(asts[0], asts[1]),
            esbmc.solve.smt_func_kind.bvsdiv :
                lambda ctx, args, asts: asts[0] / asts[1],
          #SMT_FUNC_MOD,
            esbmc.solve.smt_func_kind.bvsmod :
                lambda ctx, args, asts: asts[0] % asts[1],
            esbmc.solve.smt_func_kind.bvumod :
                lambda ctx, args, asts: z3.URem(asts[0], asts[1]),
          #SMT_FUNC_SHL,
            esbmc.solve.smt_func_kind.bvshl :
                lambda ctx, args, asts: asts[0] << asts[1],
            esbmc.solve.smt_func_kind.bvashr :
                lambda ctx, args, asts: asts[0] >> asts[1],
          #SMT_FUNC_NEG,
            esbmc.solve.smt_func_kind.bvneg :
                lambda ctx, args, asts: -asts[0],
            esbmc.solve.smt_func_kind.bvlshr :
                lambda ctx, args, asts: z3.LShR(asts[0], asts[1]),
            esbmc.solve.smt_func_kind.bvnot :
                lambda ctx, args, asts: ~asts[0],
          #SMT_FUNC_BVNXOR, These aren't actually used anywhere
          #SMT_FUNC_BVNOR,
          #SMT_FUNC_BVNAND,
            esbmc.solve.smt_func_kind.bvxor :
                lambda ctx, args, asts: asts[0] ^ asts[1],
            esbmc.solve.smt_func_kind.bvor :
                lambda ctx, args, asts: asts[0] | asts[1],
            esbmc.solve.smt_func_kind.bvand :
                lambda ctx, args, asts: asts[0] & asts[1],
            esbmc.solve.smt_func_kind.implies :
                lambda ctx, args, asts: z3.Implies(asts[0], asts[1], ctx),
            esbmc.solve.smt_func_kind.xor :
                lambda ctx, args, asts: asts[0] ^ asts[1],
            esbmc.solve.smt_func_kind._or : # or is a keyword in python
                lambda ctx, args, asts: z3.Or(asts[0], asts[1]),
            esbmc.solve.smt_func_kind._and :
                lambda ctx, args, asts: z3.And(asts[0], asts[1]),
            esbmc.solve.smt_func_kind._not :
                lambda ctx, args, asts: z3.Not(asts[0]),
          #SMT_FUNC_LT,
            esbmc.solve.smt_func_kind.bvslt :
                lambda ctx, args, asts: asts[0] < asts[1],
            esbmc.solve.smt_func_kind.bvult :
                lambda ctx, args, asts: z3.ULT(asts[0], asts[1]),
          #SMT_FUNC_GT,
            esbmc.solve.smt_func_kind.bvsgt :
                lambda ctx, args, asts: asts[0] > asts[1],
            esbmc.solve.smt_func_kind.bvugt :
                lambda ctx, args, asts: z3.UGT(asts[0], asts[1]),
            #SMT_FUNC_LTE,
            # Z3 doesn't do lte's, so invert gt
            esbmc.solve.smt_func_kind.bvslte :
                lambda ctx, args, asts: z3.Not(asts[0] > asts[1]),
            esbmc.solve.smt_func_kind.bvulte :
                lambda ctx, args, asts: z3.Not(z3.UGT(asts[0], asts[1])),
          #SMT_FUNC_GTE,
            esbmc.solve.smt_func_kind.bvsgte :
                lambda ctx, args, asts: z3.Not(asts[0] < asts[1]),
            esbmc.solve.smt_func_kind.bvugte :
                lambda ctx, args, asts: z3.Not(z3.ULT(asts[0], asts[1])),
            esbmc.solve.smt_func_kind.eq :
                lambda ctx, args, asts: asts[0] == asts[1],
            esbmc.solve.smt_func_kind.noteq :
                lambda ctx, args, asts: asts[0] != asts[1],
            esbmc.solve.smt_func_kind.ite :
                lambda ctx, args, asts: z3.If(asts[0], asts[1], asts[2], ctx),
          #SMT_FUNC_STORE, Handled via ast functions
          #SMT_FUNC_SELECT,
            esbmc.solve.smt_func_kind.concat :
                lambda ctx, args, asts: z3.Concat(asts[0], asts[1]),
          #SMT_FUNC_EXTRACT, // Not for going through mk app due to sillyness.
          # Int related stuff
          #SMT_FUNC_INT2REAL,
          #SMT_FUNC_REAL2INT,
          #SMT_FUNC_IS_INT,
        }

        # Various accounting structures for the address space modeling need to
        # be set up, but that needs to happen after the solver is online. Thus,
        # we have to call this once the object is ready to create asts.
        self.smt_post_init()
Exemplo n.º 21
0
 def ugt(self, other):
     return Bool(z3.UGT(self.symbol, other.symbol))
Exemplo n.º 22
0
    def VertexOperationToSmt(self):
        assert (self.type != VertexNode.VertexType.NONE)

        if self.type == VertexNode.VertexType.VAR:
            # Possible Vertex : input Variable, name = operand1
            # input variable: there is nothing to do.
            # assigned Variable: name = operands[0]

            # It's an input variable if there is no operand :
            if self.operands == None: return None
            # otherwise, it's an assigned variable, but make sure just in case
            assert (self.operator == VertexNode.OpCode.ASSIGN)
            return self.VertexNameToSmt() == self.operands[0].VertexNameToSmt()

        elif self.type == VertexNode.VertexType.TEMP:
            # Possible Vertex : Function Call, Array Load, Binary Operation, Comparison,
            #                   Conditional Assignment, Unary Operation
            # function call: name = func_name(arguments)
            # array load: name = array[index]
            # binary operation: name = operand1 op operand2
            # comparison: name = operand1 comp operand2
            # conditional assignment: name = ite(operand1, operand2, operand3)
            # unary operation: name = op operand1

            # It's a function call
            if self.operator == VertexNode.OpCode.FUNCCALL:
                assert (self.operands[0].type == VertexNode.VertexType.FUNC)
                # There are four possible functions that can last until now:
                if self.operands[0].name == "merge":
                    args = []
                    for op in self.operands[1:]:
                        args.append(op.VertexNameToSmt())
                    return self.VertexNameToSmt() == z3.Concat(args)
                elif self.operands[0].name == "split":
                    toSplit = self.operands[1].VertexNameToSmt()
                    # Extract requires actual numerical value.
                    lowerBound = self.operands[2].value
                    upperBound = self.operands[3].value
                    return self.VertexNameToSmt() == z3.Extract(
                        upperBound, lowerBound, toSplit)
                elif self.operands[0].name == "zeroext":
                    toExtend = self.operands[1].VertexNameToSmt()
                    # ZeroExt requires actual numerical value
                    n = self.operands[2].value
                    return self.VertexNameToSmt() == z3.ZeroExt(n, toExtend)
                elif self.operands[0].name == "concat":
                    args = []
                    for op in self.operands[1:]:
                        args.append(op.VertexNameToSmt())
                    return self.VertexNameToSmt() == z3.Concat(args)

            # It's an array load
            elif self.operator == VertexNode.OpCode.LOAD:
                array = self.operands[0].VertexNameToSmt()
                arrayIndex = self.operands[1].VertexNameToSmt()
                return self.VertexNameToSmt() == z3.Select(array, arrayIndex)

            # It's a conditional statement
            elif self.operator == VertexNode.OpCode.CONDITIONAL:
                cond = self.operands[0].VertexNameToSmt()
                truePath = self.operands[1].VertexNameToSmt()
                falsePath = self.operands[2].VertexNameToSmt()
                return self.VertexNameToSmt() == z3.If(cond, truePath,
                                                       falsePath)

            # It's a comparison (x < y)
            elif VertexNode.OpCode.IsComparison(self.operator):
                lhs = self.operands[0].VertexNameToSmt()
                rhs = self.operands[1].VertexNameToSmt()
                if self.operator == VertexNode.OpCode.GT:
                    return self.VertexNameToSmt() == z3.UGT(lhs, rhs)
                elif self.operator == VertexNode.OpCode.GE:
                    return self.VertexNameToSmt() == z3.UGE(lhs, rhs)
                elif self.operator == VertexNode.OpCode.LT:
                    return self.VertexNameToSmt() == z3.ULT(lhs, rhs)
                elif self.operator == VertexNode.OpCode.LE:
                    return self.VertexNameToSmt() == z3.ULE(lhs, rhs)
                elif self.operator == VertexNode.OpCode.EQ:
                    return self.VertexNameToSmt() == (lhs == rhs)
                elif self.operator == VertexNode.OpCode.NE:
                    return self.VertexNameToSmt() == (lhs != rhs)

            # It's a binary operation
            elif VertexNode.OpCode.IsBinaryOp(self.operator):
                lhs = self.operands[0].VertexNameToSmt()
                rhs = self.operands[1].VertexNameToSmt()
                if self.operator == VertexNode.OpCode.PLUS:
                    return self.VertexNameToSmt() == (lhs + rhs)
                elif self.operator == VertexNode.OpCode.MINUS:
                    return self.VertexNameToSmt() == (lhs - rhs)
                elif self.operator == VertexNode.OpCode.AND:
                    return self.VertexNameToSmt() == (lhs & rhs)
                elif self.operator == VertexNode.OpCode.OR:
                    return self.VertexNameToSmt() == (lhs | rhs)
                elif self.operator == VertexNode.OpCode.XOR:
                    return self.VertexNameToSmt() == (lhs ^ rhs)
                elif self.operator == VertexNode.OpCode.SHL:
                    return self.VertexNameToSmt() == (lhs << rhs)
                elif self.operator == VertexNode.OpCode.SHR:
                    return self.VertexNameToSmt() == (z3.LShR(lhs, rhs))
                elif self.operator == VertexNode.OpCode.ROL:
                    return self.VertexNameToSmt() == (z3.RotateLeft(lhs, rhs))
                elif self.operator == VertexNode.OpCode.ROR:
                    return self.VertexNameToSmt() == (z3.RotateRight(lhs, rhs))
                elif self.operator == VertexNode.OpCode.MUL:
                    return self.VertexNameToSmt() == (lhs * rhs)
                elif self.operator == VertexNnode.OpCode.DIV:
                    return self.VertexNameToSmt() == (lhs / rhs)

            # It's a unary operation
            elif VertexNode.OpCode.IsUnaryOp(self.operator):
                rhs = self.operands[0].VertexNameToSmt()
                if self.operator == VertexNode.OpCode.NOT:
                    return self.VertexNameToSmt() == ~rhs

        elif self.type == VertexNode.VertexType.IMM:
            # Possible Vertex : Immediate Value
            return None
        elif self.type == VertexNode.VertexType.ARR:
            # Possible Vertex : Input array, array store
            # input array: there is nothing to do
            # array store: newarray = store(array, index, value)

            # if operator == None, it's an "input" array
            if self.operator == None: return None
            if self.operator == VertexNode.OpCode.NONE: return None
            # Otherwise, it must be an array store operation vertex
            assert (self.operator == VertexNode.OpCode.STORE)
            oldArray = self.operands[0].VertexNameToSmt()
            index = self.operands[1].VertexNameToSmt()
            value = self.operands[2].VertexNameToSmt()
            newArray = self.VertexNameToSmt()
            return newArray == z3.Store(oldArray, index, value)

        elif self.type == VertexNode.VertexType.FUNC:
            # Possible Vertex : Name of the function
            return None
Exemplo n.º 23
0
ly2 = z3.Int('ly2')
ly3 = z3.Int('ly3')
ly4 = z3.Int('ly4')

lx11 = z3.Int('lx11')
lx12 = z3.Int('lx12')
lx21 = z3.Int('lx21')
lx22 = z3.Int('lx22')
lx31 = z3.Int('lx31')
lx32 = z3.Int('lx32')
lx41 = z3.Int('lx41')
lx42 = z3.Int('lx42')

# List of components. phi-lib
phi1 = (Y1 == X11 ^ X12)
phi2 = (Y2 == -z3.If(z3.UGT(X21, X22), z3.BitVecVal(0, 8), z3.BitVecVal(1, 8)))
phi3 = (Y3 == X31 & X32)
phi4 = (Y4 == X41 ^ X42)

# Write the spec
spec = z3.And(z3.Implies(z3.UGE(J, I), O == J),
              z3.Implies(z3.UGT(I, J), O == I))

# phi cons = line number of two different instructions cannot be the same
phicons = z3.And(ly1 != ly2, ly2 != ly3, ly1 != ly3, ly1 != ly4, ly4 != ly2,
                 ly4 != ly3)

# We only have three instructions.
# Bound the line number of each instruction and operand.
phibound = z3.And(ly1 >= 1, ly1 <= 4, ly2 >= 1, ly2 <= 4, ly3 >= 1, ly3 <= 4,
                  ly4 >= 1, ly4 <= 4, lx11 >= -1, lx11 <= 4, lx12 >= -1,
Exemplo n.º 24
0
Arquivo: vm.py Projeto: tcrs/evm-sym
def run_block(s, solver, log_trace=False):
    def end_trace(reason, *args):
        s.end_type = reason
        s.end_info = args
        pass

    while True:
        op = s.code[s.pc]
        try:
            instr = vm_isa.opcodes[op]
        except KeyError:
            end_trace('invalid')
            return

        if log_trace:
            print('{:04}: {}'.format(s.pc, instr.name))
            print('> ' + ';; '.join(str(z3.simplify(x)) for x in s.stack))
            #print('> {} | {} | {}'.format(stack, mem, store))

        try:
            instr_args = [s.stack.pop() for i in range(instr.pop)]
        except IndexError:
            end_trace('stack underflow')
            return

        def reducestack(fn):
            s.stack.append(fn(*instr_args))

        oplen = 1

        s.gas = s.gas - instr.base_gas
        if instr.extra_gas is not None:
            s.gas = s.gas - instr.extra_gas(s, *instr_args)
        s.gas = z3.simplify(s.gas)

        if op >= 0x80 and op <= 0x8f:  # DUPn
            # instr_args[0] = old top of stack
            for v in reversed(instr_args):
                s.stack.append(v)
            s.stack.append(instr_args[-1])
        elif op >= 0x90 and op <= 0x9f:  #SWAPn
            # Old top of stack pushed first
            s.stack.append(instr_args[0])
            # Then the middle section (in original order)
            for v in reversed(instr_args[1:-1]):
                s.stack.append(v)
            # Then bottom value on top
            s.stack.append(instr_args[-1])
        elif op >= 0x60 and op <= 0x7f:  #PUSHn
            npush = op - 0x60 + 1
            s.stack.append(z3.BitVecVal(util.pushval(s.code, s.pc), 256))
            oplen += npush
        elif instr.name == 'ADD':
            reducestack(lambda x, y: x + y)
        elif instr.name == 'MUL':
            reducestack(lambda x, y: x * y)
        elif instr.name == 'SUB':
            reducestack(lambda x, y: x - y)
        elif instr.name == 'DIV':
            reducestack(lambda x, y: z3.If(y == 0, z3.BitVecVal(0, 256),
                                           z3.UDiv(x, y)))
        elif instr.name == 'SDIV':
            reducestack(lambda x, y: z3.If(
                y == 0, z3.BitVecVal(0, 256),
                z3.If(x == -2**255 and y == -1, z3.BitVecVal(-2**255, 256), x /
                      y)))
        elif instr.name == 'MOD':
            reducestack(lambda x, y: z3.If(y == 0, z3.BitVecVal(0, 256),
                                           z3.URem(x, y)))
        elif instr.name == 'SMOD':
            reducestack(lambda x, y: z3.If(y == 0, z3.BitVecVal(0, 256),
                                           z3.SRem(x, y)))
        elif instr.name == 'ADDMOD':
            reducestack(lambda x, y, z: z3.If(
                z == 0, z3.BitVecVal(0, 256),
                z3.Extract(
                    255, 0,
                    z3.URem(
                        z3.ZeroExt(1, x) + z3.ZeroExt(1, y), z3.ZeroExt(1, z)))
            ))
        elif instr.name == 'MULMOD':
            reducestack(lambda x, y, z: z3.If(
                z == 0, z3.BitVecVal(0, 256),
                z3.Extract(
                    255, 0,
                    z3.URem(
                        z3.ZeroExt(256, x) * z3.ZeroExt(256, y),
                        z3.ZeroExt(256, z)))))
        elif instr.name == 'EXP':
            # TODO z3 currently doesn't seem to provide __pow__ on BitVecs?
            reducestack(lambda x, y: z3.BitVecVal(
                pow(x.as_long(), y.as_long(), 1 << 256), 256))
        elif instr.name == 'LT':
            reducestack(lambda x, y: _bool_to_01(z3.ULT(x, y)))
        elif instr.name == 'GT':
            reducestack(lambda x, y: _bool_to_01(z3.UGT(x, y)))
        elif instr.name == 'SLT':
            reducestack(lambda x, y: _bool_to_01(x < y))
        elif instr.name == 'SGT':
            reducestack(lambda x, y: _bool_to_01(x > y))
        elif instr.name == 'EQ':
            reducestack(lambda x, y: _bool_to_01(x == y))
        elif instr.name == 'ISZERO':
            reducestack(lambda x: _bool_to_01(x == 0))
        elif instr.name == 'AND':
            reducestack(lambda x, y: x & y)
        elif instr.name == 'OR':
            reducestack(lambda x, y: x | y)
        elif instr.name == 'XOR':
            reducestack(lambda x, y: x ^ y)
        elif instr.name == 'NOT':
            reducestack(lambda x: ~x)
        elif instr.name == 'BYTE':
            idx, val = instr_args
            bidx = as_concrete(idx)
            if bidx <= 31:
                s.stack.append(z3.ZeroExt(248, get_byte(val, bidx)))
            else:
                s.stack.append(z3.BitVecVal(0, 256))
        elif instr.name == 'SIGNEXTEND':
            idx, val = instr_args
            bidx = as_concrete(idx)
            if bidx <= 31:
                nbits = 8 * (bidx + 1)
                to_extend = z3.Extract(nbits - 1, 0, val)
                s.stack.append(z3.SignExt(256 - nbits, to_extend))
            else:
                s.stack.append(val)
        elif instr.name == 'CODESIZE':
            s.stack.append(z3.BitVecVal(s.code.size(), 256))
        elif instr.name == 'SHA3':
            start, sz = instr_args
            v = MemoryEmpty
            n = as_concrete(sz)
            for i in range(n):
                v = z3.Store(v, i, s.memory.select(start + i))
            s.stack.append(sha3(v))
            # TODO when n == 0 or all values are concrete, simplify!
            #start, sz = as_concrete(start), as_concrete(sz)
            #stack.append(ethereum.utils.sha3_256([as_concrete(
        elif instr.name in {
                'GASPRICE', 'COINBASE', 'TIMESTAMP', 'NUMBER', 'DIFFICULTY',
                'GASLIMIT', 'ORIGIN'
        }:
            reducestack(getattr(s.transaction, instr.name.lower()))
        elif instr.name in {'BALANCE', 'BLOCKHASH', 'EXTCODESIZE'}:
            reducestack(lambda x: (getattr(s.transaction, instr.name.lower())
                                   ())(x))
        elif instr.name == 'ADDRESS':
            s.stack.append(s.addr)
        elif instr.name == 'CALLVALUE':
            s.stack.append(s.callinfo.value)
        elif instr.name == 'CALLDATASIZE':
            s.stack.append(s.callinfo.calldata.size)
        elif instr.name == 'CALLER':
            s.stack.append(s.caller)
        elif instr.name == 'CODECOPY':
            # TODO handle non-concrete size
            start_mem, start_code, sz = instr_args
            start_code = as_concrete(start_code)
            for i in range(as_concrete(sz)):
                s.memory.store(start_mem + i, s.code[start_code + i])
        elif instr.name == 'CALLDATACOPY':
            src, dest, sz = instr_args
            cd_mem, cd_off, cd_sz = s.callinfo.calldata
            # TODO cache this limited calldata memory object - this is so that
            # out of range calldata reads correctly return 0s
            limited_cdmem = mem.Memory()
            limited_cdmem.overlay(cd_mem, 0, cd_off, cd_sz)
            s.memory.overlay(limited_cdmem, dest, cd_off + src, sz)
        elif instr.name == 'CALLDATALOAD':
            addr, = instr_args
            cd_mem, cd_off, cd_sz, *_ = s.callinfo.calldata
            s.stack.append(
                z3.simplify(
                    z3.Concat(*[
                        z3.If(addr + i < cd_sz, cd_mem.select(cd_off + addr +
                                                              i), 0)
                        for i in range(32)
                    ])))
        elif instr.name == 'RETURNDATASIZE':
            if hasattr(s, retdata):
                s.stack.append(s.retdata.size)
            else:
                s.stack.append(z3.BitVecVal(0, 256))
        elif instr.name == 'RETURNDATACOPY':
            src, dest, sz = instr_args
            # TODO non-concrete length, retdata overflow (should terminate)
            if hasattr(s, retdata):
                for i in range(sz.as_long()):
                    s.memory.store(
                        dest + i,
                        z3.Select(s.retdata.mem, s.retdata.offset + src + i))
        elif instr.name == 'POP':
            pass
        elif instr.name == 'MLOAD':
            addr, = instr_args
            s.stack.append(
                z3.simplify(
                    z3.Concat(*[s.memory.select(addr + i)
                                for i in range(32)])))
        elif instr.name == 'MSTORE':
            dst, word = instr_args
            for i in range(32):
                s.memory.store(dst + i, get_byte(word, i))
        elif instr.name == 'MSTORE8':
            dst, word = instr_args
            s.memory.store(dst, get_byte(word, 31))
        elif instr.name == 'SLOAD':
            addr, = instr_args
            s.stack.append(z3.simplify(z3.Select(s.storage, addr)))
        elif instr.name == 'SSTORE':
            addr, word = instr_args
            s.storage = z3.Store(s.storage, addr, word)
        elif instr.name == 'PC':
            s.stack.append(z3.BitVecVal(s.pc, 256))
        elif instr.name == 'GAS':
            # TODO actually track gas usage?
            s.stack.append(z3.BitVec('{}:GAS'.format(s.pc), 256))
        elif instr.name in 'STOP':
            end_trace('stop')
            return
        elif instr.name == 'RETURN':
            ret_start, ret_size = instr_args
            s.make_child_return(ret_start, ret_size)
            return
        elif instr.name == 'REVERT':
            ret_start, ret_size = instr_args
            s.make_child_revert(ret_start, ret_size)
            return
        elif instr.name in {'CALL', 'CALLCODE', 'DELEGATECALL'}:
            if instr.name in {'CALL', 'CALLCODE'}:
                gas, addr, value, in_off, in_sz, out_off, out_sz = instr_args
                caller = s.addr
            elif instr.name == 'DELEGATECALL':
                gas, addr, in_off, in_sz, out_off, out_sz = instr_args
                value = s.callinfo.value
                caller = s.caller
            else:
                assert False, instr.name
            addr = z3.simplify(addr)
            if instr.name == 'CALL':
                call_addr = addr
                code_addr = addr
            else:
                call_addr = z3.BitVecVal(s.addr, 256)
                code_addr = addr

            callres = z3.BitVec(
                '{}:{}({})'.format(s.pc, instr.name, z3.simplify(call_addr)),
                256)
            s.stack.append(callres)
            if is_concrete(call_addr):
                s.make_child_call(addr=call_addr.as_long(),
                                  code_addr=code_addr.as_long(),
                                  caller=caller,
                                  retinfo=ReturnInfo(s, s.pc + 1, out_off,
                                                     out_sz, callres),
                                  callinfo=CallInfo(
                                      MemRange(s.memory, in_off, in_sz),
                                      z3.BV2Int(gas), value))
                return
            else:
                end_trace('call', call_addr, value, gas)
                s.make_child_branch(
                    new_pc=s.pc + 1,
                    preds=[s.gas > 0,
                           z3.Or(callres == 1, callres == 0)])
                return
        elif instr.name == 'CREATE':
            value, in_off, in_sz = instr_args
            res = z3.BitVec('{}:CREATE({})'.format(s.pc, value), 256)
            s.stack.append(res)
            end_trace('create', value)
            s.make_child_branch(new_pc=s.pc + 1,
                                preds=[s.gas > 0,
                                       z3.Or(res == 0, res == 1)])
            return
        elif instr.name == 'SELFDESTRUCT':
            to_addr = instr_args
            end_trace('suicide', to_addr)
            # No successors
            return
        elif instr.name == 'JUMPI':
            end_trace(None)
            loc, cond = instr_args

            fallthrough_pc = None

            solver.push()
            solver.add(cond == 0)
            fallthrough_state = None
            if solver.check() == z3.sat:
                # Also might not take the jump
                fallthrough_pc = s.pc + 1
            solver.pop()

            solver.push()
            solver.add(cond != 0)
            if solver.check() == z3.sat:
                # OK, can take the jump
                if is_concrete(loc):
                    loc_conc = loc.as_long()
                    if loc_conc == fallthrough_pc:
                        # Fuse fallthrough and jump if to same location
                        fallthrough_pc = None
                        s.make_child_branch(new_pc=loc_conc, preds=[s.gas > 0])
                    else:
                        s.make_child_branch(new_pc=loc_conc,
                                            preds=[s.gas > 0, cond != 0])
                else:
                    for dest in s.code.all_jumpdests():
                        solver.push()
                        solver.add(loc == dest)
                        if solver.check() == z3.sat:
                            if dest == fallthrough_pc:
                                fallthrough_pc = None
                                s.make_child_branch(
                                    new_pc=dest,
                                    preds=[s.gas > 0, loc == dest])
                            else:
                                s.make_child_branch(
                                    new_pc=dest,
                                    preds=[s.gas > 0, cond != 0, loc == dest])
                        solver.pop()
            solver.pop()
            if fallthrough_pc is not None:
                s.make_child_branch(new_pc=fallthrough_pc,
                                    preds=[s.gas > 0, cond == 0])
            return
        elif instr.name == 'JUMP':
            end_trace(None)
            (loc, ) = instr_args
            if is_concrete(loc):
                s.make_child_branch(new_pc=loc.as_long(), preds=[s.gas > 0])
            else:
                successors = []
                for dest in s.code.all_jumpdests():
                    solver.push()
                    solver.add(loc == dest)
                    if solver.check() == z3.sat:
                        s.make_child_branch(new_pc=dest,
                                            preds=[s.gas > 0, loc == dest])
                    solver.pop()
            # No fallthrough
            return
        elif instr.name == 'JUMPDEST':
            s.jumpdests.add(s.pc)
        elif instr.name in {'LOG0', 'LOG1', 'LOG2', 'LOG3', 'LOG4'}:
            pass
        else:
            raise NotImplementedError(instr.name)

        if log_trace:
            print('< ' + ';; '.join(str(z3.simplify(x)) for x in s.stack))
        s.pc += oplen