def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm: def pull(e: Exp) -> Exp: if exp_is_bad(e): v = fresh_var(e.type) decls_out.append(SDecl(v.id, e)) return v return e if isinstance(s, SNoOp): return s if isinstance(s, SSeq): s1 = pull_temps(s.s1, decls_out, exp_is_bad) s2 = pull_temps(s.s2, decls_out, exp_is_bad) return SSeq(s1, s2) if isinstance(s, SIf): cond = pull(s.cond) s1 = pull_temps(s.then_branch, decls_out, exp_is_bad) s2 = pull_temps(s.else_branch, decls_out, exp_is_bad) return SIf(cond, s1, s2) if isinstance(s, SForEach): bag = pull(s.iter) d_tmp = [] body = pull_temps(s.body, d_tmp, exp_is_bad) to_fix, ok = partition(d_tmp, lambda d: s.id in free_vars(d.val)) decls_out.extend(ok) for d in to_fix: v = EVar(d.id).with_type(d.val.type) mt = TMap(s.id.type, v.type) m = EMakeMap2(bag, ELambda(s.id, d.val)).with_type(mt) mv = fresh_var(m.type) md = SDecl(mv.id, m) decls_out.append(md) body = subst(body, {v.id: EMapGet(mv, s.id).with_type(v.type)}) return SForEach(s.id, bag, body) if isinstance(s, SAssign): return SAssign(s.lhs, pull(s.rhs)) if isinstance(s, SCall): return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args)) if isinstance(s, SMapDel): return SMapDel(s.map, pull(s.key)) if isinstance(s, SMapPut): return SMapPut(s.map, pull(s.key), pull(s.value)) if isinstance(s, SMapUpdate): key = pull(s.key) d_tmp = [] change = pull_temps(s.change, d_tmp, exp_is_bad) for d in d_tmp: if s.val_var in free_vars(d.val): decls_out.append( SDecl( d.id, subst( d.val, { s.val_var.id: EMapGet(s.map, key).with_type(s.val_var.type) }))) else: decls_out.append(d) return SMapUpdate(s.map, key, s.val_var, change) raise NotImplementedError(s)
def simplify_and_optimize_lvalue(e : Exp) -> (Stm, Exp): """Helper for simplify_and_optimize. Input: e - an L-value expression to optimize Output: A pair (s, e') such that executing s and then evaluating e' is the same as evaluating the input expression. Unlike `simplify_and_optimize_expression`, this function preserves the meaning of `e` as an L-value. For instance, this function will not replace `e` with a fresh variable. """ assert is_lvalue(e), "not an L-value: {}".format(pprint(e)) if isinstance(e, EVar): return (SNoOp(), e) if isinstance(e, EGetField): if isinstance(e.e.type, THandle): setup, handle = simplify_and_optimize_expression(e.e) return (setup, EGetField(handle, e.field_name).with_type(e.type)) else: setup, lvalue = simplify_and_optimize_lvalue(e.e) return (setup, EGetField(lvalue, e.field_name).with_type(e.type)) if isinstance(e, EMapGet): mapsetup, map = simplify_and_optimize_lvalue(e.map) keysetup, key = simplify_and_optimize_expression(e.key) return (seq([mapsetup, keysetup]), EMapGet(map, key).with_type(e.type)) if isinstance(e, EListGet): listsetup, list_lvalue = simplify_and_optimize_lvalue(e.e) indexsetup, index_exp = simplify_and_optimize_expression(e.index) return (seq([listsetup, indexsetup]), EListGet(list_lvalue, index_exp).with_type(e.type)) raise NotImplementedError(repr(e))
def visit_SMapUpdate(self, update): if isinstance(update.change, SNoOp): return "" # TODO: liveness analysis to avoid this map lookup in some cases self.declare(update.val_var, EMapGet(update.map, update.key).with_type(update.map.type.v)) map = self.visit(update.map) # TODO: deduplicate key = self.visit(update.key) # TODO: deduplicate self.visit(update.change) self.begin_statement() self.write("{map}.put({key}, {val});\n".format(map=map, key=key, val=update.val_var.id)) self.end_statement()
def stream(iterable : Exp, loop_var : EVar, body : Stm) -> Stm: """Convert an iterable expression to a streaming operation. Input: iterable - an expression with an iterable type (Bag, Set, or List), not yet optimized loop_var - a variable to use as the loop variable body - a statement to run on that variable, not yet optimized Output: A statement equivalent to for (loop_var in iterable) { body } that eliminates as many intermediate collections and objects as possible. NOTE: The output of function will not be correct if the body modifies any free variable in the iterable expression or writes to any pointers that are read by the iterable expression. Generating code for the expression Map {func} (Filter {predicate} big_collection) might create two new collections as large as `big_collection`: one to hold the result of the filter and one to hold the result of the map. If all the code needs to do is to iterate over the result, then there is no reason to make the two new collections. This function is mutually recursive with `simplify_and_optimize`, so any transformations performed by that method are also applied to the output of this one. """ if isinstance(iterable, EEmptyList): return SNoOp() elif isinstance(iterable, ESingleton): setup, value = simplify_and_optimize_expression(iterable.e) # SScoped because if the iterable is e.g. [x] + [y], then the body # might be appear in the same block twice. If the body declares any # variables, that will cause problems in languages like Java or C++. return seq([setup, SScoped(re_use(value, loop_var, simplify_and_optimize(body)))]) elif isinstance(iterable, ECond): cond_setup, cond = simplify_and_optimize_expression(iterable.cond) return seq([ cond_setup, SIf(cond, stream(iterable.then_branch, loop_var, body), stream(iterable.else_branch, loop_var, body))]) elif isinstance(iterable, EUnaryOp) and iterable.op == UOp.Distinct: tmp = fresh_var(TSet(iterable.type.elem_type), "distinct_elems") return seq([ SDecl(tmp, EEmptyList().with_type(tmp.type)), stream(iterable.e, loop_var, SIf( ENot(EBinOp(loop_var, BOp.In, tmp).with_type(BOOL)), seq([body, SCall(tmp, "add", [loop_var])]), SNoOp()))]) elif isinstance(iterable, EBinOp) and iterable.op == "+": return seq([ stream(iterable.e1, loop_var, body), stream(iterable.e2, loop_var, body)]) elif isinstance(iterable, EBinOp) and iterable.op == "-": if is_hashable(iterable.type.elem_type): h_setup, h = histogram(iterable.e2) val_ref = fresh_var(INT, "count") return seq([ simplify_and_optimize(h_setup), stream( iterable.e1, loop_var, SIf(EGt(EMapGet(h, loop_var).with_type(INT), ZERO), SMapUpdate(h, loop_var, val_ref, SAssign(val_ref, EBinOp(val_ref, "-", ONE).with_type(INT))), body))]) else: rhs = fresh_var(iterable.e2.type, "bag_subtraction_right") return seq([ simplify_and_optimize(SDecl(rhs, iterable.e2)), stream( iterable.e1, loop_var, SIf(EIn(loop_var, rhs), SCall(rhs, "remove", (loop_var,)), body))]) elif isinstance(iterable, EFilter): return stream( EFlatMap(iterable.e, ELambda(iterable.predicate.arg, ECond(iterable.predicate.body, ESingleton(iterable.predicate.arg).with_type(iterable.type), EEmptyList().with_type(iterable.type)).with_type(iterable.type))).with_type(iterable.type), loop_var, body) elif isinstance(iterable, EMap): return stream( EFlatMap(iterable.e, ELambda(iterable.transform_function.arg, ESingleton(iterable.transform_function.body).with_type(iterable.type))).with_type(iterable.type), loop_var, body) elif isinstance(iterable, EFlatMap): inner_loop_var = fresh_var( iterable.transform_function.arg.type, iterable.transform_function.arg.id) return stream( iterable.e, inner_loop_var, stream(iterable.transform_function.apply_to(inner_loop_var), loop_var, body)) elif isinstance(iterable, EListSlice): elem_type = iterable.type.elem_type l = fresh_var(iterable.e.type, "list") s = fresh_var(INT, "start") e = fresh_var(INT, "end") return simplify_and_optimize(seq([ SDecl(l, iterable.e), SDecl(s, max_of(iterable.start, ZERO)), SDecl(e, min_of(iterable.end, ELen(l))), SWhile(ELt(s, e), seq([ SDecl(loop_var, EListGet(l, s).with_type(elem_type)), body, SAssign(s, EBinOp(s, "+", ONE).with_type(INT))]))])) elif isinstance(iterable, ELet): v = fresh_var( iterable.body_function.arg.type, iterable.body_function.arg.id) return seq([ simplify_and_optimize(SDecl(v, iterable.e)), stream(iterable.body_function.apply_to(v), loop_var, body)]) elif isinstance(iterable, EMove): return stream(iterable.e, loop_var, body) else: assert is_collection(iterable.type), repr(iterable) setup, e = simplify_and_optimize_expression(iterable) return seq([setup, SForEach(loop_var, e, simplify_and_optimize(body))])
def _enumerate_core(self, context: Context, size: int, pool: Pool) -> [Exp]: """Build new expressions of the given size. Arguments: context : a Context object describing the vars in scope size : size of expressions to enumerate; each expression in the output will have this size pool : pool to enumerate This function is not cached. Clients should call `enumerate` instead. This function tries to be a clean description of the Cozy grammar. It does not concern itself with deduplication (which is handled efficiently by equivalence class deduplication). """ if size < 0: return if size == 0: for e in LITERALS: yield e all_interesting_types = OrderedSet(self.hint_types) for v, _ in context.vars(): all_interesting_types |= all_types(v.type) for t in all_interesting_types: l = construct_value(t) if l not in LITERALS: yield l for (v, p) in context.vars(): if p == pool: yield v for (e, ctx, p) in self.hints: if p == pool: fvs = free_vars(e) if ctx.alpha_equivalent(context.generalize(fvs)): yield context.adapt(e, ctx, e_fvs=fvs) return if not do_enumerate.value: return def build_lambdas(bag, pool, body_size): v = fresh_var(bag.type.elem_type, omit=set(v for v, p in context.vars())) inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool) for lam_body in self.enumerate(inner_context, body_size, pool): yield ELambda(v, lam_body) # Load all smaller expressions in this context and pool. # cache[S] contains expressions of size S in this context and pool. cache = [list(self.enumerate(context, sz, pool)) for sz in range(size)] # Enable use of a state-pool expression at runtime if pool == RUNTIME_POOL: for e in self.enumerate(context.root(), size - 1, STATE_POOL): yield EStateVar(e).with_type(e.type) # Arity-1 expressions for e in cache[size - 1]: if is_collection(e.type): elem_type = e.type.elem_type # This method of generating EEmptyList() ensures that we visit # empty collections of all possible types. yield EEmptyList().with_type(e.type) if is_numeric(elem_type): yield EUnaryOp(UOp.Sum, e).with_type(elem_type) yield EUnaryOp(UOp.Length, e).with_type(INT) yield EUnaryOp(UOp.Empty, e).with_type(BOOL) yield EUnaryOp(UOp.Exists, e).with_type(BOOL) yield EUnaryOp(UOp.The, e).with_type(elem_type) yield EUnaryOp(UOp.Distinct, e).with_type(e.type) yield EUnaryOp(UOp.AreUnique, e).with_type(BOOL) if elem_type == BOOL: yield EUnaryOp(UOp.Any, e).with_type(BOOL) yield EUnaryOp(UOp.All, e).with_type(BOOL) yield ESingleton(e).with_type(TBag(e.type)) if isinstance(e.type, TRecord): for (f, t) in e.type.fields: yield EGetField(e, f).with_type(t) if isinstance(e.type, THandle): yield EGetField(e, "val").with_type(e.type.value_type) if isinstance(e.type, TTuple): for n in range(len(e.type.ts)): yield ETupleGet(e, n).with_type(e.type.ts[n]) if e.type == BOOL: yield EUnaryOp(UOp.Not, e).with_type(BOOL) if is_numeric(e.type): yield EUnaryOp("-", e).with_type(e.type) if isinstance(e.type, TMap): yield EMapKeys(e).with_type(TBag(e.type.k)) # Arity-2 expressions for (sz1, sz2) in pick_to_sum(2, size - 1): # sz1 + sz2 = size - 1 for e1 in cache[sz1]: t = e1.type if is_numeric(t): for a2 in of_type(cache[sz2], t): yield EBinOp(e1, "+", a2).with_type(t) yield EBinOp(e1, "-", a2).with_type(t) if is_ordered(t): for a2 in of_type(cache[sz2], t): yield EBinOp(e1, ">", a2).with_type(BOOL) yield EBinOp(e1, "<", a2).with_type(BOOL) yield EBinOp(e1, ">=", a2).with_type(BOOL) yield EBinOp(e1, "<=", a2).with_type(BOOL) if t == BOOL: for a2 in of_type(cache[sz2], BOOL): yield EBinOp(e1, BOp.And, a2).with_type(BOOL) yield EBinOp(e1, BOp.Or, a2).with_type(BOOL) # Cozy supports the implication operator "=>", but this # function does not enumerate it because # - (a => b) is equivalent to ((not a) or b) # - there isn't an implication operator in any of our # current target languages, so we would need to # desugar it to ((not a) or b) anyway. if not isinstance(t, TMap): for a2 in of_type(cache[sz2], t): yield EEq(e1, a2) yield EBinOp(e1, "!=", a2).with_type(BOOL) if isinstance(t, TMap): for k in of_type(cache[sz2], t.k): yield EMapGet(e1, k).with_type(t.v) yield EHasKey(e1, k).with_type(BOOL) if isinstance(t, TList): for i in of_type(cache[sz2], INT): yield EListGet(e1, i).with_type(e1.type.elem_type) if is_collection(t): elem_type = t.elem_type for e2 in of_type(cache[sz2], t): yield EBinOp(e1, "+", e2).with_type(t) yield EBinOp(e1, "-", e2).with_type(t) for e2 in of_type(cache[sz2], elem_type): yield EBinOp(e2, BOp.In, e1).with_type(BOOL) for f in build_lambdas(e1, pool, sz2): body_type = f.body.type yield EMap(e1, f).with_type(TBag(body_type)) if body_type == BOOL: yield EFilter(e1, f).with_type(t) if is_numeric(body_type): yield EArgMin(e1, f).with_type(elem_type) yield EArgMax(e1, f).with_type(elem_type) if is_collection(body_type): yield EFlatMap(e1, f).with_type( TBag(body_type.elem_type)) if pool == STATE_POOL and is_hashable(elem_type): yield EMakeMap2(e1, f).with_type( TMap(elem_type, body_type)) e1_singleton = ESingleton(e1).with_type(TBag(e1.type)) for f in build_lambdas(e1_singleton, pool, sz2): yield ELet(e1, f).with_type(f.body.type) # Arity-3 expressions for (sz1, sz2, sz3) in pick_to_sum(3, size - 1): # sz1 + sz2 + sz3 = size - 1 for e1 in cache[sz1]: if e1.type == BOOL: cond = e1 for then_branch in cache[sz2]: for else_branch in of_type(cache[sz3], then_branch.type): yield ECond(cond, then_branch, else_branch).with_type( then_branch.type) if isinstance(e1.type, TList): for start in of_type(cache[sz2], INT): for end in of_type(cache[sz3], INT): yield EListSlice(e1, start, end).with_type(e1.type) # It is not necessary to create slice expressions of # the form a[:i] or a[i:]. Those are desugared # after parsing to a[0:i] and a[i:len(a)] # respectively, and Cozy is perfectly capable of # discovering these expanded forms as well. for h in all_extension_handlers(): yield from h.enumerate(context, size, pool, self.enumerate, build_lambdas)
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm: """Remove "bad" expressions from `s`. This procedure returns a statement new_s that replaces every expression in `s` where `exp_is_bad` returns True with a fresh variable. After running, `decls_out` contains definitions for the fresh variables so that the whole statement decls_out; new_s should return the same result as `s`. """ def pull(e: Exp) -> Exp: """Pull an expression into a temporary. Creates a fresh variable for `e`, writes a declaration into `decls_out`, and returns the fresh variable. """ if exp_is_bad(e): v = fresh_var(e.type) decls_out.append(SDecl(v, e)) return v return e if isinstance(s, SNoOp): return s if isinstance(s, SSeq): s1 = pull_temps(s.s1, decls_out, exp_is_bad) s2 = pull_temps(s.s2, decls_out, exp_is_bad) return SSeq(s1, s2) if isinstance(s, SDecl): return SDecl(s.var, pull(s.val)) if isinstance(s, SIf): cond = pull(s.cond) s1 = pull_temps(s.then_branch, decls_out, exp_is_bad) s2 = pull_temps(s.else_branch, decls_out, exp_is_bad) return SIf(cond, s1, s2) if isinstance(s, SForEach): bag = pull(s.iter) d_tmp = [] body = pull_temps(s.body, d_tmp, exp_is_bad) to_fix, ok = partition(d_tmp, lambda d: s.loop_var in free_vars(d.val)) decls_out.extend(ok) for d in to_fix: v = d.var mt = TMap(s.loop_var.type, v.type) m = EMakeMap2(bag, ELambda(s.loop_var, d.val)).with_type(mt) mv = fresh_var(m.type) md = SDecl(mv, m) decls_out.append(md) body = subst(body, {v.id: EMapGet(mv, s.loop_var).with_type(v.type)}) return SForEach(s.loop_var, bag, body) if isinstance(s, SAssign): return SAssign(s.lhs, pull(s.rhs)) if isinstance(s, SCall): return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args)) if isinstance(s, SMapDel): return SMapDel(s.map, pull(s.key)) if isinstance(s, SMapPut): return SMapPut(s.map, pull(s.key), pull(s.value)) if isinstance(s, SMapUpdate): key = pull(s.key) d_tmp = [] change = pull_temps(s.change, d_tmp, exp_is_bad) for d in d_tmp: if s.val_var in free_vars(d.val): decls_out.append( SDecl( d.var, subst( d.val, { s.val_var.id: EMapGet(s.map, key).with_type(s.val_var.type) }))) else: decls_out.append(d) return SMapUpdate(s.map, key, s.val_var, change) raise NotImplementedError(s)