def visit_EUnaryOp(self, e): op = e.op if op == UOp.Distinct: return self.visit_iterable(e) elif op == UOp.The: return self.find_one(e.e) elif op == UOp.Sum: sum_var = fresh_var(e.type, "sum") loop_var = fresh_var(e.e.type.elem_type, "x") self.stms.append(simplify_and_optimize(seq([ SDecl(sum_var, ENum(0).with_type(e.type)), SForEach(loop_var, e.e, SAssign(sum_var, EBinOp(sum_var, "+", loop_var).with_type(INT)))]))) return sum_var elif op == UOp.Length: arg = EVar("x").with_type(e.e.type.elem_type) return self.visit(EUnaryOp(UOp.Sum, EMap(e.e, ELambda(arg, ONE)).with_type(INT_BAG)).with_type(INT)) elif op == UOp.All: arg = EVar("x").with_type(e.e.type.elem_type) return self.visit(EUnaryOp(UOp.Empty, EFilter(e.e, ELambda(arg, ENot(arg))).with_type(INT_BAG)).with_type(INT)) elif op == UOp.Any: arg = EVar("x").with_type(e.e.type.elem_type) return self.visit(EUnaryOp(UOp.Exists, EFilter(e.e, ELambda(arg, arg)).with_type(INT_BAG)).with_type(INT)) elif op == UOp.Empty: iterable = e.e v = fresh_var(BOOL, "v") label = fresh_name("label") x = fresh_var(iterable.type.elem_type, "x") decl = SDecl(v, ETRUE) find = SEscapableBlock(label, SForEach(x, iterable, seq([ SAssign(v, EFALSE), SEscapeBlock(label)]))) self.stms.append(simplify_and_optimize(seq([decl, find]))) return v elif op == UOp.Exists: return self.visit(ENot(EUnaryOp(UOp.Empty, e.e).with_type(BOOL))) # elif op == UOp.AreUnique: # s = fresh_var(TSet(e.e.type.elem_type), "unique_elems") # u = fresh_var(BOOL, "is_unique") # x = fresh_var(e.e.type.elem_type) # label = fresh_name("label") # self.visit(seq([ # SDecl(s, EEmptyList().with_type(s.type)), # SDecl(u, ETRUE), # SEscapableBlock(label, # SForEach(x, e.e, # SIf(EEscape("{s}.find({x}) != {s}.end()", ("s", "x"), (s, x)).with_type(BOOL), # seq([SAssign(u, EFALSE), SEscapeBlock(label)]), # SEscape("{indent}{s}.insert({x});\n", ("s", "x"), (s, x)))))])) # return u.id return self.visit_Exp(e)
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm: def pull(e: Exp) -> Exp: if exp_is_bad(e): v = fresh_var(e.type) decls_out.append(SDecl(v.id, e)) return v return e if isinstance(s, SNoOp): return s if isinstance(s, SSeq): s1 = pull_temps(s.s1, decls_out, exp_is_bad) s2 = pull_temps(s.s2, decls_out, exp_is_bad) return SSeq(s1, s2) if isinstance(s, SIf): cond = pull(s.cond) s1 = pull_temps(s.then_branch, decls_out, exp_is_bad) s2 = pull_temps(s.else_branch, decls_out, exp_is_bad) return SIf(cond, s1, s2) if isinstance(s, SForEach): bag = pull(s.iter) d_tmp = [] body = pull_temps(s.body, d_tmp, exp_is_bad) to_fix, ok = partition(d_tmp, lambda d: s.id in free_vars(d.val)) decls_out.extend(ok) for d in to_fix: v = EVar(d.id).with_type(d.val.type) mt = TMap(s.id.type, v.type) m = EMakeMap2(bag, ELambda(s.id, d.val)).with_type(mt) mv = fresh_var(m.type) md = SDecl(mv.id, m) decls_out.append(md) body = subst(body, {v.id: EMapGet(mv, s.id).with_type(v.type)}) return SForEach(s.id, bag, body) if isinstance(s, SAssign): return SAssign(s.lhs, pull(s.rhs)) if isinstance(s, SCall): return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args)) if isinstance(s, SMapDel): return SMapDel(s.map, pull(s.key)) if isinstance(s, SMapPut): return SMapPut(s.map, pull(s.key), pull(s.value)) if isinstance(s, SMapUpdate): key = pull(s.key) d_tmp = [] change = pull_temps(s.change, d_tmp, exp_is_bad) for d in d_tmp: if s.val_var in free_vars(d.val): decls_out.append( SDecl( d.id, subst( d.val, { s.val_var.id: EMapGet(s.map, key).with_type(s.val_var.type) }))) else: decls_out.append(d) return SMapUpdate(s.map, key, s.val_var, change) raise NotImplementedError(s)
def _setup_handle_updates(self): """ This method creates update code for handle objects modified by each op. Must be called once after all user-specified queries have been added. """ for op in self.op_specs: print("Setting up handle updates for {}...".format(op.name)) handles = reachable_handles_at_method(self.spec, op) # print("-"*60) for t, bag in handles.items(): # print(" {} : {}".format(pprint(t), pprint(bag))) h = fresh_var(t) lval = EGetField(h, "val").with_type(t.value_type) new_val = inc.mutate(lval, op.body) # get set of modified handles modified_handles = Query( fresh_name("modified_handles"), Visibility.Internal, [], op.assumptions, EFilter( EUnaryOp(UOp.Distinct, bag).with_type(bag.type), ELambda(h, ENot(EEq(lval, new_val)))).with_type(bag.type), "[{}] modified handles of type {}".format( op.name, pprint(t))) query_vars = [ v for v in free_vars(modified_handles) if v not in self.abstract_state ] modified_handles.args = [(arg.id, arg.type) for arg in query_vars] # modify each one subqueries = [] state_update_stm = inc.mutate_in_place( lval, lval, op.body, abstract_state=self.abstract_state, assumptions=list(op.assumptions) + [EDeepIn(h, bag), EIn(h, modified_handles.ret)], invariants=self.abstract_invariants, subgoals_out=subqueries) for sub_q in subqueries: sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring) state_update_stm = self._add_subquery( sub_q=sub_q, used_by=state_update_stm) if state_update_stm != SNoOp(): state_update_stm = SForEach( h, ECall(modified_handles.name, query_vars).with_type(bag.type), state_update_stm) state_update_stm = self._add_subquery( sub_q=modified_handles, used_by=state_update_stm) self.handle_updates[(t, op.name)] = state_update_stm
def stream(iterable : Exp, loop_var : EVar, body : Stm) -> Stm: """Convert an iterable expression to a streaming operation. Input: iterable - an expression with an iterable type (Bag, Set, or List), not yet optimized loop_var - a variable to use as the loop variable body - a statement to run on that variable, not yet optimized Output: A statement equivalent to for (loop_var in iterable) { body } that eliminates as many intermediate collections and objects as possible. NOTE: The output of function will not be correct if the body modifies any free variable in the iterable expression or writes to any pointers that are read by the iterable expression. Generating code for the expression Map {func} (Filter {predicate} big_collection) might create two new collections as large as `big_collection`: one to hold the result of the filter and one to hold the result of the map. If all the code needs to do is to iterate over the result, then there is no reason to make the two new collections. This function is mutually recursive with `simplify_and_optimize`, so any transformations performed by that method are also applied to the output of this one. """ if isinstance(iterable, EEmptyList): return SNoOp() elif isinstance(iterable, ESingleton): setup, value = simplify_and_optimize_expression(iterable.e) # SScoped because if the iterable is e.g. [x] + [y], then the body # might be appear in the same block twice. If the body declares any # variables, that will cause problems in languages like Java or C++. return seq([setup, SScoped(re_use(value, loop_var, simplify_and_optimize(body)))]) elif isinstance(iterable, ECond): cond_setup, cond = simplify_and_optimize_expression(iterable.cond) return seq([ cond_setup, SIf(cond, stream(iterable.then_branch, loop_var, body), stream(iterable.else_branch, loop_var, body))]) elif isinstance(iterable, EUnaryOp) and iterable.op == UOp.Distinct: tmp = fresh_var(TSet(iterable.type.elem_type), "distinct_elems") return seq([ SDecl(tmp, EEmptyList().with_type(tmp.type)), stream(iterable.e, loop_var, SIf( ENot(EBinOp(loop_var, BOp.In, tmp).with_type(BOOL)), seq([body, SCall(tmp, "add", [loop_var])]), SNoOp()))]) elif isinstance(iterable, EBinOp) and iterable.op == "+": return seq([ stream(iterable.e1, loop_var, body), stream(iterable.e2, loop_var, body)]) elif isinstance(iterable, EBinOp) and iterable.op == "-": if is_hashable(iterable.type.elem_type): h_setup, h = histogram(iterable.e2) val_ref = fresh_var(INT, "count") return seq([ simplify_and_optimize(h_setup), stream( iterable.e1, loop_var, SIf(EGt(EMapGet(h, loop_var).with_type(INT), ZERO), SMapUpdate(h, loop_var, val_ref, SAssign(val_ref, EBinOp(val_ref, "-", ONE).with_type(INT))), body))]) else: rhs = fresh_var(iterable.e2.type, "bag_subtraction_right") return seq([ simplify_and_optimize(SDecl(rhs, iterable.e2)), stream( iterable.e1, loop_var, SIf(EIn(loop_var, rhs), SCall(rhs, "remove", (loop_var,)), body))]) elif isinstance(iterable, EFilter): return stream( EFlatMap(iterable.e, ELambda(iterable.predicate.arg, ECond(iterable.predicate.body, ESingleton(iterable.predicate.arg).with_type(iterable.type), EEmptyList().with_type(iterable.type)).with_type(iterable.type))).with_type(iterable.type), loop_var, body) elif isinstance(iterable, EMap): return stream( EFlatMap(iterable.e, ELambda(iterable.transform_function.arg, ESingleton(iterable.transform_function.body).with_type(iterable.type))).with_type(iterable.type), loop_var, body) elif isinstance(iterable, EFlatMap): inner_loop_var = fresh_var( iterable.transform_function.arg.type, iterable.transform_function.arg.id) return stream( iterable.e, inner_loop_var, stream(iterable.transform_function.apply_to(inner_loop_var), loop_var, body)) elif isinstance(iterable, EListSlice): elem_type = iterable.type.elem_type l = fresh_var(iterable.e.type, "list") s = fresh_var(INT, "start") e = fresh_var(INT, "end") return simplify_and_optimize(seq([ SDecl(l, iterable.e), SDecl(s, max_of(iterable.start, ZERO)), SDecl(e, min_of(iterable.end, ELen(l))), SWhile(ELt(s, e), seq([ SDecl(loop_var, EListGet(l, s).with_type(elem_type)), body, SAssign(s, EBinOp(s, "+", ONE).with_type(INT))]))])) elif isinstance(iterable, ELet): v = fresh_var( iterable.body_function.arg.type, iterable.body_function.arg.id) return seq([ simplify_and_optimize(SDecl(v, iterable.e)), stream(iterable.body_function.apply_to(v), loop_var, body)]) elif isinstance(iterable, EMove): return stream(iterable.e, loop_var, body) else: assert is_collection(iterable.type), repr(iterable) setup, e = simplify_and_optimize_expression(iterable) return seq([setup, SForEach(loop_var, e, simplify_and_optimize(body))])
def possibly_useful_nonrecursive( solver, e: Exp, context: Context, pool=RUNTIME_POOL, assumptions: Exp = ETRUE, ops: [Op] = ()) -> bool: """Heuristic filter to ignore expressions that are almost certainly useless.""" state_vars = OrderedSet(v for v, p in context.vars() if p == STATE_POOL) args = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL) assumptions = EAll([assumptions, context.path_condition()]) at_runtime = pool == RUNTIME_POOL h = extension_handler(type(e)) if h is not None: res = h.possibly_useful(e, context, pool, assumptions, ops, solver) if not res: return res if isinstance(e, EStateVar) and not free_vars(e.e): return No("constant value in state position") if (isinstance(e, EDropFront) or isinstance(e, EDropBack)) and not at_runtime: return No("EDrop* in state position") if not allow_big_sets.value and isinstance(e, EFlatMap) and not at_runtime: return No("EFlatMap in state position") if not allow_int_arithmetic_state.value and not at_runtime and isinstance( e, EBinOp) and e.type == INT: return No("integer arithmetic in state position") if is_collection(e.type) and not is_scalar(e.type.elem_type): return No("collection of nonscalar: e {}\n elem_type: {}\n".format( e, e.type.elem_type)) if isinstance(e.type, TMap) and not is_scalar(e.type.k): return No("bad key type {}".format(pprint(e.type.k))) if isinstance(e.type, TMap) and isinstance(e.type.v, TMap): return No("map to map") # This check is probably a bad idea: whether `the` is legal may depend on # the contex that the expression is embedded within, so we can't skip it # during synthesis just because it looks invalid now. # if isinstance(e, EUnaryOp) and e.op == UOp.The: # len = EUnaryOp(UOp.Length, e.e).with_type(INT) # if not valid(EImplies(assumptions, EBinOp(len, "<=", ENum(1).with_type(INT)).with_type(BOOL))): # return No("illegal application of 'the': could have >1 elems") if not at_runtime and isinstance( e, EBinOp) and e.op == "-" and is_collection(e.type): return No("collection subtraction in state position") # if not at_runtime and isinstance(e, ESingleton): # return No("singleton in state position") if not allow_nonzero_state_constants.value and not at_runtime and isinstance( e, ENum) and e.val != 0: return No("nonzero integer constant in state position") if not allow_binop_state.value and at_runtime and isinstance( e, EStateVar) and isinstance(e.e, EBinOp) and is_scalar( e.e.e1.type) and is_scalar(e.e.e2.type): return No( "constant-time binary operator {!r} in state position".format( e.e.op)) if not allow_conditional_state.value and not at_runtime and isinstance( e, ECond): return No("conditional in state position") if isinstance(e, EMakeMap2) and isinstance(e.e, EEmptyList): return No("trivially empty map") if isinstance(e, EMakeMap2) and isinstance(e.e, ESingleton): return No("really tiny map") if not at_runtime and (isinstance(e, EArgMin) or isinstance(e, EArgMax)): # Cozy has no way to efficiently implement mins/maxes when more than # one element may leave the collection. from cozy.state_maintenance import mutate for op in ops: elems = e.e elems_prime = mutate(elems, op.body) formula = EAll([assumptions] + list(op.assumptions) + [ EGt( ELen( EBinOp(elems, "-", elems_prime).with_type(elems.type)), ONE) ]) if solver.satisfiable(formula): return No( "more than one element might be removed during {}".format( op.name)) if not allow_peels.value and not at_runtime and isinstance(e, EFilter): # catch "peels": removal of zero or one elements if solver.valid( EImplies( assumptions, ELe( ELen( EFilter( e.e, ELambda(e.predicate.arg, ENot( e.predicate.body))).with_type(e.type)), ONE))): return No("filter is a peel") if not allow_big_maps.value and not at_runtime and isinstance( e, EMakeMap2) and is_collection(e.type.v): all_collections = [sv for sv in state_vars if is_collection(sv.type)] total_size = ENum(0).with_type(INT) for c in all_collections: total_size = EBinOp(total_size, "+", EUnaryOp(UOp.Length, c).with_type(INT)).with_type(INT) my_size = EUnaryOp( UOp.Length, EFlatMap( EUnaryOp(UOp.Distinct, e.e).with_type(e.e.type), e.value_function).with_type(e.type.v)).with_type(INT) s = EImplies(assumptions, EBinOp(total_size, ">=", my_size).with_type(BOOL)) if not solver.valid(s): return No("non-polynomial-sized map") return True
def build_lambdas(bag, pool, body_size): v = fresh_var(bag.type.elem_type, omit=set(v for v, p in context.vars())) inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool) for lam_body in self.enumerate(inner_context, body_size, pool): yield ELambda(v, lam_body)
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm: """Remove "bad" expressions from `s`. This procedure returns a statement new_s that replaces every expression in `s` where `exp_is_bad` returns True with a fresh variable. After running, `decls_out` contains definitions for the fresh variables so that the whole statement decls_out; new_s should return the same result as `s`. """ def pull(e: Exp) -> Exp: """Pull an expression into a temporary. Creates a fresh variable for `e`, writes a declaration into `decls_out`, and returns the fresh variable. """ if exp_is_bad(e): v = fresh_var(e.type) decls_out.append(SDecl(v, e)) return v return e if isinstance(s, SNoOp): return s if isinstance(s, SSeq): s1 = pull_temps(s.s1, decls_out, exp_is_bad) s2 = pull_temps(s.s2, decls_out, exp_is_bad) return SSeq(s1, s2) if isinstance(s, SDecl): return SDecl(s.var, pull(s.val)) if isinstance(s, SIf): cond = pull(s.cond) s1 = pull_temps(s.then_branch, decls_out, exp_is_bad) s2 = pull_temps(s.else_branch, decls_out, exp_is_bad) return SIf(cond, s1, s2) if isinstance(s, SForEach): bag = pull(s.iter) d_tmp = [] body = pull_temps(s.body, d_tmp, exp_is_bad) to_fix, ok = partition(d_tmp, lambda d: s.loop_var in free_vars(d.val)) decls_out.extend(ok) for d in to_fix: v = d.var mt = TMap(s.loop_var.type, v.type) m = EMakeMap2(bag, ELambda(s.loop_var, d.val)).with_type(mt) mv = fresh_var(m.type) md = SDecl(mv, m) decls_out.append(md) body = subst(body, {v.id: EMapGet(mv, s.loop_var).with_type(v.type)}) return SForEach(s.loop_var, bag, body) if isinstance(s, SAssign): return SAssign(s.lhs, pull(s.rhs)) if isinstance(s, SCall): return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args)) if isinstance(s, SMapDel): return SMapDel(s.map, pull(s.key)) if isinstance(s, SMapPut): return SMapPut(s.map, pull(s.key), pull(s.value)) if isinstance(s, SMapUpdate): key = pull(s.key) d_tmp = [] change = pull_temps(s.change, d_tmp, exp_is_bad) for d in d_tmp: if s.val_var in free_vars(d.val): decls_out.append( SDecl( d.var, subst( d.val, { s.val_var.id: EMapGet(s.map, key).with_type(s.val_var.type) }))) else: decls_out.append(d) return SMapUpdate(s.map, key, s.val_var, change) raise NotImplementedError(s)