def queries_equivalent(q1: Query, q2: Query, state_vars: [EVar], extern_funcs: {str: TFunc}, assumptions: Exp = T): with task("checking query equivalence", q1=q1.name, q2=q2.name): if q1.ret.type != q2.ret.type: return False q1args = dict(q1.args) q2args = dict(q2.args) if q1args != q2args: return False args = FrozenDict(q1args) key = (args, assumptions) checker = _qe_cache.get(key) if checker is None: checker = ModelCachingSolver(vars=list( itertools.chain(state_vars, (EVar(v).with_type(t) for v, t in args.items()))), funcs=extern_funcs, assumptions=assumptions) _qe_cache[key] = checker q1a = EAll(q1.assumptions) q2a = EAll(q2.assumptions) return checker.valid(EEq(q1a, q2a)) and checker.valid( EImplies(q1a, EEq(q1.ret, q2.ret)))
def queries_equivalent(q1: Query, q2: Query, state_vars: [EVar], extern_funcs: {str: TFunc}, assumptions: Exp = ETRUE): """Determine whether two queries always return the same result. This function also checks that the two queries have semantically equivalent preconditions. Checking the preconditions is necessary to ensure semantic equivalence of the queries: a query object should be interpreted to mean "if my preconditions hold then I compute and return my body expression". If two queries do not have semantically equivalent preconditions, then there might be cases where one is obligated to return a value and the other has no defined behavior. """ with task("checking query equivalence", q1=q1.name, q2=q2.name): if q1.ret.type != q2.ret.type: return False q1args = dict(q1.args) q2args = dict(q2.args) if q1args != q2args: return False checker = solver_for_context(context=RootCtx( state_vars=state_vars, args=[EVar(a).with_type(t) for (a, t) in q1.args], funcs=extern_funcs), assumptions=assumptions) q1a = EAll(q1.assumptions) q2a = EAll(q2.assumptions) return checker.valid(EEq(q1a, q2a)) and checker.valid( EImplies(q1a, EEq(q1.ret, q2.ret)))
def visit_EArrayIndexOf(self, e): if isinstance(e.a, EVar): pass elif isinstance(e.a, ETupleGet) and isinstance(e.a.e, EVar): pass else: raise NotImplementedError( "finding index of non-var array" ) # TODO: make this fast when this is false it = self.fv( TNative("{}::const_iterator".format( self.visit(e.a.type, "").strip())), "cursor") res = self.fv(INT, "index") self.visit( seq([ SDecl( it, EEscape("std::find({a}.begin(), {a}.end(), {x})", ("a", "x"), (e.a, e.x)).with_type(it.type)), SDecl( res, ECond( EEq( it, EEscape("{a}.end()", ("a", ), (e.a, )).with_type(it.type)), ENum(-1).with_type(INT), EEscape("({it} - {a}.begin())", ( "it", "a", ), ( it, e.a, )).with_type(INT)).with_type(INT)) ])) return res.id
def _setup_handle_updates(self): """ This method creates update code for handle objects modified by each op. Must be called once after all user-specified queries have been added. """ for op in self.op_specs: print("Setting up handle updates for {}...".format(op.name)) handles = reachable_handles_at_method(self.spec, op) # print("-"*60) for t, bag in handles.items(): # print(" {} : {}".format(pprint(t), pprint(bag))) h = fresh_var(t) lval = EGetField(h, "val").with_type(t.value_type) new_val = inc.mutate(lval, op.body) # get set of modified handles modified_handles = Query( fresh_name("modified_handles"), Visibility.Internal, [], op.assumptions, EFilter( EUnaryOp(UOp.Distinct, bag).with_type(bag.type), ELambda(h, ENot(EEq(lval, new_val)))).with_type(bag.type), "[{}] modified handles of type {}".format( op.name, pprint(t))) query_vars = [ v for v in free_vars(modified_handles) if v not in self.abstract_state ] modified_handles.args = [(arg.id, arg.type) for arg in query_vars] # modify each one subqueries = [] state_update_stm = inc.mutate_in_place( lval, lval, op.body, abstract_state=self.abstract_state, assumptions=list(op.assumptions) + [EDeepIn(h, bag), EIn(h, modified_handles.ret)], invariants=self.abstract_invariants, subgoals_out=subqueries) for sub_q in subqueries: sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring) state_update_stm = self._add_subquery( sub_q=sub_q, used_by=state_update_stm) if state_update_stm != SNoOp(): state_update_stm = SForEach( h, ECall(modified_handles.name, query_vars).with_type(bag.type), state_update_stm) state_update_stm = self._add_subquery( sub_q=modified_handles, used_by=state_update_stm) self.handle_updates[(t, op.name)] = state_update_stm
def queries_equivalent(q1: Query, q2: Query): if q1.ret.type != q2.ret.type: return False q1args = dict(q1.args) q2args = dict(q2.args) if q1args != q2args: return False q1a = EAll(q1.assumptions) q2a = EAll(q2.assumptions) return valid(EImplies(EAny([q1a, q2a]), EEq(q1.ret, q2.ret)))
def find_refinement(ast, state_map, lib, assumptions): assumptions = EAll( itertools.chain(assumptions, ast.assumptions, (EEq(EVar(v).with_type(e.type), e) for (v, e) in state_map.items()))) for (v, t) in ast.statevars: refs = list(lib.impls(EVar(v).with_type(t), assumptions=assumptions)) if not (len(refs) == 1 and refs[0] == t): return (v, refs) return None
def can_elim_vars(spec: Exp, assumptions: Exp, vs: [EVar]): """Does any execution of `spec` actually depend on any of `vs`? It is possible for a variable to appear in an expression like `spec` without affecting its value. This function uses the solver to determine whether any of the given variables can affect the output of `spec`. """ spec = strip_EStateVar(spec) sub = {v.id: fresh_var(v.type) for v in vs} return valid( EImplies(EAll([assumptions, subst(assumptions, sub)]), EEq(spec, subst(spec, sub))))
def set_impl(self, q: Query, rep: [(EVar, Exp)], ret: Exp): """Update the implementation of a query. The query having the same name as `q` will have its implementation replaced by the given concrete representation and computation. This call may add additional "subqueries" to the implementation to maintain the new representation when each update operation is called. """ with task("updating implementation", query=q.name): with task("finding duplicated state vars"): to_remove = set() for (v, e) in rep: aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and self.state_solver.valid(EEq(e, ee))) # aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and alpha_equivalent(e, ee)) if aeq is not None: event("state var {} is equivalent to {}".format( v.id, aeq.id)) ret = subst(ret, {v.id: aeq}) to_remove.add(v) rep = [x for x in rep if x[0] not in to_remove] self._concretization_functions.extend(rep) self.query_impls[q.name] = rewrite_ret(q, lambda prev: ret, keep_assumptions=False) for op in self.op_specs: with task("incrementalizing query", query=q.name, op=op.name): for new_member, projection in rep: subqueries = [] state_update_stm = inc.mutate_in_place( new_member, projection, op.body, abstract_state=self.abstract_state, assumptions=op.assumptions, invariants=self.abstract_invariants, subgoals_out=subqueries) for sub_q in subqueries: sub_q.docstring = "[{}] {}".format( op.name, sub_q.docstring) state_update_stm = self._add_subquery( sub_q=sub_q, used_by=state_update_stm) self.updates[(new_member, op.name)] = state_update_stm
def visit_EGetField(self, e): ee = self.visit(e.e) op = "." if isinstance(e.e.type, THandle): # Ugh, we really need Cozy to know about partial functions... # Cozy doesn't know that handle types (aka pointers) can be null. # It assumes that reads of null pointers produce default- # constructed values, so we need to generate appropriate code. ee = EEscape(ee, (), ()).with_type(e.e.type) null = ENull().with_type(e.e.type) return self.visit( ECond(EEq(ee, null), evaluation.construct_value(e.type), EEscape("{ee}->val", ("ee", ), (ee, )).with_type(e.type)).with_type(e.type)) return "({ee}.{f})".format(ee=ee, f=e.field_name)
def visit_EMapGet(self, e): if self.use_trove(e.map.type): if self.trovename(e.map.type.v) == "Object" and not isinstance(evaluation.construct_value(e.map.type.v), ENull): # Le sigh... emap = self.visit(e.map) ekey = self.visit(e.key) v = self.fv(self.box_if_boolean(e.map.type.v), hint="v") self.visit(SDecl(v, EEscape("{emap}.get({ekey})".format(emap=emap, ekey=ekey), [], []).with_type(e.type))) return self.visit(ECond(EEq(v, ENull().with_type(v.type)), evaluation.construct_value(e.map.type.v), v).with_type(e.type)) else: # For Trove, defaults are set at construction time emap = self.visit(e.map) ekey = self.visit(e.key) return "{emap}.get({ekey})".format(emap=emap, ekey=ekey) else: emap = self.visit(e.map) ekey = self.visit(e.key) edefault = self.visit(evaluation.construct_value(e.type)) return "{emap}.getOrDefault({ekey}, {edefault})".format(emap=emap, ekey=ekey, edefault=edefault)
def visit_EBinOp(self, e): op = e.op if op == "==": return self._eq(e.e1, e.e2) elif op == "===": # rewrite deep-equality test into regular equality op = "==" elif op == "!=": return self.visit(ENot(EEq(e.e1, e.e2))) elif op == BOp.Or: return self.visit(ECond(e.e1, ETRUE, e.e2).with_type(BOOL)) elif op == BOp.And: return self.visit(ECond(e.e1, e.e2, EFALSE).with_type(BOOL)) elif op == BOp.In: if isinstance(e.e2.type, TSet): return self.test_set_containment_native(e.e2, e.e1) else: raise Exception( "{!r} operator is supposed to be handled by simplify_and_optimize" .format(op)) return "({e1} {op} {e2})".format(e1=self.visit(e.e1), op=op, e2=self.visit(e.e2))
def define_type(self, toplevel_name, t, name, sharing): if isinstance(t, TEnum): self.begin_statement() self.write("enum ", name, " ") with self.block(): for case in t.cases: self.begin_statement() self.write(case, ",") self.end_statement() self.write(";") self.end_statement() elif isinstance(t, THandle): fields = [("val", t.value_type)] self.begin_statement() self.write("struct ", name, " ") with self.block(): with self.deindented(): self.write_stmt("public:") for (f, ft) in fields: self.declare_field(f, ft) with self.deindented(): self.write_stmt("private:") self.write(";") self.end_statement() # s = "struct {name} {{\n".format(indent=indent, name=name) # s += "public:\n".format(indent=indent) # for (f, ft) in fields: # s += self.declare_field(f, ft) # s += "private:\n".format(indent=indent) # s += "friend class {toplevel_name};\n".format(indent=indent+INDENT, toplevel_name=toplevel_name) # for group in sharing.get(t, []): # s += "union {{\n".format(indent=indent+INDENT) # for gt in group: # intrusive_data = gt.intrusive_data(t) # s += "struct {{\n".format(indent=indent+INDENT*2) # for (f, ft) in intrusive_data: # s += "{field_decl};\n".format(indent=indent+INDENT*3, field_decl=self.visit(ft, f)) # s += "}};\n".format(indent=indent+INDENT*2) # s += "}};\n".format(indent=indent+INDENT) # s += "}};\n".format(indent=indent) # return s elif isinstance(t, TRecord): self.begin_statement() self.write("struct ", name, " ") with self.block(): # TODO: sort fields by size descending for better packing for f, ft in t.fields: self.declare_field(f, ft) self.write_stmt("inline ", name, "() { }") self.begin_statement() self.write("inline ", name, "(") self.visit_args([("_" + f, t) for (f, t) in t.fields]) self.write(") : ") for i, (f, ft) in enumerate(t.fields): if i > 0: self.write(", ") self.write(f, "(::std::move(_", f, "))") self.write(" { }") self.end_statement() self.begin_statement() self.write("inline bool operator==(const ", name, "& other) const ") with self.block(): this = EEscape("(*this)", (), ()).with_type(t) other = EVar("other").with_type(t) r = self.visit( EAll([ EEq( EGetField(this, f).with_type(ft), EGetField(other, f).with_type(ft)) for f, ft in t.fields ])) self.begin_statement() self.write("return ", r, ";") self.end_statement() self.end_statement() self.write(";") self.end_statement() elif isinstance(t, TTuple): return self.define_type( toplevel_name, TRecord( tuple( ("_{}".format(i), t.ts[i]) for i in range(len(t.ts)))), name, sharing) else: return ""
def can_elim_vars(spec: Exp, assumptions: Exp, vs: [EVar]): spec = strip_EStateVar(spec) sub = {v.id: fresh_var(v.type) for v in vs} return valid( EImplies(EAll([assumptions, subst(assumptions, sub)]), EEq(spec, subst(spec, sub))))
def define_type(self, toplevel_name, t, name, sharing): if isinstance(t, TEnum): self.begin_statement() self.write("public enum ", name, " ") with self.block(): for case in t.cases: self.begin_statement() self.write(case) self.end_statement() self.end_statement() elif isinstance(t, THandle) or isinstance(t, TRecord): public_fields = [] private_fields = [] value_equality = True handle_val_is_this = False if isinstance(t, THandle): if isinstance(t.value_type, TRecord): handle_val_is_this = True else: public_fields = [("val", t.value_type)] value_equality = False for group in sharing.get(t, []): for gt in group: intrusive_data = gt.intrusive_data(t) for (f, ft) in intrusive_data: private_fields.append((f, ft)) else: public_fields = list(t.fields) all_fields = public_fields + private_fields self.begin_statement() self.write("public static class ", name) if handle_val_is_this: self.write(" extends ", self.visit(t.value_type, "")) self.write(" implements java.io.Serializable ") with self.block(): for (f, ft) in public_fields + private_fields: self.begin_statement() self.write("private {field_decl};".format(field_decl=self.visit(ft, f))) self.end_statement() for (f, ft) in public_fields: self.begin_statement() self.write("public {type} get{Field}() {{ return {field}; }}".format( type=self.visit(ft, ""), Field=common.capitalize(f), field=f)) self.end_statement() if handle_val_is_this: self.begin_statement() self.write("public {type} getVal() {{ return this; }}".format( type=self.visit(t.value_type, ""))) self.end_statement() def flatten(field_types): args = [] exps = [] for ft in field_types: if isinstance(ft, TRecord): aa, ee = flatten([t for (f, t) in ft.fields]) args.extend(aa) exps.append(EMakeRecord(tuple((f, e) for ((f, _), e) in zip(ft.fields, ee))).with_type(ft)) elif isinstance(ft, TTuple): aa, ee = flatten(ft.ts) args.extend(aa) exps.append(ETuple(tuple(ee)).with_type(ft)) else: v = self.fv(ft, "v") args.append((v.id, ft)) exps.append(v) return args, exps if isinstance(t, THandle): args, exps = flatten([ft for (f, ft) in (t.value_type.fields if handle_val_is_this else public_fields)]) else: args = public_fields exps = [EVar(f) for (f, ft) in args] self.begin_statement() self.write("public {ctor}({args}) ".format(ctor=name, args=", ".join(self.visit(ft, f) for (f, ft) in args))) with self.block(): if handle_val_is_this: es = [self.visit(e) for e in exps] self.begin_statement() self.write("super({args});\n".format( args=", ".join(es))) for ((f, ft), e) in zip(public_fields, exps): e = self.visit(e) self.begin_statement() self.write("this.{f} = {e};\n".format(f=f, e=e)) self.end_statement() if value_equality: self.begin_statement() self.write("@Override") self.end_statement() self.begin_statement() self.write("public int hashCode() ") with self.block(): (compute, hc) = self.compute_hash(public_fields + private_fields) self.write(compute) self.begin_statement() self.write("return ", hc, ";") self.end_statement() self.end_statement() self.begin_statement() self.write("@Override") self.end_statement() self.begin_statement() self.write("public boolean equals(Object other) ") with self.block(): self.write(self.get_indent(), "if (other == null) return false;\n") self.write(self.get_indent(), "if (other == this) return true;\n") self.write(self.get_indent(), "if (!(other instanceof {name})) return false;\n".format(name=name)) self.write(self.get_indent(), "{name} o = ({name})other;\n".format(name=name)) eq = self.visit(EAll([EEq( EEscape("this.{}".format(f), (), ()).with_type(ft), EEscape("o.{}".format(f), (), ()).with_type(ft)) for (f, ft) in all_fields])) self.write(self.get_indent(), "return ", eq, ";\n") self.end_statement() self.end_statement() elif isinstance(t, TTuple): return self.define_type(toplevel_name, TRecord(tuple(("_{}".format(i), t.ts[i]) for i in range(len(t.ts)))), name, sharing) else: return ""
def improve(target: Exp, context: Context, assumptions: Exp = T, stop_callback=never_stop, hints: [Exp] = (), examples: [{ str: object }] = (), cost_model: CostModel = None, ops: [Op] = ()): """ Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target`. Notes on internals of this algorithm follow. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, and lets us find kinda-good solutions very quickly, even if the best possible solution is out of reach. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r}, ops={ops!r})""".format(target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model, ops=ops)) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() is_wf = exp_wf(target, context=context, assumptions=assumptions) if not is_wf: print( "WARNING: initial target is not well-formed [{}]; this might go poorly..." .format(is_wf)) print(pprint(is_wf.offending_subexpression)) print(pprint(is_wf.offending_subexpression.type)) state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [ freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions) ] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = solver_for_context(context, assumptions=assumptions) if not solver.satisfiable(T): print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] learner = Learner(watched_targets, solver, context, examples, cost_model, stop_callback, hints, ops=ops) while True: # 1. find any potential improvement to any sub-exp of target for new_target in learner.search(): print("Found candidate improvement: {}".format(pprint(new_target))) # 2. check with task("verifying candidate"): counterexample = solver.satisfy(ENot(EEq(target, new_target))) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, reset the learner examples.append(counterexample) event("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format( len(examples))) learner.reset(examples) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = eviction_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print( "Whoops! Looks like we already found something better." ) print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") return target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format(len(watched_targets))) learner.watch(watched_targets) break
def _enumerate_core(self, context: Context, size: int, pool: Pool) -> [Exp]: """Build new expressions of the given size. Arguments: context : a Context object describing the vars in scope size : size of expressions to enumerate; each expression in the output will have this size pool : pool to enumerate This function is not cached. Clients should call `enumerate` instead. This function tries to be a clean description of the Cozy grammar. It does not concern itself with deduplication (which is handled efficiently by equivalence class deduplication). """ if size < 0: return if size == 0: for e in LITERALS: yield e all_interesting_types = OrderedSet(self.hint_types) for v, _ in context.vars(): all_interesting_types |= all_types(v.type) for t in all_interesting_types: l = construct_value(t) if l not in LITERALS: yield l for (v, p) in context.vars(): if p == pool: yield v for (e, ctx, p) in self.hints: if p == pool: fvs = free_vars(e) if ctx.alpha_equivalent(context.generalize(fvs)): yield context.adapt(e, ctx, e_fvs=fvs) return if not do_enumerate.value: return def build_lambdas(bag, pool, body_size): v = fresh_var(bag.type.elem_type, omit=set(v for v, p in context.vars())) inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool) for lam_body in self.enumerate(inner_context, body_size, pool): yield ELambda(v, lam_body) # Load all smaller expressions in this context and pool. # cache[S] contains expressions of size S in this context and pool. cache = [list(self.enumerate(context, sz, pool)) for sz in range(size)] # Enable use of a state-pool expression at runtime if pool == RUNTIME_POOL: for e in self.enumerate(context.root(), size - 1, STATE_POOL): yield EStateVar(e).with_type(e.type) # Arity-1 expressions for e in cache[size - 1]: if is_collection(e.type): elem_type = e.type.elem_type # This method of generating EEmptyList() ensures that we visit # empty collections of all possible types. yield EEmptyList().with_type(e.type) if is_numeric(elem_type): yield EUnaryOp(UOp.Sum, e).with_type(elem_type) yield EUnaryOp(UOp.Length, e).with_type(INT) yield EUnaryOp(UOp.Empty, e).with_type(BOOL) yield EUnaryOp(UOp.Exists, e).with_type(BOOL) yield EUnaryOp(UOp.The, e).with_type(elem_type) yield EUnaryOp(UOp.Distinct, e).with_type(e.type) yield EUnaryOp(UOp.AreUnique, e).with_type(BOOL) if elem_type == BOOL: yield EUnaryOp(UOp.Any, e).with_type(BOOL) yield EUnaryOp(UOp.All, e).with_type(BOOL) yield ESingleton(e).with_type(TBag(e.type)) if isinstance(e.type, TRecord): for (f, t) in e.type.fields: yield EGetField(e, f).with_type(t) if isinstance(e.type, THandle): yield EGetField(e, "val").with_type(e.type.value_type) if isinstance(e.type, TTuple): for n in range(len(e.type.ts)): yield ETupleGet(e, n).with_type(e.type.ts[n]) if e.type == BOOL: yield EUnaryOp(UOp.Not, e).with_type(BOOL) if is_numeric(e.type): yield EUnaryOp("-", e).with_type(e.type) if isinstance(e.type, TMap): yield EMapKeys(e).with_type(TBag(e.type.k)) # Arity-2 expressions for (sz1, sz2) in pick_to_sum(2, size - 1): # sz1 + sz2 = size - 1 for e1 in cache[sz1]: t = e1.type if is_numeric(t): for a2 in of_type(cache[sz2], t): yield EBinOp(e1, "+", a2).with_type(t) yield EBinOp(e1, "-", a2).with_type(t) if is_ordered(t): for a2 in of_type(cache[sz2], t): yield EBinOp(e1, ">", a2).with_type(BOOL) yield EBinOp(e1, "<", a2).with_type(BOOL) yield EBinOp(e1, ">=", a2).with_type(BOOL) yield EBinOp(e1, "<=", a2).with_type(BOOL) if t == BOOL: for a2 in of_type(cache[sz2], BOOL): yield EBinOp(e1, BOp.And, a2).with_type(BOOL) yield EBinOp(e1, BOp.Or, a2).with_type(BOOL) # Cozy supports the implication operator "=>", but this # function does not enumerate it because # - (a => b) is equivalent to ((not a) or b) # - there isn't an implication operator in any of our # current target languages, so we would need to # desugar it to ((not a) or b) anyway. if not isinstance(t, TMap): for a2 in of_type(cache[sz2], t): yield EEq(e1, a2) yield EBinOp(e1, "!=", a2).with_type(BOOL) if isinstance(t, TMap): for k in of_type(cache[sz2], t.k): yield EMapGet(e1, k).with_type(t.v) yield EHasKey(e1, k).with_type(BOOL) if isinstance(t, TList): for i in of_type(cache[sz2], INT): yield EListGet(e1, i).with_type(e1.type.elem_type) if is_collection(t): elem_type = t.elem_type for e2 in of_type(cache[sz2], t): yield EBinOp(e1, "+", e2).with_type(t) yield EBinOp(e1, "-", e2).with_type(t) for e2 in of_type(cache[sz2], elem_type): yield EBinOp(e2, BOp.In, e1).with_type(BOOL) for f in build_lambdas(e1, pool, sz2): body_type = f.body.type yield EMap(e1, f).with_type(TBag(body_type)) if body_type == BOOL: yield EFilter(e1, f).with_type(t) if is_numeric(body_type): yield EArgMin(e1, f).with_type(elem_type) yield EArgMax(e1, f).with_type(elem_type) if is_collection(body_type): yield EFlatMap(e1, f).with_type( TBag(body_type.elem_type)) if pool == STATE_POOL and is_hashable(elem_type): yield EMakeMap2(e1, f).with_type( TMap(elem_type, body_type)) e1_singleton = ESingleton(e1).with_type(TBag(e1.type)) for f in build_lambdas(e1_singleton, pool, sz2): yield ELet(e1, f).with_type(f.body.type) # Arity-3 expressions for (sz1, sz2, sz3) in pick_to_sum(3, size - 1): # sz1 + sz2 + sz3 = size - 1 for e1 in cache[sz1]: if e1.type == BOOL: cond = e1 for then_branch in cache[sz2]: for else_branch in of_type(cache[sz3], then_branch.type): yield ECond(cond, then_branch, else_branch).with_type( then_branch.type) if isinstance(e1.type, TList): for start in of_type(cache[sz2], INT): for end in of_type(cache[sz3], INT): yield EListSlice(e1, start, end).with_type(e1.type) # It is not necessary to create slice expressions of # the form a[:i] or a[i:]. Those are desugared # after parsing to a[0:i] and a[i:len(a)] # respectively, and Cozy is perfectly capable of # discovering these expanded forms as well. for h in all_extension_handlers(): yield from h.enumerate(context, size, pool, self.enumerate, build_lambdas)
def improve(target: Exp, context: Context, assumptions: Exp = ETRUE, stop_callback: Callable[[], bool] = never_stop, hints: [Exp] = (), examples: [{ str: object }] = (), cost_model: CostModel = None, ops: [Op] = (), improve_count: Value = None): """Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target` in the given `context`. The `cost_model` defines "better". It periodically calls `stop_callback` and exits gracefully when `stop_callback` returns True. Other parameters: - assumptions: a precondition. The yielded improvements will only be correct when the assumptions are true. - hints: expressions that might be useful. These will be explored first when looking for improvements. - examples: inputs that will be used internally to differentiate semantically distinct expressions. This procedure discovers more examples as it runs, so there usually isn't a reason to provide any. - ops: update operations. This function may make different choices about what expressions are state expressions based on what changes can happen to that state. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, allowing us find kinda-good solutions very quickly, even if the best possible solution is out of reach. This is more desireable than running for an indeterminate amount of time doing nothing. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r}, ops={ops!r})""".format(target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model, ops=ops)) target = inline_lets(target) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) if heuristic_done(target): print("The target already looks great!") return print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() is_wf = exp_wf(target, context=context, assumptions=assumptions) assert is_wf, "initial target is not well-formed: {}".format(is_wf) state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [ freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions) ] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = solver_for_context(context, assumptions=assumptions) if not solver.satisfiable(ETRUE): print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return is_good = possibly_useful(solver, target, context) assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format( is_good, target) examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] blacklist = {} while True: # 1. find any potential improvement to any sub-exp of target for new_target in search_for_improvements(targets=watched_targets, wf_solver=solver, context=context, examples=examples, cost_model=cost_model, stop_callback=stop_callback, hints=hints, ops=ops, blacklist=blacklist): print("Found candidate improvement: {}".format(pprint(new_target))) # 2. check with task("verifying candidate"): counterexample = solver.satisfy(ENot(EEq(target, new_target))) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, restart examples.append(counterexample) print("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format( len(examples))) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = retention_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print( "Whoops! Looks like we already found something better." ) print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") return target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format(len(watched_targets))) break if improve_count is not None: with improve_count.get_lock(): improve_count.value += 1