def consider_new_target(old_target, e, ctx, pool, replacement): nonlocal n n += 1 k = (e, ctx, pool, replacement) if enable_blacklist.value and k in self.blacklist: event("blacklisted") print("skipping blacklisted substitution: {} ---> {} ({})". format(pprint(e), pprint(replacement), self.blacklist[k])) return new_target = freshen_binders( replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, replacement), root_ctx) if any(alpha_equivalent(t, new_target) for t in self.targets): event("already seen") return wf = check_wf(new_target, root_ctx, RUNTIME_POOL) if not wf: msg = "not well-formed [wf={}]".format(wf) event(msg) self.blacklist[k] = msg return if not fingerprints_match(fingerprint(new_target, self.examples), target_fp): msg = "not correct" event(msg) self.blacklist[k] = msg return if self.cost_model.compare(new_target, target, root_ctx, RUNTIME_POOL) not in (Order.LT, Order.AMBIGUOUS): msg = "not an improvement" event(msg) self.blacklist[k] = msg return print("FOUND A GUESS AFTER {} CONSIDERED".format(n)) print(" * in {}".format(pprint(old_target), pprint(e), pprint(replacement))) print(" * replacing {}".format(pprint(e))) print(" * with {}".format(pprint(replacement))) yield new_target
def repair_well_formedness(e: Exp, context: Context, extra_available_state: [Exp] = []) -> Exp: """Repair the EStateVar nodes in an expression that is not well-formed. Parameters: e - the expression to repair context - the intended context for e extra_available_state - extra state expressions that e can use Assuming that all expressions in extra_available_state are well-formed state expressions, the output will be a well-formed runtime expression that behaves like `e`. """ with task("repairing"): e = strip_EStateVar(e) # state expressions in decreasing order of size available_state = sorted(unique( itertools.chain((v for v, p in context.vars() if p == STATE_POOL), extra_available_state)), key=lambda e: -e.size()) with task("making replacements", size=e.size()): for s in available_state: e = replace(e, s, EStateVar(s).with_type(s.type), match=alpha_equivalent, filter=lambda e: not isinstance(e, EStateVar)) with task("freshening binders"): e = freshen_binders(e, context) with task("checking correctness"): res = exp_wf(e, context, RUNTIME_POOL) assert res, str(res) return e
def repair_well_formedness(e : Exp, context : Context, extra_available_state : [Exp] = []) -> Exp: """Repair the EStateVar nodes in an expression that is not well-formed. Parameters: e - the expression to repair context - the intended context for e extra_available_state - extra state expressions that e can use Assuming that all expressions in extra_available_state are well-formed state expressions, the output will be a well-formed runtime expression that behaves like `e`. """ with task("repairing"): e = strip_EStateVar(e) # state expressions in decreasing order of size available_state = sorted(unique(itertools.chain( (v for v, p in context.vars() if p == STATE_POOL), extra_available_state)), key=lambda e: -e.size()) with task("making replacements", size=e.size()): for s in available_state: e = replace(e, s, EStateVar(s).with_type(s.type), match=alpha_equivalent, filter=lambda e: not isinstance(e, EStateVar)) with task("freshening binders"): e = freshen_binders(e, context) with task("checking correctness"): res = exp_wf(e, context, RUNTIME_POOL) assert res, str(res) return e
def enumerate_with_info(self, context: Context, size: int, pool: Pool) -> [EnumeratedExp]: canonical_context = self.canonical_context(context) if canonical_context is not context: print("adapting request: {} ---> {}".format( context, canonical_context)) for info in self.enumerate_with_info(canonical_context, size, pool): yield info._replace(e=context.adapt(info.e, canonical_context)) return if context.parent() is not None: yield from self.enumerate_with_info(context.parent(), size, pool) k = (pool, size, context) res = self.cache.get(k) if res is not None: # print("[[{} cached @ size={}]]".format(len(res), size)) for e in res: yield e else: # print("ENTER {}".format(k)) examples = context.instantiate_examples(self.examples) assert k not in self.in_progress, "recursive enumeration?? {}".format( k) self.in_progress.add(k) res = [] self.cache[k] = res queue = self.enumerate_core(context, size, pool) cost_model = self.cost_model while True: if self.stop_callback(): raise StopException() try: e = next(queue) except StopIteration: break fvs = free_vars(e) if not belongs_in_context(fvs, context): continue e = freshen_binders(e, context) _consider(e, context, pool) wf = self.check_wf(e, context, pool) if not wf: _skip(e, context, pool, "wf={}".format(wf)) continue fp = fingerprint(e, examples) # collect all expressions from parent contexts with task("collecting prev exps", size=size, context=context, pool=pool_name(pool)): prev = [] for sz in range(0, size + 1): prev.extend(self.enumerate_with_info( context, sz, pool)) prev = [p.e for p in prev if p.fingerprint == fp] if any(alpha_equivalent(e, p) for p in prev): _skip(e, context, pool, "duplicate") should_keep = False else: # decide whether to keep this expression, # decide which can be evicted should_keep = True # cost = self.cost_model.cost(e, pool) # print("prev={}".format(prev)) # print("seen={}".format(self.seen)) with task("comparing to cached equivalents"): for prev_exp in prev: event("previous: {}".format(pprint(prev_exp))) # prev_cost = self.cost_model.cost(prev_exp, pool) # ordering = cost.compare_to(prev_cost) to_keep = eviction_policy(e, context, prev_exp, context, pool, cost_model) if e not in to_keep: _skip(e, context, pool, "preferring {}".format(pprint(prev_exp))) should_keep = False break # if ordering == Order.LT: # pass # elif ordering == Order.GT: # self.blacklist.add(e_key) # _skip(e, context, pool, "worse than {}".format(pprint(prev_exp))) # should_keep = False # break # else: # self.blacklist.add(e_key) # _skip(e, context, pool, "{} to cached {}".format( # "equal" if ordering == Order.EQUAL else "similar", # pprint(prev_exp))) # assert ordering in (Order.EQUAL, Order.AMBIGUOUS) # should_keep = False # break if should_keep: with task("evicting"): to_evict = [] for (key, exps) in self.cache.items(): (p, s, c) = key if p == pool and c in itertools.chain( [context], parent_contexts(context)): for ee in exps: if ee.fingerprint == fp: # and cost_model.compare(e, ee.e, context, pool) == Order.LT: # to_evict.append((key, ee)) to_keep = eviction_policy( e, context, ee.e, c, pool, cost_model) if ee.e not in to_keep: to_evict.append((key, ee)) for key, ee in to_evict: (p, s, c) = key # self.blacklist.add((ee.e, c, pool)) _evict(ee.e, c, pool, e) self.cache[key].remove(ee) self.seen[(c, p, fp)].remove(ee.e) _accept(e, context, pool) seen_key = (context, pool, fp) if seen_key not in self.seen: self.seen[seen_key] = [] self.seen[seen_key].append(e) info = EnumeratedExp(e=e, fingerprint=fp, cost=None) res.append(info) yield info with task("accelerating"): to_try = make_random_access( self.heuristics(e, context, pool)) if to_try: # print("trying {} accelerations".format(len(to_try))) queue = itertools.chain(to_try, queue) # print("EXIT {}".format(k)) self.in_progress.remove(k)
def enumerate_with_info(self, context: Context, size: int, pool: Pool) -> [EnumeratedExp]: canonical_context = self.canonical_context(context) if canonical_context is not context: print("adapting request: {} ---> {}".format( context, canonical_context)) for info in self.enumerate_with_info(canonical_context, size, pool): yield info._replace(e=context.adapt(info.e, canonical_context)) return examples = context.instantiate_examples(self.examples) if context.parent() is not None: for info in self.enumerate_with_info(context.parent(), size, pool): e = info.e yield EnumeratedExp(e=e, fingerprint=fingerprint(e, examples)) k = (pool, size, context) res = self.cache.get(k) if res is not None: for e in res: yield e else: assert k not in self.in_progress, "recursive enumeration?? {}".format( k) self.in_progress.add(k) res = [] self.cache[k] = res queue = self.enumerate_core(context, size, pool) cost_model = self.cost_model while True: if self.stop_callback(): raise StopException() try: e = next(queue) except StopIteration: break fvs = free_vars(e) if not belongs_in_context(fvs, context): continue e = freshen_binders(e, context) _consider(e, size, context, pool) wf = self.check_wf(e, context, pool) if not wf: _skip(e, size, context, pool, "wf={}".format(wf)) continue fp = fingerprint(e, examples) # collect all expressions from parent contexts with task("collecting prev exps", size=size, context=context, pool=pool_name(pool)): prev = [] for sz in range(0, size + 1): prev.extend(self.enumerate_with_info( context, sz, pool)) prev = [p.e for p in prev if p.fingerprint == fp] if any(alpha_equivalent(e, p) for p in prev): _skip(e, size, context, pool, "duplicate") should_keep = False else: # decide whether to keep this expression should_keep = True with task("comparing to cached equivalents"): for prev_exp in prev: event("previous: {}".format(pprint(prev_exp))) to_keep = eviction_policy(e, context, prev_exp, context, pool, cost_model) if e not in to_keep: _skip(e, size, context, pool, "preferring {}".format(pprint(prev_exp))) should_keep = False break if should_keep: if self.do_eviction: with task("evicting"): to_evict = [] for (key, exps) in self.cache.items(): (p, s, c) = key if p == pool and c == context: for ee in exps: if ee.fingerprint == fp: event("considering eviction of {}". format(pprint(ee.e))) to_keep = eviction_policy( e, context, ee.e, c, pool, cost_model) if ee.e not in to_keep: to_evict.append((key, ee)) for key, ee in to_evict: (p, s, c) = key _evict(ee.e, s, c, pool, e) self.cache[key].remove(ee) self.seen[(c, p, fp)].remove(ee.e) _accept(e, size, context, pool) seen_key = (context, pool, fp) if seen_key not in self.seen: self.seen[seen_key] = [] self.seen[seen_key].append(e) info = EnumeratedExp(e=e, fingerprint=fp) res.append(info) yield info with task("accelerating"): to_try = make_random_access( self.heuristics(e, context, pool)) if to_try: event("trying {} accelerations of {}".format( len(to_try), pprint(e))) queue = itertools.chain(to_try, queue) self.in_progress.remove(k)
def next(self): class No(object): def __init__(self, msg): self.msg = msg def __bool__(self): return False def __str__(self): return "no: {}".format(self.msg) # with task("pre-computing cardinalities"): # cards = [self.cost_model.cardinality(ctx.e) for ctx in enumerate_fragments(self.target) if is_collection(ctx.e.type)] root_ctx = self.context def check_wf(e, ctx, pool): with task("checking well-formedness", size=e.size()): try: exp_wf(e, pool=pool, context=ctx, assumptions=self.assumptions, solver=self.wf_solver) except ExpIsNotWf as exc: return No("at {}: {}".format( pprint(exc.offending_subexpression), exc.reason)) if pool == RUNTIME_POOL and self.cost_model.compare( e, self.targets[0], ctx, pool) == Order.GT: # from cozy.cost_model import debug_comparison # debug_comparison(self.cost_model, e, self.target, ctx) return No("too expensive") # if isinstance(e.type, TBag): # c = self.cost_model.cardinality(e) # if all(cc < c for cc in cards): # # print("too big: {}".format(pprint(e))) # return No("too big") return True frags = list( unique( itertools.chain(*[shred(t, root_ctx) for t in self.targets], *[shred(h, root_ctx) for h in self.hints]))) enum = Enumerator(examples=self.examples, cost_model=self.cost_model, check_wf=check_wf, hints=frags, heuristics=try_optimize, stop_callback=self.stop_callback) size = 0 # target_cost = self.cost_model.cost(self.target, RUNTIME_POOL) target_fp = fingerprint(self.targets[0], self.examples) if not hasattr(self, "blacklist"): self.blacklist = set() while True: print("starting minor iteration {} with |cache|={}".format( size, enum.cache_size())) if self.stop_callback(): raise StopException() n = 0 for target, e, ctx, pool in exploration_order( self.targets, root_ctx): with task("checking substitutions", target=pprint( replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, EVar("___"))), e=pprint(e)): for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool): with task("checking substitution", expression=pprint(info.e)): if self.stop_callback(): raise StopException() if info.e.type != e.type: event("wrong type (is {}, need {})".format( pprint(info.e.type), pprint(e.type))) continue if alpha_equivalent(info.e, e): event("no change") continue k = (e, ctx, pool, info.e) if k in self.blacklist: event("blacklisted") continue n += 1 ee = freshen_binders( replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, info.e), root_ctx) if any( alpha_equivalent(t, ee) for t in self.targets): event("already seen") continue if not self.matches(fingerprint(ee, self.examples), target_fp): event("incorrect") self.blacklist.add(k) continue wf = check_wf(ee, root_ctx, RUNTIME_POOL) if not wf: event("not well-formed [wf={}]".format(wf)) # if "expensive" in str(wf): # print(repr(self.cost_model.examples)) # print(repr(ee)) self.blacklist.add(k) continue if self.cost_model.compare( ee, target, root_ctx, RUNTIME_POOL) not in (Order.LT, Order.AMBIGUOUS): event("not an improvement") self.blacklist.add(k) continue print( "FOUND A GUESS AFTER {} CONSIDERED".format(n)) yield ee print("CONSIDERED {}".format(n)) size += 1 raise NoMoreImprovements()
def improve(target: Exp, context: Context, assumptions: Exp = T, stop_callback=never_stop, hints: [Exp] = (), examples: [{ str: object }] = (), cost_model: CostModel = None): """ Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target`. Notes on internals of this algorithm follow. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, and lets us find kinda-good solutions very quickly, even if the best possible solution is out of reach. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r})""".format(target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model)) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() try: assert exp_wf(target, context=context, assumptions=assumptions) except ExpIsNotWf as ex: print( "WARNING: initial target is not well-formed [{}]; this might go poorly..." .format(str(ex))) print(pprint(ex.offending_subexpression)) print(pprint(ex.offending_subexpression.type)) # raise state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [ freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions) ] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = None if incremental.value: solver = IncrementalSolver(vars=vars, funcs=funcs) solver.add_assumption(assumptions) _sat = solver.satisfy else: _sat = lambda e: satisfy(e, vars=vars, funcs=funcs) if _sat(assumptions) is None: print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] learner = Learner(watched_targets, assumptions, context, examples, cost_model, stop_callback, hints) try: while True: # 1. find any potential improvement to any sub-exp of target for new_target in learner.next(): print("Found candidate improvement: {}".format( pprint(new_target))) # 2. check with task("verifying candidate"): if incremental.value: solver.push() solver.add_assumption( ENot( EBinOp(target, "==", new_target).with_type(BOOL))) counterexample = _sat(T) else: formula = EAll([ assumptions, ENot( EBinOp(target, "==", new_target).with_type(BOOL)) ]) counterexample = _sat(formula) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, reset the learner examples.append(counterexample) event("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format( len(examples))) learner.reset(examples) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = eviction_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print( "Whoops! Looks like we already found something better." ) print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") raise NoMoreImprovements() target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format( len(watched_targets))) learner.watch(watched_targets) break if incremental.value: solver.pop() except NoMoreImprovements: return except KeyboardInterrupt: raise
def _consider_replacement(target: Exp, e: Exp, ctx: Context, pool: Pool, replacement: Exp, info: SearchInfo): """Helper for search_for_improvements. This procedure decides whether replacing `e` with `replacement` in the given `target` is an improvement. If yes, it yields the result of the replacement. Otherwise it yields nothing. Parameters: - target: the top-level expression to improve - e: a subexpression of the target - ctx: e's context in the target - pool: e's pool in the target - replacement: a possible replacement for e - info: a SearchInfo object with auxiliary data This procedure may add items to info.blacklist. """ context = info.context blacklist = info.blacklist k = (e, ctx, pool, replacement) if enable_blacklist.value and k in blacklist: event("blacklisted") print("skipping blacklisted substitution: {} ---> {} ({})".format( pprint(e), pprint(replacement), blacklist[k])) return new_target = freshen_binders( replace(target, context, RUNTIME_POOL, e, ctx, pool, replacement), context) if any(alpha_equivalent(t, new_target) for t in info.targets): event("already seen") return wf = info.check_wf(new_target, context, RUNTIME_POOL) if not wf: msg = "not well-formed [wf={}]".format(wf) event(msg) blacklist[k] = msg return if not Fingerprint.of(new_target, info.examples).equal_to( info.target_fingerprint): msg = "not correct" event(msg) blacklist[k] = msg return if not info.cost_model.compare(new_target, target, context, RUNTIME_POOL).could_be(Order.LT): msg = "not an improvement" event(msg) blacklist[k] = msg return print("FOUND A GUESS") print(" * in {}".format(pprint(target), pprint(e), pprint(replacement))) print(" * replacing {}".format(pprint(e))) print(" * with {}".format(pprint(replacement))) from cozy.structures.treemultiset import ETreeMultisetElems if isinstance(e, ETreeMultisetElems) and isinstance(e.e, EStateVar) and \ isinstance(replacement, EStateVar) and isinstance(replacement.e, ETreeMultisetElems): # FIXME(zhen): current enumerator will always try to make ETreeMultisetElems a state var # FIXME(zhen): we don't want this because we need to put TreeSet into state var, rather than its iterator # FIXME(zhen): I still don't know how to fix this in a sensible way, but giving up an "improvement" # FIXME(zhen): should be okay temporarily print("give up {} -> {}".format(pprint(e), pprint(replacement))) return yield new_target
def improve(target: Exp, context: Context, assumptions: Exp = ETRUE, stop_callback: Callable[[], bool] = never_stop, hints: [Exp] = (), examples: [{ str: object }] = (), cost_model: CostModel = None, ops: [Op] = (), improve_count: Value = None): """Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target` in the given `context`. The `cost_model` defines "better". It periodically calls `stop_callback` and exits gracefully when `stop_callback` returns True. Other parameters: - assumptions: a precondition. The yielded improvements will only be correct when the assumptions are true. - hints: expressions that might be useful. These will be explored first when looking for improvements. - examples: inputs that will be used internally to differentiate semantically distinct expressions. This procedure discovers more examples as it runs, so there usually isn't a reason to provide any. - ops: update operations. This function may make different choices about what expressions are state expressions based on what changes can happen to that state. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, allowing us find kinda-good solutions very quickly, even if the best possible solution is out of reach. This is more desireable than running for an indeterminate amount of time doing nothing. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r}, ops={ops!r})""".format(target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model, ops=ops)) target = inline_lets(target) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) if heuristic_done(target): print("The target already looks great!") return print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() is_wf = exp_wf(target, context=context, assumptions=assumptions) assert is_wf, "initial target is not well-formed: {}".format(is_wf) state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [ freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions) ] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = solver_for_context(context, assumptions=assumptions) if not solver.satisfiable(ETRUE): print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return is_good = possibly_useful(solver, target, context) assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format( is_good, target) examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] blacklist = {} while True: # 1. find any potential improvement to any sub-exp of target for new_target in search_for_improvements(targets=watched_targets, wf_solver=solver, context=context, examples=examples, cost_model=cost_model, stop_callback=stop_callback, hints=hints, ops=ops, blacklist=blacklist): print("Found candidate improvement: {}".format(pprint(new_target))) # 2. check with task("verifying candidate"): counterexample = solver.satisfy(ENot(EEq(target, new_target))) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, restart examples.append(counterexample) print("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format( len(examples))) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = retention_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print( "Whoops! Looks like we already found something better." ) print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") return target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format(len(watched_targets))) break if improve_count is not None: with improve_count.get_lock(): improve_count.value += 1
def _enumerate_with_info(self, context: Context, size: int, pool: Pool) -> [EnumeratedExp]: """Helper for enumerate_with_info that bypasses the cache. Note that this method DOES affect the cache: it writes its output into the cache and may do evictions. The enumerate_with_info method ensures that there is only ever one call to this method for a given (context, size, pool). """ examples = context.instantiate_examples(self.examples) cache = self.cache queue = self._enumerate_core(context, size, pool) cost_model = self.cost_model while True: if self.stop_callback(): raise StopException() try: e = next(queue) except StopIteration: # StopIteration is a "control flow exception" indicating that # there isn't a next element. Since the queue is exhausted, # breaking out of the loop is the right thing to do. break self.stat_timer.check() e = freshen_binders(e, context) _consider(e, size, context, pool) wf = self.check_wf(e, context, pool) if not wf: _skip(e, size, context, pool, "wf={}".format(wf)) continue fp = Fingerprint.of(e, examples) # Collect all expressions from parent contexts that are # fingerprint-equivalent to this one. There might be more than one # because of how `retention_policy` works. known_equivalents = list( cache.find_equivalent_expressions(context, pool, fp)) to_evict = [] if any(e.type == prev_entry.e.type and alpha_equivalent(e, prev_entry.e) for prev_entry in known_equivalents): _skip(e, size, context, pool, "duplicate") should_keep = False else: # decide whether to keep this expression should_keep = True if known_equivalents: with task("comparing to cached equivalents", count=len(known_equivalents)): for entry in known_equivalents: prev_exp = entry.e event("previous: {}".format(pprint(prev_exp))) to_keep = retention_policy(e, context, prev_exp, context, pool, cost_model) if e not in to_keep: _skip(e, size, context, pool, "preferring {}".format(pprint(prev_exp))) should_keep = False break if prev_exp not in to_keep: to_evict.append(entry) assert not (to_evict and not should_keep) if should_keep: if self.do_eviction and to_evict: with task("evicting", count=to_evict): for entry in to_evict: _evict(entry.e, entry.size, context, pool, e, size) cache.remove(context, pool, entry) _accept(e, size, context, pool, fp) info = EnumeratedExp(e=e, fingerprint=fp, size=size) yield info cache.add(context, pool, info) if size == 0: with task("accelerating"): to_try = make_random_access( self.heuristics(e, context, pool)) if to_try: event("trying {} accelerations of {}".format( len(to_try), pprint(e))) queue = itertools.chain(to_try, queue)
def improve( target : Exp, context : Context, assumptions : Exp = ETRUE, stop_callback : Callable[[], bool] = never_stop, hints : [Exp] = (), examples : [{str:object}] = (), cost_model : CostModel = None, ops : [Op] = (), improve_count : Value = None): """Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target` in the given `context`. The `cost_model` defines "better". It periodically calls `stop_callback` and exits gracefully when `stop_callback` returns True. Other parameters: - assumptions: a precondition. The yielded improvements will only be correct when the assumptions are true. - hints: expressions that might be useful. These will be explored first when looking for improvements. - examples: inputs that will be used internally to differentiate semantically distinct expressions. This procedure discovers more examples as it runs, so there usually isn't a reason to provide any. - ops: update operations. This function may make different choices about what expressions are state expressions based on what changes can happen to that state. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, allowing us find kinda-good solutions very quickly, even if the best possible solution is out of reach. This is more desireable than running for an indeterminate amount of time doing nothing. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r}, ops={ops!r})""".format( target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model, ops=ops)) target = inline_lets(target) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) if heuristic_done(target): print("The target already looks great!") return print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() is_wf = exp_wf(target, context=context, assumptions=assumptions) assert is_wf, "initial target is not well-formed: {}".format(is_wf) state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions)] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = solver_for_context(context, assumptions=assumptions) if not solver.satisfiable(ETRUE): print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return is_good = possibly_useful(solver, target, context) assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format(is_good, target) examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] blacklist = {} while True: # 1. find any potential improvement to any sub-exp of target for new_target in search_for_improvements( targets=watched_targets, wf_solver=solver, context=context, examples=examples, cost_model=cost_model, stop_callback=stop_callback, hints=hints, ops=ops, blacklist=blacklist): print("Found candidate improvement: {}".format(pprint(new_target))) # 2. check with task("verifying candidate"): counterexample = solver.satisfy(ENot(EEq(target, new_target))) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, restart examples.append(counterexample) print("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format(len(examples))) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = retention_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print("Whoops! Looks like we already found something better.") print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") return target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format(len(watched_targets))) break if improve_count is not None: with improve_count.get_lock(): improve_count.value += 1
def _consider_replacement( target : Exp, e : Exp, ctx : Context, pool : Pool, replacement : Exp, info : SearchInfo): """Helper for search_for_improvements. This procedure decides whether replacing `e` with `replacement` in the given `target` is an improvement. If yes, it yields the result of the replacement. Otherwise it yields nothing. Parameters: - target: the top-level expression to improve - e: a subexpression of the target - ctx: e's context in the target - pool: e's pool in the target - replacement: a possible replacement for e - info: a SearchInfo object with auxiliary data This procedure may add items to info.blacklist. """ context = info.context blacklist = info.blacklist k = (e, ctx, pool, replacement) if enable_blacklist.value and k in blacklist: event("blacklisted") print("skipping blacklisted substitution: {} ---> {} ({})".format(pprint(e), pprint(replacement), blacklist[k])) return new_target = freshen_binders(replace( target, context, RUNTIME_POOL, e, ctx, pool, replacement), context) if any(alpha_equivalent(t, new_target) for t in info.targets): event("already seen") return wf = info.check_wf(new_target, context, RUNTIME_POOL) if not wf: msg = "not well-formed [wf={}]".format(wf) event(msg) blacklist[k] = msg return if not Fingerprint.of(new_target, info.examples).equal_to(info.target_fingerprint): msg = "not correct" event(msg) blacklist[k] = msg return if not info.cost_model.compare(new_target, target, context, RUNTIME_POOL).could_be(Order.LT): msg = "not an improvement" event(msg) blacklist[k] = msg return print("FOUND A GUESS") print(" * in {}".format(pprint(target), pprint(e), pprint(replacement))) print(" * replacing {}".format(pprint(e))) print(" * with {}".format(pprint(replacement))) from cozy.structures.treemultiset import ETreeMultisetElems if isinstance(e, ETreeMultisetElems) and isinstance(e.e, EStateVar) and \ isinstance(replacement, EStateVar) and isinstance(replacement.e, ETreeMultisetElems): # FIXME(zhen): current enumerator will always try to make ETreeMultisetElems a state var # FIXME(zhen): we don't want this because we need to put TreeSet into state var, rather than its iterator # FIXME(zhen): I still don't know how to fix this in a sensible way, but giving up an "improvement" # FIXME(zhen): should be okay temporarily print("give up {} -> {}".format(pprint(e), pprint(replacement))) return yield new_target