예제 #1
0
    def set_impl(self, q : Query, rep : [(EVar, Exp)], ret : Exp):
        with task("updating implementation", query=q.name):
            with task("finding duplicated state vars"):
                to_remove = set()
                for (v, e) in rep:
                    aeq = find_one(vv for (vv, ee) in self.concrete_state if e.type == ee.type and self.state_solver.valid(EImplies(EAll(self.spec.assumptions), EEq(e, ee))))
                    # aeq = find_one(vv for (vv, ee) in self.concrete_state if e.type == ee.type and alpha_equivalent(e, ee))
                    if aeq is not None:
                        event("state var {} is equivalent to {}".format(v.id, aeq.id))
                        ret = subst(ret, { v.id : aeq })
                        to_remove.add(v)
                rep = [ x for x in rep if x[0] not in to_remove ]

            self.concrete_state.extend(rep)
            self.query_impls[q.name] = rewrite_ret(q, lambda prev: ret, keep_assumptions=False)

            for op in self.op_specs:
                with task("incrementalizing query", query=q.name, op=op.name):
                    for new_member, projection in rep:
                        subqueries = []
                        state_update_stm = inc.mutate_in_place(
                            new_member,
                            projection,
                            op.body,
                            abstract_state=self.abstract_state,
                            assumptions=op.assumptions,
                            subgoals_out=subqueries)
                        for sub_q in subqueries:
                            sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring)
                            state_update_stm = self._add_subquery(sub_q=sub_q, used_by=state_update_stm)
                        self.updates[(new_member, op.name)] = state_update_stm
예제 #2
0
    def _add_subquery(self, sub_q : Query, used_by : Stm) -> Stm:
        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions_for_method(
                    reachable_handles_at_method(self.spec, sub_q),
                    sub_q)
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(itertools.chain(sub_q.assumptions, new_a))

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args]
                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)
                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by
예제 #3
0
    def set_impl(self, q: Query, rep: [(EVar, Exp)], ret: Exp):
        """Update the implementation of a query.

        The query having the same name as `q` will have its implementation
        replaced by the given concrete representation and computation.

        This call may add additional "subqueries" to the implementation to
        maintain the new representation when each update operation is called.
        """
        with task("updating implementation", query=q.name):
            with task("finding duplicated state vars"):
                to_remove = set()
                for (v, e) in rep:
                    aeq = find_one(vv
                                   for (vv,
                                        ee) in self._concretization_functions
                                   if e.type == ee.type
                                   and self.state_solver.valid(EEq(e, ee)))
                    # aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and alpha_equivalent(e, ee))
                    if aeq is not None:
                        event("state var {} is equivalent to {}".format(
                            v.id, aeq.id))
                        ret = subst(ret, {v.id: aeq})
                        to_remove.add(v)
                rep = [x for x in rep if x[0] not in to_remove]

            self._concretization_functions.extend(rep)
            self.query_impls[q.name] = rewrite_ret(q,
                                                   lambda prev: ret,
                                                   keep_assumptions=False)

            for op in self.op_specs:
                with task("incrementalizing query", query=q.name, op=op.name):
                    for new_member, projection in rep:
                        subqueries = []
                        state_update_stm = inc.mutate_in_place(
                            new_member,
                            projection,
                            op.body,
                            abstract_state=self.abstract_state,
                            assumptions=op.assumptions,
                            invariants=self.abstract_invariants,
                            subgoals_out=subqueries)
                        for sub_q in subqueries:
                            sub_q.docstring = "[{}] {}".format(
                                op.name, sub_q.docstring)
                            state_update_stm = self._add_subquery(
                                sub_q=sub_q, used_by=state_update_stm)
                        self.updates[(new_member, op.name)] = state_update_stm
예제 #4
0
파일: misc.py 프로젝트: uwplse/cozy
def queries_equivalent(q1 : Query, q2 : Query, state_vars : [EVar], extern_funcs : { str : TFunc }, assumptions : Exp = ETRUE):
    """Determine whether two queries always return the same result.

    This function also checks that the two queries have semantically equivalent
    preconditions.  Checking the preconditions is necessary to ensure semantic
    equivalence of the queries: a query object should be interpreted to mean
    "if my preconditions hold then I compute and return my body expression".
    If two queries do not have semantically equivalent preconditions, then
    there might be cases where one is obligated to return a value and the other
    has no defined behavior.
    """

    with task("checking query equivalence", q1=q1.name, q2=q2.name):
        if q1.ret.type != q2.ret.type:
            return False
        q1args = dict(q1.args)
        q2args = dict(q2.args)
        if q1args != q2args:
            return False

        checker = solver_for_context(
            context=RootCtx(
                state_vars=state_vars,
                args=[EVar(a).with_type(t) for (a, t) in q1.args],
                funcs=extern_funcs),
            assumptions=assumptions)

        q1a = EAll(q1.assumptions)
        q2a = EAll(q2.assumptions)
        return checker.valid(EEq(q1a, q2a)) and checker.valid(EImplies(q1a, EEq(q1.ret, q2.ret)))
예제 #5
0
파일: misc.py 프로젝트: wcphkust/cozy
def queries_equivalent(q1: Query,
                       q2: Query,
                       state_vars: [EVar],
                       extern_funcs: {str: TFunc},
                       assumptions: Exp = ETRUE):
    """Determine whether two queries always return the same result.

    This function also checks that the two queries have semantically equivalent
    preconditions.  Checking the preconditions is necessary to ensure semantic
    equivalence of the queries: a query object should be interpreted to mean
    "if my preconditions hold then I compute and return my body expression".
    If two queries do not have semantically equivalent preconditions, then
    there might be cases where one is obligated to return a value and the other
    has no defined behavior.
    """

    with task("checking query equivalence", q1=q1.name, q2=q2.name):
        if q1.ret.type != q2.ret.type:
            return False
        q1args = dict(q1.args)
        q2args = dict(q2.args)
        if q1args != q2args:
            return False

        checker = solver_for_context(context=RootCtx(
            state_vars=state_vars,
            args=[EVar(a).with_type(t) for (a, t) in q1.args],
            funcs=extern_funcs),
                                     assumptions=assumptions)

        q1a = EAll(q1.assumptions)
        q2a = EAll(q2.assumptions)
        return checker.valid(EEq(q1a, q2a)) and checker.valid(
            EImplies(q1a, EEq(q1.ret, q2.ret)))
예제 #6
0
def queries_equivalent(q1: Query,
                       q2: Query,
                       state_vars: [EVar],
                       extern_funcs: {str: TFunc},
                       assumptions: Exp = T):
    with task("checking query equivalence", q1=q1.name, q2=q2.name):
        if q1.ret.type != q2.ret.type:
            return False
        q1args = dict(q1.args)
        q2args = dict(q2.args)
        if q1args != q2args:
            return False
        args = FrozenDict(q1args)

        key = (args, assumptions)
        checker = _qe_cache.get(key)
        if checker is None:
            checker = ModelCachingSolver(vars=list(
                itertools.chain(state_vars, (EVar(v).with_type(t)
                                             for v, t in args.items()))),
                                         funcs=extern_funcs,
                                         assumptions=assumptions)
            _qe_cache[key] = checker

        q1a = EAll(q1.assumptions)
        q2a = EAll(q2.assumptions)
        return checker.valid(EEq(q1a, q2a)) and checker.valid(
            EImplies(q1a, EEq(q1.ret, q2.ret)))
예제 #7
0
 def check_wf(e, ctx, pool):
     with task("checking well-formedness", size=e.size()):
         try:
             exp_wf(e,
                    pool=pool,
                    context=ctx,
                    assumptions=self.assumptions,
                    solver=self.wf_solver)
         except ExpIsNotWf as exc:
             return No("at {}: {}".format(
                 pprint(exc.offending_subexpression), exc.reason))
         for (sub, sub_ctx, sub_pool) in shred(e, ctx, pool):
             res = good_idea(self.wf_solver,
                             sub,
                             sub_ctx,
                             sub_pool,
                             assumptions=self.assumptions)
             if not res:
                 return res
         if pool == RUNTIME_POOL and self.cost_model.compare(
                 e, self.targets[0], ctx, pool) == Order.GT:
             # from cozy.cost_model import debug_comparison
             # debug_comparison(self.cost_model, e, self.target, ctx)
             return No("too expensive")
         # if isinstance(e.type, TBag):
         #     c = self.cost_model.cardinality(e)
         #     if all(cc < c for cc in cards):
         #         # print("too big: {}".format(pprint(e)))
         #         return No("too big")
         return True
예제 #8
0
파일: cost_model.py 프로젝트: wcphkust/cozy
 def compare(self, e1: Exp, e2: Exp, context: Context, pool: Pool) -> Order:
     with task("compare costs", context=context):
         selection = cost_model_selection.value
         if selection == 0:
             return order_objects(e1.size(), e2.size())
         if selection == 1:
             if pool == RUNTIME_POOL:
                 return prioritized_order(
                     lambda: order_objects(polynomial_runtime(e1),
                                           polynomial_runtime(e2)),
                     lambda: order_objects(e1.size(), e2.size()))
             else:
                 return order_objects(e1.size(), e2.size())
         if selection == 2:
             if pool == RUNTIME_POOL:
                 return prioritized_order(
                     lambda: order_objects(asymptotic_runtime(e1),
                                           asymptotic_runtime(e2)),
                     lambda: self._compare(
                         max_storage_size(e1, self.freebies),
                         max_storage_size(e2, self.freebies), context),
                     lambda: self._compare(rt(e1), rt(e2), context),
                     lambda: order_objects(e1.size(), e2.size()))
             else:
                 return prioritized_order(
                     lambda: self._compare(storage_size(e1, self.freebies),
                                           storage_size(e2, self.freebies),
                                           context),
                     lambda: order_objects(e1.size(), e2.size()))
         if selection == 3:
             if pool == RUNTIME_POOL:
                 return prioritized_order(
                     lambda: order_objects(asymptotic_runtime(e1),
                                           asymptotic_runtime(e2)),
                     lambda: unprioritized_order(
                         lambda: prioritized_order(
                             lambda: self._compare(
                                 max_storage_size(e1, self.freebies),
                                 max_storage_size(e2, self.freebies),
                                 context), lambda: self._compare(
                                     rt(e1), rt(e2), context)), *
                         [
                             lambda op=op: self._compare(
                                 maintenance_cost(e1, op, self.freebies),
                                 maintenance_cost(e2, op, self.freebies),
                                 context) for op in self.ops
                         ]), lambda: order_objects(e1.size(), e2.size()))
             else:
                 return prioritized_order(
                     lambda: self._compare(storage_size(e1, self.freebies),
                                           storage_size(e2, self.freebies),
                                           context),
                     lambda: order_objects(e1.size(), e2.size()))
         raise ValueError("illegal value for --{}: {}".format(
             cost_model_selection.name, selection))
예제 #9
0
파일: core.py 프로젝트: uwplse/cozy
 def check_wf(e, ctx, pool):
     with task("pruning", size=e.size()):
         is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver)
         if not is_wf:
             return is_wf
         res = possibly_useful(wf_solver, e, ctx, pool, ops=ops)
         if not res:
             return res
         if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare(e, targets[0], ctx, pool) == Order.GT:
             return No("too expensive")
         return True
예제 #10
0
def repair_well_formedness(e: Exp,
                           context: Context,
                           extra_available_state: [Exp] = []) -> Exp:
    """Repair the EStateVar nodes in an expression that is not well-formed.

    Parameters:
        e                     - the expression to repair
        context               - the intended context for e
        extra_available_state - extra state expressions that e can use

    Assuming that all expressions in extra_available_state are well-formed
    state expressions, the output will be a well-formed runtime expression that
    behaves like `e`.
    """

    with task("repairing"):

        e = strip_EStateVar(e)

        # state expressions in decreasing order of size
        available_state = sorted(unique(
            itertools.chain((v for v, p in context.vars() if p == STATE_POOL),
                            extra_available_state)),
                                 key=lambda e: -e.size())

        with task("making replacements", size=e.size()):
            for s in available_state:
                e = replace(e,
                            s,
                            EStateVar(s).with_type(s.type),
                            match=alpha_equivalent,
                            filter=lambda e: not isinstance(e, EStateVar))

        with task("freshening binders"):
            e = freshen_binders(e, context)

        with task("checking correctness"):
            res = exp_wf(e, context, RUNTIME_POOL)
            assert res, str(res)

        return e
예제 #11
0
 def check_wf(e, ctx, pool):
     with task("pruning", size=e.size()):
         is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver)
         if not is_wf:
             return is_wf
         res = possibly_useful(wf_solver, e, ctx, pool, ops=ops)
         if not res:
             return res
         if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare(
                 e, targets[0], ctx, pool) == Order.GT:
             return No("too expensive")
         return True
예제 #12
0
파일: impls.py 프로젝트: uwplse/cozy
    def set_impl(self, q : Query, rep : [(EVar, Exp)], ret : Exp):
        """Update the implementation of a query.

        The query having the same name as `q` will have its implementation
        replaced by the given concrete representation and computation.

        This call may add additional "subqueries" to the implementation to
        maintain the new representation when each update operation is called.
        """
        with task("updating implementation", query=q.name):
            with task("finding duplicated state vars"):
                to_remove = set()
                for (v, e) in rep:
                    aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and self.state_solver.valid(EEq(e, ee)))
                    # aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and alpha_equivalent(e, ee))
                    if aeq is not None:
                        event("state var {} is equivalent to {}".format(v.id, aeq.id))
                        ret = subst(ret, { v.id : aeq })
                        to_remove.add(v)
                rep = [ x for x in rep if x[0] not in to_remove ]

            self._concretization_functions.extend(rep)
            self.query_impls[q.name] = rewrite_ret(q, lambda prev: ret, keep_assumptions=False)

            for op in self.op_specs:
                with task("incrementalizing query", query=q.name, op=op.name):
                    for new_member, projection in rep:
                        subqueries = []
                        state_update_stm = inc.mutate_in_place(
                            new_member,
                            projection,
                            op.body,
                            abstract_state=self.abstract_state,
                            assumptions=op.assumptions,
                            invariants=self.abstract_invariants,
                            subgoals_out=subqueries)
                        for sub_q in subqueries:
                            sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring)
                            state_update_stm = self._add_subquery(sub_q=sub_q, used_by=state_update_stm)
                        self.updates[(new_member, op.name)] = state_update_stm
예제 #13
0
파일: wf.py 프로젝트: uwplse/cozy
def repair_well_formedness(e : Exp, context : Context, extra_available_state : [Exp] = []) -> Exp:
    """Repair the EStateVar nodes in an expression that is not well-formed.

    Parameters:
        e                     - the expression to repair
        context               - the intended context for e
        extra_available_state - extra state expressions that e can use

    Assuming that all expressions in extra_available_state are well-formed
    state expressions, the output will be a well-formed runtime expression that
    behaves like `e`.
    """

    with task("repairing"):

        e = strip_EStateVar(e)

        # state expressions in decreasing order of size
        available_state = sorted(unique(itertools.chain(
            (v for v, p in context.vars() if p == STATE_POOL),
            extra_available_state)), key=lambda e: -e.size())

        with task("making replacements", size=e.size()):
            for s in available_state:
                e = replace(e, s, EStateVar(s).with_type(s.type),
                    match=alpha_equivalent,
                    filter=lambda e: not isinstance(e, EStateVar))

        with task("freshening binders"):
            e = freshen_binders(e, context)

        with task("checking correctness"):
            res = exp_wf(e, context, RUNTIME_POOL)
            assert res, str(res)

        return e
예제 #14
0
파일: cost_model.py 프로젝트: timwee/cozy
 def compare(self, e1: Exp, e2: Exp, context: Context, pool: Pool) -> Order:
     with task("compare costs", context=context):
         if pool == RUNTIME_POOL:
             return composite_order(
                 lambda: order_objects(asymptotic_runtime(
                     e1), asymptotic_runtime(e2)), lambda: self._compare(
                         max_storage_size(e1, self.freebies),
                         max_storage_size(e2, self.freebies), context),
                 lambda: self._compare(rt(e1), rt(e2), context),
                 lambda: order_objects(e1.size(), e2.size()))
         else:
             return composite_order(
                 lambda: self._compare(storage_size(e1, self.freebies),
                                       storage_size(e2, self.freebies),
                                       context),
                 lambda: order_objects(e1.size(), e2.size()))
예제 #15
0
 def check_wf(e, ctx, pool):
     with task("checking well-formedness", size=e.size()):
         is_wf = exp_wf(e,
                        pool=pool,
                        context=ctx,
                        solver=self.wf_solver)
         if not is_wf:
             return is_wf
         res = good_idea_recursive(self.wf_solver,
                                   e,
                                   ctx,
                                   pool,
                                   ops=self.ops)
         if not res:
             return res
         if pool == RUNTIME_POOL and self.cost_model.compare(
                 e, self.targets[0], ctx, pool) == Order.GT:
             return No("too expensive")
         return True
예제 #16
0
 def compare(self, e1: Exp, e2: Exp, context: Context, pool: Pool) -> Order:
     with task("compare costs", context=context):
         if consider_maintenance_cost.value:
             if pool == RUNTIME_POOL:
                 return prioritized_order(
                     lambda: order_objects(asymptotic_runtime(e1),
                                           asymptotic_runtime(e2)),
                     lambda: unprioritized_order(
                         lambda: prioritized_order(
                             lambda: self._compare(
                                 max_storage_size(e1, self.freebies),
                                 max_storage_size(e2, self.freebies),
                                 context), lambda: self._compare(
                                     rt(e1), rt(e2), context)), *
                         [
                             lambda op=op: self._compare(
                                 maintenance_cost(e1, op, self.freebies),
                                 maintenance_cost(e2, op, self.freebies),
                                 context) for op in self.ops
                         ]), lambda: order_objects(e1.size(), e2.size()))
             else:
                 return prioritized_order(
                     lambda: self._compare(storage_size(e1, self.freebies),
                                           storage_size(e2, self.freebies),
                                           context),
                     lambda: order_objects(e1.size(), e2.size()))
         else:
             if pool == RUNTIME_POOL:
                 return prioritized_order(
                     lambda: order_objects(asymptotic_runtime(e1),
                                           asymptotic_runtime(e2)),
                     lambda: self._compare(
                         max_storage_size(e1, self.freebies),
                         max_storage_size(e2, self.freebies), context),
                     lambda: self._compare(rt(e1), rt(e2), context),
                     lambda: order_objects(e1.size(), e2.size()))
             else:
                 return prioritized_order(
                     lambda: self._compare(storage_size(e1, self.freebies),
                                           storage_size(e2, self.freebies),
                                           context),
                     lambda: order_objects(e1.size(), e2.size()))
예제 #17
0
    def search(self):

        root_ctx = self.context

        def check_wf(e, ctx, pool):
            with task("checking well-formedness", size=e.size()):
                is_wf = exp_wf(e,
                               pool=pool,
                               context=ctx,
                               solver=self.wf_solver)
                if not is_wf:
                    return is_wf
                res = good_idea_recursive(self.wf_solver,
                                          e,
                                          ctx,
                                          pool,
                                          ops=self.ops)
                if not res:
                    return res
                if pool == RUNTIME_POOL and self.cost_model.compare(
                        e, self.targets[0], ctx, pool) == Order.GT:
                    return No("too expensive")
                return True

        frags = list(
            unique(
                itertools.chain(*[shred(t, root_ctx) for t in self.targets],
                                *[shred(h, root_ctx) for h in self.hints])))
        frags.sort(key=hint_order)
        enum = Enumerator(examples=self.examples,
                          cost_model=self.cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=self.stop_callback,
                          do_eviction=enable_eviction.value)

        size = 0
        target_fp = fingerprint(self.targets[0], self.examples)

        watches = OrderedDict()
        for target in self.targets:
            for e, ctx, pool in unique(
                    shred(target, context=root_ctx, pool=RUNTIME_POOL)):
                exs = ctx.instantiate_examples(self.examples)
                fp = fingerprint(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))
        watched_ctxs = list(
            unique((ctx, pool) for fp, ctx, pool in watches.keys()))

        def consider_new_target(old_target, e, ctx, pool, replacement):
            nonlocal n
            n += 1
            k = (e, ctx, pool, replacement)
            if enable_blacklist.value and k in self.blacklist:
                event("blacklisted")
                print("skipping blacklisted substitution: {} ---> {} ({})".
                      format(pprint(e), pprint(replacement),
                             self.blacklist[k]))
                return
            new_target = freshen_binders(
                replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool,
                        replacement), root_ctx)
            if any(alpha_equivalent(t, new_target) for t in self.targets):
                event("already seen")
                return
            wf = check_wf(new_target, root_ctx, RUNTIME_POOL)
            if not wf:
                msg = "not well-formed [wf={}]".format(wf)
                event(msg)
                self.blacklist[k] = msg
                return
            if not fingerprints_match(fingerprint(new_target, self.examples),
                                      target_fp):
                msg = "not correct"
                event(msg)
                self.blacklist[k] = msg
                return
            if self.cost_model.compare(new_target, target, root_ctx,
                                       RUNTIME_POOL) not in (Order.LT,
                                                             Order.AMBIGUOUS):
                msg = "not an improvement"
                event(msg)
                self.blacklist[k] = msg
                return
            print("FOUND A GUESS AFTER {} CONSIDERED".format(n))
            print(" * in {}".format(pprint(old_target), pprint(e),
                                    pprint(replacement)))
            print(" * replacing {}".format(pprint(e)))
            print(" * with {}".format(pprint(replacement)))
            yield new_target

        while True:

            print("starting minor iteration {} with |cache|={}".format(
                size, enum.cache_size()))
            if self.stop_callback():
                raise StopException()

            n = 0

            for ctx, pool in watched_ctxs:
                with task("searching for obvious substitutions",
                          ctx=ctx,
                          pool=pool_name(pool)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("searching for obvious substitution",
                                  expression=pprint(info.e)):
                            fp = info.fingerprint
                            for ((fpx, cc, pp), reses) in watches.items():
                                if cc != ctx or pp != pool:
                                    continue

                                if not fingerprints_match(fpx, fp):
                                    continue

                                for target, watched_e in reses:
                                    replacement = info.e
                                    event("possible substitution: {} ---> {}".
                                          format(pprint(watched_e),
                                                 pprint(replacement)))
                                    event("replacement locations: {}".format(
                                        pprint(
                                            replace(target, root_ctx,
                                                    RUNTIME_POOL, watched_e,
                                                    ctx, pool, EVar("___")))))

                                    if alpha_equivalent(
                                            watched_e, replacement):
                                        event("no change")
                                        continue

                                    yield from consider_new_target(
                                        target, watched_e, ctx, pool,
                                        replacement)

            if check_all_substitutions.value:
                print("Guessing at substitutions...")
                for target, e, ctx, pool in exploration_order(
                        self.targets, root_ctx):
                    with task("checking substitutions",
                              target=pprint(
                                  replace(target, root_ctx, RUNTIME_POOL, e,
                                          ctx, pool, EVar("___"))),
                              e=pprint(e)):
                        for info in enum.enumerate_with_info(size=size,
                                                             context=ctx,
                                                             pool=pool):
                            with task("checking substitution",
                                      expression=pprint(info.e)):
                                if self.stop_callback():
                                    raise StopException()
                                replacement = info.e
                                if replacement.type != e.type:
                                    event("wrong type (is {}, need {})".format(
                                        pprint(replacement.type),
                                        pprint(e.type)))
                                    continue
                                if alpha_equivalent(replacement, e):
                                    event("no change")
                                    continue
                                should_consider = should_consider_replacement(
                                    target, root_ctx, e, ctx, pool,
                                    fingerprint(
                                        e,
                                        ctx.instantiate_examples(
                                            self.examples)), info.e,
                                    info.fingerprint)
                                if not should_consider:
                                    event(
                                        "skipped; `should_consider_replacement` returned {}"
                                        .format(should_consider))
                                    continue

                                yield from consider_new_target(
                                    target, e, ctx, pool, replacement)

            print("CONSIDERED {}".format(n))
            size += 1
예제 #18
0
파일: impls.py 프로젝트: uwplse/cozy
    def _add_subquery(self, sub_q : Query, used_by : Stm) -> Stm:
        """Add a query that helps maintain some other state.

        Parameters:
            sub_q - the specification of the helper query
            used_by - the statement that calls `sub_q`

        If a query already exists that is equivalent to `sub_q`, this method
        returns `used_by` rewritten to use the existing query and does not add
        the query to the implementation.  Otherwise it returns `used_by`
        unchanged.
        """

        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions(
                    reachable_handles_at_method(self.spec, sub_q))
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(itertools.chain(sub_q.assumptions, new_a))

            with task("repairing state var boundaries"):
                extra_available_state = [e for v, e in self._concretization_functions]
                sub_q.ret = repair_well_formedness(
                    strip_EStateVar(sub_q.ret),
                    self.context_for_method(sub_q),
                    extra_available_state)

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs, assumptions=EAll(self.abstract_invariants)))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args]
                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)
                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by
예제 #19
0
파일: core.py 프로젝트: timwee/cozy
    def next(self):
        class No(object):
            def __init__(self, msg):
                self.msg = msg

            def __bool__(self):
                return False

            def __str__(self):
                return "no: {}".format(self.msg)

        # with task("pre-computing cardinalities"):
        #     cards = [self.cost_model.cardinality(ctx.e) for ctx in enumerate_fragments(self.target) if is_collection(ctx.e.type)]

        root_ctx = self.context

        def check_wf(e, ctx, pool):
            with task("checking well-formedness", size=e.size()):
                try:
                    exp_wf(e,
                           pool=pool,
                           context=ctx,
                           assumptions=self.assumptions,
                           solver=self.wf_solver)
                except ExpIsNotWf as exc:
                    return No("at {}: {}".format(
                        pprint(exc.offending_subexpression), exc.reason))
                if pool == RUNTIME_POOL and self.cost_model.compare(
                        e, self.targets[0], ctx, pool) == Order.GT:
                    # from cozy.cost_model import debug_comparison
                    # debug_comparison(self.cost_model, e, self.target, ctx)
                    return No("too expensive")
                # if isinstance(e.type, TBag):
                #     c = self.cost_model.cardinality(e)
                #     if all(cc < c for cc in cards):
                #         # print("too big: {}".format(pprint(e)))
                #         return No("too big")
                return True

        frags = list(
            unique(
                itertools.chain(*[shred(t, root_ctx) for t in self.targets],
                                *[shred(h, root_ctx) for h in self.hints])))
        enum = Enumerator(examples=self.examples,
                          cost_model=self.cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=self.stop_callback)

        size = 0
        # target_cost = self.cost_model.cost(self.target, RUNTIME_POOL)
        target_fp = fingerprint(self.targets[0], self.examples)

        if not hasattr(self, "blacklist"):
            self.blacklist = set()

        while True:

            print("starting minor iteration {} with |cache|={}".format(
                size, enum.cache_size()))
            if self.stop_callback():
                raise StopException()

            n = 0
            for target, e, ctx, pool in exploration_order(
                    self.targets, root_ctx):
                with task("checking substitutions",
                          target=pprint(
                              replace(target, root_ctx, RUNTIME_POOL, e, ctx,
                                      pool, EVar("___"))),
                          e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("checking substitution",
                                  expression=pprint(info.e)):
                            if self.stop_callback():
                                raise StopException()
                            if info.e.type != e.type:
                                event("wrong type (is {}, need {})".format(
                                    pprint(info.e.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(info.e, e):
                                event("no change")
                                continue

                            k = (e, ctx, pool, info.e)
                            if k in self.blacklist:
                                event("blacklisted")
                                continue

                            n += 1
                            ee = freshen_binders(
                                replace(target, root_ctx, RUNTIME_POOL, e, ctx,
                                        pool, info.e), root_ctx)
                            if any(
                                    alpha_equivalent(t, ee)
                                    for t in self.targets):
                                event("already seen")
                                continue
                            if not self.matches(fingerprint(ee, self.examples),
                                                target_fp):
                                event("incorrect")
                                self.blacklist.add(k)
                                continue
                            wf = check_wf(ee, root_ctx, RUNTIME_POOL)
                            if not wf:
                                event("not well-formed [wf={}]".format(wf))
                                # if "expensive" in str(wf):
                                #     print(repr(self.cost_model.examples))
                                #     print(repr(ee))
                                self.blacklist.add(k)
                                continue
                            if self.cost_model.compare(
                                    ee, target, root_ctx,
                                    RUNTIME_POOL) not in (Order.LT,
                                                          Order.AMBIGUOUS):
                                event("not an improvement")
                                self.blacklist.add(k)
                                continue
                            print(
                                "FOUND A GUESS AFTER {} CONSIDERED".format(n))
                            yield ee

            print("CONSIDERED {}".format(n))
            size += 1

        raise NoMoreImprovements()
예제 #20
0
def search_for_improvements(targets: [Exp], wf_solver: ModelCachingSolver,
                            context: Context, examples: [{
                                str: object
                            }], cost_model: CostModel,
                            stop_callback: Callable[[], bool], hints: [Exp],
                            ops: [Op], blacklist: {
                                (Exp, Context, Pool, Exp): str
                            }):
    """Search for potential improvements to any of the target expressions.

    This function yields expressions that look like improvements (or are
    ambiguous with respect to some target).  The expressions are only
    guaranteed to be correct on the given examples.

    This function may add new items to the given blacklist.
    """

    root_ctx = context

    def check_wf(e, ctx, pool):
        with task("pruning", size=e.size()):
            is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver)
            if not is_wf:
                return is_wf
            res = possibly_useful(wf_solver, e, ctx, pool, ops=ops)
            if not res:
                return res
            if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare(
                    e, targets[0], ctx, pool) == Order.GT:
                return No("too expensive")
            return True

    with task("setting up hints"):
        frags = list(
            unique(
                itertools.chain(
                    *[
                        all_subexpressions_with_context_information(
                            t, root_ctx) for t in targets
                    ], *[
                        all_subexpressions_with_context_information(
                            h, root_ctx) for h in hints
                    ])))
        frags.sort(key=hint_order)
        enum = Enumerator(examples=examples,
                          cost_model=cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=stop_callback,
                          do_eviction=enable_eviction.value)

    target_fp = Fingerprint.of(targets[0], examples)

    with task("setting up watches"):
        watches_by_context = OrderedDict()
        for target in targets:
            for e, ctx, pool in unique(
                    all_subexpressions_with_context_information(
                        target, context=root_ctx, pool=RUNTIME_POOL)):
                l = watches_by_context.get(ctx)
                if l is None:
                    l = []
                    watches_by_context[ctx] = l
                l.append((target, e, pool))

        watches = OrderedDict()
        for ctx, exprs in watches_by_context.items():
            exs = ctx.instantiate_examples(examples)
            for target, e, pool in exprs:
                fp = Fingerprint.of(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))

        watched_ctxs = list(
            unique(
                (ctx, pool)
                for _, _, ctx, pool in exploration_order(targets, root_ctx)))

    search_info = SearchInfo(context=root_ctx,
                             targets=targets,
                             target_fingerprint=target_fp,
                             examples=examples,
                             check_wf=check_wf,
                             cost_model=cost_model,
                             blacklist=blacklist)

    size = 0
    while True:

        print("starting minor iteration {} with |cache|={}".format(
            size, enum.cache_size()))
        if stop_callback():
            raise StopException()

        for ctx, pool in watched_ctxs:
            with task("searching for obvious substitutions",
                      ctx=ctx,
                      pool=pool_name(pool)):
                for info in enum.enumerate_with_info(size=size,
                                                     context=ctx,
                                                     pool=pool):
                    with task("searching for obvious substitution",
                              expression=pprint(info.e)):
                        fp = info.fingerprint
                        for ((fpx, cc, pp), reses) in watches.items():
                            if cc != ctx or pp != pool:
                                continue

                            if not fpx.equal_to(fp):
                                continue

                            for target, watched_e in reses:
                                replacement = info.e
                                event(
                                    "possible substitution: {} ---> {}".format(
                                        pprint(watched_e),
                                        pprint(replacement)))
                                event("replacement locations: {}".format(
                                    pprint(
                                        replace(target, root_ctx, RUNTIME_POOL,
                                                watched_e, ctx, pool,
                                                EVar("___")))))

                                if alpha_equivalent(watched_e, replacement):
                                    event("no change")
                                    continue

                                yield from _consider_replacement(
                                    target, watched_e, ctx, pool, replacement,
                                    search_info)

        if check_blind_substitutions.value:
            print("Guessing at substitutions...")
            for target, e, ctx, pool in exploration_order(targets, root_ctx):
                with task("checking substitutions",
                          target=pprint(
                              replace(target, root_ctx, RUNTIME_POOL, e, ctx,
                                      pool, EVar("___"))),
                          e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("checking substitution",
                                  expression=pprint(info.e)):
                            if stop_callback():
                                raise StopException()
                            replacement = info.e
                            if replacement.type != e.type:
                                event("wrong type (is {}, need {})".format(
                                    pprint(replacement.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(replacement, e):
                                event("no change")
                                continue
                            should_consider = should_consider_replacement(
                                target, root_ctx, e, ctx, pool,
                                Fingerprint.of(
                                    e, ctx.instantiate_examples(examples)),
                                info.e, info.fingerprint)
                            if not should_consider:
                                event(
                                    "skipped; `should_consider_replacement` returned {}"
                                    .format(should_consider))
                                continue

                            yield from _consider_replacement(
                                target, e, ctx, pool, replacement, search_info)

        if not enum.expressions_may_exist_above_size(context, RUNTIME_POOL,
                                                     size):
            raise StopException(
                "no more expressions can exist above size={}".format(size))

        size += 1
예제 #21
0
    def _enumerate_with_info(self, context: Context, size: int,
                             pool: Pool) -> [EnumeratedExp]:
        """Helper for enumerate_with_info that bypasses the cache.

        Note that this method DOES affect the cache: it writes its output into
        the cache and may do evictions.  The enumerate_with_info method ensures
        that there is only ever one call to this method for a given (context,
        size, pool).
        """

        examples = context.instantiate_examples(self.examples)
        cache = self.cache
        queue = self._enumerate_core(context, size, pool)
        cost_model = self.cost_model

        while True:
            if self.stop_callback():
                raise StopException()

            try:
                e = next(queue)
            except StopIteration:
                # StopIteration is a "control flow exception" indicating that
                # there isn't a next element.  Since the queue is exhausted,
                # breaking out of the loop is the right thing to do.
                break

            self.stat_timer.check()

            e = freshen_binders(e, context)
            _consider(e, size, context, pool)

            wf = self.check_wf(e, context, pool)
            if not wf:
                _skip(e, size, context, pool, "wf={}".format(wf))
                continue

            fp = Fingerprint.of(e, examples)

            # Collect all expressions from parent contexts that are
            # fingerprint-equivalent to this one.  There might be more than one
            # because of how `retention_policy` works.
            known_equivalents = list(
                cache.find_equivalent_expressions(context, pool, fp))
            to_evict = []

            if any(e.type == prev_entry.e.type
                   and alpha_equivalent(e, prev_entry.e)
                   for prev_entry in known_equivalents):
                _skip(e, size, context, pool, "duplicate")
                should_keep = False
            else:
                # decide whether to keep this expression
                should_keep = True
                if known_equivalents:
                    with task("comparing to cached equivalents",
                              count=len(known_equivalents)):
                        for entry in known_equivalents:
                            prev_exp = entry.e
                            event("previous: {}".format(pprint(prev_exp)))
                            to_keep = retention_policy(e, context, prev_exp,
                                                       context, pool,
                                                       cost_model)
                            if e not in to_keep:
                                _skip(e, size, context, pool,
                                      "preferring {}".format(pprint(prev_exp)))
                                should_keep = False
                                break
                            if prev_exp not in to_keep:
                                to_evict.append(entry)

            assert not (to_evict and not should_keep)

            if should_keep:

                if self.do_eviction and to_evict:
                    with task("evicting", count=to_evict):
                        for entry in to_evict:
                            _evict(entry.e, entry.size, context, pool, e, size)
                            cache.remove(context, pool, entry)

                _accept(e, size, context, pool, fp)
                info = EnumeratedExp(e=e, fingerprint=fp, size=size)
                yield info
                cache.add(context, pool, info)

                if size == 0:
                    with task("accelerating"):
                        to_try = make_random_access(
                            self.heuristics(e, context, pool))
                        if to_try:
                            event("trying {} accelerations of {}".format(
                                len(to_try), pprint(e)))
                            queue = itertools.chain(to_try, queue)
예제 #22
0
파일: core.py 프로젝트: uwplse/cozy
def search_for_improvements(
        targets       : [Exp],
        wf_solver     : ModelCachingSolver,
        context       : Context,
        examples      : [{str:object}],
        cost_model    : CostModel,
        stop_callback : Callable[[], bool],
        hints         : [Exp],
        ops           : [Op],
        blacklist     : {(Exp, Context, Pool, Exp) : str}):
    """Search for potential improvements to any of the target expressions.

    This function yields expressions that look like improvements (or are
    ambiguous with respect to some target).  The expressions are only
    guaranteed to be correct on the given examples.

    This function may add new items to the given blacklist.
    """

    root_ctx = context
    def check_wf(e, ctx, pool):
        with task("pruning", size=e.size()):
            is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver)
            if not is_wf:
                return is_wf
            res = possibly_useful(wf_solver, e, ctx, pool, ops=ops)
            if not res:
                return res
            if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare(e, targets[0], ctx, pool) == Order.GT:
                return No("too expensive")
            return True

    with task("setting up hints"):
        frags = list(unique(itertools.chain(
            *[all_subexpressions_with_context_information(t, root_ctx) for t in targets],
            *[all_subexpressions_with_context_information(h, root_ctx) for h in hints])))
        frags.sort(key=hint_order)
        enum = Enumerator(
            examples=examples,
            cost_model=cost_model,
            check_wf=check_wf,
            hints=frags,
            heuristics=try_optimize,
            stop_callback=stop_callback,
            do_eviction=enable_eviction.value)

    target_fp = Fingerprint.of(targets[0], examples)

    with task("setting up watches"):
        watches_by_context = OrderedDict()
        for target in targets:
            for e, ctx, pool in unique(all_subexpressions_with_context_information(target, context=root_ctx, pool=RUNTIME_POOL)):
                l = watches_by_context.get(ctx)
                if l is None:
                    l = []
                    watches_by_context[ctx] = l
                l.append((target, e, pool))

        watches = OrderedDict()
        for ctx, exprs in watches_by_context.items():
            exs = ctx.instantiate_examples(examples)
            for target, e, pool in exprs:
                fp = Fingerprint.of(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))

        watched_ctxs = list(unique((ctx, pool) for _, _, ctx, pool in exploration_order(targets, root_ctx)))

    search_info = SearchInfo(
        context=root_ctx,
        targets=targets,
        target_fingerprint=target_fp,
        examples=examples,
        check_wf=check_wf,
        cost_model=cost_model,
        blacklist=blacklist)

    size = 0
    while True:

        print("starting minor iteration {} with |cache|={}".format(size, enum.cache_size()))
        if stop_callback():
            raise StopException()

        for ctx, pool in watched_ctxs:
            with task("searching for obvious substitutions", ctx=ctx, pool=pool_name(pool)):
                for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool):
                    with task("searching for obvious substitution", expression=pprint(info.e)):
                        fp = info.fingerprint
                        for ((fpx, cc, pp), reses) in watches.items():
                            if cc != ctx or pp != pool:
                                continue

                            if not fpx.equal_to(fp):
                                continue

                            for target, watched_e in reses:
                                replacement = info.e
                                event("possible substitution: {} ---> {}".format(pprint(watched_e), pprint(replacement)))
                                event("replacement locations: {}".format(pprint(replace(target, root_ctx, RUNTIME_POOL, watched_e, ctx, pool, EVar("___")))))

                                if alpha_equivalent(watched_e, replacement):
                                    event("no change")
                                    continue

                                yield from _consider_replacement(target, watched_e, ctx, pool, replacement, search_info)

        if check_blind_substitutions.value:
            print("Guessing at substitutions...")
            for target, e, ctx, pool in exploration_order(targets, root_ctx):
                with task("checking substitutions",
                        target=pprint(replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, EVar("___"))),
                        e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool):
                        with task("checking substitution", expression=pprint(info.e)):
                            if stop_callback():
                                raise StopException()
                            replacement = info.e
                            if replacement.type != e.type:
                                event("wrong type (is {}, need {})".format(pprint(replacement.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(replacement, e):
                                event("no change")
                                continue
                            should_consider = should_consider_replacement(
                                target, root_ctx,
                                e, ctx, pool, Fingerprint.of(e, ctx.instantiate_examples(examples)),
                                info.e, info.fingerprint)
                            if not should_consider:
                                event("skipped; `should_consider_replacement` returned {}".format(should_consider))
                                continue

                            yield from _consider_replacement(target, e, ctx, pool, replacement, search_info)

        if not enum.expressions_may_exist_above_size(context, RUNTIME_POOL, size):
            raise StopException("no more expressions can exist above size={}".format(size))

        size += 1
예제 #23
0
    def _add_subquery(self, sub_q: Query, used_by: Stm) -> Stm:
        """Add a query that helps maintain some other state.

        Parameters:
            sub_q - the specification of the helper query
            used_by - the statement that calls `sub_q`

        If a query already exists that is equivalent to `sub_q`, this method
        returns `used_by` rewritten to use the existing query and does not add
        the query to the implementation.  Otherwise it returns `used_by`
        unchanged.
        """

        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions(
                    reachable_handles_at_method(self.spec, sub_q))
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(
                        itertools.chain(sub_q.assumptions, new_a))

            with task("repairing state var boundaries"):
                extra_available_state = [
                    e for v, e in self._concretization_functions
                ]
                sub_q.ret = repair_well_formedness(
                    strip_EStateVar(sub_q.ret), self.context_for_method(sub_q),
                    extra_available_state)

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(
                    simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(
                self.query_specs, lambda qq: dedup_queries.value and
                queries_equivalent(qq,
                                   sub_q,
                                   state_vars=state_vars,
                                   extern_funcs=funcs,
                                   assumptions=EAll(self.abstract_invariants)))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(
                    sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a)
                               for (a, t) in qq.args]

                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)

                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by
예제 #24
0
    def enumerate_with_info(self, context: Context, size: int,
                            pool: Pool) -> [EnumeratedExp]:
        canonical_context = self.canonical_context(context)
        if canonical_context is not context:
            print("adapting request: {} ---> {}".format(
                context, canonical_context))
            for info in self.enumerate_with_info(canonical_context, size,
                                                 pool):
                yield info._replace(e=context.adapt(info.e, canonical_context))
            return

        examples = context.instantiate_examples(self.examples)
        if context.parent() is not None:
            for info in self.enumerate_with_info(context.parent(), size, pool):
                e = info.e
                yield EnumeratedExp(e=e, fingerprint=fingerprint(e, examples))

        k = (pool, size, context)
        res = self.cache.get(k)
        if res is not None:
            for e in res:
                yield e
        else:
            assert k not in self.in_progress, "recursive enumeration?? {}".format(
                k)
            self.in_progress.add(k)
            res = []
            self.cache[k] = res
            queue = self.enumerate_core(context, size, pool)
            cost_model = self.cost_model
            while True:
                if self.stop_callback():
                    raise StopException()

                try:
                    e = next(queue)
                except StopIteration:
                    break

                fvs = free_vars(e)
                if not belongs_in_context(fvs, context):
                    continue

                e = freshen_binders(e, context)
                _consider(e, size, context, pool)

                wf = self.check_wf(e, context, pool)
                if not wf:
                    _skip(e, size, context, pool, "wf={}".format(wf))
                    continue

                fp = fingerprint(e, examples)

                # collect all expressions from parent contexts
                with task("collecting prev exps",
                          size=size,
                          context=context,
                          pool=pool_name(pool)):
                    prev = []
                    for sz in range(0, size + 1):
                        prev.extend(self.enumerate_with_info(
                            context, sz, pool))
                    prev = [p.e for p in prev if p.fingerprint == fp]

                if any(alpha_equivalent(e, p) for p in prev):
                    _skip(e, size, context, pool, "duplicate")
                    should_keep = False
                else:
                    # decide whether to keep this expression
                    should_keep = True
                    with task("comparing to cached equivalents"):
                        for prev_exp in prev:
                            event("previous: {}".format(pprint(prev_exp)))
                            to_keep = eviction_policy(e, context, prev_exp,
                                                      context, pool,
                                                      cost_model)
                            if e not in to_keep:
                                _skip(e, size, context, pool,
                                      "preferring {}".format(pprint(prev_exp)))
                                should_keep = False
                                break

                if should_keep:

                    if self.do_eviction:
                        with task("evicting"):
                            to_evict = []
                            for (key, exps) in self.cache.items():
                                (p, s, c) = key
                                if p == pool and c == context:
                                    for ee in exps:
                                        if ee.fingerprint == fp:
                                            event("considering eviction of {}".
                                                  format(pprint(ee.e)))
                                            to_keep = eviction_policy(
                                                e, context, ee.e, c, pool,
                                                cost_model)
                                            if ee.e not in to_keep:
                                                to_evict.append((key, ee))
                            for key, ee in to_evict:
                                (p, s, c) = key
                                _evict(ee.e, s, c, pool, e)
                                self.cache[key].remove(ee)
                                self.seen[(c, p, fp)].remove(ee.e)

                    _accept(e, size, context, pool)
                    seen_key = (context, pool, fp)
                    if seen_key not in self.seen:
                        self.seen[seen_key] = []
                    self.seen[seen_key].append(e)
                    info = EnumeratedExp(e=e, fingerprint=fp)
                    res.append(info)
                    yield info

                    with task("accelerating"):
                        to_try = make_random_access(
                            self.heuristics(e, context, pool))
                        if to_try:
                            event("trying {} accelerations of {}".format(
                                len(to_try), pprint(e)))
                            queue = itertools.chain(to_try, queue)

            self.in_progress.remove(k)
예제 #25
0
파일: acceleration.py 프로젝트: timwee/cozy
def map_accelerate(e, context):
    with task("map_accelerate", size=e.size()):
        if is_constant_time(e):
            event("skipping map lookup inference for constant-time exp: {}".
                  format(pprint(e)))
            return

        @lru_cache()
        def make_binder(t):
            return fresh_var(t, hint="key")

        args = OrderedSet(v for (v, p) in context.vars() if p == RUNTIME_POOL)
        possible_keys = {}  # type -> [exp]
        i = 0

        stk = [e]
        while stk:
            event("exp {} / {}".format(i, e.size()))
            i += 1
            arg = stk.pop()
            if isinstance(arg, tuple):
                stk.extend(arg)
                continue
            if not isinstance(arg, Exp):
                continue
            if isinstance(arg, ELambda):
                stk.append(arg.body)
                continue

            if context.legal_for(free_vars(arg)):
                # all the work happens here
                binder = make_binder(arg.type)
                value = replace(
                    e,
                    arg,
                    binder,
                    match=lambda e1, e2: type(e1) == type(e2) and e1.type == e2
                    .type and alpha_equivalent(e1, e2))
                value = strip_EStateVar(value)
                # print(" ----> {}".format(pprint(value)))
                if any(v in args for v in free_vars(value)):
                    event("not all args were eliminated")
                else:
                    if arg.type not in possible_keys:
                        l = [
                            reachable_values_of_type(sv, arg.type)
                            for (sv, p) in context.vars() if p == STATE_POOL
                        ]
                        l = OrderedSet(x for x in l
                                       if not isinstance(x, EEmptyList))
                        possible_keys[arg.type] = l
                    for keys in possible_keys[arg.type]:
                        # print("reachable values of type {}: {}".format(pprint(arg.type), pprint(keys)))
                        # for v in state_vars:
                        #     print("  {} : {}".format(pprint(v), pprint(v.type)))
                        m = EMakeMap2(keys, ELambda(binder, value)).with_type(
                            TMap(arg.type, e.type))
                        assert not any(
                            v in args
                            for v in free_vars(m)), "oops! {}; args={}".format(
                                pprint(m), ", ".join(pprint(a) for a in args))
                        yield (m, STATE_POOL)
                        mg = EMapGet(EStateVar(m).with_type(m.type),
                                     arg).with_type(e.type)
                        # print(pprint(mg))
                        # mg._tag = True
                        yield (mg, RUNTIME_POOL)

            if isinstance(arg, EStateVar):
                # do not visit state expressions
                continue

            num_with_args = 0
            stk2 = list(arg.children())
            while stk2:
                child = stk2.pop()
                if isinstance(child, tuple):
                    stk.extend(child)
                    continue
                if not isinstance(child, Exp):
                    continue
                fvs = free_vars(child)
                if fvs & args:
                    num_with_args += 1
                    if num_with_args >= 2:
                        break
            if num_with_args < 2:
                stk.extend(arg.children())
            else:
                event("refusing to visit children of {}".format(pprint(arg)))
예제 #26
0
파일: core.py 프로젝트: uwplse/cozy
def improve(
        target        : Exp,
        context       : Context,
        assumptions   : Exp                = ETRUE,
        stop_callback : Callable[[], bool] = never_stop,
        hints         : [Exp]              = (),
        examples      : [{str:object}]     = (),
        cost_model    : CostModel          = None,
        ops           : [Op]               = (),
        improve_count   : Value              = None):
    """Improve the target expression using enumerative synthesis.

    This function is a generator that yields increasingly better and better
    versions of the input expression `target` in the given `context`.  The
    `cost_model` defines "better".

    It periodically calls `stop_callback` and exits gracefully when
    `stop_callback` returns True.

    Other parameters:
        - assumptions: a precondition.  The yielded improvements will only be
          correct when the assumptions are true.
        - hints: expressions that might be useful.  These will be explored
          first when looking for improvements.
        - examples: inputs that will be used internally to differentiate
          semantically distinct expressions.  This procedure discovers more
          examples as it runs, so there usually isn't a reason to provide any.
        - ops: update operations.  This function may make different choices
          about what expressions are state expressions based on what changes
          can happen to that state.

    Key differences from "regular" enumerative synthesis:
        - Expressions are either "state" expressions or "runtime" expressions,
          allowing this algorithm to choose what things to store on the data
          structure and what things to compute at query execution time. (The
          cost model is ultimately responsible for this choice.)
        - If a better version of *any subexpression* for the target is found,
          it is immediately substituted in and the overall expression is
          returned. This "smooths out" the search space a little, allowing us
          find kinda-good solutions very quickly, even if the best possible
          solution is out of reach.  This is more desireable than running for
          an indeterminate amount of time doing nothing.
    """

    print("call to improve:")
    print("""improve(
        target={target!r},
        context={context!r},
        assumptions={assumptions!r},
        stop_callback={stop_callback!r},
        hints={hints!r},
        examples={examples!r},
        cost_model={cost_model!r},
        ops={ops!r})""".format(
            target=target,
            context=context,
            assumptions=assumptions,
            stop_callback=stop_callback,
            hints=hints,
            examples=examples,
            cost_model=cost_model,
            ops=ops))

    target = inline_lets(target)
    target = freshen_binders(target, context)
    assumptions = freshen_binders(assumptions, context)

    if heuristic_done(target):
        print("The target already looks great!")
        return

    print()
    print("improving: {}".format(pprint(target)))
    print("subject to: {}".format(pprint(assumptions)))
    print()

    is_wf = exp_wf(target, context=context, assumptions=assumptions)
    assert is_wf, "initial target is not well-formed: {}".format(is_wf)

    state_vars = [v for (v, p) in context.vars() if p == STATE_POOL]
    if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars):
        print("This job does not depend on state_vars.")
        # TODO: what can we do about it?

    hints = ([freshen_binders(h, context) for h in hints]
        + [freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions)]
        + [target])
    print("{} hints".format(len(hints)))
    for h in hints:
        print(" - {}".format(pprint(h)))
    vars = list(v for (v, p) in context.vars())
    funcs = context.funcs()

    solver = solver_for_context(context, assumptions=assumptions)

    if not solver.satisfiable(ETRUE):
        print("assumptions are unsat; this query will never be called")
        yield construct_value(target.type)
        return

    is_good = possibly_useful(solver, target, context)
    assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format(is_good, target)

    examples = list(examples)

    if cost_model is None:
        cost_model = CostModel(funcs=funcs, assumptions=assumptions)

    watched_targets = [target]
    blacklist = {}

    while True:
        # 1. find any potential improvement to any sub-exp of target
        for new_target in search_for_improvements(
                targets=watched_targets,
                wf_solver=solver,
                context=context,
                examples=examples,
                cost_model=cost_model,
                stop_callback=stop_callback,
                hints=hints,
                ops=ops,
                blacklist=blacklist):
            print("Found candidate improvement: {}".format(pprint(new_target)))

            # 2. check
            with task("verifying candidate"):
                counterexample = solver.satisfy(ENot(EEq(target, new_target)))

            if counterexample is not None:
                if counterexample in examples:
                    print("assumptions = {!r}".format(assumptions))
                    print("duplicate example: {!r}".format(counterexample))
                    print("old target = {!r}".format(target))
                    print("new target = {!r}".format(new_target))
                    raise Exception("got a duplicate example")
                # a. if incorrect: add example, restart
                examples.append(counterexample)
                print("new example: {!r}".format(counterexample))
                print("wrong; restarting with {} examples".format(len(examples)))
                break
            else:
                # b. if correct: yield it, watch the new target, goto 1
                print("The candidate is valid!")
                print(repr(new_target))
                print("Determining whether to yield it...")
                with task("updating frontier"):
                    to_evict = []
                    keep = True
                    old_better = None
                    for old_target in watched_targets:
                        evc = retention_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model)
                        if old_target not in evc:
                            to_evict.append(old_target)
                        if new_target not in evc:
                            old_better = old_target
                            keep = False
                            break
                    for t in to_evict:
                        watched_targets.remove(t)
                    if not keep:
                        print("Whoops! Looks like we already found something better.")
                        print(" --> {}".format(pprint(old_better)))
                        continue
                    if target in to_evict:
                        print("Yep, it's an improvement!")
                        yield new_target
                        if heuristic_done(new_target):
                            print("target now matches doneness heuristic")
                            return
                        target = new_target
                    else:
                        print("Nope, it isn't substantially better!")

                watched_targets.append(new_target)
                print("Now watching {} targets".format(len(watched_targets)))
                break

        if improve_count is not None:
            with improve_count.get_lock():
                improve_count.value += 1
예제 #27
0
파일: core.py 프로젝트: timwee/cozy
def improve(target: Exp,
            context: Context,
            assumptions: Exp = T,
            stop_callback=never_stop,
            hints: [Exp] = (),
            examples: [{
                str: object
            }] = (),
            cost_model: CostModel = None):
    """
    Improve the target expression using enumerative synthesis.
    This function is a generator that yields increasingly better and better
    versions of the input expression `target`.

    Notes on internals of this algorithm follow.

    Key differences from "regular" enumerative synthesis:
        - Expressions are either "state" expressions or "runtime" expressions,
          allowing this algorithm to choose what things to store on the data
          structure and what things to compute at query execution time. (The
          cost model is ultimately responsible for this choice.)
        - If a better version of *any subexpression* for the target is found,
          it is immediately substituted in and the overall expression is
          returned. This "smooths out" the search space a little, and lets us
          find kinda-good solutions very quickly, even if the best possible
          solution is out of reach.
    """

    print("call to improve:")
    print("""improve(
        target={target!r},
        context={context!r},
        assumptions={assumptions!r},
        stop_callback={stop_callback!r},
        hints={hints!r},
        examples={examples!r},
        cost_model={cost_model!r})""".format(target=target,
                                             context=context,
                                             assumptions=assumptions,
                                             stop_callback=stop_callback,
                                             hints=hints,
                                             examples=examples,
                                             cost_model=cost_model))

    target = freshen_binders(target, context)
    assumptions = freshen_binders(assumptions, context)

    print()
    print("improving: {}".format(pprint(target)))
    print("subject to: {}".format(pprint(assumptions)))
    print()

    try:
        assert exp_wf(target, context=context, assumptions=assumptions)
    except ExpIsNotWf as ex:
        print(
            "WARNING: initial target is not well-formed [{}]; this might go poorly..."
            .format(str(ex)))
        print(pprint(ex.offending_subexpression))
        print(pprint(ex.offending_subexpression.type))
        # raise

    state_vars = [v for (v, p) in context.vars() if p == STATE_POOL]
    if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars):
        print("This job does not depend on state_vars.")
        # TODO: what can we do about it?

    hints = ([freshen_binders(h, context) for h in hints] + [
        freshen_binders(wrap_naked_statevars(a, state_vars), context)
        for a in break_conj(assumptions)
    ] + [target])
    print("{} hints".format(len(hints)))
    for h in hints:
        print(" - {}".format(pprint(h)))
    vars = list(v for (v, p) in context.vars())
    funcs = context.funcs()

    solver = None
    if incremental.value:
        solver = IncrementalSolver(vars=vars, funcs=funcs)
        solver.add_assumption(assumptions)
        _sat = solver.satisfy
    else:
        _sat = lambda e: satisfy(e, vars=vars, funcs=funcs)

    if _sat(assumptions) is None:
        print("assumptions are unsat; this query will never be called")
        yield construct_value(target.type)
        return

    examples = list(examples)

    if cost_model is None:
        cost_model = CostModel(funcs=funcs, assumptions=assumptions)

    watched_targets = [target]
    learner = Learner(watched_targets, assumptions, context, examples,
                      cost_model, stop_callback, hints)
    try:
        while True:
            # 1. find any potential improvement to any sub-exp of target
            for new_target in learner.next():
                print("Found candidate improvement: {}".format(
                    pprint(new_target)))

                # 2. check
                with task("verifying candidate"):
                    if incremental.value:
                        solver.push()
                        solver.add_assumption(
                            ENot(
                                EBinOp(target, "==",
                                       new_target).with_type(BOOL)))
                        counterexample = _sat(T)
                    else:
                        formula = EAll([
                            assumptions,
                            ENot(
                                EBinOp(target, "==",
                                       new_target).with_type(BOOL))
                        ])
                        counterexample = _sat(formula)
                if counterexample is not None:
                    if counterexample in examples:
                        print("assumptions = {!r}".format(assumptions))
                        print("duplicate example: {!r}".format(counterexample))
                        print("old target = {!r}".format(target))
                        print("new target = {!r}".format(new_target))
                        raise Exception("got a duplicate example")
                    # a. if incorrect: add example, reset the learner
                    examples.append(counterexample)
                    event("new example: {!r}".format(counterexample))
                    print("wrong; restarting with {} examples".format(
                        len(examples)))
                    learner.reset(examples)
                    break
                else:
                    # b. if correct: yield it, watch the new target, goto 1
                    print("The candidate is valid!")
                    print(repr(new_target))
                    print("Determining whether to yield it...")
                    with task("updating frontier"):
                        to_evict = []
                        keep = True
                        old_better = None
                        for old_target in watched_targets:
                            evc = eviction_policy(new_target, context,
                                                  old_target, context,
                                                  RUNTIME_POOL, cost_model)
                            if old_target not in evc:
                                to_evict.append(old_target)
                            if new_target not in evc:
                                old_better = old_target
                                keep = False
                                break
                        for t in to_evict:
                            watched_targets.remove(t)
                        if not keep:
                            print(
                                "Whoops! Looks like we already found something better."
                            )
                            print(" --> {}".format(pprint(old_better)))
                            continue
                        if target in to_evict:
                            print("Yep, it's an improvement!")
                            yield new_target
                            if heuristic_done(new_target):
                                print("target now matches doneness heuristic")
                                raise NoMoreImprovements()
                            target = new_target
                        else:
                            print("Nope, it isn't substantially better!")

                    watched_targets.append(new_target)
                    print("Now watching {} targets".format(
                        len(watched_targets)))
                    learner.watch(watched_targets)
                    break

                if incremental.value:
                    solver.pop()
    except NoMoreImprovements:
        return
    except KeyboardInterrupt:
        raise
예제 #28
0
    def enumerate_with_info(self, context: Context, size: int,
                            pool: Pool) -> [EnumeratedExp]:
        canonical_context = self.canonical_context(context)
        if canonical_context is not context:
            print("adapting request: {} ---> {}".format(
                context, canonical_context))
            for info in self.enumerate_with_info(canonical_context, size,
                                                 pool):
                yield info._replace(e=context.adapt(info.e, canonical_context))
            return

        if context.parent() is not None:
            yield from self.enumerate_with_info(context.parent(), size, pool)

        k = (pool, size, context)
        res = self.cache.get(k)
        if res is not None:
            # print("[[{} cached @ size={}]]".format(len(res), size))
            for e in res:
                yield e
        else:
            # print("ENTER {}".format(k))
            examples = context.instantiate_examples(self.examples)
            assert k not in self.in_progress, "recursive enumeration?? {}".format(
                k)
            self.in_progress.add(k)
            res = []
            self.cache[k] = res
            queue = self.enumerate_core(context, size, pool)
            cost_model = self.cost_model
            while True:
                if self.stop_callback():
                    raise StopException()

                try:
                    e = next(queue)
                except StopIteration:
                    break

                fvs = free_vars(e)
                if not belongs_in_context(fvs, context):
                    continue

                e = freshen_binders(e, context)
                _consider(e, context, pool)

                wf = self.check_wf(e, context, pool)
                if not wf:
                    _skip(e, context, pool, "wf={}".format(wf))
                    continue

                fp = fingerprint(e, examples)

                # collect all expressions from parent contexts
                with task("collecting prev exps",
                          size=size,
                          context=context,
                          pool=pool_name(pool)):
                    prev = []
                    for sz in range(0, size + 1):
                        prev.extend(self.enumerate_with_info(
                            context, sz, pool))
                    prev = [p.e for p in prev if p.fingerprint == fp]

                if any(alpha_equivalent(e, p) for p in prev):
                    _skip(e, context, pool, "duplicate")
                    should_keep = False
                else:
                    # decide whether to keep this expression,
                    # decide which can be evicted
                    should_keep = True
                    # cost = self.cost_model.cost(e, pool)
                    # print("prev={}".format(prev))
                    # print("seen={}".format(self.seen))
                    with task("comparing to cached equivalents"):
                        for prev_exp in prev:
                            event("previous: {}".format(pprint(prev_exp)))
                            # prev_cost = self.cost_model.cost(prev_exp, pool)
                            # ordering = cost.compare_to(prev_cost)
                            to_keep = eviction_policy(e, context, prev_exp,
                                                      context, pool,
                                                      cost_model)
                            if e not in to_keep:
                                _skip(e, context, pool,
                                      "preferring {}".format(pprint(prev_exp)))
                                should_keep = False
                                break

                            # if ordering == Order.LT:
                            #     pass
                            # elif ordering == Order.GT:
                            #     self.blacklist.add(e_key)
                            #     _skip(e, context, pool, "worse than {}".format(pprint(prev_exp)))
                            #     should_keep = False
                            #     break
                            # else:
                            #     self.blacklist.add(e_key)
                            #     _skip(e, context, pool, "{} to cached {}".format(
                            #         "equal" if ordering == Order.EQUAL else "similar",
                            #         pprint(prev_exp)))
                            #     assert ordering in (Order.EQUAL, Order.AMBIGUOUS)
                            #     should_keep = False
                            #     break

                if should_keep:

                    with task("evicting"):
                        to_evict = []
                        for (key, exps) in self.cache.items():
                            (p, s, c) = key
                            if p == pool and c in itertools.chain(
                                [context], parent_contexts(context)):
                                for ee in exps:
                                    if ee.fingerprint == fp:  # and cost_model.compare(e, ee.e, context, pool) == Order.LT:
                                        # to_evict.append((key, ee))
                                        to_keep = eviction_policy(
                                            e, context, ee.e, c, pool,
                                            cost_model)
                                        if ee.e not in to_keep:
                                            to_evict.append((key, ee))
                        for key, ee in to_evict:
                            (p, s, c) = key
                            # self.blacklist.add((ee.e, c, pool))
                            _evict(ee.e, c, pool, e)
                            self.cache[key].remove(ee)
                            self.seen[(c, p, fp)].remove(ee.e)

                    _accept(e, context, pool)
                    seen_key = (context, pool, fp)
                    if seen_key not in self.seen:
                        self.seen[seen_key] = []
                    self.seen[seen_key].append(e)
                    info = EnumeratedExp(e=e, fingerprint=fp, cost=None)
                    res.append(info)
                    yield info

                    with task("accelerating"):
                        to_try = make_random_access(
                            self.heuristics(e, context, pool))
                        if to_try:
                            # print("trying {} accelerations".format(len(to_try)))
                            queue = itertools.chain(to_try, queue)

            # print("EXIT {}".format(k))
            self.in_progress.remove(k)
예제 #29
0
def improve(target: Exp,
            context: Context,
            assumptions: Exp = ETRUE,
            stop_callback: Callable[[], bool] = never_stop,
            hints: [Exp] = (),
            examples: [{
                str: object
            }] = (),
            cost_model: CostModel = None,
            ops: [Op] = (),
            improve_count: Value = None):
    """Improve the target expression using enumerative synthesis.

    This function is a generator that yields increasingly better and better
    versions of the input expression `target` in the given `context`.  The
    `cost_model` defines "better".

    It periodically calls `stop_callback` and exits gracefully when
    `stop_callback` returns True.

    Other parameters:
        - assumptions: a precondition.  The yielded improvements will only be
          correct when the assumptions are true.
        - hints: expressions that might be useful.  These will be explored
          first when looking for improvements.
        - examples: inputs that will be used internally to differentiate
          semantically distinct expressions.  This procedure discovers more
          examples as it runs, so there usually isn't a reason to provide any.
        - ops: update operations.  This function may make different choices
          about what expressions are state expressions based on what changes
          can happen to that state.

    Key differences from "regular" enumerative synthesis:
        - Expressions are either "state" expressions or "runtime" expressions,
          allowing this algorithm to choose what things to store on the data
          structure and what things to compute at query execution time. (The
          cost model is ultimately responsible for this choice.)
        - If a better version of *any subexpression* for the target is found,
          it is immediately substituted in and the overall expression is
          returned. This "smooths out" the search space a little, allowing us
          find kinda-good solutions very quickly, even if the best possible
          solution is out of reach.  This is more desireable than running for
          an indeterminate amount of time doing nothing.
    """

    print("call to improve:")
    print("""improve(
        target={target!r},
        context={context!r},
        assumptions={assumptions!r},
        stop_callback={stop_callback!r},
        hints={hints!r},
        examples={examples!r},
        cost_model={cost_model!r},
        ops={ops!r})""".format(target=target,
                               context=context,
                               assumptions=assumptions,
                               stop_callback=stop_callback,
                               hints=hints,
                               examples=examples,
                               cost_model=cost_model,
                               ops=ops))

    target = inline_lets(target)
    target = freshen_binders(target, context)
    assumptions = freshen_binders(assumptions, context)

    if heuristic_done(target):
        print("The target already looks great!")
        return

    print()
    print("improving: {}".format(pprint(target)))
    print("subject to: {}".format(pprint(assumptions)))
    print()

    is_wf = exp_wf(target, context=context, assumptions=assumptions)
    assert is_wf, "initial target is not well-formed: {}".format(is_wf)

    state_vars = [v for (v, p) in context.vars() if p == STATE_POOL]
    if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars):
        print("This job does not depend on state_vars.")
        # TODO: what can we do about it?

    hints = ([freshen_binders(h, context) for h in hints] + [
        freshen_binders(wrap_naked_statevars(a, state_vars), context)
        for a in break_conj(assumptions)
    ] + [target])
    print("{} hints".format(len(hints)))
    for h in hints:
        print(" - {}".format(pprint(h)))
    vars = list(v for (v, p) in context.vars())
    funcs = context.funcs()

    solver = solver_for_context(context, assumptions=assumptions)

    if not solver.satisfiable(ETRUE):
        print("assumptions are unsat; this query will never be called")
        yield construct_value(target.type)
        return

    is_good = possibly_useful(solver, target, context)
    assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format(
        is_good, target)

    examples = list(examples)

    if cost_model is None:
        cost_model = CostModel(funcs=funcs, assumptions=assumptions)

    watched_targets = [target]
    blacklist = {}

    while True:
        # 1. find any potential improvement to any sub-exp of target
        for new_target in search_for_improvements(targets=watched_targets,
                                                  wf_solver=solver,
                                                  context=context,
                                                  examples=examples,
                                                  cost_model=cost_model,
                                                  stop_callback=stop_callback,
                                                  hints=hints,
                                                  ops=ops,
                                                  blacklist=blacklist):
            print("Found candidate improvement: {}".format(pprint(new_target)))

            # 2. check
            with task("verifying candidate"):
                counterexample = solver.satisfy(ENot(EEq(target, new_target)))

            if counterexample is not None:
                if counterexample in examples:
                    print("assumptions = {!r}".format(assumptions))
                    print("duplicate example: {!r}".format(counterexample))
                    print("old target = {!r}".format(target))
                    print("new target = {!r}".format(new_target))
                    raise Exception("got a duplicate example")
                # a. if incorrect: add example, restart
                examples.append(counterexample)
                print("new example: {!r}".format(counterexample))
                print("wrong; restarting with {} examples".format(
                    len(examples)))
                break
            else:
                # b. if correct: yield it, watch the new target, goto 1
                print("The candidate is valid!")
                print(repr(new_target))
                print("Determining whether to yield it...")
                with task("updating frontier"):
                    to_evict = []
                    keep = True
                    old_better = None
                    for old_target in watched_targets:
                        evc = retention_policy(new_target, context, old_target,
                                               context, RUNTIME_POOL,
                                               cost_model)
                        if old_target not in evc:
                            to_evict.append(old_target)
                        if new_target not in evc:
                            old_better = old_target
                            keep = False
                            break
                    for t in to_evict:
                        watched_targets.remove(t)
                    if not keep:
                        print(
                            "Whoops! Looks like we already found something better."
                        )
                        print(" --> {}".format(pprint(old_better)))
                        continue
                    if target in to_evict:
                        print("Yep, it's an improvement!")
                        yield new_target
                        if heuristic_done(new_target):
                            print("target now matches doneness heuristic")
                            return
                        target = new_target
                    else:
                        print("Nope, it isn't substantially better!")

                watched_targets.append(new_target)
                print("Now watching {} targets".format(len(watched_targets)))
                break

        if improve_count is not None:
            with improve_count.get_lock():
                improve_count.value += 1