def native_map_get(self, e, default_value): if e.key.type == INT: return self.visit(ECond( self.array_in_bounds(e.map.type.v, e.map, e.key), self.array_get(e.map.type.v, e.map, e.key), evaluation.construct_value(e.map.type.v)).with_type(e.map.type.v)) if self.use_trove(e.map.type): if self.trovename(e.map.type.v) == "Object" and not isinstance(evaluation.construct_value(e.map.type.v), ENull): # Le sigh... emap = self.visit(e.map) ekey = self.visit(e.key) v = self.fv(e.map.type.v, hint="v") with self.boxed_mode(): decl = self.visit(SDecl(v.id, EEscape("{emap}.get({ekey})".format(emap=emap, ekey=ekey), [], []).with_type(e.type))) s, e = self.visit(ECond(EEq(v, ENull().with_type(v.type)), evaluation.construct_value(e.map.type.v), v).with_type(e.type)) return (smap + skey + decl + s, e) else: # For Trove, defaults are set at construction time emap = self.visit(e.map) ekey = self.visit(e.key) return "{emap}.get({ekey})".format(emap=emap, ekey=ekey) else: emap = self.visit(e.map) ekey = self.visit(e.key) edefault = self.visit(evaluation.construct_value(e.type)) return "{emap}.getOrDefault({ekey}, {edefault})".format(emap=emap, ekey=ekey, edefault=edefault)
def codegen(self, e : Exp, concretization_functions : { str : Exp }, out : EVar) -> Stm: if isinstance(e, EMakeMinHeap) or isinstance(e, EMakeMaxHeap): out_raw = EVar(out.id).with_type(self.rep_type(e.type)) l = fresh_var(INT, "alloc_len") x = fresh_var(e.type.elem_type, "x") return seq([ SDecl(l.id, ELen(e.e)), SArrayAlloc(out_raw, l), SCall(out, "add_all", (ZERO, e.e))]) elif isinstance(e, EHeapElems): elem_type = e.type.t if isinstance(e.e, EMakeMinHeap) or isinstance(e.e, EMakeMaxHeap): x = fresh_var(elem_type, "x") return SForEach(x, e.e.e, SCall(out, "add", (x,))) i = fresh_var(INT, "i") return seq([ SDecl(i.id, ZERO), SWhile(ELt(i, EArrayLen(e.e).with_type(INT)), seq([ SCall(out, "add", (EArrayGet(e.e, i).with_type(elem_type),)), SAssign(i, EBinOp(i, "+", ONE).with_type(INT))]))]) elif isinstance(e, EHeapPeek): raise NotImplementedError() elif isinstance(e, EHeapPeek2): from cozy.evaluation import construct_value best = EArgMin if isinstance(e.e.type, TMinHeap) else EArgMax f = heap_func(e.e, concretization_functions) return SSwitch(e.n, ( (ZERO, SAssign(out, construct_value(e.type))), (ONE, SAssign(out, construct_value(e.type))), (TWO, SAssign(out, EArrayGet(e.e, ONE).with_type(e.type)))), SAssign(out, best(EBinOp(ESingleton(EArrayGet(e.e, ONE).with_type(e.type)).with_type(TBag(out.type)), "+", ESingleton(EArrayGet(e.e, TWO).with_type(e.type)).with_type(TBag(out.type))).with_type(TBag(out.type)), f).with_type(out.type))) else: raise NotImplementedError(e)
def visit_EMapGet(self, e, indent=""): if isinstance(e.map, EStateVar): return self.visit(EMapGet(e.map.e, e.key).with_type(e.type), indent=indent) elif isinstance(e.map, EMakeMap2): return self.visit(ELet( e.key, mk_lambda( e.key.type, lambda k: ECond( EBinOp(k, BOp.In, e.map.e).with_type(BOOL), e.map.value.apply_to(k), evaluation.construct_value(e.map.type.v)).with_type( e.type))).with_type(e.type), indent=indent) elif isinstance(e.map, ECond): return self.visit(ELet( e.key, mk_lambda( e.key.type, lambda k: ECond( e.map.cond, EMapGet(e.map.then_branch, k).with_type(e.type), EMapGet(e.map.else_branch, k).with_type(e.type)). with_type(e.type))).with_type(e.type), indent=indent) elif isinstance(e.map, EVar): if isinstance(e.map.type, library.TNativeMap) or type(e.map.type) is TMap: return self.native_map_get( e, lambda out: self.construct_concrete( e.map.type.v, evaluation.construct_value(e.map.type.v), out), indent) else: return self.visit(e.map.type.get_key(e.map, e.key), indent) else: raise NotImplementedError(type(e.map))
def codegen(self, e : Exp, concretization_functions : { str : Exp }, out : EVar) -> Stm: """Return statements that write the result of `e` to `out`. The returned statements must declare the variable `out`; it will not be declared by the caller. This function also requires the `concretization_functions` that describe the invariants for variables in `e`. """ if isinstance(e, EMakeMinHeap) or isinstance(e, EMakeMaxHeap): assert out.type == self.rep_type(e.type) elem_type = e.type.elem_type extended_concretization_functions = dict(concretization_functions) extended_concretization_functions[out.id] = e dummy_out = EVar(out.id).with_type(e.type) a = fresh_var(TArray(elem_type), "heap_elems") return seq([ SArrayAlloc(a, ZERO), SDecl(out, ETuple((ZERO, a)).with_type(out.type)), self.implement_stmt(SCall(dummy_out, "add_all", (ZERO, e.e)), extended_concretization_functions)]) elif isinstance(e, EHeapElems): elem_type = e.type.elem_type if isinstance(e.e, EMakeMinHeap) or isinstance(e.e, EMakeMaxHeap): x = fresh_var(elem_type, "x") return seq([ SDecl(out, EEmptyList().with_type(out.type)), SForEach(x, e.e.e, SCall(out, "add", (x,)))]) i = fresh_var(INT, "i") # the array index return seq([ SDecl(out, EEmptyList().with_type(out.type)), SDecl(i, ZERO), SWhile(ELt(i, ETupleGet(e.e, 0).with_type(INT)), seq([ SCall(out, "add", (EArrayGet(ETupleGet(e.e, 1), i).with_type(elem_type),)), SAssign(i, EBinOp(i, "+", ONE).with_type(INT))]))]) elif isinstance(e, EHeapPeek): raise NotImplementedError() elif isinstance(e, EHeapPeek2): from cozy.evaluation import construct_value best = EArgMin if isinstance(e.e.type, TMinHeap) else EArgMax f = heap_func(e.e, concretization_functions) return seq([ SDecl(out, construct_value(out.type)), SSwitch(ETupleGet(e.e, 0), ( (ZERO, SAssign(out, construct_value(e.type))), (ONE, SAssign(out, construct_value(e.type))), (TWO, SAssign(out, EArrayGet(ETupleGet(e.e, 1), ONE).with_type(e.type)))), SAssign(out, best(EBinOp(ESingleton(EArrayGet(ETupleGet(e.e, 1), ONE).with_type(e.type)).with_type(TBag(out.type)), "+", ESingleton(EArrayGet(ETupleGet(e.e, 1), TWO).with_type(e.type)).with_type(TBag(out.type))).with_type(TBag(out.type)), f).with_type(out.type)))]) else: raise NotImplementedError(e)
def test_mapget_of_makemap(self): t = THandle("T", INT) xs = EVar("xs").with_type(TBag(t)) x = EVar("x").with_type(t) y = EVar("y").with_type(t) mt = TTuple((INT, INT)) e1 = EMapGet( EMakeMap2( xs, ELambda( x, ETuple( (EGetField(x, "val").with_type(INT), EGetField( y, "val").with_type(INT))).with_type(mt))).with_type( TMap(t, mt)), y).with_type(mt) e2 = ECond( EIn(e1.key, e1.map.e), e1.map.value.apply_to( EUnaryOp( UOp.The, EFilter( e1.map.e, mk_lambda(e1.map.value.arg.type, lambda foo: EEq(foo, e1.key))).with_type( e1.map.e.type)).with_type( e1.map.e.type.t)), construct_value(e1.type)).with_type(e1.type) self.assert_same(e1, e2)
def visit_EListGet(self, e): l = self.visit(e.e) i = self.visit(e.index) return self.visit(EEscape( "(" + i + " >= 0 && " + i + " < " + l + ".size()) ? " + l + "[" + i + "] : {default}", ("default",), (evaluation.construct_value(e.type),)).with_type(e.type))
def optimize_the(xs, args): t = xs.type.t if isinstance(xs, ECond): for e1 in optimize_the(xs.then_branch, args): for e2 in optimize_the(xs.else_branch, args): yield optimized_cond(xs.cond, e1, e2) if isinstance(xs, EStateVar): yield EStateVar(EUnaryOp(UOp.The, xs.e).with_type(t)).with_type(t) if isinstance(xs.type, TList): x = excluded_element(xs, args) if x is not None: bag, x = x for elem in optimize_the(bag, args): yield optimized_cond(EEq(elem, x), EListGet(bag, ONE).with_type(t), elem) if isinstance(xs, EMap): exists = optimized_exists(xs.e) for x in optimize_the(xs.e, args): yield optimized_cond(exists, xs.f.apply_to(x), construct_value(t)) if isinstance(xs, EBinOp) and xs.op == "+": e1_exists = optimized_exists(xs.e1) for x in optimize_the(xs.e1, args): for y in optimize_the(xs.e2, args): yield optimized_cond(e1_exists, x, y) yield EUnaryOp(UOp.The, xs).with_type(t)
def watch(self, new_target): print("watching new target...") self.backlog_counter = 0 self.target = new_target self.roots = OrderedSet() types = OrderedSet() for e in itertools.chain(all_exps(new_target), *[all_exps(h) for h in self.hints]): if isinstance(e, ELambda): continue for pool in ALL_POOLS: exp = e if pool == STATE_POOL: exp = strip_EStateVar(e) fvs = free_vars(exp) if all(v in self.legal_free_vars for v in fvs) and self.is_legal_in_pool(exp, pool): _on_exp(exp, "new root", pool_name(pool)) exp._root = True self.roots.add((exp, pool)) if pool == STATE_POOL and all(v in self.state_vars for v in fvs): self.roots.add((EStateVar(exp).with_type(exp.type), RUNTIME_POOL)) types.add(exp.type) else: _on_exp(exp, "rejected root", pool_name(pool)) for b in self.binders: types.add(b.type) for t in types: self.roots.add((construct_value(t), RUNTIME_POOL)) self.roots = list(self.roots) self.roots.sort(key = lambda tup: tup[0].size()) self._watches = group_by( enumerate_fragments2(new_target), k=lambda ctx: (ctx.pool, ctx.e.type), v=lambda ctxs: sorted(ctxs, key=lambda ctx: -ctx.e.size())) print("done!")
def visit_EListGet(self, e): l = self.visit(e.e) i = self.visit(e.index) return self.visit( EEscape( "(" + i + " >= 0 && " + i + " < " + l + ".size()) ? " + l + "[" + i + "] : {default}", ("default", ), (evaluation.construct_value(e.type), )).with_type(e.type))
def visit_ETreeMultisetPeek(self, e): return self.visit( ECond( ELt(e.index, EEscape("{xs}.size()", ("xs", ), (e.e, ))).with_type(BOOL), EEscape("*std::next({xs}.begin(), {i})", ("xs", "i"), (e.e, e.index)).with_type(e.type), evaluation.construct_value(e.type)).with_type(e.type))
def value_at(m, k): """Make an AST node for m[k].""" if isinstance(m, target_syntax.EMakeMap2): return syntax.ECond(syntax.EIn(k, m.e), m.value.apply_to(k), construct_value(m.type.v)).with_type(m.type.v) if isinstance(m, syntax.ECond): return syntax.ECond(m.cond, value_at(m.then_branch, k), value_at(m.else_branch, k)).with_type(m.type.v) return target_syntax.EMapGet(m, k).with_type(m.type.v)
def native_map_get(self, e, default_value, indent=""): if e.key.type == INT: return self.visit(ECond( self.array_in_bounds(e.map.type.v, e.map, e.key), self.array_get(e.map.type.v, e.map, e.key), evaluation.construct_value(e.map.type.v)).with_type( e.map.type.v), indent=indent) if self.use_trove(e.map.type): if self.trovename(e.map.type.v) == "Object" and not isinstance( evaluation.construct_value(e.map.type.v), ENull): # Le sigh... (smap, emap) = self.visit(e.map, indent) (skey, ekey) = self.visit(e.key, indent) v = self.fv(e.map.type.v, hint="v") with self.boxed_mode(): decl = self.visit(SDecl( v.id, EEscape( "{emap}.get({ekey})".format(emap=emap, ekey=ekey), [], []).with_type(e.type)), indent=indent) s, e = self.visit(ECond( EEq(v, ENull().with_type(v.type)), evaluation.construct_value(e.map.type.v), v).with_type(e.type), indent=indent) return (smap + skey + decl + s, e) else: # For Trove, defaults are set at construction time (smap, emap) = self.visit(e.map, indent) (skey, ekey) = self.visit(e.key, indent) return (smap + skey, "{emap}.get({ekey})".format(emap=emap, ekey=ekey)) else: (smap, emap) = self.visit(e.map, indent) (skey, ekey) = self.visit(e.key, indent) edefault = self.fv(e.type, "lookup_result") sdefault = indent + self.visit(edefault.type, edefault.id) + ";\n" sdefault += self.visit(default_value(edefault), indent) return (smap + skey + sdefault, "{emap}.getOrDefault({ekey}, {edefault})".format( emap=emap, ekey=ekey, edefault=edefault.id))
def find_one(self, iterable, indent=""): v = self.fv(iterable.type.t, "v") label = fresh_name("label") x = self.fv(iterable.type.t, "x") decl = SDecl(v.id, evaluation.construct_value(v.type)) find = SEscapableBlock( label, SForEach(x, iterable, seq([SAssign(v, x), SEscapeBlock(label)]))) return (self.visit(seq([decl, find]), indent), v.id)
def visit_EMapGet(self, e): if self.use_trove(e.map.type): if self.trovename(e.map.type.v) == "Object" and not isinstance(evaluation.construct_value(e.map.type.v), ENull): # Le sigh... emap = self.visit(e.map) ekey = self.visit(e.key) v = self.fv(self.box_if_boolean(e.map.type.v), hint="v") self.visit(SDecl(v, EEscape("{emap}.get({ekey})".format(emap=emap, ekey=ekey), [], []).with_type(e.type))) return self.visit(ECond(EEq(v, ENull().with_type(v.type)), evaluation.construct_value(e.map.type.v), v).with_type(e.type)) else: # For Trove, defaults are set at construction time emap = self.visit(e.map) ekey = self.visit(e.key) return "{emap}.get({ekey})".format(emap=emap, ekey=ekey) else: emap = self.visit(e.map) ekey = self.visit(e.key) edefault = self.visit(evaluation.construct_value(e.type)) return "{emap}.getOrDefault({ekey}, {edefault})".format(emap=emap, ekey=ekey, edefault=edefault)
def find_one(self, iterable): v = fresh_var(iterable.type.elem_type, "v") label = fresh_name("label") x = fresh_var(iterable.type.elem_type, "x") decl = SDecl(v, evaluation.construct_value(v.type)) find = SEscapableBlock(label, SForEach(x, iterable, seq([ SAssign(v, x), SEscapeBlock(label)]))) self.stms.append(simplify_and_optimize(seq([decl, find]))) return v
def visit_EMapGet(self, e): emap = self.visit(e.map) if self.use_qhash: ekey = self.visit(e.key) return "{}.value({})".format(emap, ekey) t = self.visit(e.map.type, "").strip() + "::const_iterator" iterator = self.fv(TNative(t), "map_iterator") self.declare(iterator, EEscape(emap + ".find({key})", ("key",), (e.key,))) return self.visit(ECond( EEscape("{it} == " + emap + ".end()", ("it",), (iterator,)).with_type(BOOL), evaluation.construct_value(e.type), EEscape("{it}->second", ("it",), (iterator,)).with_type(e.type)).with_type(e.type))
def inline_mapget(e: EMapGet, context: Context) -> Exp: try: keys = mapkeys(e.map) cond = optimized_in(e.key, keys) f = map_value_func(e.map) return optimized_cond( cond, repair_well_formedness(f.apply_to(e.key), context), construct_value(e.type)) except: pass print("warning: unable to inline {}".format(e)) return e
def value_at(m, k): """Make an AST node for m[k].""" if isinstance(m, target_syntax.EMakeMap2): return syntax.ECond( syntax.EIn(k, m.e), m.value_function.apply_to(k), construct_value(m.type.v)).with_type(m.type.v) if isinstance(m, syntax.ECond): return syntax.ECond( m.cond, value_at(m.then_branch, k), value_at(m.else_branch, k)).with_type(m.type.v) return target_syntax.EMapGet(m, k).with_type(m.type.v)
def inline_mapget(e : EMapGet, context : Context) -> Exp: try: keys = mapkeys(e.map) cond = optimized_in(e.key, keys) f = map_value_func(e.map) return optimized_cond( cond, repair_well_formedness(f.apply_to(e.key), context), construct_value(e.type)) except: pass print("warning: unable to inline {}".format(e)) return e
def visit_EGetField(self, e): ee = self.visit(e.e) op = "." if isinstance(e.e.type, THandle): # Ugh, we really need Cozy to know about partial functions... # Cozy doesn't know that handle types (aka pointers) can be null. # It assumes that reads of null pointers produce default- # constructed values, so we need to generate appropriate code. ee = EEscape(ee, (), ()).with_type(e.e.type) null = ENull().with_type(e.e.type) return self.visit(ECond(EEq(ee, null), evaluation.construct_value(e.type), EEscape("{ee}->val", ("ee",), (ee,)).with_type(e.type)).with_type(e.type)) return "({ee}.{f})".format(ee=ee, f=e.field_name)
def min_or_max(self, op, e, f, indent=""): out = self.fv(e.type.t, "min" if op == "<" else "max") first = self.fv(BOOL, "first") x = self.fv(e.type.t, "x") decl1 = SDecl(out.id, evaluation.construct_value(out.type)) decl2 = SDecl(first.id, T) find = SForEach( x, e, SIf( EBinOp( first, BOp.Or, EBinOp(f.apply_to(x), op, f.apply_to(out)).with_type(BOOL)).with_type(BOOL), seq([SAssign(first, F), SAssign(out, x)]), SNoOp())) return (self.visit(seq([decl1, decl2, find]), indent), out.id)
def visit_EGetField(self, e): ee = self.visit(e.e) op = "." if isinstance(e.e.type, THandle): # Ugh, we really need Cozy to know about partial functions... # Cozy doesn't know that handle types (aka pointers) can be null. # It assumes that reads of null pointers produce default- # constructed values, so we need to generate appropriate code. ee = EEscape(ee, (), ()).with_type(e.e.type) null = ENull().with_type(e.e.type) return self.visit( ECond(EEq(ee, null), evaluation.construct_value(e.type), EEscape("{ee}->val", ("ee", ), (ee, )).with_type(e.type)).with_type(e.type)) return "({ee}.{f})".format(ee=ee, f=e.field_name)
def native_map_get(self, e, default_value): emap = self.visit(e.map) if self.use_qhash: return "{}.value({})".format(emap, ekey) t = self.visit(e.map.type, "").strip() + "::const_iterator" iterator = self.fv(TNative(t), "map_iterator") self.declare(iterator, EEscape(emap + ".find({key})", ("key", ), (e.key, ))) return self.visit( ECond( EEscape("{it} == " + emap + ".end()", ("it", ), (iterator, )).with_type(BOOL), evaluation.construct_value(e.type), EEscape("{it}->second", ("it", ), (iterator, )).with_type(e.type)).with_type(e.type))
def visit_EMapGet(self, e): m = self.visit(e.map) k = self.visit(e.key) if isinstance(m, EMakeMap2): if is_simple(k.type): return self.visit(ECond( EIn(k, m.e), m.value.apply_to(k), construct_value(m.type.v)).with_type(m.type.v)) else: return self.visit(EUnaryOp(UOp.The, EMap( EFilter(m.e, mk_lambda(m.type.k, lambda kk: EEq(kk, k))).with_type(TBag(m.type.k)), m.value).with_type(TBag(m.type.v))).with_type(m.type.v)) return EMapGet(m, k).with_type(e.type)
def visit_EMapGet(self, e): m = self.visit(e.map) k = self.visit(e.key) if isinstance(m, EMakeMap2): if equality_implies_deep_equality(k.type): return self.visit(ECond( EIn(k, m.e), m.value_function.apply_to(k), construct_value(m.type.v)).with_type(m.type.v)) else: return self.visit(EUnaryOp(UOp.The, EMap( EFilter(m.e, mk_lambda(m.type.k, lambda kk: EEq(kk, k))).with_type(TBag(m.type.k)), m.value_function).with_type(TBag(m.type.v))).with_type(m.type.v)) return EMapGet(m, k).with_type(e.type)
def test_mapget_of_makemap2(self): t = THandle("elem_type", INT) xs = EVar("xs").with_type(TBag(t)) x = EVar("x").with_type(t) y = EVar("y").with_type(t) mt = TTuple((INT, INT)) e1 = EMapGet( EMakeMap2(xs, ELambda(x, ETuple((EGetField(x, "val").with_type(INT), EGetField(y, "val").with_type(INT))).with_type(mt) )).with_type(TMap(t, mt)), y).with_type(mt) e2 = ECond( EIn(e1.key, e1.map.e), e1.map.value_function.apply_to(EUnaryOp(UOp.The, EFilter(e1.map.e, mk_lambda(e1.map.value_function.arg.type, lambda foo: EEq(foo, e1.key))).with_type(e1.map.e.type)).with_type(e1.map.e.type.elem_type)), construct_value(e1.type)).with_type(e1.type) self.assert_same(e1, e2)
def visit_EEmptyMap(self, e): map_type = e.type if self.use_trove(map_type): if self.trovename(map_type.k) == "Object": args = "64, 0.5f, {default}" elif self.trovename(map_type.v) == "Object": args = "64, 0.5f" else: args = "64, 0.5f, 0, {default}" # args: # int initialCapacity, float loadFactor, K noEntryKey, V noEntryValue # loadFactor of 0.5 (trove's default) means map has 2x more buckets than entries init = "new {}({})".format(self.visit(map_type, name=""), args) return self.visit(EEscape(init, ["default"], [evaluation.construct_value(map_type.v)])) else: return "new {}".format(self.visit(map_type, name="()"))
def construct_concrete(self, e: Exp, out: Exp): assert out.type == self, "{} : {}".format(pprint(e), pprint(e.type)) out = shallow_copy(out).with_type(self.rep_type()) assert isinstance(e, EMakeMap2) # TODO? k = fresh_var(self.k, "k") return seq([ SAssign( EVectorGet(out, ENum(i).with_type(INT)).with_type(self.v), construct_value(self.v)) for (i, k) in enumerate(self.all_keys) ] + [ SForEach( k, e.e, SAssign( EVectorGet(out, k).with_type(self.v), ELet(k, e.value).with_type(self.v))) ])
def initialize_native_map(self, out): if out.type.k == INT: return self.initialize_array(out.type.v, ENum(64).with_type(INT), out) if self.use_trove(out.type): if self.trovename(out.type.k) == "Object": args = "64, 0.5f, {default}" elif self.trovename(out.type.v) == "Object": args = "64, 0.5f" else: args = "64, 0.5f, 0, {default}" # args: # int initialCapacity, float loadFactor, K noEntryKey, V noEntryValue # loadFactor of 0.5 (trove's default) means map has 2x more buckets than entries init = "new {}({});\n".format(self.visit(out.type, name=""), args) return SEscape("{indent}{e} = " + init, ["e", "default"], [out, evaluation.construct_value(out.type.v)]) else: init = "new {};\n".format(self.visit(out.type, name="()")) return SEscape("{indent}{e} = " + init, ["e"], [out])
def visit_EListGet(self, e): l = self.to_lvalue(e.e) i = self.fv(INT) return self.visit( ELet( e.index, ELambda( i, ECond( EAll([ EGe(i, ZERO), ELt( i, EEscape("{l}.size()", ("l", ), (l, )).with_type(INT)) ]), EEscape("{l}[{i}]", ("l", "i"), (l, i)).with_type(e.type), evaluation.construct_value(e.type)).with_type( e.type))).with_type(e.type))
def min_or_max(self, op, e, f): if isinstance(e, EBinOp) and e.op == "+" and isinstance(e.e1, ESingleton) and isinstance(e.e2, ESingleton): # argmin_f ([a] + [b]) ---> f(a) < f(b) ? a : b return self.visit(ECond( EBinOp(f.apply_to(e.e1.e), op, f.apply_to(e.e2.e)).with_type(BOOL), e.e1.e, e.e2.e).with_type(e.e1.e.type)) out = fresh_var(e.type.elem_type, "min" if op == "<" else "max") first = fresh_var(BOOL, "first") x = fresh_var(e.type.elem_type, "x") decl1 = SDecl(out, evaluation.construct_value(out.type)) decl2 = SDecl(first, ETRUE) find = SForEach(x, e, SIf(EBinOp( first, BOp.Or, EBinOp(f.apply_to(x), op, f.apply_to(out)).with_type(BOOL)).with_type(BOOL), seq([SAssign(first, EFALSE), SAssign(out, x)]), SNoOp())) self.stms.append(simplify_and_optimize(seq([decl1, decl2, find]))) return out
def visit_SMapUpdate(self, update): map = self.visit(update.map) key = self.visit(update.key) t = self.visit(update.map.type, "").strip() + "::iterator" iterator = self.fv(TNative(t), "map_iterator") self.declare(iterator, EEscape(map + ".find(" + key + ")", (), ())) self.visit( SIf( EEscape("{it} == " + map + ".end()", ("it", ), (iterator, )).with_type(BOOL), SAssign( iterator, EEscape(map + ".emplace(" + key + ", {value}).first", ("value", ), (evaluation.construct_value( update.val_var.type), )).with_type( iterator.type)), SNoOp())) self.begin_statement() self.write("{decl} = {it}->second;\n".format(decl=self.visit( TRef(update.val_var.type), update.val_var.id), it=iterator.id)) self.end_statement() self.visit(update.change)
def optimized_the(xs, args): t = xs.type.elem_type if isinstance(xs, ECond): for e1 in optimized_the(xs.then_branch, args): for e2 in optimized_the(xs.else_branch, args): yield optimized_cond(xs.cond, e1, e2) if isinstance(xs, EStateVar): yield EStateVar(EUnaryOp(UOp.The, xs.e).with_type(t)).with_type(t) if isinstance(xs.type, TList): x = excluded_element(xs, args) if x is not None: bag, x = x for elem in optimized_the(bag, args): yield optimized_cond(EEq(elem, x), EListGet(bag, ONE).with_type(t), elem) if isinstance(xs, EMap): exists = optimized_exists(xs.e) for x in optimized_the(xs.e, args): yield optimized_cond(exists, xs.transform_function.apply_to(x), construct_value(t)) if isinstance(xs, EBinOp) and xs.op == "+": e1_exists = optimized_exists(xs.e1) for x in optimized_the(xs.e1, args): for y in optimized_the(xs.e2, args): yield optimized_cond(e1_exists, x, y) yield EUnaryOp(UOp.The, xs).with_type(t)
def improve( target : Exp, context : Context, assumptions : Exp = ETRUE, stop_callback : Callable[[], bool] = never_stop, hints : [Exp] = (), examples : [{str:object}] = (), cost_model : CostModel = None, ops : [Op] = (), improve_count : Value = None): """Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target` in the given `context`. The `cost_model` defines "better". It periodically calls `stop_callback` and exits gracefully when `stop_callback` returns True. Other parameters: - assumptions: a precondition. The yielded improvements will only be correct when the assumptions are true. - hints: expressions that might be useful. These will be explored first when looking for improvements. - examples: inputs that will be used internally to differentiate semantically distinct expressions. This procedure discovers more examples as it runs, so there usually isn't a reason to provide any. - ops: update operations. This function may make different choices about what expressions are state expressions based on what changes can happen to that state. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, allowing us find kinda-good solutions very quickly, even if the best possible solution is out of reach. This is more desireable than running for an indeterminate amount of time doing nothing. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r}, ops={ops!r})""".format( target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model, ops=ops)) target = inline_lets(target) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) if heuristic_done(target): print("The target already looks great!") return print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() is_wf = exp_wf(target, context=context, assumptions=assumptions) assert is_wf, "initial target is not well-formed: {}".format(is_wf) state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions)] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = solver_for_context(context, assumptions=assumptions) if not solver.satisfiable(ETRUE): print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return is_good = possibly_useful(solver, target, context) assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format(is_good, target) examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] blacklist = {} while True: # 1. find any potential improvement to any sub-exp of target for new_target in search_for_improvements( targets=watched_targets, wf_solver=solver, context=context, examples=examples, cost_model=cost_model, stop_callback=stop_callback, hints=hints, ops=ops, blacklist=blacklist): print("Found candidate improvement: {}".format(pprint(new_target))) # 2. check with task("verifying candidate"): counterexample = solver.satisfy(ENot(EEq(target, new_target))) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, restart examples.append(counterexample) print("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format(len(examples))) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = retention_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print("Whoops! Looks like we already found something better.") print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") return target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format(len(watched_targets))) break if improve_count is not None: with improve_count.get_lock(): improve_count.value += 1
def visit_SMapUpdate(self, update): map = self.visit(update.map) key = self.visit(update.key) t = self.visit(update.map.type, "").strip() + "::iterator" iterator = self.fv(TNative(t), "map_iterator") self.declare(iterator, EEscape(map + ".find(" + key + ")", (), ())) self.visit(SIf( EEscape("{it} == " + map + ".end()", ("it",), (iterator,)).with_type(BOOL), SAssign(iterator, EEscape(map + ".emplace(" + key + ", {value}).first", ("value",), (evaluation.construct_value(update.val_var.type),)).with_type(iterator.type)), SNoOp())) self.begin_statement() self.write("{decl} = {it}->second;\n".format( decl=self.visit(TRef(update.val_var.type), update.val_var.id), it=iterator.id)) self.end_statement() self.visit(update.change)
def optimized_best(xs, keyfunc, op, args): argbest = EArgMin if op == "<" else EArgMax elem_type = xs.type.elem_type key_type = keyfunc.body.type if excluded_element(xs, args) is not None: bag, x = excluded_element(xs, args) if all(v not in args for v in free_vars(bag)): heap_type, make_heap = (TMinHeap, EMakeMinHeap) if op == "<" else (TMaxHeap, EMakeMaxHeap) bag = EStateVar(strip_EStateVar(bag)).with_type(bag.type) h = make_heap(bag.e, keyfunc).with_type(heap_type(elem_type, key_type)) for prev_min in optimized_best(bag.e, keyfunc, op, args=args): prev_min = EStateVar(prev_min).with_type(elem_type) heap_peek = EHeapPeek2(EStateVar(h).with_type(h.type)).with_type(elem_type) conds = [optimized_in(x, bag), optimized_eq(x, prev_min)] if isinstance(x, EUnaryOp) and x.op == UOp.The: conds = [optimized_exists(x.e)] + conds yield optimized_cond( EAll(conds), heap_peek, prev_min) if isinstance(xs, EEmptyList): yield construct_value(elem_type) if isinstance(xs, ESingleton): yield xs.e if isinstance(xs, EBinOp) and xs.op == "+": a_ex = optimized_exists(xs.e1) b_ex = optimized_exists(xs.e2) bag_type = TBag(xs.type.elem_type) for a in optimized_best(xs.e1, keyfunc, op, args=args): for b in optimized_best(xs.e2, keyfunc, op, args=args): yield optimized_cond(a_ex, optimized_cond(b_ex, argbest(EBinOp(ESingleton(a).with_type(bag_type), "+", ESingleton(b).with_type(bag_type)).with_type(bag_type), keyfunc).with_type(elem_type), a), optimized_cond(b_ex, b, construct_value(elem_type))) # if isinstance(xs.e1, EStateVar) or isinstance(xs.e2, EStateVar): # sv, other = (xs.e1, xs.e2) if isinstance(xs.e1, EStateVar) else (xs.e2, xs.e1) # sv_best = optimized_best(sv, keyfunc, op, args=args) # yield optimized_cond( # optimized_exists(sv), # argbest(EBinOp(ESingleton(sv_best).with_type(xs.type), "+", other).with_type(xs.type), keyfunc).with_type(elem_type), # optimized_best(other, keyfunc, op, args=args)) # else: # parts = break_sum(xs) # found = EFALSE # best = construct_value(elem_type) # for p in parts: # ex = optimized_exists(p) # best_here = optimized_best(p, keyfunc, op, args=args) # best = optimized_cond(found, # optimized_cond(ex, # optimized_cond(EBinOp(keyfunc.apply_to(best_here), op, keyfunc.apply_to(best)).with_type(BOOL), # best_here, # best), # best), # best_here) # found = EAny([found, ex]) # yield best if isinstance(xs, EMap): for b in optimized_best(xs.e, compose(keyfunc, xs.transform_function), op, args): yield optimized_cond(optimized_exists(xs.e), xs.transform_function.apply_to(b), construct_value(elem_type)) if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(keyfunc)): yield EStateVar(argbest(xs.e, keyfunc).with_type(elem_type)).with_type(elem_type) if isinstance(xs, ECond): for a in optimized_best(xs.then_branch, keyfunc, op, args=args): for b in optimized_best(xs.else_branch, keyfunc, op, args=args): yield optimized_cond(xs.cond, a, b) if isinstance(xs, EUnaryOp) and xs.op == UOp.Distinct: yield from optimized_best(xs.e, keyfunc, op, args=args) # if isinstance(xs, EFilter): # yield optimized_cond( # xs.predicate.apply_to(optimized_best(xs.e, keyfunc, op, args=args)), # optimized_best(xs.e, keyfunc, op, args=args), # argbest(xs, keyfunc).with_type(elem_type)) yield argbest(xs, keyfunc).with_type(elem_type)
def improve(target: Exp, context: Context, assumptions: Exp = T, stop_callback=never_stop, hints: [Exp] = (), examples: [{ str: object }] = (), cost_model: CostModel = None): """ Improve the target expression using enumerative synthesis. This function is a generator that yields increasingly better and better versions of the input expression `target`. Notes on internals of this algorithm follow. Key differences from "regular" enumerative synthesis: - Expressions are either "state" expressions or "runtime" expressions, allowing this algorithm to choose what things to store on the data structure and what things to compute at query execution time. (The cost model is ultimately responsible for this choice.) - If a better version of *any subexpression* for the target is found, it is immediately substituted in and the overall expression is returned. This "smooths out" the search space a little, and lets us find kinda-good solutions very quickly, even if the best possible solution is out of reach. """ print("call to improve:") print("""improve( target={target!r}, context={context!r}, assumptions={assumptions!r}, stop_callback={stop_callback!r}, hints={hints!r}, examples={examples!r}, cost_model={cost_model!r})""".format(target=target, context=context, assumptions=assumptions, stop_callback=stop_callback, hints=hints, examples=examples, cost_model=cost_model)) target = freshen_binders(target, context) assumptions = freshen_binders(assumptions, context) print() print("improving: {}".format(pprint(target))) print("subject to: {}".format(pprint(assumptions))) print() try: assert exp_wf(target, context=context, assumptions=assumptions) except ExpIsNotWf as ex: print( "WARNING: initial target is not well-formed [{}]; this might go poorly..." .format(str(ex))) print(pprint(ex.offending_subexpression)) print(pprint(ex.offending_subexpression.type)) # raise state_vars = [v for (v, p) in context.vars() if p == STATE_POOL] if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars): print("This job does not depend on state_vars.") # TODO: what can we do about it? hints = ([freshen_binders(h, context) for h in hints] + [ freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions) ] + [target]) print("{} hints".format(len(hints))) for h in hints: print(" - {}".format(pprint(h))) vars = list(v for (v, p) in context.vars()) funcs = context.funcs() solver = None if incremental.value: solver = IncrementalSolver(vars=vars, funcs=funcs) solver.add_assumption(assumptions) _sat = solver.satisfy else: _sat = lambda e: satisfy(e, vars=vars, funcs=funcs) if _sat(assumptions) is None: print("assumptions are unsat; this query will never be called") yield construct_value(target.type) return examples = list(examples) if cost_model is None: cost_model = CostModel(funcs=funcs, assumptions=assumptions) watched_targets = [target] learner = Learner(watched_targets, assumptions, context, examples, cost_model, stop_callback, hints) try: while True: # 1. find any potential improvement to any sub-exp of target for new_target in learner.next(): print("Found candidate improvement: {}".format( pprint(new_target))) # 2. check with task("verifying candidate"): if incremental.value: solver.push() solver.add_assumption( ENot( EBinOp(target, "==", new_target).with_type(BOOL))) counterexample = _sat(T) else: formula = EAll([ assumptions, ENot( EBinOp(target, "==", new_target).with_type(BOOL)) ]) counterexample = _sat(formula) if counterexample is not None: if counterexample in examples: print("assumptions = {!r}".format(assumptions)) print("duplicate example: {!r}".format(counterexample)) print("old target = {!r}".format(target)) print("new target = {!r}".format(new_target)) raise Exception("got a duplicate example") # a. if incorrect: add example, reset the learner examples.append(counterexample) event("new example: {!r}".format(counterexample)) print("wrong; restarting with {} examples".format( len(examples))) learner.reset(examples) break else: # b. if correct: yield it, watch the new target, goto 1 print("The candidate is valid!") print(repr(new_target)) print("Determining whether to yield it...") with task("updating frontier"): to_evict = [] keep = True old_better = None for old_target in watched_targets: evc = eviction_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model) if old_target not in evc: to_evict.append(old_target) if new_target not in evc: old_better = old_target keep = False break for t in to_evict: watched_targets.remove(t) if not keep: print( "Whoops! Looks like we already found something better." ) print(" --> {}".format(pprint(old_better))) continue if target in to_evict: print("Yep, it's an improvement!") yield new_target if heuristic_done(new_target): print("target now matches doneness heuristic") raise NoMoreImprovements() target = new_target else: print("Nope, it isn't substantially better!") watched_targets.append(new_target) print("Now watching {} targets".format( len(watched_targets))) learner.watch(watched_targets) break if incremental.value: solver.pop() except NoMoreImprovements: return except KeyboardInterrupt: raise
def visit_Spec(self, spec : Spec, state_exps : { str : Exp }, sharing, abstract_state=()): self.state_exps = state_exps self.funcs = { f.name: f for f in spec.extern_funcs } self.queries = { q.name: q for q in spec.methods if isinstance(q, Query) } self.vars = set(e.id for e in all_exps(spec) if isinstance(e, EVar)) self.write("#pragma once\n") self.write("#include <algorithm>\n") self.write("#include <set>\n") self.write("#include <functional>\n") self.write("#include <vector>\n") self.write("#include <unordered_set>\n") self.write("#include <string>\n") if self.use_qhash: self.write("#include <QHash>\n") else: self.write("#include <unordered_map>\n") if spec.header: self.write("\n" + spec.header.strip() + "\n") self.write("{}\nclass {} {{\n".format( ("\n" + spec.docstring) if spec.docstring else "", spec.name)) self.write("public:\n") print("Setting up auxiliary types...") self.setup_types(spec, state_exps, sharing) with self.indented(): for t, name in self.types.items(): self.define_type(spec.name, t, name, sharing) self.begin_statement() if isinstance(t, THandle): # No overridden hash code! We use pointers instead. continue self.write("struct _Hash", name, " ") with self.block(): self.write_stmt("typedef ", spec.name, "::", name, " argument_type;") self.write_stmt("typedef std::size_t result_type;") self.begin_statement() self.write("result_type operator()(const argument_type& x) const noexcept ") x = EVar("x").with_type(t) if isinstance(t, TEnum): fields = [EEnumToInt(x).with_type(INT)] elif isinstance(t, TRecord): fields = [EGetField(x, f).with_type(ft) for (f, ft) in t.fields] elif isinstance(t, TTuple): fields = [ETupleGet(x, n).with_type(tt) for (n, tt) in enumerate(t.ts)] else: raise NotImplementedError(t) with self.block(): self.visit(self.compute_hash(fields)) self.end_statement() self.write(";") self.end_statement() print("Setting up member variables...") self.write("protected:\n") with self.indented(): for name, t in spec.statevars: self.statevar_name = name self.declare_field(name, t) self.write("public:\n") with self.indented(): print("Generating constructors...") # default constructor self.begin_statement() self.write("inline ", spec.name, "() ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] fvs = free_vars(initial_value) initial_value = subst(initial_value, {v.id : evaluation.construct_value(v.type) for v in fvs}) stm = simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value)) self.visit(stm) self.end_statement() # explicit constructor if abstract_state: self.begin_statement() self.write("explicit inline ", spec.name, "(") self.visit_args(abstract_state) self.write(") ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] self.visit(simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value))) self.end_statement() # disable copy constructor (TODO: support this in the future?) self.begin_statement() self.write(spec.name, "(const ", spec.name, "& other) = delete;") self.end_statement() # generate methods for op in spec.methods: print("Generating method {}...".format(op.name)) self.visit(op) self.write("};\n") if spec.footer: self.write("\n", spec.footer) if not spec.footer.endswith("\n"): self.write("\n")
def visit_ECond(self, e): v = fresh_var(e.type, "conditional_result") self.stms.append(simplify_and_optimize(seq([ SDecl(v, evaluation.construct_value(e.type)), SIf(e.cond, SAssign(v, e.then_branch), SAssign(v, e.else_branch))]))) return v
def visit_ETreeMultisetPeek(self, e): return self.visit(ECond(ELt(e.index, EEscape("{xs}.size()", ("xs",), (e.e,))).with_type(BOOL), EEscape("Iterators.get({xs}.iterator(), {i})", ("xs", "i"), (e.e, e.index)).with_type(e.type), evaluation.construct_value(e.type)).with_type(e.type))
def enumerate_core(self, context: Context, size: int, pool: Pool) -> [Exp]: """ Arguments: conext : a Context object describing the vars in scope size : size to enumerate pool : pool to enumerate Yields all expressions of the given size legal in the given context and pool. """ if size < 0: return if size == 0: for (e, p) in LITERALS: if p == pool: yield e for (v, p) in context.vars(): if p == pool: yield v for t in all_types(v): yield construct_value(t) for (e, ctx, p) in self.hints: if p == pool and ctx.alpha_equivalent(context): yield context.adapt(e, ctx) for t in all_types(e): yield construct_value(t) return yield from self.heuristic_enumeration(context, size, pool) for e in collections(self.enumerate(context, size - 1, pool)): yield EEmptyList().with_type(e.type) if is_numeric(e.type.t): yield EUnaryOp(UOp.Sum, e).with_type(e.type.t) for e in self.enumerate(context, size - 1, pool): yield ESingleton(e).with_type(TBag(e.type)) for e in self.enumerate(context, size - 1, pool): if isinstance(e.type, TRecord): for (f, t) in e.type.fields: yield EGetField(e, f).with_type(t) for e in self.enumerate(context, size - 1, pool): if isinstance(e.type, THandle): yield EGetField(e, "val").with_type(e.type.value_type) for e in self.enumerate(context, size - 1, pool): if isinstance(e.type, TTuple): for n in range(len(e.type.ts)): yield ETupleGet(e, n).with_type(e.type.ts[n]) for e in of_type(self.enumerate(context, size - 1, pool), BOOL): yield EUnaryOp(UOp.Not, e).with_type(BOOL) for e in self.enumerate(context, size - 1, pool): if is_numeric(e.type): yield EUnaryOp("-", e).with_type(e.type) for m in self.enumerate(context, size - 1, pool): if isinstance(m.type, TMap): yield EMapKeys(m).with_type(TBag(m.type.k)) for (sz1, sz2) in pick_to_sum(2, size - 1): for a1 in self.enumerate(context, sz1, pool): t = a1.type if not is_numeric(t): continue for a2 in of_type(self.enumerate(context, sz2, pool), t): yield EBinOp(a1, "+", a2).with_type(t) yield EBinOp(a1, "-", a2).with_type(t) yield EBinOp(a1, ">", a2).with_type(BOOL) yield EBinOp(a1, "<", a2).with_type(BOOL) yield EBinOp(a1, ">=", a2).with_type(BOOL) yield EBinOp(a1, "<=", a2).with_type(BOOL) for a1 in collections(self.enumerate(context, sz1, pool)): for a2 in of_type(self.enumerate(context, sz2, pool), a1.type): yield EBinOp(a1, "+", a2).with_type(a1.type) yield EBinOp(a1, "-", a2).with_type(a1.type) for a2 in of_type(self.enumerate(context, sz2, pool), a1.type.t): yield EBinOp(a2, BOp.In, a1).with_type(BOOL) for a1 in of_type(self.enumerate(context, sz1, pool), BOOL): for a2 in of_type(self.enumerate(context, sz2, pool), BOOL): yield EBinOp(a1, BOp.And, a2).with_type(BOOL) yield EBinOp(a1, BOp.Or, a2).with_type(BOOL) for a1 in self.enumerate(context, sz1, pool): if not isinstance(a1.type, TMap): for a2 in of_type(self.enumerate(context, sz2, pool), a1.type): yield EEq(a1, a2) yield EBinOp(a1, "!=", a2).with_type(BOOL) for m in self.enumerate(context, sz1, pool): if isinstance(m.type, TMap): for k in of_type(self.enumerate(context, sz2, pool), m.type.k): yield EMapGet(m, k).with_type(m.type.v) yield EHasKey(m, k).with_type(BOOL) for l in self.enumerate(context, sz1, pool): if not isinstance(l.type, TList): continue for i in of_type(self.enumerate(context, sz2, pool), INT): yield EListGet(l, i).with_type(l.type.t) for (sz1, sz2, sz3) in pick_to_sum(3, size - 1): for cond in of_type(self.enumerate(context, sz1, pool), BOOL): for then_branch in self.enumerate(context, sz2, pool): for else_branch in of_type( self.enumerate(context, sz2, pool), then_branch.type): yield ECond(cond, then_branch, else_branch).with_type(then_branch.type) for l in self.enumerate(context, sz1, pool): if not isinstance(l.type, TList): continue for st in of_type(self.enumerate(context, sz2, pool), INT): for ed in of_type(self.enumerate(context, sz3, pool), INT): yield EListSlice(l, st, ed).with_type(l.type) for bag in collections(self.enumerate(context, size - 1, pool)): # len of bag count = EUnaryOp(UOp.Length, bag).with_type(INT) yield count # empty? yield EUnaryOp(UOp.Empty, bag).with_type(BOOL) # exists? yield EUnaryOp(UOp.Exists, bag).with_type(BOOL) # singleton? yield EEq(count, ONE) yield EUnaryOp(UOp.The, bag).with_type(bag.type.t) yield EUnaryOp(UOp.Distinct, bag).with_type(bag.type) yield EUnaryOp(UOp.AreUnique, bag).with_type(BOOL) if bag.type.t == BOOL: yield EUnaryOp(UOp.Any, bag).with_type(BOOL) yield EUnaryOp(UOp.All, bag).with_type(BOOL) def build_lambdas(bag, pool, body_size): v = fresh_var(bag.type.t, omit=set(v for v, p in context.vars())) inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool) for lam_body in self.enumerate(inner_context, body_size, pool): yield ELambda(v, lam_body) # Iteration for (sz1, sz2) in pick_to_sum(2, size - 1): for bag in collections(self.enumerate(context, sz1, pool)): for lam in build_lambdas(bag, pool, sz2): body_type = lam.body.type yield EMap(bag, lam).with_type(TBag(body_type)) if body_type == BOOL: yield EFilter(bag, lam).with_type(bag.type) if is_numeric(body_type): yield EArgMin(bag, lam).with_type(bag.type.t) yield EArgMax(bag, lam).with_type(bag.type.t) if is_collection(body_type): yield EFlatMap(bag, lam).with_type(TBag(body_type.t)) # Enable use of a state-pool expression at runtime if pool == RUNTIME_POOL: for e in self.enumerate(context, size - 1, STATE_POOL): yield EStateVar(e).with_type(e.type) # Create maps if pool == STATE_POOL: for (sz1, sz2) in pick_to_sum(2, size - 1): for bag in collections(self.enumerate(context, sz1, STATE_POOL)): if not is_scalar(bag.type.t): continue for lam in build_lambdas(bag, STATE_POOL, sz2): t = TMap(bag.type.t, lam.body.type) m = EMakeMap2(bag, lam).with_type(t) yield m
def visit_Spec(self, spec: Spec, state_exps: {str: Exp}, sharing, abstract_state=()): self.state_exps = state_exps self.funcs = {f.name: f for f in spec.extern_funcs} self.queries = { q.name: q for q in spec.methods if isinstance(q, Query) } self.vars = set(e.id for e in all_exps(spec) if isinstance(e, EVar)) self.write("#pragma once\n") self.write("#include <algorithm>\n") self.write("#include <set>\n") self.write("#include <functional>\n") self.write("#include <vector>\n") self.write("#include <unordered_set>\n") self.write("#include <string>\n") if self.use_qhash: self.write("#include <QHash>\n") else: self.write("#include <unordered_map>\n") if spec.header: self.write("\n" + spec.header.strip() + "\n") self.write("{}\nclass {} {{\n".format( ("\n" + spec.docstring) if spec.docstring else "", spec.name)) self.write("public:\n") print("Setting up auxiliary types...") self.setup_types(spec, state_exps, sharing) with self.indented(): for t, name in self.types.items(): self.define_type(spec.name, t, name, sharing) self.begin_statement() if isinstance(t, THandle): # No overridden hash code! We use pointers instead. continue self.write("struct _Hash", name, " ") with self.block(): self.write_stmt("typedef ", spec.name, "::", name, " argument_type;") self.write_stmt("typedef std::size_t result_type;") self.begin_statement() self.write( "result_type operator()(const argument_type& x) const noexcept " ) x = EVar("x").with_type(t) if isinstance(t, TEnum): fields = [EEnumToInt(x).with_type(INT)] elif isinstance(t, TRecord): fields = [ EGetField(x, f).with_type(ft) for (f, ft) in t.fields ] elif isinstance(t, TTuple): fields = [ ETupleGet(x, n).with_type(tt) for (n, tt) in enumerate(t.ts) ] else: raise NotImplementedError(t) with self.block(): self.visit(self.compute_hash(fields)) self.end_statement() self.write(";") self.end_statement() print("Setting up member variables...") self.write("protected:\n") with self.indented(): for name, t in spec.statevars: self.statevar_name = name self.declare_field(name, t) self.write("public:\n") with self.indented(): print("Generating constructors...") # default constructor self.begin_statement() self.write("inline ", spec.name, "() ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] fvs = free_vars(initial_value) initial_value = subst(initial_value, { v.id: evaluation.construct_value(v.type) for v in fvs }) stm = simplify_and_optimize( SAssign(EVar(name).with_type(t), initial_value)) self.visit(stm) self.end_statement() # explicit constructor if abstract_state: self.begin_statement() self.write("explicit inline ", spec.name, "(") self.visit_args(abstract_state) self.write(") ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] self.visit( simplify_and_optimize( SAssign( EVar(name).with_type(t), initial_value))) self.end_statement() # disable copy constructor (TODO: support this in the future?) self.begin_statement() self.write(spec.name, "(const ", spec.name, "& other) = delete;") self.end_statement() # generate methods for op in spec.methods: print("Generating method {}...".format(op.name)) self.visit(op) self.write("};\n") if spec.footer: self.write("\n", spec.footer) if not spec.footer.endswith("\n"): self.write("\n")
def visit_Spec(self, spec, state_exps, sharing, abstract_state=()): self.state_exps = state_exps self.funcs = { f.name: f for f in spec.extern_funcs } self.queries = { q.name: q for q in spec.methods if isinstance(q, Query) } self.vars = set(e.id for e in all_exps(spec) if isinstance(e, EVar)) self.setup_types(spec, state_exps, sharing) if guava.value: self.write("import com.google.common.collect.TreeMultiset;\n") self.write("import com.google.common.collect.Iterators;\n") if spec.header: self.write(spec.header.strip() + "\n\n") if spec.docstring: self.write(spec.docstring + "\n") self.write("public class {} implements java.io.Serializable ".format(spec.name)) with self.block(): for name, t in spec.types: self.types[t] = name # member variables for name, t in spec.statevars: self.write("{}protected {};\n".format(INDENT, self.visit(t, name))) # constructor self.write( "{indent}public {name}() {{\n{indent2}clear();\n{indent}}}\n\n" .format(indent=INDENT, indent2=INDENT+INDENT, name=spec.name)) # explicit constructor if abstract_state: self.begin_statement() self.write("public ", spec.name, "(") self.visit_args(abstract_state) self.write(") ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] self.visit(simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value))) self.end_statement() # clear self.begin_statement() self.write("public void clear() ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] fvs = free_vars(initial_value) initial_value = subst(initial_value, {v.id : evaluation.construct_value(v.type) for v in fvs}) setup = simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value)) self.visit(setup) self.end_statement() # methods for op in spec.methods: self.visit(op) # generate auxiliary types for t, name in self.types.items(): self.define_type(spec.name, t, name, sharing) self.write("\n") self.write(spec.footer) if not spec.footer.endswith("\n"): self.write("\n")