def find_refinement(ast, state_map, lib, assumptions): assumptions = EAll( itertools.chain(assumptions, ast.assumptions, (EEq(EVar(v).with_type(e.type), e) for (v, e) in state_map.items()))) for (v, t) in ast.statevars: refs = list(lib.impls(EVar(v).with_type(t), assumptions=assumptions)) if not (len(refs) == 1 and refs[0] == t): return (v, refs) return None
def visit_EUnaryOp(self, e): op = e.op if op == UOp.Distinct: return self.visit_iterable(e) elif op == UOp.The: return self.find_one(e.e) elif op == UOp.Sum: sum_var = fresh_var(e.type, "sum") loop_var = fresh_var(e.e.type.elem_type, "x") self.stms.append(simplify_and_optimize(seq([ SDecl(sum_var, ENum(0).with_type(e.type)), SForEach(loop_var, e.e, SAssign(sum_var, EBinOp(sum_var, "+", loop_var).with_type(INT)))]))) return sum_var elif op == UOp.Length: arg = EVar("x").with_type(e.e.type.elem_type) return self.visit(EUnaryOp(UOp.Sum, EMap(e.e, ELambda(arg, ONE)).with_type(INT_BAG)).with_type(INT)) elif op == UOp.All: arg = EVar("x").with_type(e.e.type.elem_type) return self.visit(EUnaryOp(UOp.Empty, EFilter(e.e, ELambda(arg, ENot(arg))).with_type(INT_BAG)).with_type(INT)) elif op == UOp.Any: arg = EVar("x").with_type(e.e.type.elem_type) return self.visit(EUnaryOp(UOp.Exists, EFilter(e.e, ELambda(arg, arg)).with_type(INT_BAG)).with_type(INT)) elif op == UOp.Empty: iterable = e.e v = fresh_var(BOOL, "v") label = fresh_name("label") x = fresh_var(iterable.type.elem_type, "x") decl = SDecl(v, ETRUE) find = SEscapableBlock(label, SForEach(x, iterable, seq([ SAssign(v, EFALSE), SEscapeBlock(label)]))) self.stms.append(simplify_and_optimize(seq([decl, find]))) return v elif op == UOp.Exists: return self.visit(ENot(EUnaryOp(UOp.Empty, e.e).with_type(BOOL))) # elif op == UOp.AreUnique: # s = fresh_var(TSet(e.e.type.elem_type), "unique_elems") # u = fresh_var(BOOL, "is_unique") # x = fresh_var(e.e.type.elem_type) # label = fresh_name("label") # self.visit(seq([ # SDecl(s, EEmptyList().with_type(s.type)), # SDecl(u, ETRUE), # SEscapableBlock(label, # SForEach(x, e.e, # SIf(EEscape("{s}.find({x}) != {s}.end()", ("s", "x"), (s, x)).with_type(BOOL), # seq([SAssign(u, EFALSE), SEscapeBlock(label)]), # SEscape("{indent}{s}.insert({x});\n", ("s", "x"), (s, x)))))])) # return u.id return self.visit_Exp(e)
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm: def pull(e: Exp) -> Exp: if exp_is_bad(e): v = fresh_var(e.type) decls_out.append(SDecl(v.id, e)) return v return e if isinstance(s, SNoOp): return s if isinstance(s, SSeq): s1 = pull_temps(s.s1, decls_out, exp_is_bad) s2 = pull_temps(s.s2, decls_out, exp_is_bad) return SSeq(s1, s2) if isinstance(s, SIf): cond = pull(s.cond) s1 = pull_temps(s.then_branch, decls_out, exp_is_bad) s2 = pull_temps(s.else_branch, decls_out, exp_is_bad) return SIf(cond, s1, s2) if isinstance(s, SForEach): bag = pull(s.iter) d_tmp = [] body = pull_temps(s.body, d_tmp, exp_is_bad) to_fix, ok = partition(d_tmp, lambda d: s.id in free_vars(d.val)) decls_out.extend(ok) for d in to_fix: v = EVar(d.id).with_type(d.val.type) mt = TMap(s.id.type, v.type) m = EMakeMap2(bag, ELambda(s.id, d.val)).with_type(mt) mv = fresh_var(m.type) md = SDecl(mv.id, m) decls_out.append(md) body = subst(body, {v.id: EMapGet(mv, s.id).with_type(v.type)}) return SForEach(s.id, bag, body) if isinstance(s, SAssign): return SAssign(s.lhs, pull(s.rhs)) if isinstance(s, SCall): return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args)) if isinstance(s, SMapDel): return SMapDel(s.map, pull(s.key)) if isinstance(s, SMapPut): return SMapPut(s.map, pull(s.key), pull(s.value)) if isinstance(s, SMapUpdate): key = pull(s.key) d_tmp = [] change = pull_temps(s.change, d_tmp, exp_is_bad) for d in d_tmp: if s.val_var in free_vars(d.val): decls_out.append( SDecl( d.id, subst( d.val, { s.val_var.id: EMapGet(s.map, key).with_type(s.val_var.type) }))) else: decls_out.append(d) return SMapUpdate(s.map, key, s.val_var, change) raise NotImplementedError(s)
def visit_Query(self, q): if q.visibility != Visibility.Public: return "" ret_type = q.ret.type if is_collection(ret_type): x = EVar(self.fn("x")).with_type(ret_type.elem_type) def body(x): return SEscape("{indent}_callback.accept({x});\n", ["x"], [x]) if q.docstring: self.write(indent_lines(q.docstring, self.get_indent()), "\n") self.begin_statement() self.write("public ", self.visit(TNative("void"), q.name), "(") self.visit_args(itertools.chain(q.args, [("_callback", TNative("java.util.function.Consumer<{t}>".format(t=self.visit(ret_type.elem_type, ""))))])) self.write(") ") with self.block(): self.visit(simplify_and_optimize(SForEach(x, q.ret, SEscape("{indent}_callback.accept({x});\n", ["x"], [x])))) else: if q.docstring: self.write(indent_lines(q.docstring, self.get_indent()), "\n") self.begin_statement() self.write("public ", self.visit(ret_type, q.name), "(") self.visit_args(q.args) self.write(") ") with self.block(): self.visit(simplify_and_optimize(SReturn(q.ret))) self.end_statement()
def queries_equivalent(q1: Query, q2: Query, state_vars: [EVar], extern_funcs: {str: TFunc}, assumptions: Exp = T): with task("checking query equivalence", q1=q1.name, q2=q2.name): if q1.ret.type != q2.ret.type: return False q1args = dict(q1.args) q2args = dict(q2.args) if q1args != q2args: return False args = FrozenDict(q1args) key = (args, assumptions) checker = _qe_cache.get(key) if checker is None: checker = ModelCachingSolver(vars=list( itertools.chain(state_vars, (EVar(v).with_type(t) for v, t in args.items()))), funcs=extern_funcs, assumptions=assumptions) _qe_cache[key] = checker q1a = EAll(q1.assumptions) q2a = EAll(q2.assumptions) return checker.valid(EEq(q1a, q2a)) and checker.valid( EImplies(q1a, EEq(q1.ret, q2.ret)))
def queries_equivalent(q1: Query, q2: Query, state_vars: [EVar], extern_funcs: {str: TFunc}, assumptions: Exp = ETRUE): """Determine whether two queries always return the same result. This function also checks that the two queries have semantically equivalent preconditions. Checking the preconditions is necessary to ensure semantic equivalence of the queries: a query object should be interpreted to mean "if my preconditions hold then I compute and return my body expression". If two queries do not have semantically equivalent preconditions, then there might be cases where one is obligated to return a value and the other has no defined behavior. """ with task("checking query equivalence", q1=q1.name, q2=q2.name): if q1.ret.type != q2.ret.type: return False q1args = dict(q1.args) q2args = dict(q2.args) if q1args != q2args: return False checker = solver_for_context(context=RootCtx( state_vars=state_vars, args=[EVar(a).with_type(t) for (a, t) in q1.args], funcs=extern_funcs), assumptions=assumptions) q1a = EAll(q1.assumptions) q2a = EAll(q2.assumptions) return checker.valid(EEq(q1a, q2a)) and checker.valid( EImplies(q1a, EEq(q1.ret, q2.ret)))
def subset_of(self, other) -> bool: """Test for subset inclusion. If this returns True, then it could be the case that an expression with this fingerprint always returns a strict subset of the elements that would be returned by an expression with the other fingerprint. """ if not is_collection(self.type): raise ValueError( "this fingerprint is not for a collection-type expression") if not is_collection(other.type): raise ValueError( "other fingerprint is not for a collection-type expression") self._require_comparable_to(other) x = EVar("x").with_type(self.type) y = EVar("y").with_type(other.type) is_subset = EIsSubset(x, y) return all( eval_bulk(is_subset, [{ x.id: a, y.id: b } for (a, b) in zip(self.outputs, other.outputs)]))
def reconcile_jobs(): """Sync up the current set of jobs and the set of queries. This function spawns new jobs for new queries and cleans up old jobs whose queries have been dead-code-eliminated.""" # figure out what new jobs we need job_query_names = set(j.q.name for j in improvement_jobs) new = [] for q in impl.query_specs: if q.name not in job_query_names: states_maintained_by_q = impl.states_maintained_by(q) print("STARTING IMPROVEMENT JOB {}".format(q.name)) new.append( ImproveQueryJob( impl.abstract_state, list(impl.spec.assumptions) + list(q.assumptions), q, context=impl.context_for_method(q), solutions_q=solutions_q.handle_for_subjobs(), hints=[ EStateVar(c).with_type(c.type) for c in impl.concretization_functions.values() ], freebies=[ e for (v, e) in impl.concretization_functions.items() if EVar(v) in states_maintained_by_q ], ops=impl.op_specs, improve_count=improve_count)) # figure out what old jobs we can stop impl_query_names = set(q.name for q in impl.query_specs) old = [ j for j in improvement_jobs if j.q.name not in impl_query_names ] # make it so stop_jobs(old) for j in new: j.start() improvement_jobs.extend(new)
def visit_Spec(self, spec: Spec, state_exps: {str: Exp}, sharing, abstract_state=()): self.state_exps = state_exps self.funcs = {f.name: f for f in spec.extern_funcs} self.queries = { q.name: q for q in spec.methods if isinstance(q, Query) } self.vars = set(e.id for e in all_exps(spec) if isinstance(e, EVar)) self.write("#pragma once\n") self.write("#include <algorithm>\n") self.write("#include <set>\n") self.write("#include <functional>\n") self.write("#include <vector>\n") self.write("#include <unordered_set>\n") self.write("#include <string>\n") if self.use_qhash: self.write("#include <QHash>\n") else: self.write("#include <unordered_map>\n") if spec.header: self.write("\n" + spec.header.strip() + "\n") self.write("{}\nclass {} {{\n".format( ("\n" + spec.docstring) if spec.docstring else "", spec.name)) self.write("public:\n") print("Setting up auxiliary types...") self.setup_types(spec, state_exps, sharing) with self.indented(): for t, name in self.types.items(): self.define_type(spec.name, t, name, sharing) self.begin_statement() if isinstance(t, THandle): # No overridden hash code! We use pointers instead. continue self.write("struct _Hash", name, " ") with self.block(): self.write_stmt("typedef ", spec.name, "::", name, " argument_type;") self.write_stmt("typedef std::size_t result_type;") self.begin_statement() self.write( "result_type operator()(const argument_type& x) const noexcept " ) x = EVar("x").with_type(t) if isinstance(t, TEnum): fields = [EEnumToInt(x).with_type(INT)] elif isinstance(t, TRecord): fields = [ EGetField(x, f).with_type(ft) for (f, ft) in t.fields ] elif isinstance(t, TTuple): fields = [ ETupleGet(x, n).with_type(tt) for (n, tt) in enumerate(t.ts) ] else: raise NotImplementedError(t) with self.block(): self.visit(self.compute_hash(fields)) self.end_statement() self.write(";") self.end_statement() print("Setting up member variables...") self.write("protected:\n") with self.indented(): for name, t in spec.statevars: self.statevar_name = name self.declare_field(name, t) self.write("public:\n") with self.indented(): print("Generating constructors...") # default constructor self.begin_statement() self.write("inline ", spec.name, "() ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] fvs = free_vars(initial_value) initial_value = subst(initial_value, { v.id: evaluation.construct_value(v.type) for v in fvs }) stm = simplify_and_optimize( SAssign(EVar(name).with_type(t), initial_value)) self.visit(stm) self.end_statement() # explicit constructor if abstract_state: self.begin_statement() self.write("explicit inline ", spec.name, "(") self.visit_args(abstract_state) self.write(") ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] self.visit( simplify_and_optimize( SAssign( EVar(name).with_type(t), initial_value))) self.end_statement() # disable copy constructor (TODO: support this in the future?) self.begin_statement() self.write(spec.name, "(const ", spec.name, "& other) = delete;") self.end_statement() # generate methods for op in spec.methods: print("Generating method {}...".format(op.name)) self.visit(op) self.write("};\n") if spec.footer: self.write("\n", spec.footer) if not spec.footer.endswith("\n"): self.write("\n")
def define_type(self, toplevel_name, t, name, sharing): if isinstance(t, TEnum): self.begin_statement() self.write("enum ", name, " ") with self.block(): for case in t.cases: self.begin_statement() self.write(case, ",") self.end_statement() self.write(";") self.end_statement() elif isinstance(t, THandle): fields = [("val", t.value_type)] self.begin_statement() self.write("struct ", name, " ") with self.block(): with self.deindented(): self.write_stmt("public:") for (f, ft) in fields: self.declare_field(f, ft) with self.deindented(): self.write_stmt("private:") self.write(";") self.end_statement() # s = "struct {name} {{\n".format(indent=indent, name=name) # s += "public:\n".format(indent=indent) # for (f, ft) in fields: # s += self.declare_field(f, ft) # s += "private:\n".format(indent=indent) # s += "friend class {toplevel_name};\n".format(indent=indent+INDENT, toplevel_name=toplevel_name) # for group in sharing.get(t, []): # s += "union {{\n".format(indent=indent+INDENT) # for gt in group: # intrusive_data = gt.intrusive_data(t) # s += "struct {{\n".format(indent=indent+INDENT*2) # for (f, ft) in intrusive_data: # s += "{field_decl};\n".format(indent=indent+INDENT*3, field_decl=self.visit(ft, f)) # s += "}};\n".format(indent=indent+INDENT*2) # s += "}};\n".format(indent=indent+INDENT) # s += "}};\n".format(indent=indent) # return s elif isinstance(t, TRecord): self.begin_statement() self.write("struct ", name, " ") with self.block(): # TODO: sort fields by size descending for better packing for f, ft in t.fields: self.declare_field(f, ft) self.write_stmt("inline ", name, "() { }") self.begin_statement() self.write("inline ", name, "(") self.visit_args([("_" + f, t) for (f, t) in t.fields]) self.write(") : ") for i, (f, ft) in enumerate(t.fields): if i > 0: self.write(", ") self.write(f, "(::std::move(_", f, "))") self.write(" { }") self.end_statement() self.begin_statement() self.write("inline bool operator==(const ", name, "& other) const ") with self.block(): this = EEscape("(*this)", (), ()).with_type(t) other = EVar("other").with_type(t) r = self.visit( EAll([ EEq( EGetField(this, f).with_type(ft), EGetField(other, f).with_type(ft)) for f, ft in t.fields ])) self.begin_statement() self.write("return ", r, ";") self.end_statement() self.end_statement() self.write(";") self.end_statement() elif isinstance(t, TTuple): return self.define_type( toplevel_name, TRecord( tuple( ("_{}".format(i), t.ts[i]) for i in range(len(t.ts)))), name, sharing) else: return ""
def visit_Spec(self, spec, state_exps, sharing, abstract_state=()): self.state_exps = state_exps self.funcs = { f.name: f for f in spec.extern_funcs } self.queries = { q.name: q for q in spec.methods if isinstance(q, Query) } self.vars = set(e.id for e in all_exps(spec) if isinstance(e, EVar)) self.setup_types(spec, state_exps, sharing) if guava.value: self.write("import com.google.common.collect.TreeMultiset;\n") self.write("import com.google.common.collect.Iterators;\n") if spec.header: self.write(spec.header.strip() + "\n\n") if spec.docstring: self.write(spec.docstring + "\n") self.write("public class {} implements java.io.Serializable ".format(spec.name)) with self.block(): for name, t in spec.types: self.types[t] = name # member variables for name, t in spec.statevars: self.write("{}protected {};\n".format(INDENT, self.visit(t, name))) # constructor self.write( "{indent}public {name}() {{\n{indent2}clear();\n{indent}}}\n\n" .format(indent=INDENT, indent2=INDENT+INDENT, name=spec.name)) # explicit constructor if abstract_state: self.begin_statement() self.write("public ", spec.name, "(") self.visit_args(abstract_state) self.write(") ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] self.visit(simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value))) self.end_statement() # clear self.begin_statement() self.write("public void clear() ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] fvs = free_vars(initial_value) initial_value = subst(initial_value, {v.id : evaluation.construct_value(v.type) for v in fvs}) setup = simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value)) self.visit(setup) self.end_statement() # methods for op in spec.methods: self.visit(op) # generate auxiliary types for t, name in self.types.items(): self.define_type(spec.name, t, name, sharing) self.write("\n") self.write(spec.footer) if not spec.footer.endswith("\n"): self.write("\n")
def define_type(self, toplevel_name, t, name, sharing): if isinstance(t, TEnum): self.begin_statement() self.write("public enum ", name, " ") with self.block(): for case in t.cases: self.begin_statement() self.write(case) self.end_statement() self.end_statement() elif isinstance(t, THandle) or isinstance(t, TRecord): public_fields = [] private_fields = [] value_equality = True handle_val_is_this = False if isinstance(t, THandle): if isinstance(t.value_type, TRecord): handle_val_is_this = True else: public_fields = [("val", t.value_type)] value_equality = False for group in sharing.get(t, []): for gt in group: intrusive_data = gt.intrusive_data(t) for (f, ft) in intrusive_data: private_fields.append((f, ft)) else: public_fields = list(t.fields) all_fields = public_fields + private_fields self.begin_statement() self.write("public static class ", name) if handle_val_is_this: self.write(" extends ", self.visit(t.value_type, "")) self.write(" implements java.io.Serializable ") with self.block(): for (f, ft) in public_fields + private_fields: self.begin_statement() self.write("private {field_decl};".format(field_decl=self.visit(ft, f))) self.end_statement() for (f, ft) in public_fields: self.begin_statement() self.write("public {type} get{Field}() {{ return {field}; }}".format( type=self.visit(ft, ""), Field=common.capitalize(f), field=f)) self.end_statement() if handle_val_is_this: self.begin_statement() self.write("public {type} getVal() {{ return this; }}".format( type=self.visit(t.value_type, ""))) self.end_statement() def flatten(field_types): args = [] exps = [] for ft in field_types: if isinstance(ft, TRecord): aa, ee = flatten([t for (f, t) in ft.fields]) args.extend(aa) exps.append(EMakeRecord(tuple((f, e) for ((f, _), e) in zip(ft.fields, ee))).with_type(ft)) elif isinstance(ft, TTuple): aa, ee = flatten(ft.ts) args.extend(aa) exps.append(ETuple(tuple(ee)).with_type(ft)) else: v = self.fv(ft, "v") args.append((v.id, ft)) exps.append(v) return args, exps if isinstance(t, THandle): args, exps = flatten([ft for (f, ft) in (t.value_type.fields if handle_val_is_this else public_fields)]) else: args = public_fields exps = [EVar(f) for (f, ft) in args] self.begin_statement() self.write("public {ctor}({args}) ".format(ctor=name, args=", ".join(self.visit(ft, f) for (f, ft) in args))) with self.block(): if handle_val_is_this: es = [self.visit(e) for e in exps] self.begin_statement() self.write("super({args});\n".format( args=", ".join(es))) for ((f, ft), e) in zip(public_fields, exps): e = self.visit(e) self.begin_statement() self.write("this.{f} = {e};\n".format(f=f, e=e)) self.end_statement() if value_equality: self.begin_statement() self.write("@Override") self.end_statement() self.begin_statement() self.write("public int hashCode() ") with self.block(): (compute, hc) = self.compute_hash(public_fields + private_fields) self.write(compute) self.begin_statement() self.write("return ", hc, ";") self.end_statement() self.end_statement() self.begin_statement() self.write("@Override") self.end_statement() self.begin_statement() self.write("public boolean equals(Object other) ") with self.block(): self.write(self.get_indent(), "if (other == null) return false;\n") self.write(self.get_indent(), "if (other == this) return true;\n") self.write(self.get_indent(), "if (!(other instanceof {name})) return false;\n".format(name=name)) self.write(self.get_indent(), "{name} o = ({name})other;\n".format(name=name)) eq = self.visit(EAll([EEq( EEscape("this.{}".format(f), (), ()).with_type(ft), EEscape("o.{}".format(f), (), ()).with_type(ft)) for (f, ft) in all_fields])) self.write(self.get_indent(), "return ", eq, ";\n") self.end_statement() self.end_statement() elif isinstance(t, TTuple): return self.define_type(toplevel_name, TRecord(tuple(("_{}".format(i), t.ts[i]) for i in range(len(t.ts)))), name, sharing) else: return ""
def context_for_method(self, m: Method) -> Context: """Construct a context describing expressions in the given method.""" return RootCtx(state_vars=self.abstract_state, args=[EVar(a).with_type(t) for (a, t) in m.args], funcs=self.extern_funcs)
def abstract_state(self) -> [EVar]: """Returns the abstract state of this data structure.""" return [EVar(name).with_type(t) for (name, t) in self.spec.statevars]
def search_for_improvements(targets: [Exp], wf_solver: ModelCachingSolver, context: Context, examples: [{ str: object }], cost_model: CostModel, stop_callback: Callable[[], bool], hints: [Exp], ops: [Op], blacklist: { (Exp, Context, Pool, Exp): str }): """Search for potential improvements to any of the target expressions. This function yields expressions that look like improvements (or are ambiguous with respect to some target). The expressions are only guaranteed to be correct on the given examples. This function may add new items to the given blacklist. """ root_ctx = context def check_wf(e, ctx, pool): with task("pruning", size=e.size()): is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver) if not is_wf: return is_wf res = possibly_useful(wf_solver, e, ctx, pool, ops=ops) if not res: return res if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare( e, targets[0], ctx, pool) == Order.GT: return No("too expensive") return True with task("setting up hints"): frags = list( unique( itertools.chain( *[ all_subexpressions_with_context_information( t, root_ctx) for t in targets ], *[ all_subexpressions_with_context_information( h, root_ctx) for h in hints ]))) frags.sort(key=hint_order) enum = Enumerator(examples=examples, cost_model=cost_model, check_wf=check_wf, hints=frags, heuristics=try_optimize, stop_callback=stop_callback, do_eviction=enable_eviction.value) target_fp = Fingerprint.of(targets[0], examples) with task("setting up watches"): watches_by_context = OrderedDict() for target in targets: for e, ctx, pool in unique( all_subexpressions_with_context_information( target, context=root_ctx, pool=RUNTIME_POOL)): l = watches_by_context.get(ctx) if l is None: l = [] watches_by_context[ctx] = l l.append((target, e, pool)) watches = OrderedDict() for ctx, exprs in watches_by_context.items(): exs = ctx.instantiate_examples(examples) for target, e, pool in exprs: fp = Fingerprint.of(e, exs) k = (fp, ctx, pool) l = watches.get(k) if l is None: l = [] watches[k] = l l.append((target, e)) watched_ctxs = list( unique( (ctx, pool) for _, _, ctx, pool in exploration_order(targets, root_ctx))) search_info = SearchInfo(context=root_ctx, targets=targets, target_fingerprint=target_fp, examples=examples, check_wf=check_wf, cost_model=cost_model, blacklist=blacklist) size = 0 while True: print("starting minor iteration {} with |cache|={}".format( size, enum.cache_size())) if stop_callback(): raise StopException() for ctx, pool in watched_ctxs: with task("searching for obvious substitutions", ctx=ctx, pool=pool_name(pool)): for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool): with task("searching for obvious substitution", expression=pprint(info.e)): fp = info.fingerprint for ((fpx, cc, pp), reses) in watches.items(): if cc != ctx or pp != pool: continue if not fpx.equal_to(fp): continue for target, watched_e in reses: replacement = info.e event( "possible substitution: {} ---> {}".format( pprint(watched_e), pprint(replacement))) event("replacement locations: {}".format( pprint( replace(target, root_ctx, RUNTIME_POOL, watched_e, ctx, pool, EVar("___"))))) if alpha_equivalent(watched_e, replacement): event("no change") continue yield from _consider_replacement( target, watched_e, ctx, pool, replacement, search_info) if check_blind_substitutions.value: print("Guessing at substitutions...") for target, e, ctx, pool in exploration_order(targets, root_ctx): with task("checking substitutions", target=pprint( replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, EVar("___"))), e=pprint(e)): for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool): with task("checking substitution", expression=pprint(info.e)): if stop_callback(): raise StopException() replacement = info.e if replacement.type != e.type: event("wrong type (is {}, need {})".format( pprint(replacement.type), pprint(e.type))) continue if alpha_equivalent(replacement, e): event("no change") continue should_consider = should_consider_replacement( target, root_ctx, e, ctx, pool, Fingerprint.of( e, ctx.instantiate_examples(examples)), info.e, info.fingerprint) if not should_consider: event( "skipped; `should_consider_replacement` returned {}" .format(should_consider)) continue yield from _consider_replacement( target, e, ctx, pool, replacement, search_info) if not enum.expressions_may_exist_above_size(context, RUNTIME_POOL, size): raise StopException( "no more expressions can exist above size={}".format(size)) size += 1
def search(self): root_ctx = self.context def check_wf(e, ctx, pool): with task("checking well-formedness", size=e.size()): is_wf = exp_wf(e, pool=pool, context=ctx, solver=self.wf_solver) if not is_wf: return is_wf res = good_idea_recursive(self.wf_solver, e, ctx, pool, ops=self.ops) if not res: return res if pool == RUNTIME_POOL and self.cost_model.compare( e, self.targets[0], ctx, pool) == Order.GT: return No("too expensive") return True frags = list( unique( itertools.chain(*[shred(t, root_ctx) for t in self.targets], *[shred(h, root_ctx) for h in self.hints]))) frags.sort(key=hint_order) enum = Enumerator(examples=self.examples, cost_model=self.cost_model, check_wf=check_wf, hints=frags, heuristics=try_optimize, stop_callback=self.stop_callback, do_eviction=enable_eviction.value) size = 0 target_fp = fingerprint(self.targets[0], self.examples) watches = OrderedDict() for target in self.targets: for e, ctx, pool in unique( shred(target, context=root_ctx, pool=RUNTIME_POOL)): exs = ctx.instantiate_examples(self.examples) fp = fingerprint(e, exs) k = (fp, ctx, pool) l = watches.get(k) if l is None: l = [] watches[k] = l l.append((target, e)) watched_ctxs = list( unique((ctx, pool) for fp, ctx, pool in watches.keys())) def consider_new_target(old_target, e, ctx, pool, replacement): nonlocal n n += 1 k = (e, ctx, pool, replacement) if enable_blacklist.value and k in self.blacklist: event("blacklisted") print("skipping blacklisted substitution: {} ---> {} ({})". format(pprint(e), pprint(replacement), self.blacklist[k])) return new_target = freshen_binders( replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, replacement), root_ctx) if any(alpha_equivalent(t, new_target) for t in self.targets): event("already seen") return wf = check_wf(new_target, root_ctx, RUNTIME_POOL) if not wf: msg = "not well-formed [wf={}]".format(wf) event(msg) self.blacklist[k] = msg return if not fingerprints_match(fingerprint(new_target, self.examples), target_fp): msg = "not correct" event(msg) self.blacklist[k] = msg return if self.cost_model.compare(new_target, target, root_ctx, RUNTIME_POOL) not in (Order.LT, Order.AMBIGUOUS): msg = "not an improvement" event(msg) self.blacklist[k] = msg return print("FOUND A GUESS AFTER {} CONSIDERED".format(n)) print(" * in {}".format(pprint(old_target), pprint(e), pprint(replacement))) print(" * replacing {}".format(pprint(e))) print(" * with {}".format(pprint(replacement))) yield new_target while True: print("starting minor iteration {} with |cache|={}".format( size, enum.cache_size())) if self.stop_callback(): raise StopException() n = 0 for ctx, pool in watched_ctxs: with task("searching for obvious substitutions", ctx=ctx, pool=pool_name(pool)): for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool): with task("searching for obvious substitution", expression=pprint(info.e)): fp = info.fingerprint for ((fpx, cc, pp), reses) in watches.items(): if cc != ctx or pp != pool: continue if not fingerprints_match(fpx, fp): continue for target, watched_e in reses: replacement = info.e event("possible substitution: {} ---> {}". format(pprint(watched_e), pprint(replacement))) event("replacement locations: {}".format( pprint( replace(target, root_ctx, RUNTIME_POOL, watched_e, ctx, pool, EVar("___"))))) if alpha_equivalent( watched_e, replacement): event("no change") continue yield from consider_new_target( target, watched_e, ctx, pool, replacement) if check_all_substitutions.value: print("Guessing at substitutions...") for target, e, ctx, pool in exploration_order( self.targets, root_ctx): with task("checking substitutions", target=pprint( replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, EVar("___"))), e=pprint(e)): for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool): with task("checking substitution", expression=pprint(info.e)): if self.stop_callback(): raise StopException() replacement = info.e if replacement.type != e.type: event("wrong type (is {}, need {})".format( pprint(replacement.type), pprint(e.type))) continue if alpha_equivalent(replacement, e): event("no change") continue should_consider = should_consider_replacement( target, root_ctx, e, ctx, pool, fingerprint( e, ctx.instantiate_examples( self.examples)), info.e, info.fingerprint) if not should_consider: event( "skipped; `should_consider_replacement` returned {}" .format(should_consider)) continue yield from consider_new_target( target, e, ctx, pool, replacement) print("CONSIDERED {}".format(n)) size += 1