def check_wf(e, ctx, pool): with task("checking well-formedness", size=e.size()): try: exp_wf(e, pool=pool, context=ctx, assumptions=self.assumptions, solver=self.wf_solver) except ExpIsNotWf as exc: return No("at {}: {}".format( pprint(exc.offending_subexpression), exc.reason)) for (sub, sub_ctx, sub_pool) in shred(e, ctx, pool): res = good_idea(self.wf_solver, sub, sub_ctx, sub_pool, assumptions=self.assumptions) if not res: return res if pool == RUNTIME_POOL and self.cost_model.compare( e, self.targets[0], ctx, pool) == Order.GT: # from cozy.cost_model import debug_comparison # debug_comparison(self.cost_model, e, self.target, ctx) return No("too expensive") # if isinstance(e.type, TBag): # c = self.cost_model.cardinality(e) # if all(cc < c for cc in cards): # # print("too big: {}".format(pprint(e))) # return No("too big") return True
def test_let(self): e1 = ELet(ZERO, ELambda(x, x)) root_ctx = RootCtx(args=(), state_vars=()) assert retypecheck(e1) n = 0 for ee, ctx, pool in shred(e1, root_ctx, RUNTIME_POOL): if ee == x: e2 = replace(e1, root_ctx, RUNTIME_POOL, x, ctx, pool, ZERO) assert e2 == ELet(ZERO, ELambda(x, ZERO)) n += 1 assert n == 1
def test_pool_affects_alpha_equivalence(self): e = EMap(EEmptyList().with_type(INT_BAG), ELambda(x, ONE)) root_ctx = RootCtx(args=(), state_vars=()) assert retypecheck(e) c1 = [] for ee, ctx, pool in shred(e, root_ctx, RUNTIME_POOL): if ee == ONE: c1.append(ctx) assert len(c1) == 1 c1 = c1[0] c2 = [] for ee, ctx, pool in shred(e, root_ctx, STATE_POOL): if ee == ONE: c2.append(ctx) assert len(c2) == 1 c2 = c2[0] assert c1 != c2 assert not c1.alpha_equivalent(c2)
def exp_wf(e : Exp, context : Context, pool = RUNTIME_POOL, assumptions : Exp = T, solver = None): """ Returns True or throws exception indicating why `e` is not well-formed. """ if solver is None: solver = ModelCachingSolver(vars=[], funcs={}) for x, ctx, p in shred(e, context, pool): try: exp_wf_nonrecursive(solver, x, ctx, p, assumptions=ctx.adapt(assumptions, context)) except ExpIsNotWf as exc: raise ExpIsNotWf(e, x, exc.reason) return True
def test_estatevar_ctx(self): xs = EVar("xs").with_type(INT_BAG) x = EVar("x").with_type(INT) y = EVar("y").with_type(BOOL) e = EMap(xs, ELambda(x, EStateVar(y))) ctx = RootCtx(args=(xs, ), state_vars=(y, )) assert retypecheck(e) for ee, ctx, pool in shred(e, ctx): if ee == y: assert isinstance(ctx, RootCtx) e = replace(e, ctx, RUNTIME_POOL, y, ctx, STATE_POOL, T) assert e == EMap(xs, ELambda(x, EStateVar(T))), pprint(e)
def good_idea_recursive(solver, e: Exp, context: Context, pool=RUNTIME_POOL, assumptions: Exp = T, ops: [Op] = ()) -> bool: for (sub, sub_ctx, sub_pool) in shred(e, context, pool): res = good_idea(solver, sub, sub_ctx, sub_pool, assumptions=assumptions, ops=ops) if not res: return res return True
def fold_into_map(e, context): fvs = free_vars(e) state_vars = [v for v, p in context.vars() if p == STATE_POOL] for subexp, subcontext, subpool in shred(e, context, RUNTIME_POOL): if isinstance(subexp, EMapGet) and isinstance(subexp.map, EStateVar): map = subexp.map.e key = subexp.key key_type = key.type value_type = subexp.type # e is of the form `... EStateVar(map)[key] ...` arg = fresh_var(subexp.type, omit=fvs) func = ELambda( arg, replace(e, context, RUNTIME_POOL, subexp, subcontext, subpool, arg)) if not all(v in state_vars for v in free_vars(func)): continue func = strip_EStateVar(func) new_map = map_values(map, func.apply_to) yield EMapGet(EStateVar(new_map).with_type(new_map.type), key).with_type(e.type)
def exploration_order(targets: [Exp], context: Context, pool: Pool = RUNTIME_POOL): """ What order should subexpressions of the given targets be explored for possible improvements? Yields (target, subexpression, subcontext, subpool) tuples. """ # current policy (earlier requirements have priority): # - visit runtime expressions first # - visit low-complexity contexts first # - visit small expressions first def sort_key(tup): e, ctx, p = tup return (0 if p == RUNTIME_POOL else 1, ctx.complexity(), e.size()) for target in targets: for e, ctx, p in sorted(unique(shred(target, context, pool=pool)), key=sort_key): yield (target, e, ctx, p)
def exp_wf(e: Exp, context: Context, pool=RUNTIME_POOL, assumptions: Exp = T, solver=None): """Check the well-formedess of `e`. Returns True or an instance of ExpIsNotWf that indicates why `e` is not well-formed. Parameters: e - an expression to check context - a context describing e's variables pool - what pool e lives in assumptions - facts that are true whenever e begins executing (NOTE: this does NOT need to include the path conditions from the context, but it is fine if it does.) solver - a ModelCachingSolver to use for solving formulas This function requires that: - all free variables in `e` are used in the correct pool - EStateVar only occurs in runtime expressions """ if solver is None: solver = ModelCachingSolver(vars=[], funcs={}) for x, ctx, p in shred(e, context, pool): is_wf = exp_wf_nonrecursive(solver, x, ctx, p, assumptions=ctx.adapt( assumptions, context)) if not is_wf: if isinstance(is_wf, No): return ExpIsNotWf(e, x, is_wf.msg) return is_wf return True
def test_shred_minheap(self): f = ELambda(x, x) e = EMakeMinHeap(EEmptyList().with_type(INT_BAG), f).with_type(TMinHeap(INT, f)) ctx = RootCtx(args=(), state_vars=()) list(shred(e, ctx))
def next(self): class No(object): def __init__(self, msg): self.msg = msg def __bool__(self): return False def __str__(self): return "no: {}".format(self.msg) # with task("pre-computing cardinalities"): # cards = [self.cost_model.cardinality(ctx.e) for ctx in enumerate_fragments(self.target) if is_collection(ctx.e.type)] root_ctx = self.context def check_wf(e, ctx, pool): with task("checking well-formedness", size=e.size()): try: exp_wf(e, pool=pool, context=ctx, assumptions=self.assumptions, solver=self.wf_solver) except ExpIsNotWf as exc: return No("at {}: {}".format( pprint(exc.offending_subexpression), exc.reason)) if pool == RUNTIME_POOL and self.cost_model.compare( e, self.targets[0], ctx, pool) == Order.GT: # from cozy.cost_model import debug_comparison # debug_comparison(self.cost_model, e, self.target, ctx) return No("too expensive") # if isinstance(e.type, TBag): # c = self.cost_model.cardinality(e) # if all(cc < c for cc in cards): # # print("too big: {}".format(pprint(e))) # return No("too big") return True frags = list( unique( itertools.chain(*[shred(t, root_ctx) for t in self.targets], *[shred(h, root_ctx) for h in self.hints]))) enum = Enumerator(examples=self.examples, cost_model=self.cost_model, check_wf=check_wf, hints=frags, heuristics=try_optimize, stop_callback=self.stop_callback) size = 0 # target_cost = self.cost_model.cost(self.target, RUNTIME_POOL) target_fp = fingerprint(self.targets[0], self.examples) if not hasattr(self, "blacklist"): self.blacklist = set() while True: print("starting minor iteration {} with |cache|={}".format( size, enum.cache_size())) if self.stop_callback(): raise StopException() n = 0 for target, e, ctx, pool in exploration_order( self.targets, root_ctx): with task("checking substitutions", target=pprint( replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, EVar("___"))), e=pprint(e)): for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool): with task("checking substitution", expression=pprint(info.e)): if self.stop_callback(): raise StopException() if info.e.type != e.type: event("wrong type (is {}, need {})".format( pprint(info.e.type), pprint(e.type))) continue if alpha_equivalent(info.e, e): event("no change") continue k = (e, ctx, pool, info.e) if k in self.blacklist: event("blacklisted") continue n += 1 ee = freshen_binders( replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, info.e), root_ctx) if any( alpha_equivalent(t, ee) for t in self.targets): event("already seen") continue if not self.matches(fingerprint(ee, self.examples), target_fp): event("incorrect") self.blacklist.add(k) continue wf = check_wf(ee, root_ctx, RUNTIME_POOL) if not wf: event("not well-formed [wf={}]".format(wf)) # if "expensive" in str(wf): # print(repr(self.cost_model.examples)) # print(repr(ee)) self.blacklist.add(k) continue if self.cost_model.compare( ee, target, root_ctx, RUNTIME_POOL) not in (Order.LT, Order.AMBIGUOUS): event("not an improvement") self.blacklist.add(k) continue print( "FOUND A GUESS AFTER {} CONSIDERED".format(n)) yield ee print("CONSIDERED {}".format(n)) size += 1 raise NoMoreImprovements()
def search(self): root_ctx = self.context def check_wf(e, ctx, pool): with task("checking well-formedness", size=e.size()): is_wf = exp_wf(e, pool=pool, context=ctx, solver=self.wf_solver) if not is_wf: return is_wf res = good_idea_recursive(self.wf_solver, e, ctx, pool, ops=self.ops) if not res: return res if pool == RUNTIME_POOL and self.cost_model.compare( e, self.targets[0], ctx, pool) == Order.GT: return No("too expensive") return True frags = list( unique( itertools.chain(*[shred(t, root_ctx) for t in self.targets], *[shred(h, root_ctx) for h in self.hints]))) frags.sort(key=hint_order) enum = Enumerator(examples=self.examples, cost_model=self.cost_model, check_wf=check_wf, hints=frags, heuristics=try_optimize, stop_callback=self.stop_callback, do_eviction=enable_eviction.value) size = 0 target_fp = fingerprint(self.targets[0], self.examples) watches = OrderedDict() for target in self.targets: for e, ctx, pool in unique( shred(target, context=root_ctx, pool=RUNTIME_POOL)): exs = ctx.instantiate_examples(self.examples) fp = fingerprint(e, exs) k = (fp, ctx, pool) l = watches.get(k) if l is None: l = [] watches[k] = l l.append((target, e)) watched_ctxs = list( unique((ctx, pool) for fp, ctx, pool in watches.keys())) def consider_new_target(old_target, e, ctx, pool, replacement): nonlocal n n += 1 k = (e, ctx, pool, replacement) if enable_blacklist.value and k in self.blacklist: event("blacklisted") print("skipping blacklisted substitution: {} ---> {} ({})". format(pprint(e), pprint(replacement), self.blacklist[k])) return new_target = freshen_binders( replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, replacement), root_ctx) if any(alpha_equivalent(t, new_target) for t in self.targets): event("already seen") return wf = check_wf(new_target, root_ctx, RUNTIME_POOL) if not wf: msg = "not well-formed [wf={}]".format(wf) event(msg) self.blacklist[k] = msg return if not fingerprints_match(fingerprint(new_target, self.examples), target_fp): msg = "not correct" event(msg) self.blacklist[k] = msg return if self.cost_model.compare(new_target, target, root_ctx, RUNTIME_POOL) not in (Order.LT, Order.AMBIGUOUS): msg = "not an improvement" event(msg) self.blacklist[k] = msg return print("FOUND A GUESS AFTER {} CONSIDERED".format(n)) print(" * in {}".format(pprint(old_target), pprint(e), pprint(replacement))) print(" * replacing {}".format(pprint(e))) print(" * with {}".format(pprint(replacement))) yield new_target while True: print("starting minor iteration {} with |cache|={}".format( size, enum.cache_size())) if self.stop_callback(): raise StopException() n = 0 for ctx, pool in watched_ctxs: with task("searching for obvious substitutions", ctx=ctx, pool=pool_name(pool)): for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool): with task("searching for obvious substitution", expression=pprint(info.e)): fp = info.fingerprint for ((fpx, cc, pp), reses) in watches.items(): if cc != ctx or pp != pool: continue if not fingerprints_match(fpx, fp): continue for target, watched_e in reses: replacement = info.e event("possible substitution: {} ---> {}". format(pprint(watched_e), pprint(replacement))) event("replacement locations: {}".format( pprint( replace(target, root_ctx, RUNTIME_POOL, watched_e, ctx, pool, EVar("___"))))) if alpha_equivalent( watched_e, replacement): event("no change") continue yield from consider_new_target( target, watched_e, ctx, pool, replacement) if check_all_substitutions.value: print("Guessing at substitutions...") for target, e, ctx, pool in exploration_order( self.targets, root_ctx): with task("checking substitutions", target=pprint( replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, EVar("___"))), e=pprint(e)): for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool): with task("checking substitution", expression=pprint(info.e)): if self.stop_callback(): raise StopException() replacement = info.e if replacement.type != e.type: event("wrong type (is {}, need {})".format( pprint(replacement.type), pprint(e.type))) continue if alpha_equivalent(replacement, e): event("no change") continue should_consider = should_consider_replacement( target, root_ctx, e, ctx, pool, fingerprint( e, ctx.instantiate_examples( self.examples)), info.e, info.fingerprint) if not should_consider: event( "skipped; `should_consider_replacement` returned {}" .format(should_consider)) continue yield from consider_new_target( target, e, ctx, pool, replacement) print("CONSIDERED {}".format(n)) size += 1