def setup_types(self, spec, state_exps, sharing): self.types.clear() names = { t : name for (name, t) in spec.types } for t in itertools.chain(all_types(spec), *[all_types(e) for v, e in state_exps.items()]): if t not in self.types and type(t) in [THandle, TRecord, TTuple, TEnum]: name = names.get(t, self.fn("Type")) self.types[t] = name
def setup_types(self, spec, state_exps, sharing): self.types.clear() names = {t: name for (name, t) in spec.types} for t in itertools.chain( all_types(spec), *[all_types(e) for v, e in state_exps.items()]): if t not in self.types and type(t) in [ THandle, TRecord, TTuple, TEnum ]: name = names.get(t, self.fn("Type")) self.types[t] = name
def compute_sharing(state_map: dict, true_types: dict) -> dict: """ Takes a dictionary mapping { state_var_id : state_exp } and a dictionary mapping { state_var_id : refined_type } and returns a dictionary { ht : groups } for each handle type ht. Each group is a list of implementation types whose intrusive data will never be used at the same time. """ types = set(t for e in state_map.values() for t in syntax_tools.all_types(e.type)) handle_types = set(t for t in types if isinstance(t, target_syntax.THandle)) out = {} # for (var, exp) in state_map.items(): # print(" --> {} = {}".format(var, syntax_tools.pprint(exp))) for ht in handle_types: groups = [] handle = syntax_tools.fresh_var(ht, "handle") # print(ht) # for (var, exp) in state_map.items(): # print(" --> {} iff {}".format(var, syntax_tools.pprint(uses_intrusive_data(exp, handle)))) type_uses_intrusive_data = {} for (var, exp) in state_map.items(): use = uses_intrusive_data(exp, handle) for t in syntax_tools.all_types(true_types[var]): # print(syntax_tools.pprint(t)) if hasattr(t, "intrusive_data"): type_uses_intrusive_data[t] = use # else: # print(" no intrusive data for " + syntax_tools.pprint(t)) # print(type_uses_intrusive_data) for t, cond in type_uses_intrusive_data.items(): found = False for g in groups: if all(not solver.satisfy( target_syntax.EAll([cond, type_uses_intrusive_data[t] ])) for t in g): found = True g.append(t) break if not found: groups.append([t]) # print(" --> {}".format(groups)) out[ht] = groups return out
def __init__(self, examples, cost_model: CostModel, check_wf=None, hints=None, heuristics=None, stop_callback=None, do_eviction=True): """Set up a fresh enumerator. Parameters: - examples: a set of example inputs to deduplicate expressions - cost_model: a cost model to tell us which expressions to prefer - check_wf: an optional additional filter to restrict which expressions are visited - hints: expressions that get treated as size 0 during enumeration, so they are enumerated very early - heuristics: an optional function to improve visited expressions - stop_callback: a function that is checked periodically to stop enumeration - do_eviction: boolean. if true, this class spends time trying to evict older, slower versions of expressions from its cache """ self.examples = list(examples) self.cost_model = cost_model self.cache = ExpCache() # Set of (pool, size, context) tuples that are currently being # enumerated. This is used to catch infinite recursion bugs, since # enumerating expressions in one context may require enumerating # expressions in a different context recursively. self.in_progress = set() # Set of (pool, size, context) tuples that have been fully enumerated; # there are no more expressions to discover and the results have been # cached in `cache`. There is no overlap between this and the # `in_progress` set. self.complete = set() if check_wf is None: check_wf = lambda e, ctx, pool: True self.check_wf = check_wf if hints is None: hints = () self.hints = OrderedSet( (e, ctx.generalize(free_vars(e)), p) for (e, ctx, p) in hints) self.hint_types = OrderedSet() for h, _, _ in self.hints: self.hint_types |= all_types(h) if heuristics is None: heuristics = lambda e, ctx, pool: () self.heuristics = heuristics if stop_callback is None: stop_callback = lambda: False self.stop_callback = stop_callback self.do_eviction = do_eviction self.stat_timer = Periodically(self.print_stats, timespan=datetime.timedelta(seconds=2))
def desugar(spec: Spec) -> Spec: # rewrite enums repl = { name: EEnumEntry(name).with_type(t) for t in all_types(spec) if isinstance(t, TEnum) for name in t.cases } spec = subst(spec, repl) # convert all collection types to bags spec = Spec(spec.name, list(spec.types), list(spec.extern_funcs), list(spec.statevars), list(spec.assumptions), list(spec.methods), spec.header, spec.footer, spec.docstring) for i in range(len(spec.statevars)): v, t = spec.statevars[i] if isinstance(t, TSet): # Sets become bags w/ implicit unique assumptions. t = TBag(t.t) spec.statevars[i] = (v, t) v = EVar(v).with_type(t) spec.assumptions.append(EUnaryOp(UOp.AreUnique, v).with_type(BOOL)) assert retypecheck(spec, env={}) # organize queries by name queries = {q.name: q for q in spec.methods if isinstance(q, Query)} class V(BottomUpRewriter): def visit_ECall(self, e): q = queries.get(e.func) if q is not None: return self.visit( subst( q.ret, { arg_name: arg for ((arg_name, ty), arg) in zip(q.args, e.args) })) else: return ECall(e.func, tuple(self.visit(a) for a in e.args)).with_type(e.type) spec = V().visit(spec) spec.methods = [ m for m in spec.methods if not (isinstance(m, Query) and m.visibility == Visibility.Private) ] class V(BottomUpRewriter): def visit_Exp(self, e): return desugar_list_comprehensions(e) spec = V().visit(spec) assert retypecheck(spec, env={}) return spec
def inline_enum_constants(syntax_tree: ADT) -> ADT: """Convert variables that refer to enum constants into EEnumEntry nodes. Enum types introduce both a type name and a set of constants. This function replaces variables that refer to those constants with a special kind of AST node representing the constant. Most other functions in Cozy assume that this transformation has taken place, and that variables are not names for enum constants. """ repl = { name: EEnumEntry(name).with_type(t) for t in all_types(syntax_tree) if isinstance(t, TEnum) for name in t.cases } return subst(syntax_tree, repl)
def enumerate_core(self, context: Context, size: int, pool: Pool) -> [Exp]: """ Arguments: conext : a Context object describing the vars in scope size : size to enumerate pool : pool to enumerate Yields all expressions of the given size legal in the given context and pool. """ if size < 0: return if size == 0: for (e, p) in LITERALS: if p == pool: yield e for (v, p) in context.vars(): if p == pool: yield v for t in all_types(v): yield construct_value(t) for (e, ctx, p) in self.hints: if p == pool and ctx.alpha_equivalent(context): yield context.adapt(e, ctx) for t in all_types(e): yield construct_value(t) return yield from self.heuristic_enumeration(context, size, pool) for e in collections(self.enumerate(context, size - 1, pool)): yield EEmptyList().with_type(e.type) if is_numeric(e.type.t): yield EUnaryOp(UOp.Sum, e).with_type(e.type.t) for e in self.enumerate(context, size - 1, pool): yield ESingleton(e).with_type(TBag(e.type)) for e in self.enumerate(context, size - 1, pool): if isinstance(e.type, TRecord): for (f, t) in e.type.fields: yield EGetField(e, f).with_type(t) for e in self.enumerate(context, size - 1, pool): if isinstance(e.type, THandle): yield EGetField(e, "val").with_type(e.type.value_type) for e in self.enumerate(context, size - 1, pool): if isinstance(e.type, TTuple): for n in range(len(e.type.ts)): yield ETupleGet(e, n).with_type(e.type.ts[n]) for e in of_type(self.enumerate(context, size - 1, pool), BOOL): yield EUnaryOp(UOp.Not, e).with_type(BOOL) for e in self.enumerate(context, size - 1, pool): if is_numeric(e.type): yield EUnaryOp("-", e).with_type(e.type) for m in self.enumerate(context, size - 1, pool): if isinstance(m.type, TMap): yield EMapKeys(m).with_type(TBag(m.type.k)) for (sz1, sz2) in pick_to_sum(2, size - 1): for a1 in self.enumerate(context, sz1, pool): t = a1.type if not is_numeric(t): continue for a2 in of_type(self.enumerate(context, sz2, pool), t): yield EBinOp(a1, "+", a2).with_type(t) yield EBinOp(a1, "-", a2).with_type(t) yield EBinOp(a1, ">", a2).with_type(BOOL) yield EBinOp(a1, "<", a2).with_type(BOOL) yield EBinOp(a1, ">=", a2).with_type(BOOL) yield EBinOp(a1, "<=", a2).with_type(BOOL) for a1 in collections(self.enumerate(context, sz1, pool)): for a2 in of_type(self.enumerate(context, sz2, pool), a1.type): yield EBinOp(a1, "+", a2).with_type(a1.type) yield EBinOp(a1, "-", a2).with_type(a1.type) for a2 in of_type(self.enumerate(context, sz2, pool), a1.type.t): yield EBinOp(a2, BOp.In, a1).with_type(BOOL) for a1 in of_type(self.enumerate(context, sz1, pool), BOOL): for a2 in of_type(self.enumerate(context, sz2, pool), BOOL): yield EBinOp(a1, BOp.And, a2).with_type(BOOL) yield EBinOp(a1, BOp.Or, a2).with_type(BOOL) for a1 in self.enumerate(context, sz1, pool): if not isinstance(a1.type, TMap): for a2 in of_type(self.enumerate(context, sz2, pool), a1.type): yield EEq(a1, a2) yield EBinOp(a1, "!=", a2).with_type(BOOL) for m in self.enumerate(context, sz1, pool): if isinstance(m.type, TMap): for k in of_type(self.enumerate(context, sz2, pool), m.type.k): yield EMapGet(m, k).with_type(m.type.v) yield EHasKey(m, k).with_type(BOOL) for l in self.enumerate(context, sz1, pool): if not isinstance(l.type, TList): continue for i in of_type(self.enumerate(context, sz2, pool), INT): yield EListGet(l, i).with_type(l.type.t) for (sz1, sz2, sz3) in pick_to_sum(3, size - 1): for cond in of_type(self.enumerate(context, sz1, pool), BOOL): for then_branch in self.enumerate(context, sz2, pool): for else_branch in of_type( self.enumerate(context, sz2, pool), then_branch.type): yield ECond(cond, then_branch, else_branch).with_type(then_branch.type) for l in self.enumerate(context, sz1, pool): if not isinstance(l.type, TList): continue for st in of_type(self.enumerate(context, sz2, pool), INT): for ed in of_type(self.enumerate(context, sz3, pool), INT): yield EListSlice(l, st, ed).with_type(l.type) for bag in collections(self.enumerate(context, size - 1, pool)): # len of bag count = EUnaryOp(UOp.Length, bag).with_type(INT) yield count # empty? yield EUnaryOp(UOp.Empty, bag).with_type(BOOL) # exists? yield EUnaryOp(UOp.Exists, bag).with_type(BOOL) # singleton? yield EEq(count, ONE) yield EUnaryOp(UOp.The, bag).with_type(bag.type.t) yield EUnaryOp(UOp.Distinct, bag).with_type(bag.type) yield EUnaryOp(UOp.AreUnique, bag).with_type(BOOL) if bag.type.t == BOOL: yield EUnaryOp(UOp.Any, bag).with_type(BOOL) yield EUnaryOp(UOp.All, bag).with_type(BOOL) def build_lambdas(bag, pool, body_size): v = fresh_var(bag.type.t, omit=set(v for v, p in context.vars())) inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool) for lam_body in self.enumerate(inner_context, body_size, pool): yield ELambda(v, lam_body) # Iteration for (sz1, sz2) in pick_to_sum(2, size - 1): for bag in collections(self.enumerate(context, sz1, pool)): for lam in build_lambdas(bag, pool, sz2): body_type = lam.body.type yield EMap(bag, lam).with_type(TBag(body_type)) if body_type == BOOL: yield EFilter(bag, lam).with_type(bag.type) if is_numeric(body_type): yield EArgMin(bag, lam).with_type(bag.type.t) yield EArgMax(bag, lam).with_type(bag.type.t) if is_collection(body_type): yield EFlatMap(bag, lam).with_type(TBag(body_type.t)) # Enable use of a state-pool expression at runtime if pool == RUNTIME_POOL: for e in self.enumerate(context, size - 1, STATE_POOL): yield EStateVar(e).with_type(e.type) # Create maps if pool == STATE_POOL: for (sz1, sz2) in pick_to_sum(2, size - 1): for bag in collections(self.enumerate(context, sz1, STATE_POOL)): if not is_scalar(bag.type.t): continue for lam in build_lambdas(bag, STATE_POOL, sz2): t = TMap(bag.type.t, lam.body.type) m = EMakeMap2(bag, lam).with_type(t) yield m
def improve_implementation(impl: Implementation, timeout: datetime.timedelta = datetime.timedelta( seconds=60), progress_callback=None) -> Implementation: start_time = datetime.datetime.now() # we statefully modify `impl`, so let's make a defensive copy impl = Implementation(impl.spec, list(impl.concrete_state), list(impl.query_specs), OrderedDict(impl.query_impls), defaultdict(SNoOp, impl.updates), defaultdict(SNoOp, impl.handle_updates)) # gather root types types = list(all_types(impl.spec)) basic_types = set(t for t in types if is_scalar(t)) basic_types |= {BOOL, INT} print("basic types:") for t in basic_types: print(" --> {}".format(pprint(t))) basic_types = list(basic_types) ctx = SynthCtx(all_types=types, basic_types=basic_types) # the actual worker threads improvement_jobs = [] with jobs.SafeQueue() as solutions_q: def stop_jobs(js): js = list(js) jobs.stop_jobs(js) for j in js: improvement_jobs.remove(j) def reconcile_jobs(): # figure out what new jobs we need job_query_names = set(j.q.name for j in improvement_jobs) new = [] for q in impl.query_specs: if q.name not in job_query_names: new.append( ImproveQueryJob( ctx, impl.abstract_state, list(impl.spec.assumptions) + list(q.assumptions), q, k=(lambda q: lambda new_rep, new_ret: solutions_q. put((q, new_rep, new_ret)))(q), hints=[ EStateVar(c).with_type(c.type) for c in impl.concretization_functions.values() ])) # figure out what old jobs we can stop impl_query_names = set(q.name for q in impl.query_specs) old = [ j for j in improvement_jobs if j.q.name not in impl_query_names ] # make it so stop_jobs(old) for j in new: j.start() improvement_jobs.extend(new) # start jobs reconcile_jobs() # wait for results timeout = Timeout(timeout) done = False while not done and not timeout.is_timed_out(): for j in improvement_jobs: if j.done: if j.successful: j.join() else: print("failed job: {}".format(j), file=sys.stderr) # raise Exception("failed job: {}".format(j)) done = all(j.done for j in improvement_jobs) try: # list of (Query, new_rep, new_ret) objects results = solutions_q.drain(block=True, timeout=0.5) except Empty: continue # group by query name, favoring later (i.e. better) solutions print("updating with {} new solutions".format(len(results))) improved_queries_by_name = OrderedDict() killed = 0 for r in results: q, new_rep, new_ret = r if q.name in improved_queries_by_name: killed += 1 improved_queries_by_name[q.name] = r if killed: print(" --> dropped {} worse solutions".format(killed)) improvements = list(improved_queries_by_name.values()) def index_of(l, p): if not isinstance(l, list): l = list(l) for i in range(len(l)): if p(l[i]): return i return -1 improvements.sort(key=lambda i: index_of( impl.query_specs, lambda qq: qq.name == i[0].name)) print("update order:") for (q, _, _) in improvements: print(" --> {}".format(q.name)) # update query implementations i = 1 for (q, new_rep, new_ret) in improvements: print("considering update {}/{}...".format( i, len(improvements))) i += 1 # this guard might be false if a better solution was # enqueued but the job has already been cleaned up if q.name in [qq.name for qq in impl.query_specs]: elapsed = datetime.datetime.now() - start_time print("SOLUTION FOR {} AT {} [size={}]".format( q.name, elapsed, new_ret.size() + sum(proj.size() for (v, proj) in new_rep))) print("-" * 40) for (sv, proj) in new_rep: print(" {} : {} = {}".format(sv.id, pprint(sv.type), pprint(proj))) print(" return {}".format(pprint(new_ret))) print("-" * 40) impl.set_impl(q, new_rep, new_ret) # clean up impl.cleanup() if progress_callback is not None: progress_callback( (impl, impl.code, impl.concretization_functions)) reconcile_jobs() # stop jobs print("Stopping jobs") stop_jobs(list(improvement_jobs)) return impl
def _enumerate_core(self, context: Context, size: int, pool: Pool) -> [Exp]: """Build new expressions of the given size. Arguments: context : a Context object describing the vars in scope size : size of expressions to enumerate; each expression in the output will have this size pool : pool to enumerate This function is not cached. Clients should call `enumerate` instead. This function tries to be a clean description of the Cozy grammar. It does not concern itself with deduplication (which is handled efficiently by equivalence class deduplication). """ if size < 0: return if size == 0: for e in LITERALS: yield e all_interesting_types = OrderedSet(self.hint_types) for v, _ in context.vars(): all_interesting_types |= all_types(v.type) for t in all_interesting_types: l = construct_value(t) if l not in LITERALS: yield l for (v, p) in context.vars(): if p == pool: yield v for (e, ctx, p) in self.hints: if p == pool: fvs = free_vars(e) if ctx.alpha_equivalent(context.generalize(fvs)): yield context.adapt(e, ctx, e_fvs=fvs) return if not do_enumerate.value: return def build_lambdas(bag, pool, body_size): v = fresh_var(bag.type.elem_type, omit=set(v for v, p in context.vars())) inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool) for lam_body in self.enumerate(inner_context, body_size, pool): yield ELambda(v, lam_body) # Load all smaller expressions in this context and pool. # cache[S] contains expressions of size S in this context and pool. cache = [list(self.enumerate(context, sz, pool)) for sz in range(size)] # Enable use of a state-pool expression at runtime if pool == RUNTIME_POOL: for e in self.enumerate(context.root(), size - 1, STATE_POOL): yield EStateVar(e).with_type(e.type) # Arity-1 expressions for e in cache[size - 1]: if is_collection(e.type): elem_type = e.type.elem_type # This method of generating EEmptyList() ensures that we visit # empty collections of all possible types. yield EEmptyList().with_type(e.type) if is_numeric(elem_type): yield EUnaryOp(UOp.Sum, e).with_type(elem_type) yield EUnaryOp(UOp.Length, e).with_type(INT) yield EUnaryOp(UOp.Empty, e).with_type(BOOL) yield EUnaryOp(UOp.Exists, e).with_type(BOOL) yield EUnaryOp(UOp.The, e).with_type(elem_type) yield EUnaryOp(UOp.Distinct, e).with_type(e.type) yield EUnaryOp(UOp.AreUnique, e).with_type(BOOL) if elem_type == BOOL: yield EUnaryOp(UOp.Any, e).with_type(BOOL) yield EUnaryOp(UOp.All, e).with_type(BOOL) yield ESingleton(e).with_type(TBag(e.type)) if isinstance(e.type, TRecord): for (f, t) in e.type.fields: yield EGetField(e, f).with_type(t) if isinstance(e.type, THandle): yield EGetField(e, "val").with_type(e.type.value_type) if isinstance(e.type, TTuple): for n in range(len(e.type.ts)): yield ETupleGet(e, n).with_type(e.type.ts[n]) if e.type == BOOL: yield EUnaryOp(UOp.Not, e).with_type(BOOL) if is_numeric(e.type): yield EUnaryOp("-", e).with_type(e.type) if isinstance(e.type, TMap): yield EMapKeys(e).with_type(TBag(e.type.k)) # Arity-2 expressions for (sz1, sz2) in pick_to_sum(2, size - 1): # sz1 + sz2 = size - 1 for e1 in cache[sz1]: t = e1.type if is_numeric(t): for a2 in of_type(cache[sz2], t): yield EBinOp(e1, "+", a2).with_type(t) yield EBinOp(e1, "-", a2).with_type(t) if is_ordered(t): for a2 in of_type(cache[sz2], t): yield EBinOp(e1, ">", a2).with_type(BOOL) yield EBinOp(e1, "<", a2).with_type(BOOL) yield EBinOp(e1, ">=", a2).with_type(BOOL) yield EBinOp(e1, "<=", a2).with_type(BOOL) if t == BOOL: for a2 in of_type(cache[sz2], BOOL): yield EBinOp(e1, BOp.And, a2).with_type(BOOL) yield EBinOp(e1, BOp.Or, a2).with_type(BOOL) # Cozy supports the implication operator "=>", but this # function does not enumerate it because # - (a => b) is equivalent to ((not a) or b) # - there isn't an implication operator in any of our # current target languages, so we would need to # desugar it to ((not a) or b) anyway. if not isinstance(t, TMap): for a2 in of_type(cache[sz2], t): yield EEq(e1, a2) yield EBinOp(e1, "!=", a2).with_type(BOOL) if isinstance(t, TMap): for k in of_type(cache[sz2], t.k): yield EMapGet(e1, k).with_type(t.v) yield EHasKey(e1, k).with_type(BOOL) if isinstance(t, TList): for i in of_type(cache[sz2], INT): yield EListGet(e1, i).with_type(e1.type.elem_type) if is_collection(t): elem_type = t.elem_type for e2 in of_type(cache[sz2], t): yield EBinOp(e1, "+", e2).with_type(t) yield EBinOp(e1, "-", e2).with_type(t) for e2 in of_type(cache[sz2], elem_type): yield EBinOp(e2, BOp.In, e1).with_type(BOOL) for f in build_lambdas(e1, pool, sz2): body_type = f.body.type yield EMap(e1, f).with_type(TBag(body_type)) if body_type == BOOL: yield EFilter(e1, f).with_type(t) if is_numeric(body_type): yield EArgMin(e1, f).with_type(elem_type) yield EArgMax(e1, f).with_type(elem_type) if is_collection(body_type): yield EFlatMap(e1, f).with_type( TBag(body_type.elem_type)) if pool == STATE_POOL and is_hashable(elem_type): yield EMakeMap2(e1, f).with_type( TMap(elem_type, body_type)) e1_singleton = ESingleton(e1).with_type(TBag(e1.type)) for f in build_lambdas(e1_singleton, pool, sz2): yield ELet(e1, f).with_type(f.body.type) # Arity-3 expressions for (sz1, sz2, sz3) in pick_to_sum(3, size - 1): # sz1 + sz2 + sz3 = size - 1 for e1 in cache[sz1]: if e1.type == BOOL: cond = e1 for then_branch in cache[sz2]: for else_branch in of_type(cache[sz3], then_branch.type): yield ECond(cond, then_branch, else_branch).with_type( then_branch.type) if isinstance(e1.type, TList): for start in of_type(cache[sz2], INT): for end in of_type(cache[sz3], INT): yield EListSlice(e1, start, end).with_type(e1.type) # It is not necessary to create slice expressions of # the form a[:i] or a[i:]. Those are desugared # after parsing to a[0:i] and a[i:len(a)] # respectively, and Cozy is perfectly capable of # discovering these expanded forms as well. for h in all_extension_handlers(): yield from h.enumerate(context, size, pool, self.enumerate, build_lambdas)