def watch(self, new_target): print("watching new target...") self.backlog_counter = 0 self.target = new_target self.roots = OrderedSet() types = OrderedSet() for e in itertools.chain(all_exps(new_target), *[all_exps(h) for h in self.hints]): if isinstance(e, ELambda): continue for pool in ALL_POOLS: exp = e if pool == STATE_POOL: exp = strip_EStateVar(e) fvs = free_vars(exp) if all(v in self.legal_free_vars for v in fvs) and self.is_legal_in_pool(exp, pool): _on_exp(exp, "new root", pool_name(pool)) exp._root = True self.roots.add((exp, pool)) if pool == STATE_POOL and all(v in self.state_vars for v in fvs): self.roots.add((EStateVar(exp).with_type(exp.type), RUNTIME_POOL)) types.add(exp.type) else: _on_exp(exp, "rejected root", pool_name(pool)) for b in self.binders: types.add(b.type) for t in types: self.roots.add((construct_value(t), RUNTIME_POOL)) self.roots = list(self.roots) self.roots.sort(key = lambda tup: tup[0].size()) self._watches = group_by( enumerate_fragments2(new_target), k=lambda ctx: (ctx.pool, ctx.e.type), v=lambda ctxs: sorted(ctxs, key=lambda ctx: -ctx.e.size())) print("done!")
def optimized_sum(xs, args): elem_type = xs.type.elem_type if isinstance(xs, EStateVar): yield EStateVar(sum_of(strip_EStateVar(xs))).with_type(elem_type) if isinstance(xs, EBinOp) and xs.op == "+": for a in optimized_sum(xs.e1, args=args): for b in optimized_sum(xs.e2, args=args): yield EBinOp(a, "+", b).with_type(elem_type) if isinstance(xs, EBinOp) and xs.op == "-": arg = fresh_var(elem_type) for a in optimized_sum(xs.e1, args=args): for e2 in _simple_filter(xs.e2, ELambda(arg, optimized_in(arg, xs.e1)), args): for b in optimized_sum(e2, args=args): yield EBinOp(a, "-", b).with_type(elem_type) x = excluded_element(xs, args) if x is not None: bag, x = x for s in optimized_sum(bag, args): yield EBinOp(s, "-", x).with_type(x.type) if isinstance(xs, ESingleton): yield xs.e if isinstance(xs, EFlatMap): f = xs.transform_function if isinstance(f.body, EBinOp) and f.body.op == "+": for e1 in optimized_flatmap(xs.e, ELambda(f.arg, f.body.e1), args): for e2 in optimized_flatmap(xs.e, ELambda(f.arg, f.body.e2), args): for e in optimized_sum(EBinOp(e1, "+", e2).with_type(e1.type), args): yield e yield sum_of(xs)
def test_mutate_preserves_statevar(self): x = EVar("x").with_type(INT) e = EBinOp(EStateVar(x), "+", ONE) assert retypecheck(e) s = SAssign(x, EBinOp(x, "+", ONE).with_type(INT)) e2 = strip_EStateVar(inc.mutate(e, s)) e2 = repair_well_formedness(e2, context=RootCtx(state_vars=[x], args=[])) print(pprint(e)) print(pprint(e2)) assert e2 == EBinOp(EBinOp(EStateVar(x), "+", ONE), "+", ONE)
def test_mutate_preserves_statevar(self): x = EVar("x").with_type(INT) e = EBinOp(EStateVar(x), "+", ONE) assert retypecheck(e) s = SAssign(x, EBinOp(x, "+", ONE).with_type(INT)) e2 = strip_EStateVar(inc.mutate(e, s)) e2 = repair_well_formedness(e2, context=RootCtx(state_vars=[x], args=[])) print(pprint(e)) print(pprint(e2)) assert e2 == EBinOp(EBinOp(EStateVar(x), "+", ONE), "+", ONE)
def repair_EStateVar(e: syntax.Exp, available_state: [syntax.Exp]) -> syntax.Exp: class V(BottomUpRewriter): def visit_EStateVar(self, e): return e def visit_Exp(self, e): if any(alpha_equivalent(e, x) for x in available_state): return target_syntax.EStateVar(e).with_type(e.type) return super().visit_ADT(e) return V().visit(strip_EStateVar(e))
def make_subgoal(e, a=[], docstring=None): e = strip_EStateVar(e) if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)): return e query_name = fresh_name("query") query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring) query_vars = [v for v in free_vars(query) if v not in ctx] query.args = [(arg.id, arg.type) for arg in query_vars] subgoals.append(query) return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)
def can_elim_vars(spec : Exp, assumptions : Exp, vs : [EVar]): """Does any execution of `spec` actually depend on any of `vs`? It is possible for a variable to appear in an expression like `spec` without affecting its value. This function uses the solver to determine whether any of the given variables can affect the output of `spec`. """ spec = strip_EStateVar(spec) sub = { v.id : fresh_var(v.type) for v in vs } return valid(EImplies( EAll([assumptions, subst(assumptions, sub)]), EEq(spec, subst(spec, sub))))
def simplify_sum(e): parts = list(break_sum(e)) t, f = partition(parts, lambda p: p[0]) t = [x[1] for x in t] f = [x[1] for x in f] parts = [] for x in t: opp = find_one( f, lambda y: alpha_equivalent(strip_EStateVar(x), strip_EStateVar(y))) if opp: f.remove(opp) else: parts.append(x) parts.extend(EUnaryOp("-", x).with_type(INT) for x in f) if not parts: return ZERO res = parts[0] for i in range(1, len(parts)): res = EBinOp(res, "+", parts[i]).with_type(INT) return res
def can_elim_vars(spec: Exp, assumptions: Exp, vs: [EVar]): """Does any execution of `spec` actually depend on any of `vs`? It is possible for a variable to appear in an expression like `spec` without affecting its value. This function uses the solver to determine whether any of the given variables can affect the output of `spec`. """ spec = strip_EStateVar(spec) sub = {v.id: fresh_var(v.type) for v in vs} return valid( EImplies(EAll([assumptions, subst(assumptions, sub)]), EEq(spec, subst(spec, sub))))
def extract_listcomp(e): """ Extract list comprehension components from its desugared form :param e: list comprehension expression :return: list comprehension structure { "P": ..., "C": ..., "V": ... } (we use "lc" to refer to this type of structure). "P" is pulled expressions, "C" is condition, "V" is returned value. In the written form, it is { V(p0, ..., pn) | p0 <- P_0, ..., pn <- P_n, C(p0, ..., pn)}. Notice that all V and C already have free variables p0 to pn. If the structure doesn't follow our assumption, return None """ if isinstance(e, EFlatMap): pulled = e.e f = e.transform_function var = f.arg ebody = f.body lc = extract_listcomp(ebody) if lc is not None: lc["P"][var] = strip_EStateVar(pulled) return lc elif isinstance(e, EMap): f = e.transform_function ebody = f.body lc = extract_listcomp(e.e) if lc is not None: lc["V"] = ebody return lc elif isinstance(e, EFilter): lc = { "C": e.predicate.body, "P": { e.predicate.arg: strip_EStateVar(e.e) } } return lc return None
def _simple_filter(xs, p, args): if p.body == T: return xs if p.body == F: return EEmptyList().with_type(xs.type) if isinstance(xs, EEmptyList): return xs if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(p)): return EStateVar(EFilter(xs.e, strip_EStateVar(p)).with_type( xs.type)).with_type(xs.type) if isinstance(xs, EMapGet) and isinstance( xs.map, EStateVar) and not any(v in args for v in free_vars(p)): m = map_values(xs.map.e, lambda ys: _simple_filter(ys, p, args)) return EMapGet(EStateVar(m).with_type(m.type), xs.key).with_type(xs.type) if isinstance(xs, EBinOp) and xs.op == "+": return EBinOp(_simple_filter(xs.e1, p, args), "+", _simple_filter(xs.e2, p, args)).with_type(xs.type) if isinstance(p.body, EBinOp) and p.body.op == "==": fvs2 = free_vars(p.body.e2) if p.body.e1 == p.arg and p.arg not in fvs2: return optimized_cond(optimized_in(p.body.e2, xs), ESingleton(p.body.e2).with_type(xs.type), EEmptyList().with_type(xs.type)).with_type( xs.type) fvs1 = free_vars(p.body.e1) if p.arg in fvs1 and not any( a in fvs1 for a in args) and p.arg not in fvs2 and isinstance( xs, EStateVar): k = fresh_var(p.body.e1.type) e = EMapGet( EStateVar( EMakeMap2( EMap(xs.e, ELambda(p.arg, p.body.e1)), ELambda( k, EFilter(xs.e, ELambda(p.arg, EEq(p.body.e1, k)))))), p.body.e2) res = retypecheck(e) assert res return e return EFilter(xs, p).with_type(xs.type)
def fold_into_map(e, context): fvs = free_vars(e) state_vars = [v for v, p in context.vars() if p == STATE_POOL] for subexp, subcontext, subpool in all_subexpressions_with_context_information(e, context, RUNTIME_POOL): if isinstance(subexp, EMapGet) and isinstance(subexp.map, EStateVar): map = subexp.map.e key = subexp.key key_type = key.type value_type = subexp.type # e is of the form `... EStateVar(map)[key] ...` arg = fresh_var(subexp.type, omit=fvs) func = ELambda(arg, replace( e, context, RUNTIME_POOL, subexp, subcontext, subpool, arg)) if not all(v in state_vars for v in free_vars(func)): continue func = strip_EStateVar(func) new_map = map_values(map, func.apply_to) yield EMapGet(EStateVar(new_map).with_type(new_map.type), key).with_type(e.type)
def fold_into_map(e, context): fvs = free_vars(e) state_vars = [v for v, p in context.vars() if p == STATE_POOL] for subexp, subcontext, subpool in all_subexpressions_with_context_information(e, context, RUNTIME_POOL): if isinstance(subexp, EMapGet) and isinstance(subexp.map, EStateVar): map = subexp.map.e key = subexp.key key_type = key.type value_type = subexp.type # e is of the form `... EStateVar(map)[key] ...` arg = fresh_var(subexp.type, omit=fvs) func = ELambda(arg, replace( e, context, RUNTIME_POOL, subexp, subcontext, subpool, arg)) if not all(v in state_vars for v in free_vars(func)): continue func = strip_EStateVar(func) new_map = map_values(map, func.apply_to) yield EMapGet(EStateVar(new_map).with_type(new_map.type), key).with_type(e.type)
def repair_well_formedness(e: Exp, context: Context, extra_available_state: [Exp] = []) -> Exp: """Repair the EStateVar nodes in an expression that is not well-formed. Parameters: e - the expression to repair context - the intended context for e extra_available_state - extra state expressions that e can use Assuming that all expressions in extra_available_state are well-formed state expressions, the output will be a well-formed runtime expression that behaves like `e`. """ with task("repairing"): e = strip_EStateVar(e) # state expressions in decreasing order of size available_state = sorted(unique( itertools.chain((v for v, p in context.vars() if p == STATE_POOL), extra_available_state)), key=lambda e: -e.size()) with task("making replacements", size=e.size()): for s in available_state: e = replace(e, s, EStateVar(s).with_type(s.type), match=alpha_equivalent, filter=lambda e: not isinstance(e, EStateVar)) with task("freshening binders"): e = freshen_binders(e, context) with task("checking correctness"): res = exp_wf(e, context, RUNTIME_POOL) assert res, str(res) return e
def map_accelerate(e, state_vars, binders, args, cache, size): for (_, arg, f, bound) in enumerate_fragments(strip_EStateVar(e)): if any(v in state_vars for v in free_vars(arg)): continue for binder in (b for b in binders if b.type == arg.type and b not in bound): value = f(binder) if any(v not in state_vars and v not in binders for v in free_vars(value)): continue for bag in cache.find_collections(pool=STATE_POOL, size=size, of=arg.type): if isinstance(bag, EEmptyList): continue m = EMakeMap2(bag, ELambda(binder, value)).with_type(TMap(arg.type, e.type)) assert not any(v in args for v in free_vars(m)) if any(v in binders for v in free_vars(m)): continue yield (m, STATE_POOL) yield (EMapGet(EStateVar(m).with_type(m.type), arg).with_type(e.type), RUNTIME_POOL)
def repair_well_formedness(e : Exp, context : Context, extra_available_state : [Exp] = []) -> Exp: """Repair the EStateVar nodes in an expression that is not well-formed. Parameters: e - the expression to repair context - the intended context for e extra_available_state - extra state expressions that e can use Assuming that all expressions in extra_available_state are well-formed state expressions, the output will be a well-formed runtime expression that behaves like `e`. """ with task("repairing"): e = strip_EStateVar(e) # state expressions in decreasing order of size available_state = sorted(unique(itertools.chain( (v for v, p in context.vars() if p == STATE_POOL), extra_available_state)), key=lambda e: -e.size()) with task("making replacements", size=e.size()): for s in available_state: e = replace(e, s, EStateVar(s).with_type(s.type), match=alpha_equivalent, filter=lambda e: not isinstance(e, EStateVar)) with task("freshening binders"): e = freshen_binders(e, context) with task("checking correctness"): res = exp_wf(e, context, RUNTIME_POOL) assert res, str(res) return e
def sketch_update( lval: syntax.Exp, old_value: syntax.Exp, new_value: syntax.Exp, ctx: [syntax.EVar], assumptions: [syntax.Exp] = [], invariants: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]): """ Write code to update `lval` when it changes from `old_value` to `new_value`. Variables in `ctx` are assumed to be part of the data structure abstract state, and `assumptions` will be appended to all generated subgoals. This function returns a statement (code to update `lval`) and a list of subgoals (new queries that appear in the code). """ if valid( syntax.EImplies( syntax.EAll(itertools.chain(assumptions, invariants)), syntax.EEq(old_value, new_value))): return (syntax.SNoOp(), []) subgoals = [] new_value = strip_EStateVar(new_value) def make_subgoal(e, a=[], docstring=None): if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)): return e query_name = fresh_name("query") query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring) query_vars = [v for v in free_vars(query) if v not in ctx] query.args = [(arg.id, arg.type) for arg in query_vars] subgoals.append(query) return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type) def recurse(*args, **kwargs): (code, sgs) = sketch_update(*args, **kwargs) subgoals.extend(sgs) return code t = lval.type if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet): to_add = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="additions to {}".format(pprint(lval))) to_del = make_subgoal( syntax.EBinOp(old_value, "-", new_value).with_type(t), docstring="deletions from {}".format(pprint(lval))) v = fresh_var(t.t) stm = syntax.seq([ syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])), syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v])) ]) elif is_numeric(t) and update_numbers_with_deltas.value: change = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="delta for {}".format(pprint(lval))) stm = syntax.SAssign(lval, syntax.EBinOp(lval, "+", change).with_type(t)) elif isinstance(t, syntax.TTuple): get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions, invariants=invariants) for i in range(len(t.ts)) ]) elif isinstance(t, syntax.TRecord): get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type( t.fields[i][1]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions, invariants=invariants) for i in range(len(t.fields)) ]) elif isinstance(t, syntax.TMap): k = fresh_var(lval.type.k) v = fresh_var(lval.type.v) key_bag = syntax.TBag(lval.type.k) old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag) new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag) # (1) exit set deleted_keys = syntax.EBinOp(old_keys, "-", new_keys).with_type(key_bag) s1 = syntax.SForEach( k, make_subgoal(deleted_keys, docstring="keys removed from {}".format( pprint(lval))), target_syntax.SMapDel(lval, k)) # (2) enter/mod set new_or_modified = target_syntax.EFilter( new_keys, syntax.ELambda( k, syntax.EAny([ syntax.ENot(syntax.EIn(k, old_keys)), syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))) ]))).with_type(key_bag) update_value = recurse(v, value_at(old_value, k), value_at(new_value, k), ctx=ctx, assumptions=assumptions + [ syntax.EIn(k, new_or_modified), syntax.EEq(v, value_at(old_value, k)) ], invariants=invariants) s2 = syntax.SForEach( k, make_subgoal(new_or_modified, docstring="new or modified keys from {}".format( pprint(lval))), target_syntax.SMapUpdate(lval, k, v, update_value)) stm = syntax.SSeq(s1, s2) else: # Fallback rule: just compute a new value from scratch stm = syntax.SAssign( lval, make_subgoal(new_value, docstring="new value for {}".format(pprint(lval)))) return (stm, subgoals)
def build(self, cache, size): for e in cache.find(pool=RUNTIME_POOL, size=size - 1, type=INT): if not is_root(e): continue e2 = simplify_sum(e) if e != e2: yield self.check(e2, RUNTIME_POOL) # for e in cache.find(pool=RUNTIME_POOL, size=size-1): # if isinstance(e, EMapGet) and isinstance(e.map, EMakeMap2): # x = e.map.value.apply_to(e.key) # x._tag = True # yield self.check(x, RUNTIME_POOL) # [x] - ys for e in cache.find_collections(pool=RUNTIME_POOL, size=size - 1): if not is_root(e): continue if isinstance(e, EBinOp) and e.op == "-" and isinstance( e.e1, ESingleton): x = e.e1.e y = e.e2 x = ECond( EBinOp(x, BOp.In, y).with_type(BOOL), EEmptyList().with_type(e.type), e.e1).with_type(e.type) yield self.check(x, RUNTIME_POOL) elif isinstance(e, EUnaryOp) and e.op == UOp.Distinct: e = strip_EStateVar(e) m = EMakeMap2(e.e, mk_lambda(e.type.t, lambda x: T)).with_type( TMap(e.type.t, BOOL)) yield self.check(m, STATE_POOL) m = EStateVar(m).with_type(m.type) yield self.check(m, RUNTIME_POOL) x = EMapKeys(m).with_type(e.type) # x._tag = True yield self.check(x, RUNTIME_POOL) # # x in ys ----> (count x in ys) > 0 # for e in cache.find(pool=RUNTIME_POOL, type=BOOL, size=size-1): # if isinstance(e, EBinOp) and e.op == BOp.In: # for b in self.binders: # if b.type != e.e1.type: # continue # x = EGt( # EUnaryOp(UOp.Length, EFilter(e.e2, ELambda(b, EEq(e.e1, b))).with_type(e.e2.type)).with_type(INT), # ZERO) # x._tag = True # yield self.check(x, RUNTIME_POOL) for e in cache.find(pool=RUNTIME_POOL, size=size - 1): if not is_root(e): continue if (isinstance(e, EArgMin) or isinstance( e, EArgMax)) and isinstance(e.e, EBinOp) and e.e.op == "+": l = e.e.e1 r = e.e.e2 op = e.e.op f = lambda x: type(e)(x, e.f).with_type(e.type) ll = EStateVar(f(l.e)).with_type(e.type) if isinstance( l, EStateVar) else f(l) rr = EStateVar(f(r.e)).with_type(e.type) if isinstance( r, EStateVar) else f(r) x = ECond( EUnaryOp(UOp.Exists, l).with_type(BOOL), ECond( EUnaryOp(UOp.Exists, r).with_type(BOOL), f( EBinOp( ESingleton(ll).with_type(e.e.type), op, ESingleton(rr).with_type(e.e.type)).with_type( e.e.type)), ll).with_type(e.type), rr).with_type(e.type) # from cozy.solver import valid # assert valid(EEq(e, x), model_callback=print) x._tag = True yield self.check(x, RUNTIME_POOL) # is-last(x, l) for (sz1, sz2) in pick_to_sum(2, size - 1): for e1 in cache.find(pool=RUNTIME_POOL, size=sz1): if not is_root(e1): continue for e2 in cache.find_collections(pool=STATE_POOL, size=sz2, of=e1.type): if not is_root(e2): continue for b in self.binders: if b.type != e1.type: continue m = EMakeMap2( e2, mk_lambda( e2.type.t, lambda x: EUnaryOp( UOp.Length, EFilter( e2, mk_lambda(e2.type.t, lambda y: EEq( x, y))).with_type(e2.type)). with_type(INT))).with_type(TMap( e2.type.t, INT)) # filt = EFilter(e2, ELambda(b, EEq(e1, b))).with_type(e2.type) # x = EEq( # EUnaryOp(UOp.Length, filt).with_type(INT), # ONE) x = EGt( EMapGet(EStateVar(m).with_type(m.type), e1).with_type(INT), ONE) # x._tag = True yield self.check(x, RUNTIME_POOL) # histogram # for e in cache.find_collections(pool=STATE_POOL, size=size-1): # m = EMakeMap2(e, # mk_lambda(e.type.t, lambda x: # EUnaryOp(UOp.Length, EFilter(e, # mk_lambda(e.type.t, lambda y: EEq(x, y))).with_type(e.type)).with_type(INT))).with_type(TMap(e.type.t, INT)) # m._tag = True # yield self.check(m, STATE_POOL) # Fixup EFilter(\x -> ECond...) for e in cache.find_collections(pool=RUNTIME_POOL, size=size - 1): if not is_root(e): continue if isinstance(e, EFilter): for (_, x, r, _) in enumerate_fragments(e.p.body): if isinstance(x, ECond): lhs = EFilter( e.e, ELambda(e.p.arg, EAll([x.cond, r(x.then_branch) ]))).with_type(e.type) rhs = EFilter( e.e, ELambda(e.p.arg, EAll([ENot(x.cond), r(x.else_branch) ]))).with_type(e.type) union = EBinOp(lhs, "+", rhs).with_type(e.type) # yield self.check(lhs.p.body, RUNTIME_POOL) # yield self.check(rhs.p.body, RUNTIME_POOL) yield self.check(lhs, RUNTIME_POOL) yield self.check(rhs, RUNTIME_POOL) yield self.check(union, RUNTIME_POOL) # Try instantiating bound expressions for pool in (STATE_POOL, RUNTIME_POOL): for (sz1, sz2) in pick_to_sum(2, size - 1): for e1 in cache.find(pool=pool, size=sz1): if not is_root(e1): continue for v in free_vars(e1): if pool == RUNTIME_POOL: e1 = subst( strip_EStateVar(e1), { sv.id: EStateVar(sv).with_type(sv.type) for sv in self.state_vars if sv != v }) for e2 in cache.find(pool=pool, type=v.type, size=sz2): yield self.check(subst(e1, {v.id: e2}), pool) for (sz1, sz2) in pick_to_sum(2, size - 1): for e in cache.find(pool=RUNTIME_POOL, size=sz1): if not is_root(e): continue for x, pool in map_accelerate(e, self.state_vars, self.binders, self.args, cache, sz2): yield self.check(x, pool) if isinstance(e, EFilter) and not any(v in self.binders for v in free_vars(e)): for x, pool in accelerate_filter(e.e, e.p, self.state_vars, self.binders, self.args, cache, sz2): yield self.check(x, pool) for bag in cache.find_collections(pool=RUNTIME_POOL, size=size - 1): if not is_root(bag): continue for a in self.args: for v in self.state_vars: if is_collection(v.type) and v.type == a.type: v = EStateVar(v).with_type(v.type) cond = EBinOp(a, BOp.In, v).with_type(BOOL) yield self.check( EFilter(bag, mk_lambda(bag.type.t, lambda _: cond)).with_type( bag.type), RUNTIME_POOL) yield self.check( EFilter( bag, mk_lambda(bag.type.t, lambda _: ENot(cond))).with_type( bag.type), RUNTIME_POOL) if isinstance(bag, EFilter): if any(v not in self.state_vars for v in free_vars(bag.e)): continue # separate filter conds if isinstance(bag.p.body, EBinOp) and bag.p.body.op == BOp.And: p1 = ELambda(bag.p.arg, bag.p.body.e1) p2 = ELambda(bag.p.arg, bag.p.body.e2) f1 = EFilter(bag.e, p1).with_type(bag.type) f2 = EFilter(bag.e, p2).with_type(bag.type) f3 = EFilter(f1, p2).with_type(bag.type) f4 = EFilter(f2, p1).with_type(bag.type) yield self.check(f1, RUNTIME_POOL) yield self.check(f2, RUNTIME_POOL) yield self.check(f3, RUNTIME_POOL) yield self.check(f4, RUNTIME_POOL) # construct map lookups binder = bag.p.arg inf = infer_map_lookup(bag.p.body, binder, set(self.state_vars)) if inf: key_proj, key_lookup, remaining_filter = inf bag_binder = find_one( self.binders, lambda b: b.type == key_proj.type and b != binder) if bag_binder: m = strip_EStateVar( EMakeMap2( EMap(bag.e, ELambda(binder, key_proj)).with_type( type(bag.type)(key_proj.type)), ELambda( bag_binder, EFilter( bag.e, ELambda(binder, EEq(key_proj, bag_binder))).with_type( bag.type))).with_type( TMap( key_proj.type, bag.type))) assert not any(v in self.args for v in free_vars(m)) yield self.check(m, STATE_POOL) m = EStateVar(m).with_type(m.type) mg = EMapGet(m, key_lookup).with_type(bag.type) yield self.check(mg, RUNTIME_POOL) yield self.check( EFilter(mg, ELambda( binder, remaining_filter)).with_type(mg.type), RUNTIME_POOL) # for e in cache.find(size=size-1): # # F(xs +/- ys) ---> F(xs), F(ys) # for z in break_plus_minus(e): # if z != e: # # print("broke {} --> {}".format(pprint(e), pprint(z))) # yield z # # try reordering operations # for (_, e1, f) in enumerate_fragments(e): # if e1.type == e.type and e1 != e: # for (_, e2, g) in enumerate_fragments(e1): # if e2.type == e.type and e2 != e1: # # e == f(g(e2)) # yield g(f(e2)) yield from self.wrapped.build(cache, size)
def sketch_update( lval : syntax.Exp, old_value : syntax.Exp, new_value : syntax.Exp, ctx : [syntax.EVar], assumptions : [syntax.Exp] = [], invariants : [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]): """ Write code to update `lval` when it changes from `old_value` to `new_value`. Variables in `ctx` are assumed to be part of the data structure abstract state, and `assumptions` will be appended to all generated subgoals. This function returns a statement (code to update `lval`) and a list of subgoals (new queries that appear in the code). """ if valid(syntax.EImplies( syntax.EAll(itertools.chain(assumptions, invariants)), syntax.EEq(old_value, new_value))): return (syntax.SNoOp(), []) subgoals = [] new_value = strip_EStateVar(new_value) def make_subgoal(e, a=[], docstring=None): if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)): return e query_name = fresh_name("query") query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring) query_vars = [v for v in free_vars(query) if v not in ctx] query.args = [(arg.id, arg.type) for arg in query_vars] subgoals.append(query) return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type) def recurse(*args, **kwargs): (code, sgs) = sketch_update(*args, **kwargs) subgoals.extend(sgs) return code t = lval.type if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet): to_add = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="additions to {}".format(pprint(lval))) to_del = make_subgoal(syntax.EBinOp(old_value, "-", new_value).with_type(t), docstring="deletions from {}".format(pprint(lval))) v = fresh_var(t.elem_type) stm = syntax.seq([ syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])), syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))]) elif is_numeric(t) and update_numbers_with_deltas.value: change = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="delta for {}".format(pprint(lval))) stm = syntax.SAssign(lval, syntax.EBinOp(lval, "+", change).with_type(t)) elif isinstance(t, syntax.TTuple): get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions, invariants=invariants) for i in range(len(t.ts))]) elif isinstance(t, syntax.TRecord): get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(t.fields[i][1]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions, invariants=invariants) for i in range(len(t.fields))]) elif isinstance(t, syntax.TMap): k = fresh_var(lval.type.k) v = fresh_var(lval.type.v) key_bag = syntax.TBag(lval.type.k) old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag) new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag) # (1) exit set deleted_keys = syntax.EBinOp(old_keys, "-", new_keys).with_type(key_bag) s1 = syntax.SForEach(k, make_subgoal(deleted_keys, docstring="keys removed from {}".format(pprint(lval))), target_syntax.SMapDel(lval, k)) # (2) enter/mod set new_or_modified = target_syntax.EFilter(new_keys, syntax.ELambda(k, syntax.EAny([syntax.ENot(syntax.EIn(k, old_keys)), syntax.ENot(syntax.EEq(value_at(old_value, k), value_at(new_value, k)))]))).with_type(key_bag) update_value = recurse( v, value_at(old_value, k), value_at(new_value, k), ctx = ctx, assumptions = assumptions + [syntax.EIn(k, new_or_modified), syntax.EEq(v, value_at(old_value, k))], invariants = invariants) s2 = syntax.SForEach(k, make_subgoal(new_or_modified, docstring="new or modified keys from {}".format(pprint(lval))), target_syntax.SMapUpdate(lval, k, v, update_value)) stm = syntax.SSeq(s1, s2) else: # Fallback rule: just compute a new value from scratch stm = syntax.SAssign(lval, make_subgoal(new_value, docstring="new value for {}".format(pprint(lval)))) return (stm, subgoals)
def _add_subquery(self, sub_q: Query, used_by: Stm) -> Stm: """Add a query that helps maintain some other state. Parameters: sub_q - the specification of the helper query used_by - the statement that calls `sub_q` If a query already exists that is equivalent to `sub_q`, this method returns `used_by` rewritten to use the existing query and does not add the query to the implementation. Otherwise it returns `used_by` unchanged. """ with task("adding query", query=sub_q.name): sub_q = shallow_copy(sub_q) with task("checking whether we need more handle assumptions"): new_a = implicit_handle_assumptions( reachable_handles_at_method(self.spec, sub_q)) if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))): event("we do!") sub_q.assumptions = list( itertools.chain(sub_q.assumptions, new_a)) with task("repairing state var boundaries"): extra_available_state = [ e for v, e in self._concretization_functions ] sub_q.ret = repair_well_formedness( strip_EStateVar(sub_q.ret), self.context_for_method(sub_q), extra_available_state) with task("simplifying"): orig_a = sub_q.assumptions orig_a_size = sum(a.size() for a in sub_q.assumptions) orig_ret_size = sub_q.ret.size() sub_q.assumptions = tuple( simplify_or_ignore(a) for a in sub_q.assumptions) sub_q.ret = simplify(sub_q.ret) a_size = sum(a.size() for a in sub_q.assumptions) ret_size = sub_q.ret.size() event("|assumptions|: {} -> {}".format(orig_a_size, a_size)) event("|ret|: {} -> {}".format(orig_ret_size, ret_size)) if a_size > orig_a_size: print("NO, BAD SIMPLIFICATION") print("original") for a in orig_a: print(" - {}".format(pprint(a))) print("simplified") for a in sub_q.assumptions: print(" - {}".format(pprint(a))) assert False state_vars = self.abstract_state funcs = self.extern_funcs qq = find_one( self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs, assumptions=EAll(self.abstract_invariants))) if qq is not None: event("subgoal {} is equivalent to {}".format( sub_q.name, qq.name)) arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args] class Repl(BottomUpRewriter): def visit_ECall(self, e): args = tuple(self.visit(a) for a in e.args) if e.func == sub_q.name: args = tuple(args[idx] for idx in arg_reorder) return ECall(qq.name, args).with_type(e.type) else: return ECall(e.func, args).with_type(e.type) used_by = Repl().visit(used_by) else: self.add_query(sub_q) return used_by
def _simple_filter(xs : Exp, p : ELambda, args : {EVar}): """Assumes the body of p is already in negation normal form""" if p.body == ETRUE: yield xs return if p.body == EFALSE: yield EEmptyList().with_type(xs.type) return if isinstance(xs, EEmptyList): yield xs return yielded = False if isinstance(xs, ESingleton): yielded = True yield optimized_cond(p.apply_to(xs.e), xs, EEmptyList().with_type(xs.type)) if isinstance(p.body, EBinOp) and p.body.op == BOp.Or: for e1, e2 in itertools.permutations([p.body.e1, p.body.e2]): for r1 in _simple_filter(xs, ELambda(p.arg, e1), args): for r2 in _simple_filter(xs, ELambda(p.arg, EAll([e2, ENot(e1)])), args): yielded = True yield EBinOp(r1, "+", r2).with_type(xs.type) if isinstance(p.body, EBinOp) and p.body.op == BOp.And: for e1, e2 in itertools.permutations([p.body.e1, p.body.e2]): for r1 in _simple_filter(xs, ELambda(p.arg, e1), args): yielded = True yield from _simple_filter(r1, ELambda(p.arg, e2), args) if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(p)): yielded = True yield EStateVar(EFilter(xs.e, strip_EStateVar(p)).with_type(xs.type)).with_type(xs.type) if isinstance(xs, EMapGet) and isinstance(xs.map, EStateVar) and not any(v in args for v in free_vars(p)): for m in map_values_multi(xs.map.e, lambda ys: _simple_filter(ys, p, args)): yielded = True yield EMapGet(EStateVar(m).with_type(m.type), xs.key).with_type(xs.type) if isinstance(xs, EBinOp) and xs.op in ("+", "-"): for e1 in _simple_filter(xs.e1, p, args): for e2 in _simple_filter(xs.e2, p, args): yielded = True yield EBinOp(e1, xs.op, e2).with_type(xs.type) if isinstance(p.body, EBinOp) and p.body.op == "==": e1 = p.body.e1 e2 = p.body.e2 fvs2 = free_vars(e2) fvs1 = free_vars(e1) for (e1, fvs1), (e2, fvs2) in itertools.permutations([(e1, fvs1), (e2, fvs2)]): if p.arg in fvs1 and not any(a in fvs1 for a in args) and p.arg not in fvs2 and isinstance(xs, EStateVar): if e1 == p.arg: yield optimized_cond( optimized_in(e2, xs), ESingleton(e2).with_type(xs.type), EEmptyList().with_type(xs.type)) k = fresh_var(e1.type) e = EMapGet( EStateVar( EMakeMap2( EMap(xs.e, ELambda(p.arg, e1)), ELambda(k, EFilter(xs.e, ELambda(p.arg, EEq(e1, k)))))), e2) res = retypecheck(e) assert res yielded = True yield e if not yielded: yield EFilter(xs, p).with_type(xs.type)
def optimized_best(xs, keyfunc, op, args): argbest = EArgMin if op == "<" else EArgMax elem_type = xs.type.elem_type key_type = keyfunc.body.type if excluded_element(xs, args) is not None: bag, x = excluded_element(xs, args) if all(v not in args for v in free_vars(bag)): heap_type, make_heap = (TMinHeap, EMakeMinHeap) if op == "<" else (TMaxHeap, EMakeMaxHeap) bag = EStateVar(strip_EStateVar(bag)).with_type(bag.type) h = make_heap(bag.e, keyfunc).with_type(heap_type(elem_type, key_type)) for prev_min in optimized_best(bag.e, keyfunc, op, args=args): prev_min = EStateVar(prev_min).with_type(elem_type) heap_peek = EHeapPeek2(EStateVar(h).with_type(h.type)).with_type(elem_type) conds = [optimized_in(x, bag), optimized_eq(x, prev_min)] if isinstance(x, EUnaryOp) and x.op == UOp.The: conds = [optimized_exists(x.e)] + conds yield optimized_cond( EAll(conds), heap_peek, prev_min) if isinstance(xs, EEmptyList): yield construct_value(elem_type) if isinstance(xs, ESingleton): yield xs.e if isinstance(xs, EBinOp) and xs.op == "+": a_ex = optimized_exists(xs.e1) b_ex = optimized_exists(xs.e2) bag_type = TBag(xs.type.elem_type) for a in optimized_best(xs.e1, keyfunc, op, args=args): for b in optimized_best(xs.e2, keyfunc, op, args=args): yield optimized_cond(a_ex, optimized_cond(b_ex, argbest(EBinOp(ESingleton(a).with_type(bag_type), "+", ESingleton(b).with_type(bag_type)).with_type(bag_type), keyfunc).with_type(elem_type), a), optimized_cond(b_ex, b, construct_value(elem_type))) # if isinstance(xs.e1, EStateVar) or isinstance(xs.e2, EStateVar): # sv, other = (xs.e1, xs.e2) if isinstance(xs.e1, EStateVar) else (xs.e2, xs.e1) # sv_best = optimized_best(sv, keyfunc, op, args=args) # yield optimized_cond( # optimized_exists(sv), # argbest(EBinOp(ESingleton(sv_best).with_type(xs.type), "+", other).with_type(xs.type), keyfunc).with_type(elem_type), # optimized_best(other, keyfunc, op, args=args)) # else: # parts = break_sum(xs) # found = EFALSE # best = construct_value(elem_type) # for p in parts: # ex = optimized_exists(p) # best_here = optimized_best(p, keyfunc, op, args=args) # best = optimized_cond(found, # optimized_cond(ex, # optimized_cond(EBinOp(keyfunc.apply_to(best_here), op, keyfunc.apply_to(best)).with_type(BOOL), # best_here, # best), # best), # best_here) # found = EAny([found, ex]) # yield best if isinstance(xs, EMap): for b in optimized_best(xs.e, compose(keyfunc, xs.transform_function), op, args): yield optimized_cond(optimized_exists(xs.e), xs.transform_function.apply_to(b), construct_value(elem_type)) if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(keyfunc)): yield EStateVar(argbest(xs.e, keyfunc).with_type(elem_type)).with_type(elem_type) if isinstance(xs, ECond): for a in optimized_best(xs.then_branch, keyfunc, op, args=args): for b in optimized_best(xs.else_branch, keyfunc, op, args=args): yield optimized_cond(xs.cond, a, b) if isinstance(xs, EUnaryOp) and xs.op == UOp.Distinct: yield from optimized_best(xs.e, keyfunc, op, args=args) # if isinstance(xs, EFilter): # yield optimized_cond( # xs.predicate.apply_to(optimized_best(xs.e, keyfunc, op, args=args)), # optimized_best(xs.e, keyfunc, op, args=args), # argbest(xs, keyfunc).with_type(elem_type)) yield argbest(xs, keyfunc).with_type(elem_type)
def _try_optimize(e, context, pool): if not accelerate.value: return state_vars = [v for v, p in context.vars() if p == STATE_POOL] args = [v for v, p in context.vars() if p == RUNTIME_POOL] if pool == RUNTIME_POOL: if all(v in state_vars for v in free_vars(e)): nsv = strip_EStateVar(e) sv = EStateVar(nsv).with_type(e.type) yield _check(sv, context, RUNTIME_POOL) for ee, p in map_accelerate(e, context): if p == RUNTIME_POOL: yield _check(ee, context, p) if isinstance(e, EListGet) and e.index == ZERO: yield _check( EUnaryOp(UOp.The, e.e).with_type(e.type), context, RUNTIME_POOL) if isinstance(e, EArgMin) or isinstance(e, EArgMax): ee = optimized_best(e.e, e.f, "<" if isinstance(e, EArgMin) else ">", args=args) yield _check(ee, context, RUNTIME_POOL) if is_collection(e.type) and isinstance(e, EBinOp) and e.op == "-": ee = optimized_bag_difference(e.e1, e.e2) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EBinOp) and e.op == "===" and isinstance( e.e1.type, THandle): yield _check( EAll([ optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)), optimized_eq(optimized_val(e.e1), optimized_val(e.e2)).with_type(BOOL) ]), context, RUNTIME_POOL) if isinstance(e, EBinOp) and e.op == BOp.In: ee = optimized_in(e.e1, e.e2) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp) and e.op == UOp.Sum: for ee in optimized_sum(e.e, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp) and e.op == UOp.Empty: ee = optimized_empty(e.e) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp) and e.op == UOp.Exists: ee = optimized_exists(e.e) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp) and e.op == UOp.Length: ee = optimized_len(e.e) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp) and e.op == UOp.The: for ee in optimize_the(e.e, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EFilter): ee = optimize_filter_as_if_distinct(e.e, e.p, args=args) yield _check(ee, context, RUNTIME_POOL) if isinstance(e.e, EFilter): # try swizzle ee = EFilter(_simple_filter(e.e.e, e.p, args=args), e.e.p).with_type(e.type) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EMap): for ee in optimize_map(e.e, e.f, args=args): yield _check(ee, context, RUNTIME_POOL)
def map_accelerate(e, context): with task("map_accelerate", size=e.size()): if is_constant_time(e): event("skipping map lookup inference for constant-time exp: {}". format(pprint(e))) return @lru_cache() def make_binder(t): return fresh_var(t, hint="key") args = OrderedSet(v for (v, p) in context.vars() if p == RUNTIME_POOL) possible_keys = {} # type -> [exp] i = 0 stk = [e] while stk: event("exp {} / {}".format(i, e.size())) i += 1 arg = stk.pop() if isinstance(arg, tuple): stk.extend(arg) continue if not isinstance(arg, Exp): continue if isinstance(arg, ELambda): stk.append(arg.body) continue if context.legal_for(free_vars(arg)): # all the work happens here binder = make_binder(arg.type) value = replace( e, arg, binder, match=lambda e1, e2: type(e1) == type(e2) and e1.type == e2 .type and alpha_equivalent(e1, e2)) value = strip_EStateVar(value) # print(" ----> {}".format(pprint(value))) if any(v in args for v in free_vars(value)): event("not all args were eliminated") else: if arg.type not in possible_keys: l = [ reachable_values_of_type(sv, arg.type) for (sv, p) in context.vars() if p == STATE_POOL ] l = OrderedSet(x for x in l if not isinstance(x, EEmptyList)) possible_keys[arg.type] = l for keys in possible_keys[arg.type]: # print("reachable values of type {}: {}".format(pprint(arg.type), pprint(keys))) # for v in state_vars: # print(" {} : {}".format(pprint(v), pprint(v.type))) m = EMakeMap2(keys, ELambda(binder, value)).with_type( TMap(arg.type, e.type)) assert not any( v in args for v in free_vars(m)), "oops! {}; args={}".format( pprint(m), ", ".join(pprint(a) for a in args)) yield (m, STATE_POOL) mg = EMapGet(EStateVar(m).with_type(m.type), arg).with_type(e.type) # print(pprint(mg)) # mg._tag = True yield (mg, RUNTIME_POOL) if isinstance(arg, EStateVar): # do not visit state expressions continue num_with_args = 0 stk2 = list(arg.children()) while stk2: child = stk2.pop() if isinstance(child, tuple): stk.extend(child) continue if not isinstance(child, Exp): continue fvs = free_vars(child) if fvs & args: num_with_args += 1 if num_with_args >= 2: break if num_with_args < 2: stk.extend(arg.children()) else: event("refusing to visit children of {}".format(pprint(arg)))
def optimized_best(xs, keyfunc, op, args): argbest = EArgMin if op == "<" else EArgMax elem_type = xs.type.elem_type key_type = keyfunc.body.type if excluded_element(xs, args) is not None: bag, x = excluded_element(xs, args) if all(v not in args for v in free_vars(bag)): heap_type, make_heap = (TMinHeap, EMakeMinHeap) if op == "<" else ( TMaxHeap, EMakeMaxHeap) bag = EStateVar(strip_EStateVar(bag)).with_type(bag.type) h = make_heap(bag.e, keyfunc).with_type(heap_type(elem_type, key_type)) for prev_min in optimized_best(bag.e, keyfunc, op, args=args): prev_min = EStateVar(prev_min).with_type(elem_type) heap_peek = EHeapPeek2(EStateVar(h).with_type( h.type)).with_type(elem_type) conds = [optimized_in(x, bag), optimized_eq(x, prev_min)] if isinstance(x, EUnaryOp) and x.op == UOp.The: conds = [optimized_exists(x.e)] + conds yield optimized_cond(EAll(conds), heap_peek, prev_min) if isinstance(xs, EEmptyList): yield construct_value(elem_type) if isinstance(xs, ESingleton): yield xs.e if isinstance(xs, EBinOp) and xs.op == "+": a_ex = optimized_exists(xs.e1) b_ex = optimized_exists(xs.e2) bag_type = TBag(xs.type.elem_type) for a in optimized_best(xs.e1, keyfunc, op, args=args): for b in optimized_best(xs.e2, keyfunc, op, args=args): yield optimized_cond( a_ex, optimized_cond( b_ex, argbest( EBinOp( ESingleton(a).with_type(bag_type), "+", ESingleton(b).with_type(bag_type)).with_type( bag_type), keyfunc).with_type(elem_type), a), optimized_cond(b_ex, b, construct_value(elem_type))) # if isinstance(xs.e1, EStateVar) or isinstance(xs.e2, EStateVar): # sv, other = (xs.e1, xs.e2) if isinstance(xs.e1, EStateVar) else (xs.e2, xs.e1) # sv_best = optimized_best(sv, keyfunc, op, args=args) # yield optimized_cond( # optimized_exists(sv), # argbest(EBinOp(ESingleton(sv_best).with_type(xs.type), "+", other).with_type(xs.type), keyfunc).with_type(elem_type), # optimized_best(other, keyfunc, op, args=args)) # else: # parts = break_sum(xs) # found = EFALSE # best = construct_value(elem_type) # for p in parts: # ex = optimized_exists(p) # best_here = optimized_best(p, keyfunc, op, args=args) # best = optimized_cond(found, # optimized_cond(ex, # optimized_cond(EBinOp(keyfunc.apply_to(best_here), op, keyfunc.apply_to(best)).with_type(BOOL), # best_here, # best), # best), # best_here) # found = EAny([found, ex]) # yield best if isinstance(xs, EMap): for b in optimized_best(xs.e, compose(keyfunc, xs.transform_function), op, args): yield optimized_cond(optimized_exists(xs.e), xs.transform_function.apply_to(b), construct_value(elem_type)) if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(keyfunc)): yield EStateVar(argbest( xs.e, keyfunc).with_type(elem_type)).with_type(elem_type) if isinstance(xs, ECond): for a in optimized_best(xs.then_branch, keyfunc, op, args=args): for b in optimized_best(xs.else_branch, keyfunc, op, args=args): yield optimized_cond(xs.cond, a, b) if isinstance(xs, EUnaryOp) and xs.op == UOp.Distinct: yield from optimized_best(xs.e, keyfunc, op, args=args) # if isinstance(xs, EFilter): # yield optimized_cond( # xs.predicate.apply_to(optimized_best(xs.e, keyfunc, op, args=args)), # optimized_best(xs.e, keyfunc, op, args=args), # argbest(xs, keyfunc).with_type(elem_type)) yield argbest(xs, keyfunc).with_type(elem_type)
def _try_optimize(e: Exp, context: Context, pool: Pool): if not accelerate.value: return if pool != RUNTIME_POOL: return state_vars = [v for v, p in context.vars() if p == STATE_POOL] args = [v for v, p in context.vars() if p == RUNTIME_POOL] # --------------------------------------------------------------------- # "Rewrite schemes": these trigger on many different AST shapes # They are listed first because they are more powerful than the # specific rewrite rules below. if not free_vars(e) and not free_funcs(e): try: yield _check(uneval(e.type, eval(e, {})), context, RUNTIME_POOL) except NotImplementedError: print("Unable to evaluate {!r}".format(e)) if all(v in state_vars for v in free_vars(e)): nsv = strip_EStateVar(e) sv = EStateVar(nsv).with_type(e.type) yield _check(sv, context, RUNTIME_POOL) for ee in fold_into_map(e, context): yield _check(ee, context, pool) # --------------------------------------------------------------------- # "Rewrites": these trigger on specific AST nodes if isinstance(e, EBinOp): if e.op == "-" and is_collection(e.type): ee = optimized_bag_difference(e.e1, e.e2) yield _check(ee, context, RUNTIME_POOL) if e.op == "===" and isinstance(e.e1.type, THandle): yield _check( EAll([ optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)), optimized_eq(optimized_val(e.e1), optimized_val(e.e2)).with_type(BOOL) ]), context, RUNTIME_POOL) if e.op == BOp.In: ee = optimized_in(e.e1, e.e2) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, ECond): yield _check(optimized_cond(e.cond, e.then_branch, e.else_branch), context, RUNTIME_POOL) if isinstance(e, EGetField): for ee in optimized_get_field(e.e, e.field_name, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EListGet) and e.index == ZERO: for res in optimized_the(e.e, args): yield _check(res, context, RUNTIME_POOL) if isinstance(e, EListGet) and isinstance(e.e, ECond): yield optimized_cond( e.e.cond, EListGet(e.e.then_branch, e.index).with_type(e.type), EListGet(e.e.else_branch, e.index).with_type(e.type)) from cozy.structures.treemultiset import ETreeMultisetElems, ETreeMultisetPeek if isinstance(e, EListGet) and isinstance(e.e, ETreeMultisetElems): yield ETreeMultisetPeek(e.e.e, e.index).with_type(e.type) if isinstance(e, EMapGet): ee = inline_mapget(e, context) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp): if e.op == UOp.Sum: for ee in optimized_sum(e.e, args): yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Length: ee = optimized_len(e.e) yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Empty: ee = optimized_empty(e.e) yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Exists: ee = optimized_exists(e.e) yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Distinct: for ee in optimized_distinct(e.e, args): yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.The: for ee in optimized_the(e.e, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EArgMin) or isinstance(e, EArgMax): for ee in optimized_best(e.e, e.key_function, "<" if isinstance(e, EArgMin) else ">", args=args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EFilter): for ee in optimized_filter(e.e, e.predicate, args=args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EMap): for ee in optimized_map(e.e, e.transform_function, args=args): yield _check(ee, context, RUNTIME_POOL) from cozy.syntax import ESorted from cozy.structures.treemultiset import EMakeMaxTreeMultiset, TMaxTreeMultiset, EMakeMinTreeMultiset, TMinTreeMultiset, ETreeMultisetElems target = e if isinstance(target, ESorted) and isinstance(target.e, EStateVar): e_max = EMakeMaxTreeMultiset(target.e.e).with_type( TMaxTreeMultiset(target.e.e.type.elem_type)) e_min = EMakeMinTreeMultiset(target.e.e).with_type( TMinTreeMultiset(target.e.e.type.elem_type)) ee = optimized_cond( target.asc, ETreeMultisetElems(EStateVar(e_min).with_type( e_min.type)).with_type(target.type), ETreeMultisetElems(EStateVar(e_max).with_type( e_max.type)).with_type(target.type)) yield _check(ee, context, RUNTIME_POOL)
def _try_optimize(e, context, pool): if not accelerate.value: return state_vars = [v for v, p in context.vars() if p == STATE_POOL] args = [v for v, p in context.vars() if p == RUNTIME_POOL] if pool == RUNTIME_POOL: if not free_vars(e) and not free_funcs(e): try: yield _check(uneval(e.type, eval(e, {})), context, RUNTIME_POOL) except NotImplementedError: print("Unable to evaluate {!r}".format(e)) if all(v in state_vars for v in free_vars(e)): nsv = strip_EStateVar(e) sv = EStateVar(nsv).with_type(e.type) yield _check(sv, context, RUNTIME_POOL) for ee in fold_into_map(e, context): yield _check(ee, context, pool) if isinstance(e, EListGet) and e.index == ZERO: for res in optimize_the(e.e, args): yield _check(res, context, RUNTIME_POOL) if isinstance(e, EArgMin) or isinstance(e, EArgMax): for ee in optimized_best(e.e, e.f, "<" if isinstance(e, EArgMin) else ">", args=args): yield _check(ee, context, RUNTIME_POOL) if is_collection(e.type) and isinstance(e, EBinOp) and e.op == "-": ee = optimized_bag_difference(e.e1, e.e2) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EBinOp) and e.op == "===" and isinstance( e.e1.type, THandle): yield _check( EAll([ optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)), optimized_eq(optimized_val(e.e1), optimized_val(e.e2)).with_type(BOOL) ]), context, RUNTIME_POOL) if isinstance(e, ECond): yield _check(optimized_cond(e.cond, e.then_branch, e.else_branch), context, RUNTIME_POOL) if isinstance(e, EGetField): for ee in optimized_get_field(e.e, e.f, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EBinOp) and e.op == BOp.In: ee = optimized_in(e.e1, e.e2) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp) and e.op == UOp.Sum: for ee in optimized_sum(e.e, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp) and e.op == UOp.Empty: ee = optimized_empty(e.e) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp) and e.op == UOp.Exists: ee = optimized_exists(e.e) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp) and e.op == UOp.Length: ee = optimized_len(e.e) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp) and e.op == UOp.The: for ee in optimize_the(e.e, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EFilter): ee = optimize_filter_as_if_distinct(e.e, e.p, args=args) yield _check(ee, context, RUNTIME_POOL) if isinstance(e.e, EFilter): # try swizzle ee = EFilter(_simple_filter(e.e.e, e.p, args=args), e.e.p).with_type(e.type) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EMap): for ee in optimize_map(e.e, e.f, args=args): yield _check(ee, context, RUNTIME_POOL)
def can_elim_vars(spec: Exp, assumptions: Exp, vs: [EVar]): spec = strip_EStateVar(spec) sub = {v.id: fresh_var(v.type) for v in vs} return valid( EImplies(EAll([assumptions, subst(assumptions, sub)]), EEq(spec, subst(spec, sub))))
def is_lenof(e, xs): return alpha_equivalent(strip_EStateVar(e), ELen(strip_EStateVar(xs)))
def _add_subquery(self, sub_q : Query, used_by : Stm) -> Stm: """Add a query that helps maintain some other state. Parameters: sub_q - the specification of the helper query used_by - the statement that calls `sub_q` If a query already exists that is equivalent to `sub_q`, this method returns `used_by` rewritten to use the existing query and does not add the query to the implementation. Otherwise it returns `used_by` unchanged. """ with task("adding query", query=sub_q.name): sub_q = shallow_copy(sub_q) with task("checking whether we need more handle assumptions"): new_a = implicit_handle_assumptions( reachable_handles_at_method(self.spec, sub_q)) if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))): event("we do!") sub_q.assumptions = list(itertools.chain(sub_q.assumptions, new_a)) with task("repairing state var boundaries"): extra_available_state = [e for v, e in self._concretization_functions] sub_q.ret = repair_well_formedness( strip_EStateVar(sub_q.ret), self.context_for_method(sub_q), extra_available_state) with task("simplifying"): orig_a = sub_q.assumptions orig_a_size = sum(a.size() for a in sub_q.assumptions) orig_ret_size = sub_q.ret.size() sub_q.assumptions = tuple(simplify_or_ignore(a) for a in sub_q.assumptions) sub_q.ret = simplify(sub_q.ret) a_size = sum(a.size() for a in sub_q.assumptions) ret_size = sub_q.ret.size() event("|assumptions|: {} -> {}".format(orig_a_size, a_size)) event("|ret|: {} -> {}".format(orig_ret_size, ret_size)) if a_size > orig_a_size: print("NO, BAD SIMPLIFICATION") print("original") for a in orig_a: print(" - {}".format(pprint(a))) print("simplified") for a in sub_q.assumptions: print(" - {}".format(pprint(a))) assert False state_vars = self.abstract_state funcs = self.extern_funcs qq = find_one(self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs, assumptions=EAll(self.abstract_invariants))) if qq is not None: event("subgoal {} is equivalent to {}".format(sub_q.name, qq.name)) arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args] class Repl(BottomUpRewriter): def visit_ECall(self, e): args = tuple(self.visit(a) for a in e.args) if e.func == sub_q.name: args = tuple(args[idx] for idx in arg_reorder) return ECall(qq.name, args).with_type(e.type) else: return ECall(e.func, args).with_type(e.type) used_by = Repl().visit(used_by) else: self.add_query(sub_q) return used_by
def _try_optimize(e : Exp, context : Context, pool : Pool): if not accelerate.value: return if pool != RUNTIME_POOL: return state_vars = [v for v, p in context.vars() if p == STATE_POOL] args = [v for v, p in context.vars() if p == RUNTIME_POOL] # --------------------------------------------------------------------- # "Rewrite schemes": these trigger on many different AST shapes # They are listed first because they are more powerful than the # specific rewrite rules below. if not free_vars(e) and not free_funcs(e): try: yield _check(uneval(e.type, eval(e, {})), context, RUNTIME_POOL) except NotImplementedError: print("Unable to evaluate {!r}".format(e)) if all(v in state_vars for v in free_vars(e)): nsv = strip_EStateVar(e) sv = EStateVar(nsv).with_type(e.type) yield _check(sv, context, RUNTIME_POOL) for ee in fold_into_map(e, context): yield _check(ee, context, pool) # --------------------------------------------------------------------- # "Rewrites": these trigger on specific AST nodes if isinstance(e, EBinOp): if e.op == "-" and is_collection(e.type): ee = optimized_bag_difference(e.e1, e.e2) yield _check(ee, context, RUNTIME_POOL) if e.op == "===" and isinstance(e.e1.type, THandle): yield _check(EAll([ optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)), optimized_eq(optimized_val(e.e1), optimized_val(e.e2)).with_type(BOOL)]), context, RUNTIME_POOL) if e.op == BOp.In: ee = optimized_in(e.e1, e.e2) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, ECond): yield _check(optimized_cond(e.cond, e.then_branch, e.else_branch), context, RUNTIME_POOL) if isinstance(e, EGetField): for ee in optimized_get_field(e.e, e.field_name, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EListGet) and e.index == ZERO: for res in optimized_the(e.e, args): yield _check(res, context, RUNTIME_POOL) if isinstance(e, EListGet) and isinstance(e.e, ECond): yield optimized_cond(e.e.cond, EListGet(e.e.then_branch, e.index).with_type(e.type), EListGet(e.e.else_branch, e.index).with_type(e.type)) from cozy.structures.treemultiset import ETreeMultisetElems, ETreeMultisetPeek if isinstance(e, EListGet) and isinstance(e.e, ETreeMultisetElems): yield ETreeMultisetPeek(e.e.e, e.index).with_type(e.type) if isinstance(e, EMapGet): ee = inline_mapget(e, context) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp): if e.op == UOp.Sum: for ee in optimized_sum(e.e, args): yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Length: ee = optimized_len(e.e) yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Empty: ee = optimized_empty(e.e) yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Exists: ee = optimized_exists(e.e) yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Distinct: for ee in optimized_distinct(e.e, args): yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.The: for ee in optimized_the(e.e, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EArgMin) or isinstance(e, EArgMax): for ee in optimized_best(e.e, e.key_function, "<" if isinstance(e, EArgMin) else ">", args=args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EFilter): for ee in optimized_filter(e.e, e.predicate, args=args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EMap): for ee in optimized_map(e.e, e.transform_function, args=args): yield _check(ee, context, RUNTIME_POOL) from cozy.syntax import ESorted from cozy.structures.treemultiset import EMakeMaxTreeMultiset, TMaxTreeMultiset, EMakeMinTreeMultiset, TMinTreeMultiset, ETreeMultisetElems target = e if isinstance(target, ESorted) and isinstance(target.e, EStateVar): e_max = EMakeMaxTreeMultiset(target.e.e).with_type(TMaxTreeMultiset(target.e.e.type.elem_type)) e_min = EMakeMinTreeMultiset(target.e.e).with_type(TMinTreeMultiset(target.e.e.type.elem_type)) ee = optimized_cond(target.asc, ETreeMultisetElems(EStateVar(e_min).with_type(e_min.type)).with_type(target.type), ETreeMultisetElems(EStateVar(e_max).with_type(e_max.type)).with_type(target.type)) yield _check(ee, context, RUNTIME_POOL)
def is_lenof(e, xs): return alpha_equivalent(strip_EStateVar(e), ELen(strip_EStateVar(xs)))
def optimized_best(xs, keyfunc, op, args): argbest = EArgMin if op == "<" else EArgMax elem_type = xs.type.t key_type = keyfunc.body.type if excluded_element(xs, args) is not None: bag, x = excluded_element(xs, args) if all(v not in args for v in free_vars(bag)): heap_type, make_heap = (TMinHeap, EMakeMinHeap) if op == "<" else ( TMaxHeap, EMakeMaxHeap) bag = EStateVar(strip_EStateVar(bag)).with_type(bag.type) h = make_heap(bag.e, keyfunc).with_type(heap_type(elem_type, key_type)) prev_min = EStateVar( optimized_best( bag.e, keyfunc, op, args=args).with_type(elem_type)).with_type(elem_type) heap_peek = EHeapPeek2( EStateVar(h).with_type(h.type), EStateVar(ELen(bag.e)).with_type(INT)).with_type(elem_type) return optimized_cond( EAll([optimized_in(x, bag), optimized_eq(x, prev_min)]), heap_peek, prev_min) if isinstance(xs, EEmptyList): return construct_value(elem_type) if isinstance(xs, ESingleton): return xs.e if isinstance(xs, EBinOp) and xs.op == "+": if isinstance(xs.e1, EStateVar) or isinstance(xs.e2, EStateVar): sv, other = (xs.e1, xs.e2) if isinstance(xs.e1, EStateVar) else (xs.e2, xs.e1) sv_best = optimized_best(sv, keyfunc, op, args=args) return optimized_cond( optimized_exists(sv), argbest( EBinOp(ESingleton(sv_best).with_type(xs.type), "+", other).with_type(xs.type), keyfunc).with_type(elem_type), optimized_best(other, keyfunc, op, args=args)) else: parts = break_sum(xs) found = F best = construct_value(elem_type) for p in parts: ex = optimized_exists(p) best_here = optimized_best(p, keyfunc, op, args=args) best = optimized_cond( found, optimized_cond( ex, optimized_cond( EBinOp(keyfunc.apply_to(best_here), op, keyfunc.apply_to(best)).with_type(BOOL), best_here, best), best), best_here) found = EAny([found, ex]) return best if isinstance(xs, EMap): return xs.f.apply_to( optimized_best(xs.e, compose(keyfunc, xs.f), op, args)) if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(keyfunc)): return EStateVar(argbest( xs.e, keyfunc).with_type(elem_type)).with_type(elem_type) if isinstance(xs, ECond): return optimized_cond( xs.cond, optimized_best(xs.then_branch, keyfunc, op, args=args), optimized_best(xs.else_branch, keyfunc, op, args=args)) if isinstance(xs, EUnaryOp) and xs.op == UOp.Distinct: return optimized_best(xs.e, keyfunc, op, args=args) # if isinstance(xs, EFilter): # return optimized_cond( # xs.p.apply_to(optimized_best(xs.e, keyfunc, op, args=args)), # optimized_best(xs.e, keyfunc, op, args=args), # argbest(xs, keyfunc).with_type(elem_type)) return argbest(xs, keyfunc).with_type(elem_type)
def _simple_filter(xs: Exp, p: ELambda, args: {EVar}): """Assumes the body of p is already in negation normal form""" if p.body == ETRUE: yield xs return if p.body == EFALSE: yield EEmptyList().with_type(xs.type) return if isinstance(xs, EEmptyList): yield xs return yielded = False if isinstance(xs, ESingleton): yielded = True yield optimized_cond(p.apply_to(xs.e), xs, EEmptyList().with_type(xs.type)) if isinstance(p.body, EBinOp) and p.body.op == BOp.Or: for e1, e2 in itertools.permutations([p.body.e1, p.body.e2]): for r1 in _simple_filter(xs, ELambda(p.arg, e1), args): for r2 in _simple_filter(xs, ELambda(p.arg, EAll([e2, ENot(e1)])), args): yielded = True yield EBinOp(r1, "+", r2).with_type(xs.type) if isinstance(p.body, EBinOp) and p.body.op == BOp.And: for e1, e2 in itertools.permutations([p.body.e1, p.body.e2]): for r1 in _simple_filter(xs, ELambda(p.arg, e1), args): yielded = True yield from _simple_filter(r1, ELambda(p.arg, e2), args) if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(p)): yielded = True yield EStateVar(EFilter(xs.e, strip_EStateVar(p)).with_type( xs.type)).with_type(xs.type) if isinstance(xs, EMapGet) and isinstance( xs.map, EStateVar) and not any(v in args for v in free_vars(p)): for m in map_values_multi(xs.map.e, lambda ys: _simple_filter(ys, p, args)): yielded = True yield EMapGet(EStateVar(m).with_type(m.type), xs.key).with_type(xs.type) if isinstance(xs, EBinOp) and xs.op in ("+", "-"): for e1 in _simple_filter(xs.e1, p, args): for e2 in _simple_filter(xs.e2, p, args): yielded = True yield EBinOp(e1, xs.op, e2).with_type(xs.type) if isinstance(p.body, EBinOp) and p.body.op == "==": e1 = p.body.e1 e2 = p.body.e2 fvs2 = free_vars(e2) fvs1 = free_vars(e1) for (e1, fvs1), (e2, fvs2) in itertools.permutations([(e1, fvs1), (e2, fvs2)]): if p.arg in fvs1 and not any( a in fvs1 for a in args) and p.arg not in fvs2 and isinstance( xs, EStateVar): if e1 == p.arg: yield optimized_cond(optimized_in(e2, xs), ESingleton(e2).with_type(xs.type), EEmptyList().with_type(xs.type)) k = fresh_var(e1.type) e = EMapGet( EStateVar( EMakeMap2( EMap(xs.e, ELambda(p.arg, e1)), ELambda(k, EFilter(xs.e, ELambda(p.arg, EEq(e1, k)))))), e2) res = retypecheck(e) assert res yielded = True yield e if not yielded: yield EFilter(xs, p).with_type(xs.type)