Exemple #1
0
def solve(formula, n, max_models=None, solver="msat"):
    s = Solver(name=solver)
    st = s.is_sat(formula)
    if st:
        vs = [x for xs in variables(n) for x in xs]
        k = 0
        s.add_assertion(formula)
        while s.solve() and ((not max_models) or k < max_models):
            k = k + 1
            model = s.get_model()
            s.add_assertion(Not(And([EqualsOrIff(v, model[v]) for v in vs])))
            yield to_bn(model, n)
Exemple #2
0
def StatisticsSolver(statistics: Statistics, name=None, logic=None, **kwargs):
    """Create a new PySMT solver which also updates the sat check timer automatically."""
    solver = Solver(name, logic, **kwargs)
    old_is_sat = solver.is_sat

    def _timing_is_sat(self, *args, **kwargs):
        statistics.sat_check_time.start_timer()
        try:
            res = old_is_sat(*args, **kwargs)
        finally:
            statistics.sat_check_time.stop_timer()
        return res

    solver.is_sat = MethodType(_timing_is_sat, solver)
    return solver
Exemple #3
0
def analyze_rv32_interpreter(program: List[Instruction],
                             bbs: List[BasicBlock]):
    #print("analyzing rv32 interpreter ...")

    mk_dot(dot_cfg(bbs), filename="cfg.pdf")
    #for bb in program: print(bb)

    # start at MainStart @ 0x0056
    start_pc = 0x56
    # symbolic instruction: ADD rs2, rs1, rd
    funct7 = BitVecVal(0, 7)
    rs2 = Symbol("RV32I_ADD_rs2", BVType(5))
    rs1 = Symbol("RV32I_ADD_rs1", BVType(5))
    funct3 = BitVecVal(0b00, 3)  # ADD
    rd = Symbol("RV32I_ADD_rd", BVType(5))
    opcode = BitVecVal(0b0110011, 7)  # OP
    #RV32I_instr = Symbol("RV32IInstruction", BVType(32))
    RV32I_instr = cat(funct7, rs2, rs1, funct3, rd, opcode)
    print(f"Symbolically executing: {RV32I_instr}")

    # interpreter
    orig_state = MachineState().update(PC=BitVecVal(start_pc, 16))

    def place_instr(loc, instr, st) -> MachineState:
        # make sure PC fits into two registers
        assert loc & 0xffff == loc
        msb, lsb = BitVecVal(loc >> 8, 8), BitVecVal(loc & 0xff, 8)
        st = st.update(R=st.R.update(10, lsb).update(11, msb))
        instr_parts = [
            BVExtract(instr, *jj)
            for jj in ((jj * 8, jj * 8 + 7) for jj in range(4))
        ]
        if isinstance(loc, int):
            instr_locs = [loc + ii for ii in range(4)]
        else:
            assert False, "TODO: support symbolic address"
        mem = st.MEM
        for loc, val in zip(instr_locs, instr_parts):
            mem = mem.update(loc, val)
        return st.update(MEM=mem)

    orig_state = place_instr(loc=0, instr=RV32I_instr, st=orig_state)

    mf8_ex = SymExec()
    ex = SymbolicExecutionEngine(program=program,
                                 start_state=orig_state,
                                 semantics=mf8_ex)

    print()
    print()
    print("SYM EXEC")
    print("--------")
    done, end_state = ex.run(max_steps=2000)
    #ex.print_state()
    #ex.print_mem(ex.st)
    #ex.print_path()
    print(ex.taken)
    print(f"DONE? {done}")
    #print("PATHS:")
    for ii, (cond, st) in enumerate(end_state):
        print(str(ii) + ") " + cond.serialize())
        #ex.print_mem(st)

    solver = Solver(name="z3", logic=QF_AUFBV)

    # check for completeness
    conds = reduce(Or, (cond for cond, st in end_state))
    complete = not solver.is_sat(Not(conds))
    print(f"Complete? {complete}")

    # check result of every path:
    def to_mem_addrs(reg_index):
        return reversed([0xf100 + reg_index * 8 + jj for jj in range(4)])

    def relate_regs(mem, regs):
        def relate_loc(ii):
            mem_locs = [
                Select(mem, BitVecVal(addr, 16)) for addr in to_mem_addrs(ii)
            ]
            return Equals(cat(*mem_locs), Select(regs, BitVecVal(ii, 5)))

        return reduce(And, [relate_loc(ii) for ii in range(32)])

    def name_value(solver, name, val):
        sym = Symbol(name, val.get_type())
        solver.add_assertion(Equals(sym, val))

    def locs_to_str(name, array, locs):
        return "; ".join(f"{name}[{ii:04x}] = 0x{array[ii]:02x}"
                         for ii in sorted(list(set(locs))))

    for ii, (cond, end_st) in enumerate(end_state):
        # create clean slate solver
        solver = Solver(name="cvc4", logic=QF_AUFBV, generate_models=True)
        # symbolically execute the RISC-V add
        regs = Symbol('RV32I_REGS', ArrayType(BVType(5), BVType(32)))
        regs_n = sym_exec_rsicv_add(rs1=rs1, rs2=rs2, rd=rd, regs=regs)
        name_value(solver, "DBG_RV32I_REGS_N", regs_n)
        # add mem to regs relation
        mem_orig = orig_state.MEM.array()
        pre = And(And(cond, relate_regs(mem_orig, regs)),
                  Equals(Select(regs, BitVecVal(0, 5)), BitVecVal(0, 32)))
        mem_n = end_st.MEM.array()
        post = relate_regs(mem_n, regs_n)
        # DEBUG: add symbols for every memory write
        mem_data = end_st._mem._data
        mem_write_locs = [
            Symbol(f"DBG_MF8_MEM_WRITE_LOC_{ii}", BVType(16))
            for ii in range(len(mem_data))
        ]
        for sym, (expr, _) in zip(mem_write_locs, mem_data):
            solver.add_assertion(Equals(sym, expr))
        # now check for validity
        formula = Implies(pre, post)
        write_smtlib(Not(formula), f"path_{ii:02}.smt2")
        correct = solver.is_valid(formula)
        print(f"Correct? {correct}")
        if not correct:
            print("Path condition:")
            print(cond.serialize())
            print("Symbolic Mem:")
            ex.print_mem(end_st)
            print("Model:")
            rs1_val = solver.get_value(rs1).bv_unsigned_value()
            rs2_val = solver.get_value(rs2).bv_unsigned_value()
            rd_val = solver.get_value(rd).bv_unsigned_value()
            regs_val = ArrayValue(solver.get_value(regs))
            regs_n_val = ArrayValue(solver.get_value(regs_n))
            mem_val = ArrayValue(solver.get_value(mem_orig))
            mem_n_val = ArrayValue(solver.get_value(mem_n))
            reg_addrs = [rd_val, rs1_val, rs2_val]
            mem_write_locs_vals = [
                solver.get_value(ll).bv_unsigned_value()
                for ll in mem_write_locs
            ]
            mem_addrs = reduce(operator.add,
                               [list(to_mem_addrs(ii))
                                for ii in reg_addrs]) + mem_write_locs_vals
            print(f"R[{rd_val}] <- R[{rs1_val}] + R[{rs2_val}]")
            print(f"Pre:  {locs_to_str('R', regs_val, reg_addrs)}")
            print(f"      {locs_to_str('M',  mem_val, mem_addrs)}")
            print(f"Post: {locs_to_str('R', regs_n_val, reg_addrs)}")
            print(f"      {locs_to_str('M',  mem_n_val, mem_addrs)}")
            print(
                f"MEM write addresses: {[f'0x{loc:04x}' for loc in mem_write_locs_vals]}"
            )
            #print(regs_n_val)
            #print(mem_val)
            # TODO: check PC post-condition
            # TODO: add pre and post conditions for program memory equivalence
            # TODO: add pre and post conditions for data memory equivalence
            break

    return
class TestCounterEnc(unittest.TestCase):

    def setUp(self):
        self.enc = CounterEnc(get_env(), False)
        self.solver = Solver(logic=pysmt.logics.BOOL)

    def _is_eq(self, a, b):
        f = Iff(a, b)
        self.assertTrue(self.solver.is_valid(f))

    def test_0(self):
        var_name = "counter_0"
        self.enc.add_var(var_name, 0)

        b0 = self.enc._get_bitvar(var_name,0)

        e = self.enc.eq_val(var_name, 0)
        self._is_eq(e, Not(b0))

        with self.assertRaises(AssertionError):
            e = self.enc.eq_val(var_name, 1)

        mask = self.enc.get_mask(var_name)
        self._is_eq(mask, Not(b0))

    def test_1(self):
        var_name = "counter_1"
        self.enc.add_var(var_name, 1)

        b0 = self.enc._get_bitvar(var_name,0)

        e = self.enc.eq_val(var_name, 0)
        self._is_eq(e, Not(b0))
        e = self.enc.eq_val(var_name, 1)
        self._is_eq(e, b0)

        with self.assertRaises(AssertionError):
            e = self.enc.eq_val(var_name, 2)

        mask = self.enc.get_mask(var_name)
        self._is_eq(mask, TRUE())


    def test_2(self):
        # need 2 bits
        var_name = "counter_2"
        self.enc.add_var(var_name, 2)

        b0 = self.enc._get_bitvar(var_name,0)
        b1 = self.enc._get_bitvar(var_name,1)

        e = self.enc.eq_val(var_name, 0)
        self._is_eq(e, And(Not(b0), Not(b1)))
        e = self.enc.eq_val(var_name, 1)
        self._is_eq(e, And(b0, Not(b1)))
        e = self.enc.eq_val(var_name, 2)
        self._is_eq(e, And(Not(b0), b1))

        with self.assertRaises(AssertionError):
            # out of the counter bound
            e = self.enc.eq_val(var_name, 3)

        mask = self.enc.get_mask(var_name)
        self._is_eq(mask, Not(And(b0, b1)))

    def test_3(self):
        # need 2 bits
        var_name = "counter_3"
        self.enc.add_var(var_name, 3)

        b0 = self.enc._get_bitvar(var_name,0)
        b1 = self.enc._get_bitvar(var_name,1)

        e = self.enc.eq_val(var_name, 0)
        self._is_eq(e, And(Not(b0), Not(b1)))
        e = self.enc.eq_val(var_name, 1)
        self._is_eq(e, And(b0, Not(b1)))
        e = self.enc.eq_val(var_name, 2)
        self._is_eq(e, And(Not(b0), b1))
        e = self.enc.eq_val(var_name, 3)
        self._is_eq(e, And(b0, b1))

        mask = self.enc.get_mask(var_name)
        self._is_eq(mask, TRUE())


    def test_4(self):
        # need 3 bits
        var_name = "counter_4"
        self.enc.add_var(var_name, 4)

        b0 = self.enc._get_bitvar(var_name,0)
        b1 = self.enc._get_bitvar(var_name,1)
        b2 = self.enc._get_bitvar(var_name,2)

        e = self.enc.eq_val(var_name, 0)
        self._is_eq(e, And([Not(b0), Not(b1), Not(b2)]))
        e = self.enc.eq_val(var_name, 1)
        self._is_eq(e, And([b0, Not(b1), Not(b2)]))

        e = self.enc.eq_val(var_name, 4)
        self._is_eq(e, And([Not(b0), Not(b1), b2]))

        with self.assertRaises(AssertionError):
            e = self.enc.eq_val(var_name, 5)

        mask = self.enc.get_mask(var_name)
        models = Or([And([b0, Not(b1), b2]),
                     And([Not(b0), b1, b2]),
                     And([b0, b1, b2])])
        self._is_eq(mask, Not(models))


    def test_value(self):
        def eq_value(self, var_name, value):
            eq_val = self.enc.eq_val(var_name, value)
            self.solver.is_sat(eq_val)
            model = self.solver.get_model()
            res = self.enc.get_counter_value(var_name, model, False)
            self.assertTrue(res == value)

        var_name = "counter_4"

        self.enc.add_var(var_name, 4)
        eq_value(self, var_name, 0)
        eq_value(self, var_name, 1)
        eq_value(self, var_name, 2)
        eq_value(self, var_name, 3)
        eq_value(self, var_name, 4)