Esempio n. 1
0
def queries_equivalent(q1: Query,
                       q2: Query,
                       state_vars: [EVar],
                       extern_funcs: {str: TFunc},
                       assumptions: Exp = T):
    with task("checking query equivalence", q1=q1.name, q2=q2.name):
        if q1.ret.type != q2.ret.type:
            return False
        q1args = dict(q1.args)
        q2args = dict(q2.args)
        if q1args != q2args:
            return False
        args = FrozenDict(q1args)

        key = (args, assumptions)
        checker = _qe_cache.get(key)
        if checker is None:
            checker = ModelCachingSolver(vars=list(
                itertools.chain(state_vars, (EVar(v).with_type(t)
                                             for v, t in args.items()))),
                                         funcs=extern_funcs,
                                         assumptions=assumptions)
            _qe_cache[key] = checker

        q1a = EAll(q1.assumptions)
        q2a = EAll(q2.assumptions)
        return checker.valid(EEq(q1a, q2a)) and checker.valid(
            EImplies(q1a, EEq(q1.ret, q2.ret)))
Esempio n. 2
0
def queries_equivalent(q1: Query,
                       q2: Query,
                       state_vars: [EVar],
                       extern_funcs: {str: TFunc},
                       assumptions: Exp = ETRUE):
    """Determine whether two queries always return the same result.

    This function also checks that the two queries have semantically equivalent
    preconditions.  Checking the preconditions is necessary to ensure semantic
    equivalence of the queries: a query object should be interpreted to mean
    "if my preconditions hold then I compute and return my body expression".
    If two queries do not have semantically equivalent preconditions, then
    there might be cases where one is obligated to return a value and the other
    has no defined behavior.
    """

    with task("checking query equivalence", q1=q1.name, q2=q2.name):
        if q1.ret.type != q2.ret.type:
            return False
        q1args = dict(q1.args)
        q2args = dict(q2.args)
        if q1args != q2args:
            return False

        checker = solver_for_context(context=RootCtx(
            state_vars=state_vars,
            args=[EVar(a).with_type(t) for (a, t) in q1.args],
            funcs=extern_funcs),
                                     assumptions=assumptions)

        q1a = EAll(q1.assumptions)
        q2a = EAll(q2.assumptions)
        return checker.valid(EEq(q1a, q2a)) and checker.valid(
            EImplies(q1a, EEq(q1.ret, q2.ret)))
Esempio n. 3
0
 def visit_EArrayIndexOf(self, e):
     if isinstance(e.a, EVar): pass
     elif isinstance(e.a, ETupleGet) and isinstance(e.a.e, EVar): pass
     else:
         raise NotImplementedError(
             "finding index of non-var array"
         )  # TODO: make this fast when this is false
     it = self.fv(
         TNative("{}::const_iterator".format(
             self.visit(e.a.type, "").strip())), "cursor")
     res = self.fv(INT, "index")
     self.visit(
         seq([
             SDecl(
                 it,
                 EEscape("std::find({a}.begin(), {a}.end(), {x})",
                         ("a", "x"), (e.a, e.x)).with_type(it.type)),
             SDecl(
                 res,
                 ECond(
                     EEq(
                         it,
                         EEscape("{a}.end()", ("a", ),
                                 (e.a, )).with_type(it.type)),
                     ENum(-1).with_type(INT),
                     EEscape("({it} - {a}.begin())", (
                         "it",
                         "a",
                     ), (
                         it,
                         e.a,
                     )).with_type(INT)).with_type(INT))
         ]))
     return res.id
Esempio n. 4
0
    def _setup_handle_updates(self):
        """
        This method creates update code for handle objects modified by each op.
        Must be called once after all user-specified queries have been added.
        """
        for op in self.op_specs:
            print("Setting up handle updates for {}...".format(op.name))
            handles = reachable_handles_at_method(self.spec, op)
            # print("-"*60)
            for t, bag in handles.items():
                # print("  {} : {}".format(pprint(t), pprint(bag)))
                h = fresh_var(t)
                lval = EGetField(h, "val").with_type(t.value_type)
                new_val = inc.mutate(lval, op.body)

                # get set of modified handles
                modified_handles = Query(
                    fresh_name("modified_handles"), Visibility.Internal, [],
                    op.assumptions,
                    EFilter(
                        EUnaryOp(UOp.Distinct, bag).with_type(bag.type),
                        ELambda(h, ENot(EEq(lval,
                                            new_val)))).with_type(bag.type),
                    "[{}] modified handles of type {}".format(
                        op.name, pprint(t)))
                query_vars = [
                    v for v in free_vars(modified_handles)
                    if v not in self.abstract_state
                ]
                modified_handles.args = [(arg.id, arg.type)
                                         for arg in query_vars]

                # modify each one
                subqueries = []
                state_update_stm = inc.mutate_in_place(
                    lval,
                    lval,
                    op.body,
                    abstract_state=self.abstract_state,
                    assumptions=list(op.assumptions) +
                    [EDeepIn(h, bag),
                     EIn(h, modified_handles.ret)],
                    invariants=self.abstract_invariants,
                    subgoals_out=subqueries)
                for sub_q in subqueries:
                    sub_q.docstring = "[{}] {}".format(op.name,
                                                       sub_q.docstring)
                    state_update_stm = self._add_subquery(
                        sub_q=sub_q, used_by=state_update_stm)
                if state_update_stm != SNoOp():
                    state_update_stm = SForEach(
                        h,
                        ECall(modified_handles.name,
                              query_vars).with_type(bag.type),
                        state_update_stm)
                    state_update_stm = self._add_subquery(
                        sub_q=modified_handles, used_by=state_update_stm)
                self.handle_updates[(t, op.name)] = state_update_stm
Esempio n. 5
0
def queries_equivalent(q1: Query, q2: Query):
    if q1.ret.type != q2.ret.type:
        return False
    q1args = dict(q1.args)
    q2args = dict(q2.args)
    if q1args != q2args:
        return False
    q1a = EAll(q1.assumptions)
    q2a = EAll(q2.assumptions)
    return valid(EImplies(EAny([q1a, q2a]), EEq(q1.ret, q2.ret)))
Esempio n. 6
0
def find_refinement(ast, state_map, lib, assumptions):
    assumptions = EAll(
        itertools.chain(assumptions, ast.assumptions,
                        (EEq(EVar(v).with_type(e.type), e)
                         for (v, e) in state_map.items())))
    for (v, t) in ast.statevars:
        refs = list(lib.impls(EVar(v).with_type(t), assumptions=assumptions))
        if not (len(refs) == 1 and refs[0] == t):
            return (v, refs)
    return None
Esempio n. 7
0
def can_elim_vars(spec: Exp, assumptions: Exp, vs: [EVar]):
    """Does any execution of `spec` actually depend on any of `vs`?

    It is possible for a variable to appear in an expression like `spec`
    without affecting its value.  This function uses the solver to
    determine whether any of the given variables can affect the output of
    `spec`.
    """
    spec = strip_EStateVar(spec)
    sub = {v.id: fresh_var(v.type) for v in vs}
    return valid(
        EImplies(EAll([assumptions, subst(assumptions, sub)]),
                 EEq(spec, subst(spec, sub))))
Esempio n. 8
0
    def set_impl(self, q: Query, rep: [(EVar, Exp)], ret: Exp):
        """Update the implementation of a query.

        The query having the same name as `q` will have its implementation
        replaced by the given concrete representation and computation.

        This call may add additional "subqueries" to the implementation to
        maintain the new representation when each update operation is called.
        """
        with task("updating implementation", query=q.name):
            with task("finding duplicated state vars"):
                to_remove = set()
                for (v, e) in rep:
                    aeq = find_one(vv
                                   for (vv,
                                        ee) in self._concretization_functions
                                   if e.type == ee.type
                                   and self.state_solver.valid(EEq(e, ee)))
                    # aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and alpha_equivalent(e, ee))
                    if aeq is not None:
                        event("state var {} is equivalent to {}".format(
                            v.id, aeq.id))
                        ret = subst(ret, {v.id: aeq})
                        to_remove.add(v)
                rep = [x for x in rep if x[0] not in to_remove]

            self._concretization_functions.extend(rep)
            self.query_impls[q.name] = rewrite_ret(q,
                                                   lambda prev: ret,
                                                   keep_assumptions=False)

            for op in self.op_specs:
                with task("incrementalizing query", query=q.name, op=op.name):
                    for new_member, projection in rep:
                        subqueries = []
                        state_update_stm = inc.mutate_in_place(
                            new_member,
                            projection,
                            op.body,
                            abstract_state=self.abstract_state,
                            assumptions=op.assumptions,
                            invariants=self.abstract_invariants,
                            subgoals_out=subqueries)
                        for sub_q in subqueries:
                            sub_q.docstring = "[{}] {}".format(
                                op.name, sub_q.docstring)
                            state_update_stm = self._add_subquery(
                                sub_q=sub_q, used_by=state_update_stm)
                        self.updates[(new_member, op.name)] = state_update_stm
Esempio n. 9
0
 def visit_EGetField(self, e):
     ee = self.visit(e.e)
     op = "."
     if isinstance(e.e.type, THandle):
         # Ugh, we really need Cozy to know about partial functions...
         # Cozy doesn't know that handle types (aka pointers) can be null.
         # It assumes that reads of null pointers produce default-
         # constructed values, so we need to generate appropriate code.
         ee = EEscape(ee, (), ()).with_type(e.e.type)
         null = ENull().with_type(e.e.type)
         return self.visit(
             ECond(EEq(ee, null), evaluation.construct_value(e.type),
                   EEscape("{ee}->val", ("ee", ),
                           (ee, )).with_type(e.type)).with_type(e.type))
     return "({ee}.{f})".format(ee=ee, f=e.field_name)
Esempio n. 10
0
 def visit_EMapGet(self, e):
     if self.use_trove(e.map.type):
         if self.trovename(e.map.type.v) == "Object" and not isinstance(evaluation.construct_value(e.map.type.v), ENull):
             # Le sigh...
             emap = self.visit(e.map)
             ekey = self.visit(e.key)
             v = self.fv(self.box_if_boolean(e.map.type.v), hint="v")
             self.visit(SDecl(v, EEscape("{emap}.get({ekey})".format(emap=emap, ekey=ekey), [], []).with_type(e.type)))
             return self.visit(ECond(EEq(v, ENull().with_type(v.type)), evaluation.construct_value(e.map.type.v), v).with_type(e.type))
         else:
             # For Trove, defaults are set at construction time
             emap = self.visit(e.map)
             ekey = self.visit(e.key)
             return "{emap}.get({ekey})".format(emap=emap, ekey=ekey)
     else:
         emap = self.visit(e.map)
         ekey = self.visit(e.key)
         edefault = self.visit(evaluation.construct_value(e.type))
         return "{emap}.getOrDefault({ekey}, {edefault})".format(emap=emap, ekey=ekey, edefault=edefault)
Esempio n. 11
0
 def visit_EBinOp(self, e):
     op = e.op
     if op == "==":
         return self._eq(e.e1, e.e2)
     elif op == "===":
         # rewrite deep-equality test into regular equality
         op = "=="
     elif op == "!=":
         return self.visit(ENot(EEq(e.e1, e.e2)))
     elif op == BOp.Or:
         return self.visit(ECond(e.e1, ETRUE, e.e2).with_type(BOOL))
     elif op == BOp.And:
         return self.visit(ECond(e.e1, e.e2, EFALSE).with_type(BOOL))
     elif op == BOp.In:
         if isinstance(e.e2.type, TSet):
             return self.test_set_containment_native(e.e2, e.e1)
         else:
             raise Exception(
                 "{!r} operator is supposed to be handled by simplify_and_optimize"
                 .format(op))
     return "({e1} {op} {e2})".format(e1=self.visit(e.e1),
                                      op=op,
                                      e2=self.visit(e.e2))
Esempio n. 12
0
    def define_type(self, toplevel_name, t, name, sharing):
        if isinstance(t, TEnum):
            self.begin_statement()
            self.write("enum ", name, " ")
            with self.block():
                for case in t.cases:
                    self.begin_statement()
                    self.write(case, ",")
                    self.end_statement()
            self.write(";")
            self.end_statement()
        elif isinstance(t, THandle):
            fields = [("val", t.value_type)]
            self.begin_statement()
            self.write("struct ", name, " ")
            with self.block():
                with self.deindented():
                    self.write_stmt("public:")
                for (f, ft) in fields:
                    self.declare_field(f, ft)
                with self.deindented():
                    self.write_stmt("private:")
            self.write(";")
            self.end_statement()

            # s = "struct {name} {{\n".format(indent=indent, name=name)
            # s += "public:\n".format(indent=indent)
            # for (f, ft) in fields:
            #     s += self.declare_field(f, ft)
            # s += "private:\n".format(indent=indent)
            # s += "friend class {toplevel_name};\n".format(indent=indent+INDENT, toplevel_name=toplevel_name)
            # for group in sharing.get(t, []):
            #     s += "union {{\n".format(indent=indent+INDENT)
            #     for gt in group:
            #         intrusive_data = gt.intrusive_data(t)
            #         s += "struct {{\n".format(indent=indent+INDENT*2)
            #         for (f, ft) in intrusive_data:
            #             s += "{field_decl};\n".format(indent=indent+INDENT*3, field_decl=self.visit(ft, f))
            #         s += "}};\n".format(indent=indent+INDENT*2)
            #     s += "}};\n".format(indent=indent+INDENT)
            # s += "}};\n".format(indent=indent)
            # return s
        elif isinstance(t, TRecord):
            self.begin_statement()
            self.write("struct ", name, " ")
            with self.block():

                # TODO: sort fields by size descending for better packing
                for f, ft in t.fields:
                    self.declare_field(f, ft)

                self.write_stmt("inline ", name, "() { }")
                self.begin_statement()
                self.write("inline ", name, "(")
                self.visit_args([("_" + f, t) for (f, t) in t.fields])
                self.write(") : ")
                for i, (f, ft) in enumerate(t.fields):
                    if i > 0:
                        self.write(", ")
                    self.write(f, "(::std::move(_", f, "))")
                self.write(" { }")
                self.end_statement()

                self.begin_statement()
                self.write("inline bool operator==(const ", name,
                           "& other) const ")
                with self.block():
                    this = EEscape("(*this)", (), ()).with_type(t)
                    other = EVar("other").with_type(t)
                    r = self.visit(
                        EAll([
                            EEq(
                                EGetField(this, f).with_type(ft),
                                EGetField(other, f).with_type(ft))
                            for f, ft in t.fields
                        ]))
                    self.begin_statement()
                    self.write("return ", r, ";")
                    self.end_statement()
                self.end_statement()
            self.write(";")
            self.end_statement()
        elif isinstance(t, TTuple):
            return self.define_type(
                toplevel_name,
                TRecord(
                    tuple(
                        ("_{}".format(i), t.ts[i]) for i in range(len(t.ts)))),
                name, sharing)
        else:
            return ""
Esempio n. 13
0
def can_elim_vars(spec: Exp, assumptions: Exp, vs: [EVar]):
    spec = strip_EStateVar(spec)
    sub = {v.id: fresh_var(v.type) for v in vs}
    return valid(
        EImplies(EAll([assumptions, subst(assumptions, sub)]),
                 EEq(spec, subst(spec, sub))))
Esempio n. 14
0
    def define_type(self, toplevel_name, t, name, sharing):
        if isinstance(t, TEnum):
            self.begin_statement()
            self.write("public enum ", name, " ")
            with self.block():
                for case in t.cases:
                    self.begin_statement()
                    self.write(case)
                    self.end_statement()
            self.end_statement()
        elif isinstance(t, THandle) or isinstance(t, TRecord):
            public_fields = []
            private_fields = []
            value_equality = True
            handle_val_is_this = False
            if isinstance(t, THandle):
                if isinstance(t.value_type, TRecord):
                    handle_val_is_this = True
                else:
                    public_fields = [("val", t.value_type)]
                value_equality = False
                for group in sharing.get(t, []):
                    for gt in group:
                        intrusive_data = gt.intrusive_data(t)
                        for (f, ft) in intrusive_data:
                            private_fields.append((f, ft))
            else:
                public_fields = list(t.fields)
            all_fields = public_fields + private_fields
            self.begin_statement()
            self.write("public static class ", name)
            if handle_val_is_this:
                self.write(" extends ", self.visit(t.value_type, ""))
            self.write(" implements java.io.Serializable ")
            with self.block():
                for (f, ft) in public_fields + private_fields:
                    self.begin_statement()
                    self.write("private {field_decl};".format(field_decl=self.visit(ft, f)))
                    self.end_statement()
                for (f, ft) in public_fields:
                    self.begin_statement()
                    self.write("public {type} get{Field}() {{ return {field}; }}".format(
                        type=self.visit(ft, ""),
                        Field=common.capitalize(f),
                        field=f))
                    self.end_statement()
                if handle_val_is_this:
                    self.begin_statement()
                    self.write("public {type} getVal() {{ return this; }}".format(
                        type=self.visit(t.value_type, "")))
                    self.end_statement()

                def flatten(field_types):
                    args = []
                    exps = []
                    for ft in field_types:
                        if isinstance(ft, TRecord):
                            aa, ee = flatten([t for (f, t) in ft.fields])
                            args.extend(aa)
                            exps.append(EMakeRecord(tuple((f, e) for ((f, _), e) in zip(ft.fields, ee))).with_type(ft))
                        elif isinstance(ft, TTuple):
                            aa, ee = flatten(ft.ts)
                            args.extend(aa)
                            exps.append(ETuple(tuple(ee)).with_type(ft))
                        else:
                            v = self.fv(ft, "v")
                            args.append((v.id, ft))
                            exps.append(v)
                    return args, exps

                if isinstance(t, THandle):
                    args, exps = flatten([ft for (f, ft) in (t.value_type.fields if handle_val_is_this else public_fields)])
                else:
                    args = public_fields
                    exps = [EVar(f) for (f, ft) in args]
                self.begin_statement()
                self.write("public {ctor}({args}) ".format(ctor=name, args=", ".join(self.visit(ft, f) for (f, ft) in args)))
                with self.block():
                    if handle_val_is_this:
                        es = [self.visit(e) for e in exps]
                        self.begin_statement()
                        self.write("super({args});\n".format(
                            args=", ".join(es)))
                    for ((f, ft), e) in zip(public_fields, exps):
                        e = self.visit(e)
                        self.begin_statement()
                        self.write("this.{f} = {e};\n".format(f=f, e=e))
                self.end_statement()

                if value_equality:
                    self.begin_statement()
                    self.write("@Override")
                    self.end_statement()
                    self.begin_statement()
                    self.write("public int hashCode() ")
                    with self.block():
                        (compute, hc) = self.compute_hash(public_fields + private_fields)
                        self.write(compute)
                        self.begin_statement()
                        self.write("return ", hc, ";")
                        self.end_statement()
                    self.end_statement()

                    self.begin_statement()
                    self.write("@Override")
                    self.end_statement()
                    self.begin_statement()
                    self.write("public boolean equals(Object other) ")
                    with self.block():
                        self.write(self.get_indent(), "if (other == null) return false;\n")
                        self.write(self.get_indent(), "if (other == this) return true;\n")
                        self.write(self.get_indent(), "if (!(other instanceof {name})) return false;\n".format(name=name))
                        self.write(self.get_indent(), "{name} o = ({name})other;\n".format(name=name))
                        eq = self.visit(EAll([EEq(
                            EEscape("this.{}".format(f), (), ()).with_type(ft),
                            EEscape("o.{}".format(f),    (), ()).with_type(ft))
                            for (f, ft) in all_fields]))
                        self.write(self.get_indent(), "return ", eq, ";\n")
                    self.end_statement()
            self.end_statement()
        elif isinstance(t, TTuple):
            return self.define_type(toplevel_name, TRecord(tuple(("_{}".format(i), t.ts[i]) for i in range(len(t.ts)))), name, sharing)
        else:
            return ""
Esempio n. 15
0
def improve(target: Exp,
            context: Context,
            assumptions: Exp = T,
            stop_callback=never_stop,
            hints: [Exp] = (),
            examples: [{
                str: object
            }] = (),
            cost_model: CostModel = None,
            ops: [Op] = ()):
    """
    Improve the target expression using enumerative synthesis.
    This function is a generator that yields increasingly better and better
    versions of the input expression `target`.

    Notes on internals of this algorithm follow.

    Key differences from "regular" enumerative synthesis:
        - Expressions are either "state" expressions or "runtime" expressions,
          allowing this algorithm to choose what things to store on the data
          structure and what things to compute at query execution time. (The
          cost model is ultimately responsible for this choice.)
        - If a better version of *any subexpression* for the target is found,
          it is immediately substituted in and the overall expression is
          returned. This "smooths out" the search space a little, and lets us
          find kinda-good solutions very quickly, even if the best possible
          solution is out of reach.
    """

    print("call to improve:")
    print("""improve(
        target={target!r},
        context={context!r},
        assumptions={assumptions!r},
        stop_callback={stop_callback!r},
        hints={hints!r},
        examples={examples!r},
        cost_model={cost_model!r},
        ops={ops!r})""".format(target=target,
                               context=context,
                               assumptions=assumptions,
                               stop_callback=stop_callback,
                               hints=hints,
                               examples=examples,
                               cost_model=cost_model,
                               ops=ops))

    target = freshen_binders(target, context)
    assumptions = freshen_binders(assumptions, context)

    print()
    print("improving: {}".format(pprint(target)))
    print("subject to: {}".format(pprint(assumptions)))
    print()

    is_wf = exp_wf(target, context=context, assumptions=assumptions)
    if not is_wf:
        print(
            "WARNING: initial target is not well-formed [{}]; this might go poorly..."
            .format(is_wf))
        print(pprint(is_wf.offending_subexpression))
        print(pprint(is_wf.offending_subexpression.type))

    state_vars = [v for (v, p) in context.vars() if p == STATE_POOL]
    if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars):
        print("This job does not depend on state_vars.")
        # TODO: what can we do about it?

    hints = ([freshen_binders(h, context) for h in hints] + [
        freshen_binders(wrap_naked_statevars(a, state_vars), context)
        for a in break_conj(assumptions)
    ] + [target])
    print("{} hints".format(len(hints)))
    for h in hints:
        print(" - {}".format(pprint(h)))
    vars = list(v for (v, p) in context.vars())
    funcs = context.funcs()

    solver = solver_for_context(context, assumptions=assumptions)

    if not solver.satisfiable(T):
        print("assumptions are unsat; this query will never be called")
        yield construct_value(target.type)
        return

    examples = list(examples)

    if cost_model is None:
        cost_model = CostModel(funcs=funcs, assumptions=assumptions)

    watched_targets = [target]
    learner = Learner(watched_targets,
                      solver,
                      context,
                      examples,
                      cost_model,
                      stop_callback,
                      hints,
                      ops=ops)

    while True:
        # 1. find any potential improvement to any sub-exp of target
        for new_target in learner.search():
            print("Found candidate improvement: {}".format(pprint(new_target)))

            # 2. check
            with task("verifying candidate"):
                counterexample = solver.satisfy(ENot(EEq(target, new_target)))

            if counterexample is not None:
                if counterexample in examples:
                    print("assumptions = {!r}".format(assumptions))
                    print("duplicate example: {!r}".format(counterexample))
                    print("old target = {!r}".format(target))
                    print("new target = {!r}".format(new_target))
                    raise Exception("got a duplicate example")
                # a. if incorrect: add example, reset the learner
                examples.append(counterexample)
                event("new example: {!r}".format(counterexample))
                print("wrong; restarting with {} examples".format(
                    len(examples)))
                learner.reset(examples)
                break
            else:
                # b. if correct: yield it, watch the new target, goto 1
                print("The candidate is valid!")
                print(repr(new_target))
                print("Determining whether to yield it...")
                with task("updating frontier"):
                    to_evict = []
                    keep = True
                    old_better = None
                    for old_target in watched_targets:
                        evc = eviction_policy(new_target, context, old_target,
                                              context, RUNTIME_POOL,
                                              cost_model)
                        if old_target not in evc:
                            to_evict.append(old_target)
                        if new_target not in evc:
                            old_better = old_target
                            keep = False
                            break
                    for t in to_evict:
                        watched_targets.remove(t)
                    if not keep:
                        print(
                            "Whoops! Looks like we already found something better."
                        )
                        print(" --> {}".format(pprint(old_better)))
                        continue
                    if target in to_evict:
                        print("Yep, it's an improvement!")
                        yield new_target
                        if heuristic_done(new_target):
                            print("target now matches doneness heuristic")
                            return
                        target = new_target
                    else:
                        print("Nope, it isn't substantially better!")

                watched_targets.append(new_target)
                print("Now watching {} targets".format(len(watched_targets)))
                learner.watch(watched_targets)
                break
Esempio n. 16
0
    def _enumerate_core(self, context: Context, size: int,
                        pool: Pool) -> [Exp]:
        """Build new expressions of the given size.

        Arguments:
            context : a Context object describing the vars in scope
            size    : size of expressions to enumerate; each expression in
                      the output will have this size
            pool    : pool to enumerate

        This function is not cached.  Clients should call `enumerate` instead.

        This function tries to be a clean description of the Cozy grammar.  It
        does not concern itself with deduplication (which is handled
        efficiently by equivalence class deduplication).
        """

        if size < 0:
            return

        if size == 0:
            for e in LITERALS:
                yield e

            all_interesting_types = OrderedSet(self.hint_types)
            for v, _ in context.vars():
                all_interesting_types |= all_types(v.type)
            for t in all_interesting_types:
                l = construct_value(t)
                if l not in LITERALS:
                    yield l

            for (v, p) in context.vars():
                if p == pool:
                    yield v
            for (e, ctx, p) in self.hints:
                if p == pool:
                    fvs = free_vars(e)
                    if ctx.alpha_equivalent(context.generalize(fvs)):
                        yield context.adapt(e, ctx, e_fvs=fvs)
            return

        if not do_enumerate.value:
            return

        def build_lambdas(bag, pool, body_size):
            v = fresh_var(bag.type.elem_type,
                          omit=set(v for v, p in context.vars()))
            inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool)
            for lam_body in self.enumerate(inner_context, body_size, pool):
                yield ELambda(v, lam_body)

        # Load all smaller expressions in this context and pool.
        # cache[S] contains expressions of size S in this context and pool.
        cache = [list(self.enumerate(context, sz, pool)) for sz in range(size)]

        # Enable use of a state-pool expression at runtime
        if pool == RUNTIME_POOL:
            for e in self.enumerate(context.root(), size - 1, STATE_POOL):
                yield EStateVar(e).with_type(e.type)

        # Arity-1 expressions
        for e in cache[size - 1]:
            if is_collection(e.type):
                elem_type = e.type.elem_type

                # This method of generating EEmptyList() ensures that we visit
                # empty collections of all possible types.
                yield EEmptyList().with_type(e.type)

                if is_numeric(elem_type):
                    yield EUnaryOp(UOp.Sum, e).with_type(elem_type)

                yield EUnaryOp(UOp.Length, e).with_type(INT)
                yield EUnaryOp(UOp.Empty, e).with_type(BOOL)
                yield EUnaryOp(UOp.Exists, e).with_type(BOOL)
                yield EUnaryOp(UOp.The, e).with_type(elem_type)
                yield EUnaryOp(UOp.Distinct, e).with_type(e.type)
                yield EUnaryOp(UOp.AreUnique, e).with_type(BOOL)

                if elem_type == BOOL:
                    yield EUnaryOp(UOp.Any, e).with_type(BOOL)
                    yield EUnaryOp(UOp.All, e).with_type(BOOL)

            yield ESingleton(e).with_type(TBag(e.type))

            if isinstance(e.type, TRecord):
                for (f, t) in e.type.fields:
                    yield EGetField(e, f).with_type(t)

            if isinstance(e.type, THandle):
                yield EGetField(e, "val").with_type(e.type.value_type)

            if isinstance(e.type, TTuple):
                for n in range(len(e.type.ts)):
                    yield ETupleGet(e, n).with_type(e.type.ts[n])

            if e.type == BOOL:
                yield EUnaryOp(UOp.Not, e).with_type(BOOL)

            if is_numeric(e.type):
                yield EUnaryOp("-", e).with_type(e.type)

            if isinstance(e.type, TMap):
                yield EMapKeys(e).with_type(TBag(e.type.k))

        # Arity-2 expressions
        for (sz1, sz2) in pick_to_sum(2, size - 1):
            # sz1 + sz2 = size - 1
            for e1 in cache[sz1]:
                t = e1.type

                if is_numeric(t):
                    for a2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, "+", a2).with_type(t)
                        yield EBinOp(e1, "-", a2).with_type(t)

                if is_ordered(t):
                    for a2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, ">", a2).with_type(BOOL)
                        yield EBinOp(e1, "<", a2).with_type(BOOL)
                        yield EBinOp(e1, ">=", a2).with_type(BOOL)
                        yield EBinOp(e1, "<=", a2).with_type(BOOL)

                if t == BOOL:
                    for a2 in of_type(cache[sz2], BOOL):
                        yield EBinOp(e1, BOp.And, a2).with_type(BOOL)
                        yield EBinOp(e1, BOp.Or, a2).with_type(BOOL)
                        # Cozy supports the implication operator "=>", but this
                        # function does not enumerate it because
                        #  - (a => b) is equivalent to ((not a) or b)
                        #  - there isn't an implication operator in any of our
                        #    current target languages, so we would need to
                        #    desugar it to ((not a) or b) anyway.

                if not isinstance(t, TMap):
                    for a2 in of_type(cache[sz2], t):
                        yield EEq(e1, a2)
                        yield EBinOp(e1, "!=", a2).with_type(BOOL)

                if isinstance(t, TMap):
                    for k in of_type(cache[sz2], t.k):
                        yield EMapGet(e1, k).with_type(t.v)
                        yield EHasKey(e1, k).with_type(BOOL)

                if isinstance(t, TList):
                    for i in of_type(cache[sz2], INT):
                        yield EListGet(e1, i).with_type(e1.type.elem_type)

                if is_collection(t):
                    elem_type = t.elem_type
                    for e2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, "+", e2).with_type(t)
                        yield EBinOp(e1, "-", e2).with_type(t)
                    for e2 in of_type(cache[sz2], elem_type):
                        yield EBinOp(e2, BOp.In, e1).with_type(BOOL)
                    for f in build_lambdas(e1, pool, sz2):
                        body_type = f.body.type
                        yield EMap(e1, f).with_type(TBag(body_type))
                        if body_type == BOOL:
                            yield EFilter(e1, f).with_type(t)
                        if is_numeric(body_type):
                            yield EArgMin(e1, f).with_type(elem_type)
                            yield EArgMax(e1, f).with_type(elem_type)
                        if is_collection(body_type):
                            yield EFlatMap(e1, f).with_type(
                                TBag(body_type.elem_type))

                        if pool == STATE_POOL and is_hashable(elem_type):
                            yield EMakeMap2(e1, f).with_type(
                                TMap(elem_type, body_type))

                e1_singleton = ESingleton(e1).with_type(TBag(e1.type))
                for f in build_lambdas(e1_singleton, pool, sz2):
                    yield ELet(e1, f).with_type(f.body.type)

        # Arity-3 expressions
        for (sz1, sz2, sz3) in pick_to_sum(3, size - 1):
            # sz1 + sz2 + sz3 = size - 1
            for e1 in cache[sz1]:
                if e1.type == BOOL:
                    cond = e1
                    for then_branch in cache[sz2]:
                        for else_branch in of_type(cache[sz3],
                                                   then_branch.type):
                            yield ECond(cond, then_branch,
                                        else_branch).with_type(
                                            then_branch.type)
                if isinstance(e1.type, TList):
                    for start in of_type(cache[sz2], INT):
                        for end in of_type(cache[sz3], INT):
                            yield EListSlice(e1, start, end).with_type(e1.type)
                            # It is not necessary to create slice expressions of
                            # the form a[:i] or a[i:].  Those are desugared
                            # after parsing to a[0:i] and a[i:len(a)]
                            # respectively, and Cozy is perfectly capable of
                            # discovering these expanded forms as well.

        for h in all_extension_handlers():
            yield from h.enumerate(context, size, pool, self.enumerate,
                                   build_lambdas)
Esempio n. 17
0
def improve(target: Exp,
            context: Context,
            assumptions: Exp = ETRUE,
            stop_callback: Callable[[], bool] = never_stop,
            hints: [Exp] = (),
            examples: [{
                str: object
            }] = (),
            cost_model: CostModel = None,
            ops: [Op] = (),
            improve_count: Value = None):
    """Improve the target expression using enumerative synthesis.

    This function is a generator that yields increasingly better and better
    versions of the input expression `target` in the given `context`.  The
    `cost_model` defines "better".

    It periodically calls `stop_callback` and exits gracefully when
    `stop_callback` returns True.

    Other parameters:
        - assumptions: a precondition.  The yielded improvements will only be
          correct when the assumptions are true.
        - hints: expressions that might be useful.  These will be explored
          first when looking for improvements.
        - examples: inputs that will be used internally to differentiate
          semantically distinct expressions.  This procedure discovers more
          examples as it runs, so there usually isn't a reason to provide any.
        - ops: update operations.  This function may make different choices
          about what expressions are state expressions based on what changes
          can happen to that state.

    Key differences from "regular" enumerative synthesis:
        - Expressions are either "state" expressions or "runtime" expressions,
          allowing this algorithm to choose what things to store on the data
          structure and what things to compute at query execution time. (The
          cost model is ultimately responsible for this choice.)
        - If a better version of *any subexpression* for the target is found,
          it is immediately substituted in and the overall expression is
          returned. This "smooths out" the search space a little, allowing us
          find kinda-good solutions very quickly, even if the best possible
          solution is out of reach.  This is more desireable than running for
          an indeterminate amount of time doing nothing.
    """

    print("call to improve:")
    print("""improve(
        target={target!r},
        context={context!r},
        assumptions={assumptions!r},
        stop_callback={stop_callback!r},
        hints={hints!r},
        examples={examples!r},
        cost_model={cost_model!r},
        ops={ops!r})""".format(target=target,
                               context=context,
                               assumptions=assumptions,
                               stop_callback=stop_callback,
                               hints=hints,
                               examples=examples,
                               cost_model=cost_model,
                               ops=ops))

    target = inline_lets(target)
    target = freshen_binders(target, context)
    assumptions = freshen_binders(assumptions, context)

    if heuristic_done(target):
        print("The target already looks great!")
        return

    print()
    print("improving: {}".format(pprint(target)))
    print("subject to: {}".format(pprint(assumptions)))
    print()

    is_wf = exp_wf(target, context=context, assumptions=assumptions)
    assert is_wf, "initial target is not well-formed: {}".format(is_wf)

    state_vars = [v for (v, p) in context.vars() if p == STATE_POOL]
    if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars):
        print("This job does not depend on state_vars.")
        # TODO: what can we do about it?

    hints = ([freshen_binders(h, context) for h in hints] + [
        freshen_binders(wrap_naked_statevars(a, state_vars), context)
        for a in break_conj(assumptions)
    ] + [target])
    print("{} hints".format(len(hints)))
    for h in hints:
        print(" - {}".format(pprint(h)))
    vars = list(v for (v, p) in context.vars())
    funcs = context.funcs()

    solver = solver_for_context(context, assumptions=assumptions)

    if not solver.satisfiable(ETRUE):
        print("assumptions are unsat; this query will never be called")
        yield construct_value(target.type)
        return

    is_good = possibly_useful(solver, target, context)
    assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format(
        is_good, target)

    examples = list(examples)

    if cost_model is None:
        cost_model = CostModel(funcs=funcs, assumptions=assumptions)

    watched_targets = [target]
    blacklist = {}

    while True:
        # 1. find any potential improvement to any sub-exp of target
        for new_target in search_for_improvements(targets=watched_targets,
                                                  wf_solver=solver,
                                                  context=context,
                                                  examples=examples,
                                                  cost_model=cost_model,
                                                  stop_callback=stop_callback,
                                                  hints=hints,
                                                  ops=ops,
                                                  blacklist=blacklist):
            print("Found candidate improvement: {}".format(pprint(new_target)))

            # 2. check
            with task("verifying candidate"):
                counterexample = solver.satisfy(ENot(EEq(target, new_target)))

            if counterexample is not None:
                if counterexample in examples:
                    print("assumptions = {!r}".format(assumptions))
                    print("duplicate example: {!r}".format(counterexample))
                    print("old target = {!r}".format(target))
                    print("new target = {!r}".format(new_target))
                    raise Exception("got a duplicate example")
                # a. if incorrect: add example, restart
                examples.append(counterexample)
                print("new example: {!r}".format(counterexample))
                print("wrong; restarting with {} examples".format(
                    len(examples)))
                break
            else:
                # b. if correct: yield it, watch the new target, goto 1
                print("The candidate is valid!")
                print(repr(new_target))
                print("Determining whether to yield it...")
                with task("updating frontier"):
                    to_evict = []
                    keep = True
                    old_better = None
                    for old_target in watched_targets:
                        evc = retention_policy(new_target, context, old_target,
                                               context, RUNTIME_POOL,
                                               cost_model)
                        if old_target not in evc:
                            to_evict.append(old_target)
                        if new_target not in evc:
                            old_better = old_target
                            keep = False
                            break
                    for t in to_evict:
                        watched_targets.remove(t)
                    if not keep:
                        print(
                            "Whoops! Looks like we already found something better."
                        )
                        print(" --> {}".format(pprint(old_better)))
                        continue
                    if target in to_evict:
                        print("Yep, it's an improvement!")
                        yield new_target
                        if heuristic_done(new_target):
                            print("target now matches doneness heuristic")
                            return
                        target = new_target
                    else:
                        print("Nope, it isn't substantially better!")

                watched_targets.append(new_target)
                print("Now watching {} targets".format(len(watched_targets)))
                break

        if improve_count is not None:
            with improve_count.get_lock():
                improve_count.value += 1