def visit_EUnaryOp(self, e): op = e.op if op == UOp.Distinct: return self.visit_iterable(e) elif op == UOp.The: return self.find_one(e.e) elif op == UOp.Sum: sum_var = fresh_var(e.type, "sum") loop_var = fresh_var(e.e.type.elem_type, "x") self.stms.append(simplify_and_optimize(seq([ SDecl(sum_var, ENum(0).with_type(e.type)), SForEach(loop_var, e.e, SAssign(sum_var, EBinOp(sum_var, "+", loop_var).with_type(INT)))]))) return sum_var elif op == UOp.Length: arg = EVar("x").with_type(e.e.type.elem_type) return self.visit(EUnaryOp(UOp.Sum, EMap(e.e, ELambda(arg, ONE)).with_type(INT_BAG)).with_type(INT)) elif op == UOp.All: arg = EVar("x").with_type(e.e.type.elem_type) return self.visit(EUnaryOp(UOp.Empty, EFilter(e.e, ELambda(arg, ENot(arg))).with_type(INT_BAG)).with_type(INT)) elif op == UOp.Any: arg = EVar("x").with_type(e.e.type.elem_type) return self.visit(EUnaryOp(UOp.Exists, EFilter(e.e, ELambda(arg, arg)).with_type(INT_BAG)).with_type(INT)) elif op == UOp.Empty: iterable = e.e v = fresh_var(BOOL, "v") label = fresh_name("label") x = fresh_var(iterable.type.elem_type, "x") decl = SDecl(v, ETRUE) find = SEscapableBlock(label, SForEach(x, iterable, seq([ SAssign(v, EFALSE), SEscapeBlock(label)]))) self.stms.append(simplify_and_optimize(seq([decl, find]))) return v elif op == UOp.Exists: return self.visit(ENot(EUnaryOp(UOp.Empty, e.e).with_type(BOOL))) # elif op == UOp.AreUnique: # s = fresh_var(TSet(e.e.type.elem_type), "unique_elems") # u = fresh_var(BOOL, "is_unique") # x = fresh_var(e.e.type.elem_type) # label = fresh_name("label") # self.visit(seq([ # SDecl(s, EEmptyList().with_type(s.type)), # SDecl(u, ETRUE), # SEscapableBlock(label, # SForEach(x, e.e, # SIf(EEscape("{s}.find({x}) != {s}.end()", ("s", "x"), (s, x)).with_type(BOOL), # seq([SAssign(u, EFALSE), SEscapeBlock(label)]), # SEscape("{indent}{s}.insert({x});\n", ("s", "x"), (s, x)))))])) # return u.id return self.visit_Exp(e)
def visit_SWhile(self, w): self.begin_statement() self.write("for (;;) ") with self.block(): self.visit( SIf(ENot(w.e), SEscape("{indent}break;\n", (), ()), SNoOp())) self.visit(w.body) self.end_statement()
def _setup_handle_updates(self): """ This method creates update code for handle objects modified by each op. Must be called once after all user-specified queries have been added. """ for op in self.op_specs: print("Setting up handle updates for {}...".format(op.name)) handles = reachable_handles_at_method(self.spec, op) # print("-"*60) for t, bag in handles.items(): # print(" {} : {}".format(pprint(t), pprint(bag))) h = fresh_var(t) lval = EGetField(h, "val").with_type(t.value_type) new_val = inc.mutate(lval, op.body) # get set of modified handles modified_handles = Query( fresh_name("modified_handles"), Visibility.Internal, [], op.assumptions, EFilter( EUnaryOp(UOp.Distinct, bag).with_type(bag.type), ELambda(h, ENot(EEq(lval, new_val)))).with_type(bag.type), "[{}] modified handles of type {}".format( op.name, pprint(t))) query_vars = [ v for v in free_vars(modified_handles) if v not in self.abstract_state ] modified_handles.args = [(arg.id, arg.type) for arg in query_vars] # modify each one subqueries = [] state_update_stm = inc.mutate_in_place( lval, lval, op.body, abstract_state=self.abstract_state, assumptions=list(op.assumptions) + [EDeepIn(h, bag), EIn(h, modified_handles.ret)], invariants=self.abstract_invariants, subgoals_out=subqueries) for sub_q in subqueries: sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring) state_update_stm = self._add_subquery( sub_q=sub_q, used_by=state_update_stm) if state_update_stm != SNoOp(): state_update_stm = SForEach( h, ECall(modified_handles.name, query_vars).with_type(bag.type), state_update_stm) state_update_stm = self._add_subquery( sub_q=modified_handles, used_by=state_update_stm) self.handle_updates[(t, op.name)] = state_update_stm
def array_resize_for_index(self, elem_type, a, i): """Resize the array until `i` is a legal index. When i < 0, it will do nothing instead. """ new_a = fresh_name(hint="new_array") if elem_type == BOOL: t = "long" else: t = self.strip_generics(self.visit(elem_type, name="")) len = EEscape("{a}.length", ["a"], [a]).with_type(INT) double_and_incr_size = SEscape( "{{indent}}{t}[] {new_a} = new {t}[({{len}} << 1) + 1];\n{{indent}}System.arraycopy({{a}}, 0, {new_a}, 0, {{len}});\n{{indent}}{{a}} = {new_a};\n".format(t=t, new_a=new_a), ["a", "len"], [a, len]) self.visit(SWhile( EAll([EBinOp(i, ">=", ZERO).with_type(BOOL), ENot(self.array_in_bounds(elem_type, a, i))]), double_and_incr_size))
def visit_EBinOp(self, e): op = e.op if op == "==": return self._eq(e.e1, e.e2) elif op == "===": # rewrite deep-equality test into regular equality op = "==" elif op == "!=": return self.visit(ENot(EEq(e.e1, e.e2))) elif op == BOp.Or: return self.visit(ECond(e.e1, ETRUE, e.e2).with_type(BOOL)) elif op == BOp.And: return self.visit(ECond(e.e1, e.e2, EFALSE).with_type(BOOL)) elif op == BOp.In: if isinstance(e.e2.type, TSet): return self.test_set_containment_native(e.e2, e.e1) else: raise Exception( "{!r} operator is supposed to be handled by simplify_and_optimize" .format(op)) return "({e1} {op} {e2})".format(e1=self.visit(e.e1), op=op, e2=self.visit(e.e2))
def stream(iterable : Exp, loop_var : EVar, body : Stm) -> Stm: """Convert an iterable expression to a streaming operation. Input: iterable - an expression with an iterable type (Bag, Set, or List), not yet optimized loop_var - a variable to use as the loop variable body - a statement to run on that variable, not yet optimized Output: A statement equivalent to for (loop_var in iterable) { body } that eliminates as many intermediate collections and objects as possible. NOTE: The output of function will not be correct if the body modifies any free variable in the iterable expression or writes to any pointers that are read by the iterable expression. Generating code for the expression Map {func} (Filter {predicate} big_collection) might create two new collections as large as `big_collection`: one to hold the result of the filter and one to hold the result of the map. If all the code needs to do is to iterate over the result, then there is no reason to make the two new collections. This function is mutually recursive with `simplify_and_optimize`, so any transformations performed by that method are also applied to the output of this one. """ if isinstance(iterable, EEmptyList): return SNoOp() elif isinstance(iterable, ESingleton): setup, value = simplify_and_optimize_expression(iterable.e) # SScoped because if the iterable is e.g. [x] + [y], then the body # might be appear in the same block twice. If the body declares any # variables, that will cause problems in languages like Java or C++. return seq([setup, SScoped(re_use(value, loop_var, simplify_and_optimize(body)))]) elif isinstance(iterable, ECond): cond_setup, cond = simplify_and_optimize_expression(iterable.cond) return seq([ cond_setup, SIf(cond, stream(iterable.then_branch, loop_var, body), stream(iterable.else_branch, loop_var, body))]) elif isinstance(iterable, EUnaryOp) and iterable.op == UOp.Distinct: tmp = fresh_var(TSet(iterable.type.elem_type), "distinct_elems") return seq([ SDecl(tmp, EEmptyList().with_type(tmp.type)), stream(iterable.e, loop_var, SIf( ENot(EBinOp(loop_var, BOp.In, tmp).with_type(BOOL)), seq([body, SCall(tmp, "add", [loop_var])]), SNoOp()))]) elif isinstance(iterable, EBinOp) and iterable.op == "+": return seq([ stream(iterable.e1, loop_var, body), stream(iterable.e2, loop_var, body)]) elif isinstance(iterable, EBinOp) and iterable.op == "-": if is_hashable(iterable.type.elem_type): h_setup, h = histogram(iterable.e2) val_ref = fresh_var(INT, "count") return seq([ simplify_and_optimize(h_setup), stream( iterable.e1, loop_var, SIf(EGt(EMapGet(h, loop_var).with_type(INT), ZERO), SMapUpdate(h, loop_var, val_ref, SAssign(val_ref, EBinOp(val_ref, "-", ONE).with_type(INT))), body))]) else: rhs = fresh_var(iterable.e2.type, "bag_subtraction_right") return seq([ simplify_and_optimize(SDecl(rhs, iterable.e2)), stream( iterable.e1, loop_var, SIf(EIn(loop_var, rhs), SCall(rhs, "remove", (loop_var,)), body))]) elif isinstance(iterable, EFilter): return stream( EFlatMap(iterable.e, ELambda(iterable.predicate.arg, ECond(iterable.predicate.body, ESingleton(iterable.predicate.arg).with_type(iterable.type), EEmptyList().with_type(iterable.type)).with_type(iterable.type))).with_type(iterable.type), loop_var, body) elif isinstance(iterable, EMap): return stream( EFlatMap(iterable.e, ELambda(iterable.transform_function.arg, ESingleton(iterable.transform_function.body).with_type(iterable.type))).with_type(iterable.type), loop_var, body) elif isinstance(iterable, EFlatMap): inner_loop_var = fresh_var( iterable.transform_function.arg.type, iterable.transform_function.arg.id) return stream( iterable.e, inner_loop_var, stream(iterable.transform_function.apply_to(inner_loop_var), loop_var, body)) elif isinstance(iterable, EListSlice): elem_type = iterable.type.elem_type l = fresh_var(iterable.e.type, "list") s = fresh_var(INT, "start") e = fresh_var(INT, "end") return simplify_and_optimize(seq([ SDecl(l, iterable.e), SDecl(s, max_of(iterable.start, ZERO)), SDecl(e, min_of(iterable.end, ELen(l))), SWhile(ELt(s, e), seq([ SDecl(loop_var, EListGet(l, s).with_type(elem_type)), body, SAssign(s, EBinOp(s, "+", ONE).with_type(INT))]))])) elif isinstance(iterable, ELet): v = fresh_var( iterable.body_function.arg.type, iterable.body_function.arg.id) return seq([ simplify_and_optimize(SDecl(v, iterable.e)), stream(iterable.body_function.apply_to(v), loop_var, body)]) elif isinstance(iterable, EMove): return stream(iterable.e, loop_var, body) else: assert is_collection(iterable.type), repr(iterable) setup, e = simplify_and_optimize_expression(iterable) return seq([setup, SForEach(loop_var, e, simplify_and_optimize(body))])
def possibly_useful_nonrecursive( solver, e: Exp, context: Context, pool=RUNTIME_POOL, assumptions: Exp = ETRUE, ops: [Op] = ()) -> bool: """Heuristic filter to ignore expressions that are almost certainly useless.""" state_vars = OrderedSet(v for v, p in context.vars() if p == STATE_POOL) args = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL) assumptions = EAll([assumptions, context.path_condition()]) at_runtime = pool == RUNTIME_POOL h = extension_handler(type(e)) if h is not None: res = h.possibly_useful(e, context, pool, assumptions, ops, solver) if not res: return res if isinstance(e, EStateVar) and not free_vars(e.e): return No("constant value in state position") if (isinstance(e, EDropFront) or isinstance(e, EDropBack)) and not at_runtime: return No("EDrop* in state position") if not allow_big_sets.value and isinstance(e, EFlatMap) and not at_runtime: return No("EFlatMap in state position") if not allow_int_arithmetic_state.value and not at_runtime and isinstance( e, EBinOp) and e.type == INT: return No("integer arithmetic in state position") if is_collection(e.type) and not is_scalar(e.type.elem_type): return No("collection of nonscalar: e {}\n elem_type: {}\n".format( e, e.type.elem_type)) if isinstance(e.type, TMap) and not is_scalar(e.type.k): return No("bad key type {}".format(pprint(e.type.k))) if isinstance(e.type, TMap) and isinstance(e.type.v, TMap): return No("map to map") # This check is probably a bad idea: whether `the` is legal may depend on # the contex that the expression is embedded within, so we can't skip it # during synthesis just because it looks invalid now. # if isinstance(e, EUnaryOp) and e.op == UOp.The: # len = EUnaryOp(UOp.Length, e.e).with_type(INT) # if not valid(EImplies(assumptions, EBinOp(len, "<=", ENum(1).with_type(INT)).with_type(BOOL))): # return No("illegal application of 'the': could have >1 elems") if not at_runtime and isinstance( e, EBinOp) and e.op == "-" and is_collection(e.type): return No("collection subtraction in state position") # if not at_runtime and isinstance(e, ESingleton): # return No("singleton in state position") if not allow_nonzero_state_constants.value and not at_runtime and isinstance( e, ENum) and e.val != 0: return No("nonzero integer constant in state position") if not allow_binop_state.value and at_runtime and isinstance( e, EStateVar) and isinstance(e.e, EBinOp) and is_scalar( e.e.e1.type) and is_scalar(e.e.e2.type): return No( "constant-time binary operator {!r} in state position".format( e.e.op)) if not allow_conditional_state.value and not at_runtime and isinstance( e, ECond): return No("conditional in state position") if isinstance(e, EMakeMap2) and isinstance(e.e, EEmptyList): return No("trivially empty map") if isinstance(e, EMakeMap2) and isinstance(e.e, ESingleton): return No("really tiny map") if not at_runtime and (isinstance(e, EArgMin) or isinstance(e, EArgMax)): # Cozy has no way to efficiently implement mins/maxes when more than # one element may leave the collection. from cozy.state_maintenance import mutate for op in ops: elems = e.e elems_prime = mutate(elems, op.body) formula = EAll([assumptions] + list(op.assumptions) + [ EGt( ELen( EBinOp(elems, "-", elems_prime).with_type(elems.type)), ONE) ]) if solver.satisfiable(formula): return No( "more than one element might be removed during {}".format( op.name)) if not allow_peels.value and not at_runtime and isinstance(e, EFilter): # catch "peels": removal of zero or one elements if solver.valid( EImplies( assumptions, ELe( ELen( EFilter( e.e, ELambda(e.predicate.arg, ENot( e.predicate.body))).with_type(e.type)), ONE))): return No("filter is a peel") if not allow_big_maps.value and not at_runtime and isinstance( e, EMakeMap2) and is_collection(e.type.v): all_collections = [sv for sv in state_vars if is_collection(sv.type)] total_size = ENum(0).with_type(INT) for c in all_collections: total_size = EBinOp(total_size, "+", EUnaryOp(UOp.Length, c).with_type(INT)).with_type(INT) my_size = EUnaryOp( UOp.Length, EFlatMap( EUnaryOp(UOp.Distinct, e.e).with_type(e.e.type), e.value_function).with_type(e.type.v)).with_type(INT) s = EImplies(assumptions, EBinOp(total_size, ">=", my_size).with_type(BOOL)) if not solver.valid(s): return No("non-polynomial-sized map") return True
def improve(target: Exp, context: Context, assumptions: Exp = ETRUE, stop_callback: Callable[[], bool] = never_stop, hints: [Exp] = (), examples: [{ str: object }] = (), cost_model: CostModel = None, ops: [Op] = (), improve_count: Value = None): """Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target` in the given `context`. The `cost_model` defines "better". It periodically calls `stop_callback` and exits gracefully when `stop_callback` returns True. Other parameters: - assumptions: a precondition. The yielded improvements will only be correct when the assumptions are true. - hints: expressions that might be useful. These will be explored first when looking for improvements. - examples: inputs that will be used internally to differentiate semantically distinct expressions. This procedure discovers more examples as it runs, so there usually isn't a reason to provide any. - ops: update operations. This function may make different choices about what expressions are state expressions based on what changes can happen to that state. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, allowing us find kinda-good solutions very quickly, even if the best possible solution is out of reach. This is more desireable than running for an indeterminate amount of time doing nothing. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r}, ops={ops!r})""".format(target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model, ops=ops)) target = inline_lets(target) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) if heuristic_done(target): print("The target already looks great!") return print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() is_wf = exp_wf(target, context=context, assumptions=assumptions) assert is_wf, "initial target is not well-formed: {}".format(is_wf) state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [ freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions) ] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = solver_for_context(context, assumptions=assumptions) if not solver.satisfiable(ETRUE): print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return is_good = possibly_useful(solver, target, context) assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format( is_good, target) examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] blacklist = {} while True: # 1. find any potential improvement to any sub-exp of target for new_target in search_for_improvements(targets=watched_targets, wf_solver=solver, context=context, examples=examples, cost_model=cost_model, stop_callback=stop_callback, hints=hints, ops=ops, blacklist=blacklist): print("Found candidate improvement: {}".format(pprint(new_target))) # 2. check with task("verifying candidate"): counterexample = solver.satisfy(ENot(EEq(target, new_target))) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, restart examples.append(counterexample) print("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format( len(examples))) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = retention_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print( "Whoops! Looks like we already found something better." ) print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") return target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format(len(watched_targets))) break if improve_count is not None: with improve_count.get_lock(): improve_count.value += 1
def improve(target: Exp, context: Context, assumptions: Exp = T, stop_callback=never_stop, hints: [Exp] = (), examples: [{ str: object }] = (), cost_model: CostModel = None, ops: [Op] = ()): """ Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target`. Notes on internals of this algorithm follow. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, and lets us find kinda-good solutions very quickly, even if the best possible solution is out of reach. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r}, ops={ops!r})""".format(target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model, ops=ops)) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() is_wf = exp_wf(target, context=context, assumptions=assumptions) if not is_wf: print( "WARNING: initial target is not well-formed [{}]; this might go poorly..." .format(is_wf)) print(pprint(is_wf.offending_subexpression)) print(pprint(is_wf.offending_subexpression.type)) state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [ freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions) ] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = solver_for_context(context, assumptions=assumptions) if not solver.satisfiable(T): print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] learner = Learner(watched_targets, solver, context, examples, cost_model, stop_callback, hints, ops=ops) while True: # 1. find any potential improvement to any sub-exp of target for new_target in learner.search(): print("Found candidate improvement: {}".format(pprint(new_target))) # 2. check with task("verifying candidate"): counterexample = solver.satisfy(ENot(EEq(target, new_target))) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, reset the learner examples.append(counterexample) event("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format( len(examples))) learner.reset(examples) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = eviction_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print( "Whoops! Looks like we already found something better." ) print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") return target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format(len(watched_targets))) learner.watch(watched_targets) break