def test_enumerator_fingerprints(self): """ The enumerator should always give us fingerprints in the context we asked for. """ x = EVar("x").with_type(INT) ctx = RootCtx(args=(x, ), state_vars=()) enumerator = Enumerator(examples=[{ "x": 0 }, { "x": 1 }], cost_model=CostModel()) inner_ctx = UnderBinder( ctx, EVar("y").with_type(INT), EBinOp( ESingleton(ZERO).with_type(INT_BAG), "+", ESingleton(ONE).with_type(INT_BAG)).with_type(INT_BAG), RUNTIME_POOL) fingerprint_lens = set() for info in enumerator.enumerate_with_info(inner_ctx, 0, RUNTIME_POOL): fingerprint_lens.add(len(info.fingerprint)) print(info) assert len(fingerprint_lens) == 1, fingerprint_lens
def test_heap_enumeration(self): xs = EVar("xs").with_type(INT_BAG) context = RootCtx(state_vars=[xs]) cost_model = CostModel() def not_min_or_max(e, *args, **kwargs): # forbid min/max to ensure that heap operations get cached if isinstance(e, EArgMin) or isinstance(e, EArgMax): return False return True enumerator = Enumerator(examples=[{ "xs": Bag(()) }, { "xs": Bag((1, 2)) }, { "xs": Bag((1, 1)) }], cost_model=cost_model, check_wf=not_min_or_max) with save_property(accelerate, "value"): accelerate.value = False print("-" * 20 + " Looking for xs...") found_xs = False for e in enumerator.enumerate(context, 0, STATE_POOL): print(pprint(e)) if e == xs: assert retypecheck(deep_copy(e)) found_xs = True print("^^^ FOUND") assert found_xs print("-" * 20 + " Looking for heap construction...") found_make_heap = False for e in enumerator.enumerate(context, 1, STATE_POOL): print(pprint(e)) if isinstance(e, EMakeMinHeap) or isinstance(e, EMakeMaxHeap): assert retypecheck(deep_copy(e)) found_make_heap = True print("^^^ FOUND") assert found_make_heap print("-" * 20 + " Looking for heap usage...") found_heap_peek = False for e in enumerator.enumerate(context, 2, RUNTIME_POOL): print(pprint(e)) if isinstance(e, EHeapPeek) or isinstance(e, EHeapPeek2): assert retypecheck(deep_copy(e)) found_heap_peek = True print("^^^ FOUND") assert found_heap_peek
def eviction_policy(new_exp: Exp, new_ctx: Context, old_exp: Exp, old_ctx: Context, pool: Pool, cost_model: CostModel) -> [Exp]: """Decide which expressions to keep in the cache. The returned list contains the new exp, the old exp, or both. """ context = more_specific(new_ctx, old_ctx) ordering = cost_model.compare(new_exp, old_exp, context, pool) if ordering == Order.LT: return [new_exp] if ordering == Order.GT: return [old_exp] if ordering == Order.EQUAL: return [old_exp] if ordering == Order.AMBIGUOUS: return [new_exp, old_exp] raise ValueError(ordering)
def run(self): os.makedirs(log_dir.value, exist_ok=True) with open(os.path.join(log_dir.value, "{}.log".format(self.q.name)), "w", buffering=LINE_BUFFER_MODE) as f: original_stdout = sys.stdout sys.stdout = f try: print("STARTING IMPROVEMENT JOB {}".format(self.q.name)) print(pprint(self.q)) if nice_children.value: os.nice(20) stop_callback = lambda: self.stop_requested cost_model = CostModel( funcs=self.context.funcs(), assumptions=EAll(self.assumptions), freebies=self.freebies, ops=self.ops, solver_args={"stop_callback": stop_callback}) for expr in itertools.chain( (self.q.ret, ), core.improve(target=self.q.ret, assumptions=EAll(self.assumptions), context=self.context, hints=self.hints, stop_callback=stop_callback, cost_model=cost_model, ops=self.ops, improve_count=self.improve_count)): self.solutions_q.put((self.q, expr)) print("PROVED OPTIMALITY FOR {}".format(self.q.name)) except core.StopException: print("stopping synthesis of {}".format(self.q.name)) return finally: # Restore the original stdout handle. Python multiprocessing does # some stream flushing as the process exits, and if we leave stdout # unchanged then it will refer to a closed file when that happens. sys.stdout = original_stdout
def test_hint_instantation(self): x = EVar("x").with_type(INT) y = EVar("y").with_type(INT) z = EVar("z").with_type(INT) hint = ECall("f", (x, )).with_type(INT) context = UnderBinder(RootCtx(args=[x]), v=y, bag=ESingleton(x).with_type(TBag(x.type)), bag_pool=RUNTIME_POOL) cost_model = CostModel() f = lambda a: a + 1 enumerator = Enumerator(examples=[{ "x": 1, "f": f }, { "x": 100, "f": f }], hints=[(hint, context, RUNTIME_POOL)], cost_model=cost_model) results = [] for ctx in (context, context.parent(), UnderBinder(context, v=z, bag=ESingleton(y).with_type(TBag(y.type)), bag_pool=RUNTIME_POOL), UnderBinder(context.parent(), v=z, bag=ESingleton(x).with_type(TBag(y.type)), bag_pool=RUNTIME_POOL), UnderBinder(context.parent(), v=y, bag=ESingleton(ONE).with_type(INT_BAG), bag_pool=RUNTIME_POOL)): print("-" * 30) found = False for e in enumerator.enumerate(ctx, 0, RUNTIME_POOL): print(" -> {}".format(pprint(e))) found = found or alpha_equivalent(e, hint) print("found? {}".format(found)) results.append(found) assert all(results)
def test_state_pool_boundary(self): """ When enumerating expressions, we shouldn't ever enumerate state expressions in a context where some binders are runtime variables. """ class TestEnumerator(Enumerator): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.state_enumerations = 0 def _enumerate_core(self, context, size, pool): print("_enumerate_core({}, {}, {})".format( context, size, pool)) if pool == STATE_POOL: self.state_enumerations += 1 return super()._enumerate_core(context, size, pool) state_bag = EVar("state").with_type(INT_BAG) context = RootCtx(state_vars=[state_bag], args=[EVar("arg").with_type(INT)]) enumerator = TestEnumerator(examples=[{ "state": Bag([10]), "arg": 10 }, { "state": Bag([20]), "arg": 30 }], cost_model=CostModel()) for e in enumerator.enumerate(context, 1, RUNTIME_POOL): pass for e in enumerator.enumerate( UnderBinder(context, EVar("x").with_type(INT), EStateVar(state_bag).with_type(state_bag.type), RUNTIME_POOL), 1, RUNTIME_POOL): pass assert enumerator.state_enumerations == 1
def run(self): print("STARTING IMPROVEMENT JOB {}".format(self.q.name)) os.makedirs(log_dir.value, exist_ok=True) with open(os.path.join(log_dir.value, "{}.log".format(self.q.name)), "w", buffering=LINE_BUFFER_MODE) as f: sys.stdout = f print("STARTING IMPROVEMENT JOB {}".format(self.q.name)) print(pprint(self.q)) if nice_children.value: os.nice(20) ctx = RootCtx( state_vars=self.state, args=[EVar(v).with_type(t) for (v, t) in self.q.args], funcs=self.funcs) cost_model = CostModel(funcs=ctx.funcs(), assumptions=EAll(self.assumptions), freebies=self.freebies, ops=self.ops) try: for expr in itertools.chain( (self.q.ret, ), core.improve(target=self.q.ret, assumptions=EAll(self.assumptions), context=ctx, hints=self.hints, stop_callback=lambda: self.stop_requested, cost_model=cost_model, ops=self.ops)): new_rep, new_ret = unpack_representation(expr) self.k(new_rep, new_ret) print("PROVED OPTIMALITY FOR {}".format(self.q.name)) except core.StopException: print("stopping synthesis of {}".format(self.q.name)) return
def improve(target: Exp, context: Context, assumptions: Exp = T, stop_callback=never_stop, hints: [Exp] = (), examples: [{ str: object }] = (), cost_model: CostModel = None): """ Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target`. Notes on internals of this algorithm follow. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, and lets us find kinda-good solutions very quickly, even if the best possible solution is out of reach. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r})""".format(target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model)) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() try: assert exp_wf(target, context=context, assumptions=assumptions) except ExpIsNotWf as ex: print( "WARNING: initial target is not well-formed [{}]; this might go poorly..." .format(str(ex))) print(pprint(ex.offending_subexpression)) print(pprint(ex.offending_subexpression.type)) # raise state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [ freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions) ] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = None if incremental.value: solver = IncrementalSolver(vars=vars, funcs=funcs) solver.add_assumption(assumptions) _sat = solver.satisfy else: _sat = lambda e: satisfy(e, vars=vars, funcs=funcs) if _sat(assumptions) is None: print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] learner = Learner(watched_targets, assumptions, context, examples, cost_model, stop_callback, hints) try: while True: # 1. find any potential improvement to any sub-exp of target for new_target in learner.next(): print("Found candidate improvement: {}".format( pprint(new_target))) # 2. check with task("verifying candidate"): if incremental.value: solver.push() solver.add_assumption( ENot( EBinOp(target, "==", new_target).with_type(BOOL))) counterexample = _sat(T) else: formula = EAll([ assumptions, ENot( EBinOp(target, "==", new_target).with_type(BOOL)) ]) counterexample = _sat(formula) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, reset the learner examples.append(counterexample) event("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format( len(examples))) learner.reset(examples) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = eviction_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print( "Whoops! Looks like we already found something better." ) print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") raise NoMoreImprovements() target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format( len(watched_targets))) learner.watch(watched_targets) break if incremental.value: solver.pop() except NoMoreImprovements: return except KeyboardInterrupt: raise
def improve( target : Exp, assumptions : Exp, binders : [EVar], state_vars : [EVar], args : [EVar], cost_model : CostModel, builder : ExpBuilder, stop_callback = never_stop, hints : [Exp] = None, examples = None): """ Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target`. Notes on internals of this algorithm follow. Key differences from "regular" enumerative synthesis: - Expressions may be built using a set of "binders"---extra free variables thrown into the mix at the beginning. - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) Other features of this algorithm: - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, and lets us find kinda-good solutions very quickly, even if the best possible solution is out of reach. """ print("call to improve:") print("""improve( target={target!r}, assumptions={assumptions!r}, binders={binders!r}, state_vars={state_vars!r}, args={args!r}, cost_model={cost_model!r}, builder={builder!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r})""".format( target=target, assumptions=assumptions, binders=binders, state_vars=state_vars, args=args, cost_model=cost_model, builder=builder, stop_callback=stop_callback, hints=hints, examples=examples)) print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() assert exp_wf( target, state_vars=set(state_vars), args=set(args), assumptions=assumptions) binders = list(binders) target = fixup_binders(target, binders, allow_add=False) hints = [fixup_binders(h, binders, allow_add=False) for h in (hints or ())] assumptions = fixup_binders(assumptions, binders, allow_add=False) builder = FixedBuilder(builder, state_vars, args, binders, assumptions) target_cost = cost_model.cost(target, RUNTIME_POOL) if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") builder = StateElimBuilder(builder) vars = list(free_vars(target) | free_vars(assumptions)) funcs = free_funcs(EAll([target, assumptions])) solver = None if incremental.value: solver = IncrementalSolver(vars=vars, funcs=funcs, collection_depth=check_depth.value) solver.add_assumption(assumptions) _sat = solver.satisfy else: _sat = lambda e: satisfy(e, vars=vars, funcs=funcs, collection_depth=check_depth.value) if _sat(T) is None: print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return if examples is None: examples = [] learner = Learner(target, assumptions, binders, state_vars, args, vars + binders, examples, cost_model, builder, stop_callback, hints, solver=solver) try: while True: # 1. find any potential improvement to any sub-exp of target try: old_e, new_e, local_assumptions, repl = learner.next() except NoMoreImprovements: break # 2. substitute-in the improvement print("Found candidate replacement [{}] for [{}] in".format(pprint(new_e), pprint(old_e))) print(pprint(repl(EVar("@___")))) new_target = repl(new_e) # 3. check if incremental.value: solver.push() solver.add_assumption(ENot(EBinOp(target, "==", new_target).with_type(BOOL))) counterexample = _sat(T) else: formula = EAll([assumptions, ENot(EBinOp(target, "==", new_target).with_type(BOOL))]) counterexample = _sat(formula) if counterexample is not None: # Ok they aren't equal. Now we need an example that # differentiates BOTH target/new_target AND old_e/new_e. if incremental.value: counterexample = _sat(EAll([ EAll(local_assumptions), ENot(EBinOp(old_e, "===", new_e).with_type(BOOL))])) else: counterexample = _sat(EAll([ assumptions, EAll(local_assumptions), ENot(EBinOp(target, "==", new_target).with_type(BOOL)), ENot(EBinOp(old_e, "===", new_e).with_type(BOOL))])) if counterexample is None: print("!!! unable to satisfy top- and sub-expressions") print("assumptions = {!r}".format(assumptions)) print("local_assumptions = {!r}".format(EAll(local_assumptions))) print("old_e = {!r}".format(old_e)) print("target = {!r}".format(target)) print("new_e = {!r}".format(new_e)) print("new_target = {!r}".format(new_target)) raise Exception("unable to find an example that differentiates both the toplevel- and sub-expressions") if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) print("old fp = {}".format(learner._fingerprint(old_e))) print("new fp = {}".format(learner._fingerprint(new_e))) print("old target fp = {}".format(learner._fingerprint(target))) print("new target fp = {}".format(learner._fingerprint(new_target))) raise Exception("got a duplicate example") # a. if incorrect: add example, reset the learner examples.append(counterexample) print("new example: {}".format(truncate(repr(counterexample)))) print("restarting with {} examples".format(len(examples))) learner.reset(examples) else: # b. if correct: yield it, watch the new target, goto 1 if CHECK_FINAL_COST: new_cost = cost_model.cost(new_target, RUNTIME_POOL) print("cost: {} -----> {}".format(target_cost, new_cost)) if incremental.value: ordering = new_cost.compare_to(target_cost, solver=solver) else: ordering = new_cost.compare_to(target_cost, assumptions=assumptions) if ordering == Cost.WORSE: if CHECK_SUBST_COST: print("WHOOPS! COST GOT WORSE!") if save_testcases.value: with open(save_testcases.value, "a") as f: f.write("def testcase():\n") f.write(" costmodel = {}\n".format(repr(cost_model))) f.write(" old_e = {}\n".format(repr(old_e))) f.write(" new_e = {}\n".format(repr(new_e))) f.write(" target = {}\n".format(repr(target))) f.write(" new_target = {}\n".format(repr(new_target))) f.write(" if costmodel.cost(new_e, RUNTIME_POOL) <= costmodel.cost(old_e, RUNTIME_POOL) and costmodel.cost(new_target, RUNTIME_POOL) > costmodel.cost(target, RUNTIME_POOL):\n") f.write(' for name, x in zip(["old_e", "new_e", "target", "new_target"], [old_e, new_e, target, new_target]):\n') f.write(' print("{}: {}".format(name, pprint(x)))\n') f.write(' print(" cost = {}".format(costmodel.cost(x, RUNTIME_POOL)))\n') f.write(" assert False\n") # raise Exception("detected nonmonotonicity") else: print("*** cost is worse") # print(repr(target)) # print(repr(new_target)) continue elif ordering == Cost.UNORDERED: print("*** cost is unchanged") # print(repr(target)) # print(repr(new_target)) continue target_cost = new_cost print("found improvement: {} -----> {}".format(pprint(old_e), pprint(new_e))) # print(repr(target)) # print(repr(new_target)) # binders are not allowed to "leak" out to_yield = new_target if any(v in binders for v in free_vars(new_target)): print("WARNING: stripping binders in {}".format(pprint(new_target)), file=sys.stderr) to_yield = subst(new_target, { b.id : construct_value(b.type) for b in binders }) yield to_yield if reset_on_success.value and (not CHECK_FINAL_COST or ordering != Cost.UNORDERED): learner.reset(examples) learner.watch(new_target) target = new_target if heuristic_done(new_target, args): print("target now matches doneness heuristic") break if incremental.value: solver.pop() except KeyboardInterrupt: for e in learner.cache.random_sample(50): print(pprint(e)) raise
def improve(target: Exp, context: Context, assumptions: Exp = ETRUE, stop_callback: Callable[[], bool] = never_stop, hints: [Exp] = (), examples: [{ str: object }] = (), cost_model: CostModel = None, ops: [Op] = (), improve_count: Value = None): """Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target` in the given `context`. The `cost_model` defines "better". It periodically calls `stop_callback` and exits gracefully when `stop_callback` returns True. Other parameters: - assumptions: a precondition. The yielded improvements will only be correct when the assumptions are true. - hints: expressions that might be useful. These will be explored first when looking for improvements. - examples: inputs that will be used internally to differentiate semantically distinct expressions. This procedure discovers more examples as it runs, so there usually isn't a reason to provide any. - ops: update operations. This function may make different choices about what expressions are state expressions based on what changes can happen to that state. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, allowing us find kinda-good solutions very quickly, even if the best possible solution is out of reach. This is more desireable than running for an indeterminate amount of time doing nothing. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r}, ops={ops!r})""".format(target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model, ops=ops)) target = inline_lets(target) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) if heuristic_done(target): print("The target already looks great!") return print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() is_wf = exp_wf(target, context=context, assumptions=assumptions) assert is_wf, "initial target is not well-formed: {}".format(is_wf) state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [ freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions) ] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = solver_for_context(context, assumptions=assumptions) if not solver.satisfiable(ETRUE): print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return is_good = possibly_useful(solver, target, context) assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format( is_good, target) examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] blacklist = {} while True: # 1. find any potential improvement to any sub-exp of target for new_target in search_for_improvements(targets=watched_targets, wf_solver=solver, context=context, examples=examples, cost_model=cost_model, stop_callback=stop_callback, hints=hints, ops=ops, blacklist=blacklist): print("Found candidate improvement: {}".format(pprint(new_target))) # 2. check with task("verifying candidate"): counterexample = solver.satisfy(ENot(EEq(target, new_target))) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, restart examples.append(counterexample) print("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format( len(examples))) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = retention_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print( "Whoops! Looks like we already found something better." ) print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") return target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format(len(watched_targets))) break if improve_count is not None: with improve_count.get_lock(): improve_count.value += 1