Beispiel #1
0
    def gen(self, deg, traces, inps):
        assert deg >= 1, deg
        assert isinstance(traces, DTraces) and traces, traces
        assert isinstance(inps, Inps), inps

        locs = traces.keys()
        # first obtain enough traces
        tasks = [(loc,
                  self._get_init_traces(loc, deg, traces, inps,
                                        settings.EQT_RATE)) for loc in locs]
        tasks = [(loc, tcs) for loc, tcs in tasks if tcs]

        # then solve/prove in parallel
        def f(tasks):
            return [(loc, self._infer(loc, template, uks, exprs, traces, inps))
                    for loc, (template, uks, exprs) in tasks]

        wrs = Miscs.run_mp('find eqts', tasks, f)

        # put results together
        dinvs = DInvs()
        for loc, (eqts, cexs) in wrs:
            new_inps = inps.merge(cexs, self.inp_decls.names)
            mlog.debug("{}: got {} eqts, {} new inps".format(
                loc, len(eqts), len(new_inps)))
            if eqts:
                mlog.debug('\n'.join(map(str, eqts)))
            dinvs[loc] = Invs(eqts)

        return dinvs
Beispiel #2
0
    def gen(self, dinvs, traces):
        assert isinstance(dinvs, DInvs), dinvs
        assert isinstance(traces, DTraces), traces

        dinvs_ = DInvs()
        post_locs = [loc for loc in dinvs if settings.POST_LOC in loc]
        for loc in post_locs:

            postconds = [
                self.get_postconds(inv.inv) for inv in dinvs[loc]
                if isinstance(inv, Eqt)
            ]
            postconds = [pcs for pcs in postconds if pcs]
            postconds = set(p for pcs in postconds for p in pcs)

            preposts = []
            for postcond in postconds:
                prepost = self.get_preposts1(loc, postcond)
                if prepost:
                    preposts.append(prepost)

            if preposts:
                dinvs_[loc] = Invs(preposts)
            # preposts = self.get_preposts(loc, postconds, traces[loc])
            # if preposts:
            #     dinvs_[loc] = Invs(preposts)

        return dinvs_
Beispiel #3
0
    def infer_from_traces(self,
                          itraces,
                          traceid,
                          inps=None,
                          maxdeg=1,
                          simpl=False):
        r = None
        old_do_simplify = dig_settings.DO_SIMPLIFY
        dig_settings.DO_SIMPLIFY = simpl

        try:
            train_inps, test_inps = self._split(inps)
            train_dtraces = self.get_traces_by_id(itraces, traceid, train_inps)
            test_dtraces = self.get_traces_by_id(itraces, traceid, test_inps)

            import alg as dig_alg
            dig = dig_alg.DigTraces.from_dtraces(self.inv_decls, train_dtraces,
                                                 test_dtraces)
            invs, traces = dig.start(self.seed, maxdeg)
            mlog.debug(
                "invs: {}".format(invs))  # <class 'data.inv.invs.DInvs'>
            if traceid in invs:
                r = invs[traceid]
            else:
                r = Invs()
        except Exception as e:
            mlog.debug("exception: {}".format(e))
            pass
        finally:
            dig_settings.DO_SIMPLIFY = old_do_simplify
        return r
Beispiel #4
0
    def _infer(self, loc, template, uks, exprs, dtraces, inps):
        assert isinstance(loc, str) and loc, loc
        assert Miscs.is_expr(template), template
        assert isinstance(uks, list), uks
        assert isinstance(exprs, set) and exprs, exprs
        assert isinstance(dtraces, DTraces) and dtraces, dtraces
        assert isinstance(inps, Inps) and inps, inps

        cache = set()
        eqts = set()  # results
        exprs = list(exprs)

        new_cexs = []
        curIter = 0

        while True:
            curIter += 1
            mlog.debug("{}, iter {} infer using {} exprs".format(
                loc, curIter, len(exprs)))

            new_eqts = Miscs.solve_eqts(exprs, uks, template)
            unchecks = [eqt for eqt in new_eqts if eqt not in cache]

            if not unchecks:
                mlog.debug("{}: no new results -- break".format(loc))
                break

            mlog.debug('{}: {} candidates:\n{}'.format(
                loc, len(new_eqts), '\n'.join(map(str, new_eqts))))

            mlog.debug("{}: check {} unchecked ({} candidates)".format(
                loc, len(unchecks), len(new_eqts)))

            dinvs = DInvs.mk(loc, Invs(list(map(data.inv.eqt.Eqt, unchecks))))
            cexs, dinvs = self.check(dinvs, None)
            if cexs:
                new_cexs.append(cexs)

            [eqts.add(inv) for inv in dinvs[loc] if not inv.is_disproved]
            [cache.add(inv.inv) for inv in dinvs[loc] if inv.stat is not None]

            if loc not in cexs:
                mlog.debug("{}: no disproved candidates -- break".format(loc))
                break

            cexs = Traces.extract(cexs[loc])
            cexs = cexs.padzeros(set(self.inv_decls[loc].names))
            exprs_ = cexs.instantiate(template, None)
            mlog.debug("{}: {} new cex exprs".format(loc, len(exprs_)))
            exprs.extend(exprs_)

        return eqts, new_cexs
Beispiel #5
0
    def get_disj_preconds(self, loc, preconds, postcond_expr, traces):
        assert all(isinstance(p, Inv) for p in preconds), preconds
        assert z3.is_expr(postcond_expr), postcond_expr

        preconds_ = []
        for pc in preconds:
            if self.check(pc.expr(self.use_reals), postcond_expr, loc):
                #print("hello: {} => {}".format(pc, postcond_expr))
                preconds_.append(pc)

        if len(preconds_) >= 2:
            is_conj = False
            preconds_ = Invs._simplify(preconds_, is_conj, self.use_reals)

        return preconds_
Beispiel #6
0
    def get_preposts1(self, loc, postcond):
        assert isinstance(loc, str), loc
        assert postcond.operator() == operator.eq, postcond

        import infer.opt
        solver = infer.opt.Ieq(self.symstates, self.prog)
        postcond_expr = Eqt(postcond).expr(self.use_reals)
        preconds = solver.gen([loc], postcond_expr)
        preconds = list(preconds[loc]) if loc in preconds else []
        #conj_preconds = self.get_conj_preconds(loc, preconds, postcond)
        if preconds:
            precond_expr = z3.And([pc.expr(self.use_reals) for pc in preconds])
            inv = z3.Implies(precond_expr, postcond_expr)
            cexs, isSucc = self.symstates.mcheck_d(loc, inv, None, 1)
            if not cexs and isSucc:
                prepost = PrePost(Invs(preconds), postcond, stat=Inv.PROVED)
                prepost.is_conj = True
                return prepost
            else:
                return None
        return None
Beispiel #7
0
    def get_preposts(self, loc, postconds, traces):
        assert isinstance(loc, str), loc
        assert isinstance(postconds, set) and postconds, postconds
        assert all(p.operator() == operator.eq for p in postconds), postconds
        assert isinstance(traces, Traces), traces

        preconds = [pc for pc in self.preconds]
        # preconds = [pc for pc in self.preconds
        #             if self._check(pc.expr(self.use_reals), loc, check_consistency=True)]
        #print("preconds", preconds)
        postconds = sorted(postconds, key=lambda d: len(str(d)))
        postconds = [Eqt(p) for p in postconds]

        # find traces satifies each postcond
        ptraces = {
            p: Traces([t for t in traces if p.test_single_trace(t)])
            for p in postconds
        }

        preposts = []  # results

        def myappend(mypreconds, is_conj):
            # TODO: check, stat=Inv.PROVED ?
            prepost = PrePost(Invs(mypreconds), postcond, stat=Inv.PROVED)
            prepost.is_conj = is_conj
            preposts.append(prepost)

        postconds = sorted(postconds,
                           key=lambda d: len(ptraces[d]),
                           reverse=True)
        idxs = list(range(len(postconds)))
        for idx in idxs:
            print('gh1')
            postcond = postconds[idx]
            try:
                postcond_expr = postcond.expr(self.use_reals)
            except NotImplementedError as ex:
                # cannot parse something like sqrt(x)
                continue

            #print("postcond", postcond)
            print('gh1a')
            others = [postconds[i] for i in idxs[:idx] + idxs[idx + 1:]]
            traces_ = [
                t for t in ptraces[postcond]
                if all(t not in ptraces[other] for other in others)
            ]
            traces_ = Traces(traces_)

            conj_preconds = [pc for pc in preconds if pc.test(traces_)]
            #print(conj_preconds, conj_preconds)

            conj_preconds = self.get_conj_preconds(loc, conj_preconds,
                                                   postcond_expr)
            #print('cpreconds', conj_preconds)
            if conj_preconds:
                myappend(conj_preconds, is_conj=True)
            print('gh1b')
            disj_preconds = self.get_disj_preconds(loc, preconds,
                                                   postcond_expr, traces)
            print('gh1b@@@')
            #print('dpreconds', disj_preconds)
            if disj_preconds:
                myappend(disj_preconds, is_conj=False)
            print('gh1c')

        print('gh2')
        preposts = Invs(preposts)
        print('gh3')
        print(preposts)
        #preposts = preposts.simplify(self.use_reals)
        return preposts
Beispiel #8
0
 def myappend(mypreconds, is_conj):
     # TODO: check, stat=Inv.PROVED ?
     prepost = PrePost(Invs(mypreconds), postcond, stat=Inv.PROVED)
     prepost.is_conj = is_conj
     preposts.append(prepost)