Ejemplo n.º 1
0
    def set_impl(self, q : Query, rep : [(EVar, Exp)], ret : Exp):
        with task("updating implementation", query=q.name):
            with task("finding duplicated state vars"):
                to_remove = set()
                for (v, e) in rep:
                    aeq = find_one(vv for (vv, ee) in self.concrete_state if e.type == ee.type and self.state_solver.valid(EImplies(EAll(self.spec.assumptions), EEq(e, ee))))
                    # aeq = find_one(vv for (vv, ee) in self.concrete_state if e.type == ee.type and alpha_equivalent(e, ee))
                    if aeq is not None:
                        event("state var {} is equivalent to {}".format(v.id, aeq.id))
                        ret = subst(ret, { v.id : aeq })
                        to_remove.add(v)
                rep = [ x for x in rep if x[0] not in to_remove ]

            self.concrete_state.extend(rep)
            self.query_impls[q.name] = rewrite_ret(q, lambda prev: ret, keep_assumptions=False)

            for op in self.op_specs:
                with task("incrementalizing query", query=q.name, op=op.name):
                    for new_member, projection in rep:
                        subqueries = []
                        state_update_stm = inc.mutate_in_place(
                            new_member,
                            projection,
                            op.body,
                            abstract_state=self.abstract_state,
                            assumptions=op.assumptions,
                            subgoals_out=subqueries)
                        for sub_q in subqueries:
                            sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring)
                            state_update_stm = self._add_subquery(sub_q=sub_q, used_by=state_update_stm)
                        self.updates[(new_member, op.name)] = state_update_stm
Ejemplo n.º 2
0
    def _add_subquery(self, sub_q : Query, used_by : Stm) -> Stm:
        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions_for_method(
                    reachable_handles_at_method(self.spec, sub_q),
                    sub_q)
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(itertools.chain(sub_q.assumptions, new_a))

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args]
                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)
                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by
Ejemplo n.º 3
0
    def set_impl(self, q: Query, rep: [(EVar, Exp)], ret: Exp):
        """Update the implementation of a query.

        The query having the same name as `q` will have its implementation
        replaced by the given concrete representation and computation.

        This call may add additional "subqueries" to the implementation to
        maintain the new representation when each update operation is called.
        """
        with task("updating implementation", query=q.name):
            with task("finding duplicated state vars"):
                to_remove = set()
                for (v, e) in rep:
                    aeq = find_one(vv
                                   for (vv,
                                        ee) in self._concretization_functions
                                   if e.type == ee.type
                                   and self.state_solver.valid(EEq(e, ee)))
                    # aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and alpha_equivalent(e, ee))
                    if aeq is not None:
                        event("state var {} is equivalent to {}".format(
                            v.id, aeq.id))
                        ret = subst(ret, {v.id: aeq})
                        to_remove.add(v)
                rep = [x for x in rep if x[0] not in to_remove]

            self._concretization_functions.extend(rep)
            self.query_impls[q.name] = rewrite_ret(q,
                                                   lambda prev: ret,
                                                   keep_assumptions=False)

            for op in self.op_specs:
                with task("incrementalizing query", query=q.name, op=op.name):
                    for new_member, projection in rep:
                        subqueries = []
                        state_update_stm = inc.mutate_in_place(
                            new_member,
                            projection,
                            op.body,
                            abstract_state=self.abstract_state,
                            assumptions=op.assumptions,
                            invariants=self.abstract_invariants,
                            subgoals_out=subqueries)
                        for sub_q in subqueries:
                            sub_q.docstring = "[{}] {}".format(
                                op.name, sub_q.docstring)
                            state_update_stm = self._add_subquery(
                                sub_q=sub_q, used_by=state_update_stm)
                        self.updates[(new_member, op.name)] = state_update_stm
Ejemplo n.º 4
0
Archivo: impls.py Proyecto: uwplse/cozy
    def set_impl(self, q : Query, rep : [(EVar, Exp)], ret : Exp):
        """Update the implementation of a query.

        The query having the same name as `q` will have its implementation
        replaced by the given concrete representation and computation.

        This call may add additional "subqueries" to the implementation to
        maintain the new representation when each update operation is called.
        """
        with task("updating implementation", query=q.name):
            with task("finding duplicated state vars"):
                to_remove = set()
                for (v, e) in rep:
                    aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and self.state_solver.valid(EEq(e, ee)))
                    # aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and alpha_equivalent(e, ee))
                    if aeq is not None:
                        event("state var {} is equivalent to {}".format(v.id, aeq.id))
                        ret = subst(ret, { v.id : aeq })
                        to_remove.add(v)
                rep = [ x for x in rep if x[0] not in to_remove ]

            self._concretization_functions.extend(rep)
            self.query_impls[q.name] = rewrite_ret(q, lambda prev: ret, keep_assumptions=False)

            for op in self.op_specs:
                with task("incrementalizing query", query=q.name, op=op.name):
                    for new_member, projection in rep:
                        subqueries = []
                        state_update_stm = inc.mutate_in_place(
                            new_member,
                            projection,
                            op.body,
                            abstract_state=self.abstract_state,
                            assumptions=op.assumptions,
                            invariants=self.abstract_invariants,
                            subgoals_out=subqueries)
                        for sub_q in subqueries:
                            sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring)
                            state_update_stm = self._add_subquery(sub_q=sub_q, used_by=state_update_stm)
                        self.updates[(new_member, op.name)] = state_update_stm
Ejemplo n.º 5
0
def simplify_sum(e):
    parts = list(break_sum(e))
    t, f = partition(parts, lambda p: p[0])
    t = [x[1] for x in t]
    f = [x[1] for x in f]
    parts = []
    for x in t:
        opp = find_one(
            f,
            lambda y: alpha_equivalent(strip_EStateVar(x), strip_EStateVar(y)))
        if opp:
            f.remove(opp)
        else:
            parts.append(x)
    parts.extend(EUnaryOp("-", x).with_type(INT) for x in f)

    if not parts:
        return ZERO
    res = parts[0]
    for i in range(1, len(parts)):
        res = EBinOp(res, "+", parts[i]).with_type(INT)
    return res
Ejemplo n.º 6
0
    def set_impl(self, q: Query, rep: [(EVar, Exp)], ret: Exp):
        to_remove = set()
        from cozy.solver import valid
        for (v, e) in rep:
            aeq = find_one(
                vv for (vv, ee) in self.concrete_state if e.type == ee.type
                and valid(EImplies(EAll(self.spec.assumptions), EEq(e, ee))))
            # aeq = find_one(vv for (vv, ee) in self.concrete_state if e.type == ee.type and alpha_equivalent(e, ee))
            if aeq is not None:
                print("########### state var {} is equivalent to {}".format(
                    v.id, aeq.id))
                ret = subst(ret, {v.id: aeq})
                to_remove.add(v)
        rep = [x for x in rep if x[0] not in to_remove]

        self.concrete_state.extend(rep)
        self.query_impls[q.name] = rewrite_ret(q,
                                               lambda prev: ret,
                                               keep_assumptions=False)
        op_deltas = {
            op.name: inc.delta_form(self.spec.statevars, op)
            for op in self.op_specs
        }

        for op in self.op_specs:
            # print("###### INCREMENTALIZING: {}".format(op.name))
            delta = op_deltas[op.name]
            for new_member, projection in rep:
                (state_update_stm,
                 subqueries) = inc.sketch_update(new_member, projection,
                                                 subst(projection, delta),
                                                 self.abstract_state,
                                                 list(op.assumptions))
                for sub_q in subqueries:
                    sub_q.docstring = "[{}] {}".format(op.name,
                                                       sub_q.docstring)
                    state_update_stm = self._add_subquery(
                        sub_q=sub_q, used_by=state_update_stm)
                self.updates[(new_member, op.name)] = state_update_stm
Ejemplo n.º 7
0
    def _add_subquery(self, sub_q: Query, used_by: Stm) -> Stm:
        print("Adding new query {}...".format(sub_q.name))
        # orig_ret = sub_q.ret
        # print("rewritng ret for {}".format(pprint(orig_ret)))
        sub_q = shallow_copy(sub_q)
        sub_q.assumptions += tuple(
            implicit_handle_assumptions_for_method(
                reachable_handles_at_method(self.spec, sub_q), sub_q))
        sub_q.ret = simplify(sub_q.ret)
        # sub_q = rewrite_ret(sub_q, simplify)
        # if sub_q.ret != orig_ret:
        #     print("rewrote ret")
        #     print(" --> {}".format(pprint(sub_q.ret)))

        qq = find_one(
            self.query_specs,
            lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q))
        if qq is not None:
            print("########### subgoal {} is equivalent to {}".format(
                sub_q.name, qq.name))
            arg_reorder = [[x[0] for x in sub_q.args].index(a)
                           for (a, t) in qq.args]

            class Repl(BottomUpRewriter):
                def visit_ECall(self, e):
                    args = tuple(self.visit(a) for a in e.args)
                    if e.func == sub_q.name:
                        args = tuple(args[idx] for idx in arg_reorder)
                        return ECall(qq.name, args).with_type(e.type)
                    else:
                        return ECall(e.func, args).with_type(e.type)

            used_by = Repl().visit(used_by)
        else:
            self.add_query(sub_q)
        return used_by
Ejemplo n.º 8
0
def find_one_or_fail(iter):
    res = find_one(iter)
    if res is None:
        raise ValueError()
    return res
Ejemplo n.º 9
0
    def build(self, cache, size):

        for e in cache.find(pool=RUNTIME_POOL, size=size - 1, type=INT):
            if not is_root(e):
                continue
            e2 = simplify_sum(e)
            if e != e2:
                yield self.check(e2, RUNTIME_POOL)

        # for e in cache.find(pool=RUNTIME_POOL, size=size-1):
        #     if isinstance(e, EMapGet) and isinstance(e.map, EMakeMap2):
        #         x = e.map.value.apply_to(e.key)
        #         x._tag = True
        #         yield self.check(x, RUNTIME_POOL)

        # [x] - ys
        for e in cache.find_collections(pool=RUNTIME_POOL, size=size - 1):
            if not is_root(e):
                continue
            if isinstance(e, EBinOp) and e.op == "-" and isinstance(
                    e.e1, ESingleton):
                x = e.e1.e
                y = e.e2
                x = ECond(
                    EBinOp(x, BOp.In, y).with_type(BOOL),
                    EEmptyList().with_type(e.type), e.e1).with_type(e.type)
                yield self.check(x, RUNTIME_POOL)
            elif isinstance(e, EUnaryOp) and e.op == UOp.Distinct:
                e = strip_EStateVar(e)
                m = EMakeMap2(e.e, mk_lambda(e.type.t, lambda x: T)).with_type(
                    TMap(e.type.t, BOOL))
                yield self.check(m, STATE_POOL)
                m = EStateVar(m).with_type(m.type)
                yield self.check(m, RUNTIME_POOL)
                x = EMapKeys(m).with_type(e.type)
                # x._tag = True
                yield self.check(x, RUNTIME_POOL)

        # # x in ys ----> (count x in ys) > 0
        # for e in cache.find(pool=RUNTIME_POOL, type=BOOL, size=size-1):
        #     if isinstance(e, EBinOp) and e.op == BOp.In:
        #         for b in self.binders:
        #             if b.type != e.e1.type:
        #                 continue
        #             x = EGt(
        #                 EUnaryOp(UOp.Length, EFilter(e.e2, ELambda(b, EEq(e.e1, b))).with_type(e.e2.type)).with_type(INT),
        #                 ZERO)
        #             x._tag = True
        #             yield self.check(x, RUNTIME_POOL)

        for e in cache.find(pool=RUNTIME_POOL, size=size - 1):
            if not is_root(e):
                continue
            if (isinstance(e, EArgMin) or isinstance(
                    e, EArgMax)) and isinstance(e.e, EBinOp) and e.e.op == "+":
                l = e.e.e1
                r = e.e.e2
                op = e.e.op
                f = lambda x: type(e)(x, e.f).with_type(e.type)
                ll = EStateVar(f(l.e)).with_type(e.type) if isinstance(
                    l, EStateVar) else f(l)
                rr = EStateVar(f(r.e)).with_type(e.type) if isinstance(
                    r, EStateVar) else f(r)
                x = ECond(
                    EUnaryOp(UOp.Exists, l).with_type(BOOL),
                    ECond(
                        EUnaryOp(UOp.Exists, r).with_type(BOOL),
                        f(
                            EBinOp(
                                ESingleton(ll).with_type(e.e.type), op,
                                ESingleton(rr).with_type(e.e.type)).with_type(
                                    e.e.type)), ll).with_type(e.type),
                    rr).with_type(e.type)
                # from cozy.solver import valid
                # assert valid(EEq(e, x), model_callback=print)
                x._tag = True
                yield self.check(x, RUNTIME_POOL)

        # is-last(x, l)
        for (sz1, sz2) in pick_to_sum(2, size - 1):
            for e1 in cache.find(pool=RUNTIME_POOL, size=sz1):
                if not is_root(e1):
                    continue
                for e2 in cache.find_collections(pool=STATE_POOL,
                                                 size=sz2,
                                                 of=e1.type):
                    if not is_root(e2):
                        continue
                    for b in self.binders:
                        if b.type != e1.type:
                            continue
                        m = EMakeMap2(
                            e2,
                            mk_lambda(
                                e2.type.t, lambda x: EUnaryOp(
                                    UOp.Length,
                                    EFilter(
                                        e2,
                                        mk_lambda(e2.type.t, lambda y: EEq(
                                            x, y))).with_type(e2.type)).
                                with_type(INT))).with_type(TMap(
                                    e2.type.t, INT))
                        # filt = EFilter(e2, ELambda(b, EEq(e1, b))).with_type(e2.type)
                        # x = EEq(
                        #     EUnaryOp(UOp.Length, filt).with_type(INT),
                        #     ONE)
                        x = EGt(
                            EMapGet(EStateVar(m).with_type(m.type),
                                    e1).with_type(INT), ONE)
                        # x._tag = True
                        yield self.check(x, RUNTIME_POOL)

        # histogram
        # for e in cache.find_collections(pool=STATE_POOL, size=size-1):
        #     m = EMakeMap2(e,
        #         mk_lambda(e.type.t, lambda x:
        #             EUnaryOp(UOp.Length, EFilter(e,
        #                 mk_lambda(e.type.t, lambda y: EEq(x, y))).with_type(e.type)).with_type(INT))).with_type(TMap(e.type.t, INT))
        #     m._tag = True
        #     yield self.check(m, STATE_POOL)

        # Fixup EFilter(\x -> ECond...)
        for e in cache.find_collections(pool=RUNTIME_POOL, size=size - 1):
            if not is_root(e):
                continue
            if isinstance(e, EFilter):
                for (_, x, r, _) in enumerate_fragments(e.p.body):
                    if isinstance(x, ECond):
                        lhs = EFilter(
                            e.e,
                            ELambda(e.p.arg, EAll([x.cond,
                                                   r(x.then_branch)
                                                   ]))).with_type(e.type)
                        rhs = EFilter(
                            e.e,
                            ELambda(e.p.arg,
                                    EAll([ENot(x.cond),
                                          r(x.else_branch)
                                          ]))).with_type(e.type)
                        union = EBinOp(lhs, "+", rhs).with_type(e.type)
                        # yield self.check(lhs.p.body, RUNTIME_POOL)
                        # yield self.check(rhs.p.body, RUNTIME_POOL)
                        yield self.check(lhs, RUNTIME_POOL)
                        yield self.check(rhs, RUNTIME_POOL)
                        yield self.check(union, RUNTIME_POOL)

        # Try instantiating bound expressions
        for pool in (STATE_POOL, RUNTIME_POOL):
            for (sz1, sz2) in pick_to_sum(2, size - 1):
                for e1 in cache.find(pool=pool, size=sz1):
                    if not is_root(e1):
                        continue
                    for v in free_vars(e1):
                        if pool == RUNTIME_POOL:
                            e1 = subst(
                                strip_EStateVar(e1), {
                                    sv.id: EStateVar(sv).with_type(sv.type)
                                    for sv in self.state_vars if sv != v
                                })
                        for e2 in cache.find(pool=pool, type=v.type, size=sz2):
                            yield self.check(subst(e1, {v.id: e2}), pool)

        for (sz1, sz2) in pick_to_sum(2, size - 1):
            for e in cache.find(pool=RUNTIME_POOL, size=sz1):
                if not is_root(e):
                    continue
                for x, pool in map_accelerate(e, self.state_vars, self.binders,
                                              self.args, cache, sz2):
                    yield self.check(x, pool)
                if isinstance(e, EFilter) and not any(v in self.binders
                                                      for v in free_vars(e)):
                    for x, pool in accelerate_filter(e.e, e.p, self.state_vars,
                                                     self.binders, self.args,
                                                     cache, sz2):
                        yield self.check(x, pool)

        for bag in cache.find_collections(pool=RUNTIME_POOL, size=size - 1):
            if not is_root(bag):
                continue
            for a in self.args:
                for v in self.state_vars:
                    if is_collection(v.type) and v.type == a.type:
                        v = EStateVar(v).with_type(v.type)
                        cond = EBinOp(a, BOp.In, v).with_type(BOOL)
                        yield self.check(
                            EFilter(bag, mk_lambda(bag.type.t,
                                                   lambda _: cond)).with_type(
                                                       bag.type), RUNTIME_POOL)
                        yield self.check(
                            EFilter(
                                bag,
                                mk_lambda(bag.type.t,
                                          lambda _: ENot(cond))).with_type(
                                              bag.type), RUNTIME_POOL)

            if isinstance(bag, EFilter):
                if any(v not in self.state_vars for v in free_vars(bag.e)):
                    continue

                # separate filter conds
                if isinstance(bag.p.body, EBinOp) and bag.p.body.op == BOp.And:
                    p1 = ELambda(bag.p.arg, bag.p.body.e1)
                    p2 = ELambda(bag.p.arg, bag.p.body.e2)
                    f1 = EFilter(bag.e, p1).with_type(bag.type)
                    f2 = EFilter(bag.e, p2).with_type(bag.type)
                    f3 = EFilter(f1, p2).with_type(bag.type)
                    f4 = EFilter(f2, p1).with_type(bag.type)
                    yield self.check(f1, RUNTIME_POOL)
                    yield self.check(f2, RUNTIME_POOL)
                    yield self.check(f3, RUNTIME_POOL)
                    yield self.check(f4, RUNTIME_POOL)

                # construct map lookups
                binder = bag.p.arg
                inf = infer_map_lookup(bag.p.body, binder,
                                       set(self.state_vars))
                if inf:
                    key_proj, key_lookup, remaining_filter = inf
                    bag_binder = find_one(
                        self.binders,
                        lambda b: b.type == key_proj.type and b != binder)
                    if bag_binder:
                        m = strip_EStateVar(
                            EMakeMap2(
                                EMap(bag.e,
                                     ELambda(binder, key_proj)).with_type(
                                         type(bag.type)(key_proj.type)),
                                ELambda(
                                    bag_binder,
                                    EFilter(
                                        bag.e,
                                        ELambda(binder,
                                                EEq(key_proj,
                                                    bag_binder))).with_type(
                                                        bag.type))).with_type(
                                                            TMap(
                                                                key_proj.type,
                                                                bag.type)))
                        assert not any(v in self.args for v in free_vars(m))
                        yield self.check(m, STATE_POOL)
                        m = EStateVar(m).with_type(m.type)
                        mg = EMapGet(m, key_lookup).with_type(bag.type)
                        yield self.check(mg, RUNTIME_POOL)
                        yield self.check(
                            EFilter(mg, ELambda(
                                binder, remaining_filter)).with_type(mg.type),
                            RUNTIME_POOL)

        # for e in cache.find(size=size-1):
        #     # F(xs +/- ys) ---> F(xs), F(ys)
        #     for z in break_plus_minus(e):
        #         if z != e:
        #             # print("broke {} --> {}".format(pprint(e), pprint(z)))
        #             yield z

        #     # try reordering operations
        #     for (_, e1, f) in enumerate_fragments(e):
        #         if e1.type == e.type and e1 != e:
        #             for (_, e2, g) in enumerate_fragments(e1):
        #                 if e2.type == e.type and e2 != e1:
        #                     # e == f(g(e2))
        #                     yield g(f(e2))

        yield from self.wrapped.build(cache, size)
Ejemplo n.º 10
0
Archivo: impls.py Proyecto: uwplse/cozy
    def _add_subquery(self, sub_q : Query, used_by : Stm) -> Stm:
        """Add a query that helps maintain some other state.

        Parameters:
            sub_q - the specification of the helper query
            used_by - the statement that calls `sub_q`

        If a query already exists that is equivalent to `sub_q`, this method
        returns `used_by` rewritten to use the existing query and does not add
        the query to the implementation.  Otherwise it returns `used_by`
        unchanged.
        """

        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions(
                    reachable_handles_at_method(self.spec, sub_q))
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(itertools.chain(sub_q.assumptions, new_a))

            with task("repairing state var boundaries"):
                extra_available_state = [e for v, e in self._concretization_functions]
                sub_q.ret = repair_well_formedness(
                    strip_EStateVar(sub_q.ret),
                    self.context_for_method(sub_q),
                    extra_available_state)

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs, assumptions=EAll(self.abstract_invariants)))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args]
                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)
                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by
Ejemplo n.º 11
0
    def _add_subquery(self, sub_q: Query, used_by: Stm) -> Stm:
        """Add a query that helps maintain some other state.

        Parameters:
            sub_q - the specification of the helper query
            used_by - the statement that calls `sub_q`

        If a query already exists that is equivalent to `sub_q`, this method
        returns `used_by` rewritten to use the existing query and does not add
        the query to the implementation.  Otherwise it returns `used_by`
        unchanged.
        """

        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions(
                    reachable_handles_at_method(self.spec, sub_q))
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(
                        itertools.chain(sub_q.assumptions, new_a))

            with task("repairing state var boundaries"):
                extra_available_state = [
                    e for v, e in self._concretization_functions
                ]
                sub_q.ret = repair_well_formedness(
                    strip_EStateVar(sub_q.ret), self.context_for_method(sub_q),
                    extra_available_state)

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(
                    simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(
                self.query_specs, lambda qq: dedup_queries.value and
                queries_equivalent(qq,
                                   sub_q,
                                   state_vars=state_vars,
                                   extern_funcs=funcs,
                                   assumptions=EAll(self.abstract_invariants)))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(
                    sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a)
                               for (a, t) in qq.args]

                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)

                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by
Ejemplo n.º 12
0
    def next(self):
        target_cost = self.cost_model.cost(self.target, RUNTIME_POOL)
        self.ncount += 1
        while True:
            if self.backlog is not None:
                if self.stop_callback():
                    raise StopException()
                (e, pool, cost) = self.backlog
                improvements = list(self._possible_replacements(e, pool, cost))
                if self.backlog_counter < len(improvements):
                    i = improvements[self.backlog_counter]
                    self.backlog_counter += 1
                    return i
                else:
                    self.backlog = None
                    self.backlog_counter = 0
            for (e, pool) in self.builder_iter:
                self._on_exp(e, pool)
                if self.stop_callback():
                    raise StopException()

                # # Stopgap measure... long story --Calvin
                # bad = False
                # for x in all_exps(e):
                #     if isinstance(x, EStateVar):
                #         if any(v not in self.state_vars for v in free_vars(x.e)):
                #             bad = True
                #             _on_exp(e, "skipping due to illegal free vars under EStateVar")
                # if bad:
                #     continue

                new_e = self.pre_optimize(e, pool) if preopt.value else e
                if new_e is not e:
                    _on_exp(e, "preoptimized", new_e)
                    e = new_e

                cost = self.cost_model.cost(e, pool)

                if pool == RUNTIME_POOL and (self.cost_model.is_monotonic() or hyperaggressive_culling.value) and self.compare_costs(cost, target_cost) == Cost.WORSE:
                    _on_exp(e, "too expensive", cost, target_cost)
                    continue

                fp = self._fingerprint(e)
                prev = list(self.seen.find_all(pool, fp))
                should_add = True
                if not prev:
                    _on_exp(e, "new", pool_name(pool))
                elif any(alpha_equivalent(e, ee) for (ee, _, _) in prev):
                    _on_exp(e, "duplicate")
                    should_add = False
                else:
                    better_than = None
                    worse_than = None
                    for prev_exp, prev_size, prev_cost in prev:
                        self._on_cost_cmp()
                        ordering = self.compare_costs(cost, prev_cost)
                        assert ordering in (Cost.WORSE, Cost.BETTER, Cost.UNORDERED)
                        if enforce_strong_progress.value and ordering != Cost.WORSE:
                            bad = find_one(all_exps(e), lambda ee: alpha_equivalent(ee, prev_exp))
                            if bad:
                                _on_exp(e, "failed strong progress requirement", bad)
                                should_add = False
                                break
                        _on_exp(e, ordering, pool_name(pool), prev_exp)
                        if ordering == Cost.UNORDERED:
                            continue
                        elif ordering == Cost.BETTER:
                            better_than = (prev_exp, prev_size, prev_cost)
                            _on_exp(prev_exp, "found better alternative", e)
                            self.cache.evict(prev_exp, size=prev_size, pool=pool)
                            self.seen.remove(prev_exp, pool, fp)
                            if (self.cost_model.is_monotonic() or hyperaggressive_culling.value) and hyperaggressive_eviction.value:
                                for (cached_e, size, p) in list(self.cache):
                                    if p != pool:
                                        continue
                                    if prev_exp in all_exps(cached_e):
                                        _on_exp(cached_e, "evicted since it contains", prev_exp)
                                        self.cache.evict(cached_e, size=size, pool=pool)
                        else:
                            should_add = False
                            worse_than = (prev_exp, prev_size, prev_cost)
                            # break
                    if worse_than and better_than:
                        print("Uh-oh! Strange cost relationship between")
                        print("  (1) this exp: {}".format(pprint(e)))
                        print("  (2) prev. A:  {}".format(pprint(worse_than[0])))
                        print("  (2) prev. B:  {}".format(pprint(better_than[0])))
                        print("e1 = {}".format(repr(e)))
                        print("e2 = {}".format(repr(worse_than[0])))
                        print("e3 = {}".format(repr(better_than[0])))
                        print("(1) vs (2): {}".format(cost.compare_to(worse_than[2], self.assumptions)))
                        print("(2) vs (3): {}".format(worse_than[2].compare_to(better_than[2], self.assumptions)))
                        print("(3) vs (1): {}".format(better_than[2].compare_to(cost, self.assumptions)))
                        # raise Exception("insane cost model behavior")

                if should_add:
                    self.cache.add(e, pool=pool, size=self.current_size)
                    self.seen.add(e, pool, fp, self.current_size, cost)
                    self.last_progress = self.current_size
                else:
                    continue

                for pr in self._possible_replacements(e, pool, cost):
                    self.backlog = (e, pool, cost)
                    self.backlog_counter = 1
                    return pr

            if self.last_progress < (self.current_size+1) // 2:
                raise NoMoreImprovements("hit termination condition")

            self.current_size += 1
            self.builder_iter = self.builder.build(self.cache, self.current_size)
            if self.current_size == 0:
                self.builder_iter = itertools.chain(self.builder_iter, list(self.roots))
            for f, ct in sorted(_fates.items(), key=lambda x: x[1], reverse=True):
                print("  {:6} | {}".format(ct, f))
            _fates.clear()
            self._start_minor_it()
Ejemplo n.º 13
0
def find_one_or_fail(iter):
    res = find_one(iter)
    if res is None:
        raise ValueError()
    return res
Ejemplo n.º 14
0
 def contains(self, e, pool):
     return find_one(self.find(pool=pool, type=e.type),
                     lambda x: alpha_equivalent(x, e)) is not None