def excluded_element(xs, args): if isinstance(xs, EMap): res = excluded_element(xs.e, args) if res is not None: bag, x = res return (EMap(bag, xs.transform_function).with_type(xs.type), xs.transform_function.apply_to(x)) if isinstance(xs, EFilter): arg = xs.predicate.arg e = xs.predicate.body if isinstance(e, EUnaryOp) and e.op == UOp.Not and isinstance(e.e, EBinOp) and e.e.op == "==": e = EBinOp(e.e.e1, "!=", e.e.e2).with_type(BOOL) if isinstance(e, EBinOp) and e.op == "!=": arg_left = arg in free_vars(e.e1) arg_right = arg in free_vars(e.e2) if arg_left and not arg_right: return (xs.e, EUnaryOp(UOp.The, find_one_or_fail(_simple_filter(xs.e, ELambda(arg, EEq(e.e1, e.e2)), args=args))).with_type(xs.type.elem_type)) if arg_right and not arg_left: return (xs.e, EUnaryOp(UOp.The, find_one_or_fail(_simple_filter(xs.e, ELambda(arg, EEq(e.e1, e.e2)), args=args))).with_type(xs.type.elem_type)) return (xs.e, EFirst(find_one_or_fail(_simple_filter(xs.e, ELambda(xs.predicate.arg, ENot(xs.predicate.body)), args)))) if isinstance(xs, EBinOp) and xs.op == "-" and isinstance(xs.e2, ESingleton): return (xs.e1, xs.e2.e) if isinstance(xs, EBinOp) and xs.op == "-": return (xs.e1, EFirst(xs.e2)) if isinstance(xs, EBinOp) and xs.op == "+" and isinstance(xs.e1, EListSlice) and isinstance(xs.e2, EListSlice): for e1, e2 in [(xs.e1, xs.e2), (xs.e2, xs.e1)]: if e1.e == e2.e and e1.start == ZERO and e2.start == EBinOp(e1.end, "+", ONE) and is_lenof(e2.end, e2.e): return (e1.e, EListGet(e1.e, e1.end).with_type(xs.type.elem_type)) return None
def test_edeepin(self): ht = THandle("H", INT) hb = EVar("hb").with_type(TBag(ht)) h = fresh_var(ht, omit=free_vars(hb)) arg = fresh_var(ht, omit=free_vars(h)|free_vars(hb)) f1 = EDeepIn(h, hb) f2 = EUnaryOp(UOp.Any, EMap(hb, ELambda(arg, EBinOp(arg, "===", h).with_type(BOOL))).with_type(BOOL_BAG)).with_type(BOOL) self.assert_same(f1, f2)
def make_subgoal(e, a=[], docstring=None): if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)): return e query_name = fresh_name("query") query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring) query_vars = [v for v in free_vars(query) if v not in ctx] query.args = [(arg.id, arg.type) for arg in query_vars] subgoals.append(query) return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)
def retypecheck(exp, env : {str : syntax.Type} = None, fenv = None): """Add or fix the .type annotations on the given tree. Returns True or False to indicate success or failure. If it fails, it prints type errors to stdout. Unlike `typecheck`, this procedure attempts to guess the types of variables and functions in the expression using their .type annotations. The `env` dictionary overrides these guesses and forces variables to be annotated with a particular type. """ if env is None: env = { v.id:v.type for v in free_vars(exp) } if fenv is not None: fenv = { f : (tuple(ty.arg_types), ty.ret_type) for f, ty in fenv.items() } else: fenv = { } for e in all_exps(exp): if isinstance(e, syntax.EEnumEntry): env[e.name] = e.type if isinstance(e, syntax.ECall): if e.func not in fenv: fenv[e.func] = (tuple(arg.type for arg in e.args), e.type) errs = typecheck(exp, env=env, fenv=fenv) if errs: print("errors") for e in errs: print(" --> {}".format(e)) return not errs
def optimized_in(x, xs): if isinstance(xs, EStateVar): m = EMakeMap2(xs.e, mk_lambda(x.type, lambda x: ETRUE)).with_type(TMap(x.type, BOOL)) m = EStateVar(m).with_type(m.type) return EHasKey(m, x).with_type(BOOL) elif isinstance(xs, EBinOp) and xs.op == "-" and isinstance(xs.e1, EStateVar) and isinstance(xs.e2, ESingleton): return optimized_cond(optimized_eq(x, xs.e2.e), EGt(optimized_count(x, xs.e1), ONE), optimized_in(x, xs.e1)).with_type(BOOL) elif isinstance(xs, EBinOp) and xs.op == "-": return EGt(optimized_count(x, xs.e1), optimized_count(x, xs.e2)) elif isinstance(xs, EBinOp) and xs.op == "+": return EAny([ optimized_in(x, xs.e1), optimized_in(x, xs.e2)]) elif isinstance(xs, ECond): return optimized_cond(xs.cond, optimized_in(x, xs.then_branch), optimized_in(x, xs.else_branch)).with_type(BOOL) elif isinstance(xs, EFilter): return EAll([xs.predicate.apply_to(x), optimized_in(x, xs.e)]) elif isinstance(xs, EMap) and xs.transform_function.arg not in free_vars(x): return optimized_any_matches(xs.e, ELambda(xs.transform_function.arg, optimized_eq(xs.transform_function.body, x))) elif isinstance(xs, ESingleton): return optimized_eq(x, xs.e) elif isinstance(xs, EEmptyList): return EFALSE else: return EBinOp(x, BOp.In, xs).with_type(BOOL)
def deps(thing): if isinstance(thing, str): yield from free_vars(self.query_impls[thing]) elif isinstance(thing, EVar): for op in self.op_specs: yield self.updates[(thing, op.name)] elif isinstance(thing, Stm): yield from self.queries_used_by(thing) else: raise ValueError(repr(thing))
def generalize(self, fvs): if self.var not in fvs: return self._parent.generalize(fvs) new_parent = self._parent.generalize(fvs - { self.var } | free_vars(self.bag)) if new_parent is self._parent: return self return UnderBinder( new_parent, self.var, self.bag, self.pool)
def fold_into_map(e, context): fvs = free_vars(e) state_vars = [v for v, p in context.vars() if p == STATE_POOL] for subexp, subcontext, subpool in all_subexpressions_with_context_information(e, context, RUNTIME_POOL): if isinstance(subexp, EMapGet) and isinstance(subexp.map, EStateVar): map = subexp.map.e key = subexp.key key_type = key.type value_type = subexp.type # e is of the form `... EStateVar(map)[key] ...` arg = fresh_var(subexp.type, omit=fvs) func = ELambda(arg, replace( e, context, RUNTIME_POOL, subexp, subcontext, subpool, arg)) if not all(v in state_vars for v in free_vars(func)): continue func = strip_EStateVar(func) new_map = map_values(map, func.apply_to) yield EMapGet(EStateVar(new_map).with_type(new_map.type), key).with_type(e.type)
def fold_into_map(e, context): fvs = free_vars(e) state_vars = [v for v, p in context.vars() if p == STATE_POOL] for subexp, subcontext, subpool in shred(e, context, RUNTIME_POOL): if isinstance(subexp, EMapGet) and isinstance(subexp.map, EStateVar): map = subexp.map.e key = subexp.key key_type = key.type value_type = subexp.type # e is of the form `... EStateVar(map)[key] ...` arg = fresh_var(subexp.type, omit=fvs) func = ELambda( arg, replace(e, context, RUNTIME_POOL, subexp, subcontext, subpool, arg)) if not all(v in state_vars for v in free_vars(func)): continue func = strip_EStateVar(func) new_map = map_values(map, func.apply_to) yield EMapGet(EStateVar(new_map).with_type(new_map.type), key).with_type(e.type)
def add_query(self, q : Query): """ Given a query in terms of abstract state, add an initial concrete implementation. """ print("Adding query {}...".format(q.name)) self.query_specs.append(q) fvs = free_vars(q) # initial rep qargs = set(EVar(a).with_type(t) for (a, t) in q.args) rep, ret = tease_apart(wrap_naked_statevars(q.ret, self.abstract_state)) self.set_impl(q, rep, ret)
def adapt(self, e : Exp, ctx, e_fvs=None) -> Exp: if self == ctx: return e if e_fvs is None: e_fvs = free_vars(e) if isinstance(ctx, UnderBinder): if ctx.var not in e_fvs: return self.adapt(e, ctx.parent(), e_fvs=e_fvs) if alpha_equivalent(self.bag, self._parent.adapt(ctx.bag, ctx._parent)): e = self._parent.adapt(e, ctx._parent, e_fvs=e_fvs) return subst(e, { ctx.var.id : self.var }) return self._parent.adapt(e, ctx, e_fvs=e_fvs)
def __init__(self, ctx: SynthCtx, state: [EVar], assumptions: [Exp], q: Query, k, hints: [Exp] = [], examples: [dict] = None): super().__init__() self.ctx = ctx self.state = state self.assumptions = assumptions self.q = shallow_copy(q) assert all( v in state for v in free_vars(q) ), "Oops, query looks malformed due to {}:\n{}\nfree_vars({})".format( [v for v in free_vars(q) if v not in state], pprint(q), repr(q)) q.ret = wrap_naked_statevars(q.ret, OrderedSet(state)) self.hints = hints self.examples = examples self.k = k
def cleanup(self): """ Remove unused state, queries, and updates. """ def deps(thing): if isinstance(thing, str): yield from free_vars(self.query_impls[thing]) elif isinstance(thing, EVar): for op in self.op_specs: yield self.updates[(thing, op.name)] elif isinstance(thing, Stm): yield from self.queries_used_by(thing) else: raise ValueError(repr(thing)) g = DirectedGraph(nodes=itertools.chain( self.query_impls.keys(), (v for v, _ in self._concretization_functions), self.updates.values()), successors=deps) roots = [ q.name for q in self.query_specs if q.visibility == Visibility.Public ] roots.extend( itertools.chain(*[ self.queries_used_by(code) for ((ht, op_name), code) in self.handle_updates.items() ])) queries_to_keep = set(q for q in g.reachable_nodes(roots) if isinstance(q, str)) # remove old specs for q in list(self.query_specs): if q.name not in queries_to_keep: self.query_specs.remove(q) # remove old implementations for qname in list(self.query_impls.keys()): if qname not in queries_to_keep: del self.query_impls[qname] # remove old state vars self._concretization_functions = [ v for v in self._concretization_functions if any(v[0] in free_vars(q) for q in self.query_impls.values()) ] # remove old method implementations for k in list(self.updates.keys()): v, op_name = k if v not in [var for (var, exp) in self._concretization_functions]: del self.updates[k]
def _simple_filter(xs, p, args): if p.body == T: return xs if p.body == F: return EEmptyList().with_type(xs.type) if isinstance(xs, EEmptyList): return xs if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(p)): return EStateVar(EFilter(xs.e, strip_EStateVar(p)).with_type( xs.type)).with_type(xs.type) if isinstance(xs, EMapGet) and isinstance( xs.map, EStateVar) and not any(v in args for v in free_vars(p)): m = map_values(xs.map.e, lambda ys: _simple_filter(ys, p, args)) return EMapGet(EStateVar(m).with_type(m.type), xs.key).with_type(xs.type) if isinstance(xs, EBinOp) and xs.op == "+": return EBinOp(_simple_filter(xs.e1, p, args), "+", _simple_filter(xs.e2, p, args)).with_type(xs.type) if isinstance(xs, EBinOp) and xs.op == "-": return EBinOp(_simple_filter(xs.e1, p, args), "-", _simple_filter(xs.e2, p, args)).with_type(xs.type) if isinstance(p.body, EBinOp) and p.body.op == "==": fvs2 = free_vars(p.body.e2) fvs1 = free_vars(p.body.e1) if p.arg in fvs1 and not any( a in fvs1 for a in args) and p.arg not in fvs2 and isinstance( xs, EStateVar): k = fresh_var(p.body.e1.type) e = EMapGet( EStateVar( EMakeMap2( EMap(xs.e, ELambda(p.arg, p.body.e1)), ELambda( k, EFilter(xs.e, ELambda(p.arg, EEq(p.body.e1, k)))))), p.body.e2) res = retypecheck(e) assert res return e return EFilter(xs, p).with_type(xs.type)
def adapt(self, e: Exp, ctx, e_fvs=None) -> Exp: if self == ctx: return e if e_fvs is None: e_fvs = free_vars(e) if isinstance(ctx, UnderBinder): if ctx.var not in e_fvs: return self.adapt(e, ctx.parent(), e_fvs=e_fvs) if alpha_equivalent(self.bag, self._parent.adapt(ctx.bag, ctx._parent)): e = self._parent.adapt(e, ctx._parent, e_fvs=e_fvs) return subst(e, {ctx.var.id: self.var}) return self._parent.adapt(e, ctx, e_fvs=e_fvs)
def retypecheck(exp, env=None): from cozy.syntax_tools import free_vars if env is None: env = {v.id: v.type for v in free_vars(exp)} for e in all_exps(exp): if isinstance(e, syntax.EEnumEntry): env[e.name] = e.type errs = typecheck(exp, env=env) if errs: print("errors") for e in errs: print(" --> {}".format(e)) return not errs
def __init__(self, ctx: SynthCtx, state: [EVar], assumptions: [Exp], q: Query, k, hints: [Exp] = [], funcs: {str: TFunc} = {}): assert all( v in state for v in free_vars(q) ), "Oops, query looks malformed due to {}:\n{}\nfree_vars({})".format( [v for v in free_vars(q) if v not in state], pprint(q), repr(q)) super().__init__() self.ctx = ctx self.state = state self.assumptions = assumptions q = shallow_copy(q) q.ret = wrap_naked_statevars(q.ret, OrderedSet(state)) self.q = q self.hints = hints self.k = k self.funcs = OrderedDict(funcs)
def _satisfy(e, solver, assumptions): """ :param e: expression to test sat :param solver: the default solver :param assumptions: a list of expressions that are assumed true Heuristic to decide whether e is satisfiable quickly. it is a partial procedure: the possible outputs are a satisfying assignment or None (indicating unknown) it is allowed to indicate unknown with an arbitrary exception (in which case falling back to the symbolic solver is a reasonable choice) """ if isinstance(e, EUnaryOp) and e.op == "not" and isinstance( e.e, EBinOp) and e.e.op == "==": e1 = e.e.e1 e2 = e.e.e2 if isinstance(e1, EFlatMap) and isinstance(e2, EFlatMap): lc1 = extract_listcomp(e1) lc2 = extract_listcomp(e2) if lc1 is not None and lc2 is not None: cond1 = get_cond(lc1) cond2 = get_cond(lc2) sat1 = solver.satisfy(cond1) sat2 = solver.satisfy(cond2) if sat1 is None and sat2 is not None: return {k: v for k, v in sat2.items() if k not in lc2["P"]} if sat1 is not None and sat2 is None: return {k: v for k, v in sat1.items() if k not in lc1["P"]} iterables = [random_value(v.type) for v in free_vars(e)] ids = [v.id for v in free_vars(e)] for vs in product(*iterables): assignments = {} for id_, val in zip(ids, vs): assignments[id_] = val sat = eval(EAll([e] + assumptions), assignments) if sat: return assignments return None
def _compare(self, e1: Exp, e2: Exp, context: Context): e1_constant = not free_vars(e1) and not free_funcs(e1) e2_constant = not free_vars(e2) and not free_funcs(e2) if e1_constant and e2_constant: e1v = eval(e1, {}) e2v = eval(e2, {}) event("comparison obvious on constants: {} vs {}".format(e1v, e2v)) return order_objects(e1v, e2v) if alpha_equivalent(e1, e2): event("shortcutting comparison of identical terms") return Order.EQUAL path_condition = EAll(context.path_conditions()) always_le = self.solver.valid(EImplies(path_condition, ELe(e1, e2))) always_ge = self.solver.valid(EImplies(path_condition, EGe(e1, e2))) if always_le and always_ge: return Order.EQUAL if always_le: return Order.LT if always_ge: return Order.GT return Order.AMBIGUOUS
def storage_size(e, freebies : [Exp] = []): h = extension_handler(type(e.type)) if h is not None: return h.storage_size(e, storage_size=storage_size) if e in freebies: return ZERO elif e.type == BOOL: return ONE elif is_numeric(e.type) or isinstance(e.type, THandle): return FOUR elif isinstance(e.type, TEnum): return TWO elif isinstance(e.type, TNative): return FOUR elif isinstance(e.type, TString): return TWENTY elif isinstance(e.type, TTuple): return ESum([storage_size(ETupleGet(e, n).with_type(t)) for (n, t) in enumerate(e.type.ts)]) elif isinstance(e.type, TRecord): return ESum([storage_size(EGetField(e, f).with_type(t)) for (f, t) in e.type.fields]) elif is_collection(e.type): v = fresh_var(e.type.elem_type, omit=free_vars(e)) return ESum([ FOUR, EUnaryOp(UOp.Sum, EMap(e, ELambda(v, storage_size(v))).with_type(INT_BAG)).with_type(INT)]) elif isinstance(e.type, TMap): k = fresh_var(e.type.k, omit=free_vars(e)) return ESum([ FOUR, EUnaryOp(UOp.Sum, EMap( EMapKeys(e).with_type(TBag(e.type.k)), ELambda(k, ESum([ storage_size(k), storage_size(EMapGet(e, k).with_type(e.type.v))]))).with_type(INT_BAG)).with_type(INT)]) else: raise NotImplementedError(e.type)
def map_accelerate(e, state_vars, binders, args, cache, size): for (_, arg, f, bound) in enumerate_fragments(strip_EStateVar(e)): if any(v in state_vars for v in free_vars(arg)): continue for binder in (b for b in binders if b.type == arg.type and b not in bound): value = f(binder) if any(v not in state_vars and v not in binders for v in free_vars(value)): continue for bag in cache.find_collections(pool=STATE_POOL, size=size, of=arg.type): if isinstance(bag, EEmptyList): continue m = EMakeMap2(bag, ELambda(binder, value)).with_type(TMap(arg.type, e.type)) assert not any(v in args for v in free_vars(m)) if any(v in binders for v in free_vars(m)): continue yield (m, STATE_POOL) yield (EMapGet(EStateVar(m).with_type(m.type), arg).with_type(e.type), RUNTIME_POOL)
def replace_get_value(e : syntax.Exp, ptr : syntax.Exp, new_value : syntax.Exp) -> syntax.Exp: """ Return an expression representing the value of `e` after writing `new_value` to `ptr`. This amounts to replacing all instances of `_.val` in `e` with (_ == ptr) ? (new_value) : (_.val) """ t = ptr.type fvs = free_vars(ptr) | free_vars(new_value) class V(BottomUpRewriter): def visit_ELambda(self, e): if e.arg in fvs: v = fresh_var(e.arg.type, omit=fvs) e = syntax.ELambda(v, e.apply_to(v)) return syntax.ELambda(e.arg, self.visit(e.body)) def visit_EGetField(self, e): ee = self.visit(e.e) res = syntax.EGetField(ee, e.field_name).with_type(e.type) if e.e.type == t and e.field_name == "val": res = syntax.ECond(syntax.EEq(ee, ptr), new_value, res).with_type(e.type) return res return V().visit(e)
def retypecheck(exp, env=None): if env is None: env = { v.id:v.type for v in free_vars(exp) } fenv = { } for e in all_exps(exp): if isinstance(e, syntax.EEnumEntry): env[e.name] = e.type if isinstance(e, syntax.ECall): fenv[e.func] = (tuple(arg.type for arg in e.args), e.type) errs = typecheck(exp, env=env, fenv=fenv) if errs: print("errors") for e in errs: print(" --> {}".format(e)) return not errs
def _possible_replacements(self, e, pool, cost): """ Yields watched expressions that appear as worse versions of the given expression. There may be more than one. """ # return free_binders = OrderedSet(v for v in free_vars(e) if v in self.binders) for ctx in self._watched_contexts(pool, e.type): watched_e = ctx.e p = ctx.pool r = ctx.replace_e_with assert e.type == watched_e.type assert p == pool _on_exp(e, "considering replacement of", watched_e) # if e.type != watched_e.type: # # _on_exp(e, "wrong type") # continue # if p != pool: # # _on_exp(e, "wrong pool") # continue if e == watched_e: # _on_exp(e, "no change") continue unbound_binders = [b for b in free_binders if b not in ctx.bound_vars] if unbound_binders: _on_exp(e, "skipped exp with free binders", ", ".join(b.id for b in unbound_binders)) continue if CHECK_SUBST_COST: watched_cost = self.cost_model.cost(watched_e, pool=pool) ordering = self.compare_costs(cost, watched_cost) if ordering == Cost.WORSE: _on_exp(e, "skipped worse replacement", pool_name(pool), watched_e) continue if ordering == Cost.UNORDERED: _on_exp(e, "skipped equivalent replacement", pool_name(pool), watched_e) # print(" e1 = {!r}".format(e)) # print(" e2 = {!r}".format(watched_e)) continue # assert all(eval_bulk(self.assumptions, self.all_examples)) if all(eval_bulk(EEq(self.target, r(e)), self.all_examples)): yield (watched_e, e, ctx.facts, r) else: _on_exp(e, "visited pointless replacement", watched_e)
def eval_bulk(e: Exp, envs: [{ str: object }], use_default_values_for_undefined_vars: bool = False): """Evaluate an expression in many different environments. This function accepts the same arguments as `eval`, but takes a list of environments instead of just one. The call eval_bulk(e, envs) is equivalent to [eval(e, env) for env in envs]. However, using `eval_bulk` is much faster than repeatedly calling `eval` on the same expression. """ e = purify(e) if not envs: return [] ops = [] vars = OrderedSet(free_vars_and_funcs(e)) types = {v.id: v.type for v in free_vars(e)} vmap = {v: i for (i, v) in enumerate(vars)} try: envs = [[(env.get(v, mkval(types[v])) if (use_default_values_for_undefined_vars and v in types) else env[v]) for v in vars] for env in envs] except KeyError: import sys print("OH NO", file=sys.stderr) print("e = {}".format(pprint(e)), file=sys.stderr) print( "eval_bulk({!r}, {!r}, use_default_values_for_undefined_vars={!r})" .format(e, envs, use_default_values_for_undefined_vars), file=sys.stderr) raise _compile(e, vmap, ops) return [_eval_compiled(ops, env) for env in envs]
def optimized_any_matches(xs, p): if isinstance(xs, EEmptyList): return EFALSE if isinstance(xs, ESingleton): return p.apply_to(xs.e) if isinstance(xs, EMap): return optimized_any_matches(xs.e, compose(p, xs.transform_function)) # exists filter (not-in xs) ys if isinstance(p.body, EUnaryOp) and p.body.op == UOp.Not and isinstance( p.body.e, EBinOp) and p.body.e.op == BOp.In: if p.arg not in free_vars(p.body.e.e2): # er, this only works when xs is a subset of ys return EGt(optimized_len(xs), optimized_len(p.body.e.e2)) if isinstance(p.body, EBinOp) and p.body.op == BOp.Or: return EAny([ optimized_any_matches(xs, ELambda(p.arg, p.body.e1)).with_type(xs.type), optimized_any_matches(xs, ELambda(p.arg, p.body.e2)).with_type(xs.type) ]) if isinstance(xs, EFilter): return optimized_any_matches( xs.e, ELambda(p.arg, EAll([p.body, xs.predicate.apply_to(p.arg)]))) if isinstance(xs, EBinOp) and xs.op == "+": return EAny( [optimized_any_matches(xs.e1, p), optimized_any_matches(xs.e2, p)]) if isinstance(xs, EBinOp) and xs.op == "-": return EAll([ optimized_any_matches(xs.e1, p), ENot(optimized_any_matches(xs.e2, p)) ]) if isinstance(xs, ECond): return optimized_cond(xs.cond, optimized_any_matches(xs.then_branch, p), optimized_any_matches(xs.else_branch, p)).with_type(BOOL) return EUnaryOp(UOp.Exists, EFilter(xs, p).with_type(xs.type)).with_type(BOOL)
def visit_ELambda(self, e): if e.arg in binders_by_type[e.arg.type]: return super().visit_ADT(e) fvs = free_vars(e.body) legal_repls = [ b for b in binders_by_type[e.arg.type] if b not in fvs ] if not legal_repls: if allow_add: print("Adding aux binder {} and returning {}".format(e.arg, pprint(ELambda(e.arg, e.body))), file=sys.stderr) binders_to_use.append(e.arg) binders_by_type[e.arg.type].append(e.arg) return ELambda(e.arg, self.visit(e.body)) else: if throw: print("No legal binder to use for {}".format(pprint(e))) raise Exception(pprint(e)) else: return ELambda(e.arg, self.visit(e.body)) b = legal_repls[0] return ELambda(b, self.visit(subst(e.body, { e.arg.id : b })))
def _setup_handle_updates(self): """ This method creates update code for handle objects modified by each op. Must be called once after all user-specified queries have been added. """ for op in self.op_specs: print("Setting up handle updates for {}...".format(op.name)) handles = reachable_handles_at_method(self.spec, op) # print("-"*60) for t, bag in handles.items(): # print(" {} : {}".format(pprint(t), pprint(bag))) h = fresh_var(t) lval = EGetField(h, "val").with_type(t.value_type) new_val = inc.mutate(lval, op.body) # get set of modified handles modified_handles = Query( fresh_name("modified_handles"), Visibility.Internal, [], op.assumptions, EFilter(EUnaryOp(UOp.Distinct, bag).with_type(bag.type), ELambda(h, ENot(EEq(lval, new_val)))).with_type(bag.type), "[{}] modified handles of type {}".format(op.name, pprint(t))) query_vars = [v for v in free_vars(modified_handles) if v not in self.abstract_state] modified_handles.args = [(arg.id, arg.type) for arg in query_vars] # modify each one subqueries = [] state_update_stm = inc.mutate_in_place( lval, lval, op.body, abstract_state=self.abstract_state, assumptions=list(op.assumptions) + [EDeepIn(h, bag), EIn(h, modified_handles.ret)], invariants=self.abstract_invariants, subgoals_out=subqueries) for sub_q in subqueries: sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring) state_update_stm = self._add_subquery(sub_q=sub_q, used_by=state_update_stm) if state_update_stm != SNoOp(): state_update_stm = SForEach(h, ECall(modified_handles.name, query_vars).with_type(bag.type), state_update_stm) state_update_stm = self._add_subquery(sub_q=modified_handles, used_by=state_update_stm) self.handle_updates[(t, op.name)] = state_update_stm
def EMax(es): es = make_random_access(es) assert es assert all(isinstance(e, Exp) for e in es), es res = es[0] t = res.type fvs = set(v.id for v in free_vars(res)) for i in range(1, len(es)): res = maybe_inline(res, lambda v1: maybe_inline(es[i], lambda v2: ECond(EGt(v1, v2), v1, v2).with_type(t))) # v1 = fresh_var(res.type, omit=fvs) # fvs.add(v1.id) # v2 = fresh_var(res.type, omit=fvs) # fvs.add(v2.id) # fvs |= set(v.id for v in free_vars(es[i])) # res = ELet(res, ELambda(v1, # ELet(es[i], ELambda(v2, # ECond(EGt(v1, v2), v1, v2).with_type(res.type))).with_type(res.type))).with_type(res.type) # res = ECond(EGt(res, es[i]), res, es[i]).with_type(res.type) return res
def eval_bulk( e : Exp, envs : [{str:object}], use_default_values_for_undefined_vars : bool = False): """Evaluate an expression in many different environments. This function accepts the same arguments as `eval`, but takes a list of environments instead of just one. The call eval_bulk(e, envs) is equivalent to [eval(e, env) for env in envs]. However, using `eval_bulk` is much faster than repeatedly calling `eval` on the same expression. """ if not envs: return [] e = purify(e) ops = [] types = { v.id : v.type for v in free_vars(e) } vars = OrderedSet(itertools.chain(types.keys(), free_funcs(e).keys())) vmap = { v : i for (i, v) in enumerate(vars) } try: envs = [ [(env.get(v, mkval(types[v])) if (use_default_values_for_undefined_vars and v in types) else env[v]) for v in vars] for env in envs ] except KeyError: import sys print("OH NO", file=sys.stderr) print("e = {}".format(pprint(e)), file=sys.stderr) print("eval_bulk({!r}, {!r}, use_default_values_for_undefined_vars={!r})".format(e, envs, use_default_values_for_undefined_vars), file=sys.stderr) raise _compile(e, vmap, ops) return [_eval_compiled(ops, env) for env in envs]
def _setup_handle_updates(self): """ This method creates update code for handle objects modified by each op. Must be called once after all user-specified queries have been added. """ for op in self.op_specs: print("Setting up handle updates for {}...".format(op.name)) handles = reachable_handles_at_method(self.spec, op) # print("-"*60) for t, bag in handles.items(): # print(" {} : {}".format(pprint(t), pprint(bag))) h = fresh_var(t) lval = EGetField(h, "val").with_type(t.value_type) new_val = inc.mutate(lval, op.body) # get set of modified handles modified_handles = Query( fresh_name("modified_handles"), Visibility.Internal, [], op.assumptions, EFilter(EUnaryOp(UOp.Distinct, bag).with_type(bag.type), ELambda(h, ENot(EEq(lval, new_val)))).with_type(bag.type), "[{}] modified handles of type {}".format(op.name, pprint(t))) query_vars = [v for v in free_vars(modified_handles) if v not in self.abstract_state] modified_handles.args = [(arg.id, arg.type) for arg in query_vars] # modify each one subqueries = [] state_update_stm = inc.mutate_in_place( lval, lval, op.body, abstract_state=self.abstract_state, assumptions=list(op.assumptions) + [EDeepIn(h, bag), EIn(h, modified_handles.ret)], subgoals_out=subqueries) for sub_q in subqueries: sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring) state_update_stm = self._add_subquery(sub_q=sub_q, used_by=state_update_stm) if state_update_stm != SNoOp(): state_update_stm = SForEach(h, ECall(modified_handles.name, query_vars).with_type(bag.type), state_update_stm) state_update_stm = self._add_subquery(sub_q=modified_handles, used_by=state_update_stm) self.handle_updates[(t, op.name)] = state_update_stm
def optimized_map(xs, f, args): res_type = type(xs.type)(f.body.type) if f.arg == f.body: yield xs if isinstance(xs, ESingleton): yield ESingleton(f.apply_to(xs.e)).with_type(res_type) if isinstance(xs, EEmptyList): yield EEmptyList().with_type(res_type) if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(f)): yield EStateVar(EMap(xs.e, f).with_type(res_type)).with_type(res_type) if isinstance(xs, EBinOp): if xs.op in ("+", "-"): for a in optimized_map(xs.e1, f, args): for b in optimized_map(xs.e2, f, args): yield EBinOp(a, xs.op, b).with_type(res_type) if isinstance(f.body, ECond): for true_elems in optimized_filter(xs, ELambda(f.arg, f.body.cond) , args=args): for false_elems in optimized_filter(xs, ELambda(f.arg, ENot(f.body.cond)), args=args): for a in optimized_map(true_elems, ELambda(f.arg, f.body.then_branch), args): for b in optimized_map(false_elems, ELambda(f.arg, f.body.else_branch), args): yield EBinOp(a, "+", b).with_type(res_type) yield EMap(xs, f).with_type(res_type)
def __init__(self, parent : Context, v : EVar, bag : Exp, bag_pool : Pool): """Construct a context under a binder. Parameters: parent - the context of the enclosing lambda v - the bound variable bag - the bag that v must be a member of bag_pool - the pool that the bag belongs to `v`'s type annotation must be the same as the type of elements in the bag. `bag` must be legal in the parent context. `v` must not already be described by the parent context. """ assert v.type == bag.type.elem_type assert parent.legal_for(free_vars(bag)), "cannot create context for {} in {}, {}".format(v.id, pprint(bag), parent) assert not any(v == vv for vv, p in parent.vars()), "binder {} already free in {}".format(v.id, parent) self._parent = parent self.var = v self.bag = bag self.pool = bag_pool
def __init__(self, parent : Context, v : EVar, bag : Exp, bag_pool : Pool): """Construct a context under a binder. Parameters: parent - the context of the enclosing lambda v - the bound variable bag - the bag that v must be a member of bag_pool - the pool that the bag belongs to `v`'s type annotation must be the same as the type of elements in the bag. `bag` must be legal in the parent context. `v` must not already be described by the parent context. """ assert v.type == bag.type.t assert parent.legal_for(free_vars(bag)), "cannot create context for {} in {}, {}".format(v.id, pprint(bag), parent) assert not any(v == vv for vv, p in parent.vars()), "binder {} already free in {}".format(v.id, parent) self._parent = parent self.var = v self.bag = bag self.pool = bag_pool
def __init__(self, examples, cost_model: CostModel, check_wf=None, hints=None, heuristics=None, stop_callback=None): self.examples = list(examples) self.cost_model = cost_model self.cache = {} # keys -> [exp] self.seen = {} # (ctx, pool, fp) -> frontier, i.e. [exp] self.in_progress = set() if check_wf is None: check_wf = lambda e, ctx, pool: True self.check_wf = check_wf self.hints = OrderedSet( (e, ctx.generalize(free_vars(e)), p) for (e, ctx, p) in hints) if heuristics is None: heuristics = lambda e, ctx, pool: () self.heuristics = heuristics if stop_callback is None: stop_callback = lambda: False self.stop_callback = stop_callback
def eval_bulk(e, envs, use_default_values_for_undefined_vars: bool = False): e = purify(e) if not envs: return [] ops = [] vars = OrderedSet(free_vars_and_funcs(e)) types = {v.id: v.type for v in free_vars(e)} vmap = {v: i for (i, v) in enumerate(vars)} try: envs = [[(env.get(v, mkval(types[v])) if (use_default_values_for_undefined_vars and v in types) else env[v]) for v in vars] for env in envs] except KeyError: import sys print("OH NO", file=sys.stderr) print("e = {}".format(pprint(e)), file=sys.stderr) print( "eval_bulk({!r}, {!r}, use_default_values_for_undefined_vars={!r})" .format(e, envs, use_default_values_for_undefined_vars), file=sys.stderr) raise _compile(e, vmap, ops) return [_eval_compiled(ops, env) for env in envs]
def optimized_any_matches(xs, p): if isinstance(xs, EEmptyList): return EFALSE if isinstance(xs, ESingleton): return p.apply_to(xs.e) if isinstance(xs, EMap): return optimized_any_matches(xs.e, compose(p, xs.transform_function)) # exists filter (not-in xs) ys if isinstance(p.body, EUnaryOp) and p.body.op == UOp.Not and isinstance(p.body.e, EBinOp) and p.body.e.op == BOp.In: if p.arg not in free_vars(p.body.e.e2): # er, this only works when xs is a subset of ys return EGt( optimized_len(xs), optimized_len(p.body.e.e2)) if isinstance(p.body, EBinOp) and p.body.op == BOp.Or: return EAny([ optimized_any_matches(xs, ELambda(p.arg, p.body.e1)).with_type(xs.type), optimized_any_matches(xs, ELambda(p.arg, p.body.e2)).with_type(xs.type)]) if isinstance(xs, EFilter): return optimized_any_matches(xs.e, ELambda(p.arg, EAll([p.body, xs.predicate.apply_to(p.arg)]))) if isinstance(xs, EBinOp) and xs.op == "+": return EAny([optimized_any_matches(xs.e1, p), optimized_any_matches(xs.e2, p)]) if isinstance(xs, EBinOp) and xs.op == "-": return EAll([ optimized_any_matches(xs.e1, p), ENot(optimized_any_matches(xs.e2, p))]) if isinstance(xs, ECond): return optimized_cond(xs.cond, optimized_any_matches(xs.then_branch, p), optimized_any_matches(xs.else_branch, p)).with_type(BOOL) return EUnaryOp(UOp.Exists, EFilter(xs, p).with_type(xs.type)).with_type(BOOL)
def optimized_best(xs, keyfunc, op, args): argbest = EArgMin if op == "<" else EArgMax elem_type = xs.type.elem_type key_type = keyfunc.body.type if excluded_element(xs, args) is not None: bag, x = excluded_element(xs, args) if all(v not in args for v in free_vars(bag)): heap_type, make_heap = (TMinHeap, EMakeMinHeap) if op == "<" else (TMaxHeap, EMakeMaxHeap) bag = EStateVar(strip_EStateVar(bag)).with_type(bag.type) h = make_heap(bag.e, keyfunc).with_type(heap_type(elem_type, key_type)) for prev_min in optimized_best(bag.e, keyfunc, op, args=args): prev_min = EStateVar(prev_min).with_type(elem_type) heap_peek = EHeapPeek2(EStateVar(h).with_type(h.type)).with_type(elem_type) conds = [optimized_in(x, bag), optimized_eq(x, prev_min)] if isinstance(x, EUnaryOp) and x.op == UOp.The: conds = [optimized_exists(x.e)] + conds yield optimized_cond( EAll(conds), heap_peek, prev_min) if isinstance(xs, EEmptyList): yield construct_value(elem_type) if isinstance(xs, ESingleton): yield xs.e if isinstance(xs, EBinOp) and xs.op == "+": a_ex = optimized_exists(xs.e1) b_ex = optimized_exists(xs.e2) bag_type = TBag(xs.type.elem_type) for a in optimized_best(xs.e1, keyfunc, op, args=args): for b in optimized_best(xs.e2, keyfunc, op, args=args): yield optimized_cond(a_ex, optimized_cond(b_ex, argbest(EBinOp(ESingleton(a).with_type(bag_type), "+", ESingleton(b).with_type(bag_type)).with_type(bag_type), keyfunc).with_type(elem_type), a), optimized_cond(b_ex, b, construct_value(elem_type))) # if isinstance(xs.e1, EStateVar) or isinstance(xs.e2, EStateVar): # sv, other = (xs.e1, xs.e2) if isinstance(xs.e1, EStateVar) else (xs.e2, xs.e1) # sv_best = optimized_best(sv, keyfunc, op, args=args) # yield optimized_cond( # optimized_exists(sv), # argbest(EBinOp(ESingleton(sv_best).with_type(xs.type), "+", other).with_type(xs.type), keyfunc).with_type(elem_type), # optimized_best(other, keyfunc, op, args=args)) # else: # parts = break_sum(xs) # found = EFALSE # best = construct_value(elem_type) # for p in parts: # ex = optimized_exists(p) # best_here = optimized_best(p, keyfunc, op, args=args) # best = optimized_cond(found, # optimized_cond(ex, # optimized_cond(EBinOp(keyfunc.apply_to(best_here), op, keyfunc.apply_to(best)).with_type(BOOL), # best_here, # best), # best), # best_here) # found = EAny([found, ex]) # yield best if isinstance(xs, EMap): for b in optimized_best(xs.e, compose(keyfunc, xs.transform_function), op, args): yield optimized_cond(optimized_exists(xs.e), xs.transform_function.apply_to(b), construct_value(elem_type)) if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(keyfunc)): yield EStateVar(argbest(xs.e, keyfunc).with_type(elem_type)).with_type(elem_type) if isinstance(xs, ECond): for a in optimized_best(xs.then_branch, keyfunc, op, args=args): for b in optimized_best(xs.else_branch, keyfunc, op, args=args): yield optimized_cond(xs.cond, a, b) if isinstance(xs, EUnaryOp) and xs.op == UOp.Distinct: yield from optimized_best(xs.e, keyfunc, op, args=args) # if isinstance(xs, EFilter): # yield optimized_cond( # xs.predicate.apply_to(optimized_best(xs.e, keyfunc, op, args=args)), # optimized_best(xs.e, keyfunc, op, args=args), # argbest(xs, keyfunc).with_type(elem_type)) yield argbest(xs, keyfunc).with_type(elem_type)
def visit_Spec(self, spec : Spec, state_exps : { str : Exp }, sharing, abstract_state=()): self.state_exps = state_exps self.funcs = { f.name: f for f in spec.extern_funcs } self.queries = { q.name: q for q in spec.methods if isinstance(q, Query) } self.vars = set(e.id for e in all_exps(spec) if isinstance(e, EVar)) self.write("#pragma once\n") self.write("#include <algorithm>\n") self.write("#include <set>\n") self.write("#include <functional>\n") self.write("#include <vector>\n") self.write("#include <unordered_set>\n") self.write("#include <string>\n") if self.use_qhash: self.write("#include <QHash>\n") else: self.write("#include <unordered_map>\n") if spec.header: self.write("\n" + spec.header.strip() + "\n") self.write("{}\nclass {} {{\n".format( ("\n" + spec.docstring) if spec.docstring else "", spec.name)) self.write("public:\n") print("Setting up auxiliary types...") self.setup_types(spec, state_exps, sharing) with self.indented(): for t, name in self.types.items(): self.define_type(spec.name, t, name, sharing) self.begin_statement() if isinstance(t, THandle): # No overridden hash code! We use pointers instead. continue self.write("struct _Hash", name, " ") with self.block(): self.write_stmt("typedef ", spec.name, "::", name, " argument_type;") self.write_stmt("typedef std::size_t result_type;") self.begin_statement() self.write("result_type operator()(const argument_type& x) const noexcept ") x = EVar("x").with_type(t) if isinstance(t, TEnum): fields = [EEnumToInt(x).with_type(INT)] elif isinstance(t, TRecord): fields = [EGetField(x, f).with_type(ft) for (f, ft) in t.fields] elif isinstance(t, TTuple): fields = [ETupleGet(x, n).with_type(tt) for (n, tt) in enumerate(t.ts)] else: raise NotImplementedError(t) with self.block(): self.visit(self.compute_hash(fields)) self.end_statement() self.write(";") self.end_statement() print("Setting up member variables...") self.write("protected:\n") with self.indented(): for name, t in spec.statevars: self.statevar_name = name self.declare_field(name, t) self.write("public:\n") with self.indented(): print("Generating constructors...") # default constructor self.begin_statement() self.write("inline ", spec.name, "() ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] fvs = free_vars(initial_value) initial_value = subst(initial_value, {v.id : evaluation.construct_value(v.type) for v in fvs}) stm = simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value)) self.visit(stm) self.end_statement() # explicit constructor if abstract_state: self.begin_statement() self.write("explicit inline ", spec.name, "(") self.visit_args(abstract_state) self.write(") ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] self.visit(simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value))) self.end_statement() # disable copy constructor (TODO: support this in the future?) self.begin_statement() self.write(spec.name, "(const ", spec.name, "& other) = delete;") self.end_statement() # generate methods for op in spec.methods: print("Generating method {}...".format(op.name)) self.visit(op) self.write("};\n") if spec.footer: self.write("\n", spec.footer) if not spec.footer.endswith("\n"): self.write("\n")
def possibly_useful_nonrecursive(solver, e : Exp, context : Context, pool = RUNTIME_POOL, assumptions : Exp = ETRUE, ops : [Op] = ()) -> bool: """Heuristic filter to ignore expressions that are almost certainly useless.""" state_vars = OrderedSet(v for v, p in context.vars() if p == STATE_POOL) args = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL) assumptions = EAll([assumptions, context.path_condition()]) at_runtime = pool == RUNTIME_POOL h = extension_handler(type(e)) if h is not None: res = h.possibly_useful(e, context, pool, assumptions, ops, solver) if not res: return res if isinstance(e, EStateVar) and not free_vars(e.e): return No("constant value in state position") if (isinstance(e, EDropFront) or isinstance(e, EDropBack)) and not at_runtime: return No("EDrop* in state position") if not allow_big_sets.value and isinstance(e, EFlatMap) and not at_runtime: return No("EFlatMap in state position") if not allow_int_arithmetic_state.value and not at_runtime and isinstance(e, EBinOp) and e.type == INT: return No("integer arithmetic in state position") if is_collection(e.type) and not is_scalar(e.type.elem_type): return No("collection of nonscalar: e {}\n elem_type: {}\n".format(e, e.type.elem_type)) if isinstance(e.type, TMap) and not is_scalar(e.type.k): return No("bad key type {}".format(pprint(e.type.k))) if isinstance(e.type, TMap) and isinstance(e.type.v, TMap): return No("map to map") # This check is probably a bad idea: whether `the` is legal may depend on # the contex that the expression is embedded within, so we can't skip it # during synthesis just because it looks invalid now. # if isinstance(e, EUnaryOp) and e.op == UOp.The: # len = EUnaryOp(UOp.Length, e.e).with_type(INT) # if not valid(EImplies(assumptions, EBinOp(len, "<=", ENum(1).with_type(INT)).with_type(BOOL))): # return No("illegal application of 'the': could have >1 elems") if not at_runtime and isinstance(e, EBinOp) and e.op == "-" and is_collection(e.type): return No("collection subtraction in state position") # if not at_runtime and isinstance(e, ESingleton): # return No("singleton in state position") if not allow_nonzero_state_constants.value and not at_runtime and isinstance(e, ENum) and e.val != 0: return No("nonzero integer constant in state position") if not allow_binop_state.value and at_runtime and isinstance(e, EStateVar) and isinstance(e.e, EBinOp) and is_scalar(e.e.e1.type) and is_scalar(e.e.e2.type): return No("constant-time binary operator {!r} in state position".format(e.e.op)) if not allow_conditional_state.value and not at_runtime and isinstance(e, ECond): return No("conditional in state position") if isinstance(e, EMakeMap2) and isinstance(e.e, EEmptyList): return No("trivially empty map") if isinstance(e, EMakeMap2) and isinstance(e.e, ESingleton): return No("really tiny map") if not at_runtime and (isinstance(e, EArgMin) or isinstance(e, EArgMax)): # Cozy has no way to efficiently implement mins/maxes when more than # one element may leave the collection. from cozy.state_maintenance import mutate for op in ops: elems = e.e elems_prime = mutate(elems, op.body) formula = EAll([assumptions] + list(op.assumptions) + [EGt(ELen(EBinOp(elems, "-", elems_prime).with_type(elems.type)), ONE)]) if solver.satisfiable(formula): return No("more than one element might be removed during {}".format(op.name)) if not allow_peels.value and not at_runtime and isinstance(e, EFilter): # catch "peels": removal of zero or one elements if solver.valid(EImplies(assumptions, ELe(ELen(EFilter(e.e, ELambda(e.predicate.arg, ENot(e.predicate.body))).with_type(e.type)), ONE))): return No("filter is a peel") if not allow_big_maps.value and not at_runtime and isinstance(e, EMakeMap2) and is_collection(e.type.v): all_collections = [sv for sv in state_vars if is_collection(sv.type)] total_size = ENum(0).with_type(INT) for c in all_collections: total_size = EBinOp(total_size, "+", EUnaryOp(UOp.Length, c).with_type(INT)).with_type(INT) my_size = EUnaryOp(UOp.Length, EFlatMap(EUnaryOp(UOp.Distinct, e.e).with_type(e.e.type), e.value_function).with_type(e.type.v)).with_type(INT) s = EImplies( assumptions, EBinOp(total_size, ">=", my_size).with_type(BOOL)) if not solver.valid(s): return No("non-polynomial-sized map") return True
def build(self, cache, size): # print("Cache:") # for (e, sz, pool) in cache: # from cozy.syntax_tools import pprint # print(" @size={}, pool={}\t:\t{}".format(sz, pool, pprint(e))) binders_by_type = group_by(self.binders, lambda b: b.type) for pool in ALL_POOLS: if size == 1: yield self.check(T, pool) yield self.check(F, pool) yield self.check(ZERO, pool) yield self.check(ONE, pool) for b in self.binders: yield self.check(b, pool) if pool == STATE_POOL: for v in self.state_vars: yield self.check(v, pool) elif pool == RUNTIME_POOL: for v in self.args: yield self.check(v, pool) if not build_exprs.value: return for e in cache.find(pool=STATE_POOL, size=size - 1): if all(v in self.state_vars for v in free_vars(e)): yield self.check( EStateVar(e).with_type(e.type), RUNTIME_POOL) for e in cache.find(pool=pool, size=size - 1): t = TBag(e.type) yield self.check(EEmptyList().with_type(t), pool) yield self.check(ESingleton(e).with_type(t), pool) for e in cache.find(pool=pool, type=TRecord, size=size - 1): for (f, t) in e.type.fields: yield self.check(EGetField(e, f).with_type(t), pool) for e in cache.find_collections(pool=pool, size=size - 1): if is_numeric(e.type.t): yield self.check( EUnaryOp(UOp.Sum, e).with_type(e.type.t), pool) for e in cache.find(pool=pool, type=THandle, size=size - 1): yield self.check( EGetField(e, "val").with_type(e.type.value_type), pool) for e in cache.find(pool=pool, type=TTuple, size=size - 1): for n in range(len(e.type.ts)): yield self.check( ETupleGet(e, n).with_type(e.type.ts[n]), pool) for e in cache.find(pool=pool, type=BOOL, size=size - 1): yield self.check(EUnaryOp(UOp.Not, e).with_type(BOOL), pool) for e in cache.find(pool=pool, type=INT, size=size - 1): yield self.check(EUnaryOp("-", e).with_type(INT), pool) for m in cache.find(pool=pool, type=TMap, size=size - 1): yield self.check(EMapKeys(m).with_type(TBag(m.type.k)), pool) for (sz1, sz2) in pick_to_sum(2, size - 1): for a1 in cache.find(pool=pool, size=sz1): if not is_numeric(a1.type): continue for a2 in cache.find(pool=pool, type=a1.type, size=sz2): yield self.check( EBinOp(a1, "+", a2).with_type(INT), pool) yield self.check( EBinOp(a1, "-", a2).with_type(INT), pool) yield self.check( EBinOp(a1, ">", a2).with_type(BOOL), pool) yield self.check( EBinOp(a1, "<", a2).with_type(BOOL), pool) yield self.check( EBinOp(a1, ">=", a2).with_type(BOOL), pool) yield self.check( EBinOp(a1, "<=", a2).with_type(BOOL), pool) for a1 in cache.find_collections(pool=pool, size=sz1): for a2 in cache.find(pool=pool, type=a1.type, size=sz2): yield self.check( EBinOp(a1, "+", a2).with_type(a1.type), pool) yield self.check( EBinOp(a1, "-", a2).with_type(a1.type), pool) for a2 in cache.find(pool=pool, type=a1.type.t, size=sz2): yield self.check( EBinOp(a2, BOp.In, a1).with_type(BOOL), pool) for a1 in cache.find(pool=pool, type=BOOL, size=sz1): for a2 in cache.find(pool=pool, type=BOOL, size=sz2): yield self.check( EBinOp(a1, BOp.And, a2).with_type(BOOL), pool) yield self.check( EBinOp(a1, BOp.Or, a2).with_type(BOOL), pool) for a1 in cache.find(pool=pool, size=sz1): if not isinstance(a1.type, TMap): for a2 in cache.find(pool=pool, type=a1.type, size=sz2): yield self.check(EEq(a1, a2), pool) yield self.check( EBinOp(a1, "!=", a2).with_type(BOOL), pool) for m in cache.find(pool=pool, type=TMap, size=sz1): for k in cache.find(pool=pool, type=m.type.k, size=sz2): yield self.check( EMapGet(m, k).with_type(m.type.v), pool) yield self.check(EHasKey(m, k).with_type(BOOL), pool) for (sz1, sz2, sz3) in pick_to_sum(3, size - 1): for cond in cache.find(pool=pool, type=BOOL, size=sz1): for then_branch in cache.find(pool=pool, size=sz2): for else_branch in cache.find(pool=pool, size=sz3, type=then_branch.type): yield self.check( ECond(cond, then_branch, else_branch).with_type(then_branch.type), pool) for bag in cache.find_collections(pool=pool, size=size - 1): # len of bag count = EUnaryOp(UOp.Length, bag).with_type(INT) yield self.check(count, pool) # empty? yield self.check( EUnaryOp(UOp.Empty, bag).with_type(BOOL), pool) # exists? yield self.check( EUnaryOp(UOp.Exists, bag).with_type(BOOL), pool) # singleton? yield self.check(EEq(count, ONE), pool) yield self.check( EUnaryOp(UOp.The, bag).with_type(bag.type.t), pool) yield self.check( EUnaryOp(UOp.Distinct, bag).with_type(bag.type), pool) yield self.check( EUnaryOp(UOp.AreUnique, bag).with_type(BOOL), pool) if bag.type.t == BOOL: yield self.check( EUnaryOp(UOp.Any, bag).with_type(BOOL), pool) yield self.check( EUnaryOp(UOp.All, bag).with_type(BOOL), pool) for (sz1, sz2) in pick_to_sum(2, size - 1): for bag in cache.find_collections(pool=pool, size=sz1): for binder in binders_by_type[bag.type.t]: for body in itertools.chain( cache.find(pool=pool, size=sz2), (binder, )): yield self.check( EMap(bag, ELambda(binder, body)).with_type(TBag(body.type)), pool) if body.type == BOOL: yield self.check( EFilter(bag, ELambda(binder, body)).with_type(bag.type), pool) if body.type == INT: yield self.check( EArgMin(bag, ELambda( binder, body)).with_type(bag.type.t), pool) yield self.check( EArgMax(bag, ELambda( binder, body)).with_type(bag.type.t), pool) if pool == RUNTIME_POOL and isinstance( body.type, TBag): yield self.check( EFlatMap(bag, ELambda(binder, body)).with_type( TBag(body.type.t)), pool) for (sz1, sz2) in pick_to_sum(2, size - 1): for bag in cache.find_collections(pool=STATE_POOL, size=sz1): if not is_scalar(bag.type.t): continue for b in binders_by_type[bag.type.t]: for val in cache.find(pool=STATE_POOL, size=sz2): t = TMap(bag.type.t, val.type) m = EMakeMap2(bag, ELambda(b, val)).with_type(t) yield self.check(m, STATE_POOL)
def can_serve_as_value(e, binder, state): fvs = free_vars(e) return binder not in fvs and not any(v == binder or v in state for v in fvs)
def ECountIn(e, collection): """Count the number of times e occurs in the collection""" from cozy.syntax_tools import free_vars, fresh_var assert e.type == collection.type.elem_type arg = fresh_var(e.type, omit=free_vars(e)) return EUnaryOp(UOp.Length, EFilter(collection, ELambda(arg, EEq(arg, e))).with_type(collection.type)).with_type(INT)
def build(self, cache, size): for e in cache.find(pool=RUNTIME_POOL, size=size - 1, type=INT): if not is_root(e): continue e2 = simplify_sum(e) if e != e2: yield self.check(e2, RUNTIME_POOL) # for e in cache.find(pool=RUNTIME_POOL, size=size-1): # if isinstance(e, EMapGet) and isinstance(e.map, EMakeMap2): # x = e.map.value.apply_to(e.key) # x._tag = True # yield self.check(x, RUNTIME_POOL) # [x] - ys for e in cache.find_collections(pool=RUNTIME_POOL, size=size - 1): if not is_root(e): continue if isinstance(e, EBinOp) and e.op == "-" and isinstance( e.e1, ESingleton): x = e.e1.e y = e.e2 x = ECond( EBinOp(x, BOp.In, y).with_type(BOOL), EEmptyList().with_type(e.type), e.e1).with_type(e.type) yield self.check(x, RUNTIME_POOL) elif isinstance(e, EUnaryOp) and e.op == UOp.Distinct: e = strip_EStateVar(e) m = EMakeMap2(e.e, mk_lambda(e.type.t, lambda x: T)).with_type( TMap(e.type.t, BOOL)) yield self.check(m, STATE_POOL) m = EStateVar(m).with_type(m.type) yield self.check(m, RUNTIME_POOL) x = EMapKeys(m).with_type(e.type) # x._tag = True yield self.check(x, RUNTIME_POOL) # # x in ys ----> (count x in ys) > 0 # for e in cache.find(pool=RUNTIME_POOL, type=BOOL, size=size-1): # if isinstance(e, EBinOp) and e.op == BOp.In: # for b in self.binders: # if b.type != e.e1.type: # continue # x = EGt( # EUnaryOp(UOp.Length, EFilter(e.e2, ELambda(b, EEq(e.e1, b))).with_type(e.e2.type)).with_type(INT), # ZERO) # x._tag = True # yield self.check(x, RUNTIME_POOL) for e in cache.find(pool=RUNTIME_POOL, size=size - 1): if not is_root(e): continue if (isinstance(e, EArgMin) or isinstance( e, EArgMax)) and isinstance(e.e, EBinOp) and e.e.op == "+": l = e.e.e1 r = e.e.e2 op = e.e.op f = lambda x: type(e)(x, e.f).with_type(e.type) ll = EStateVar(f(l.e)).with_type(e.type) if isinstance( l, EStateVar) else f(l) rr = EStateVar(f(r.e)).with_type(e.type) if isinstance( r, EStateVar) else f(r) x = ECond( EUnaryOp(UOp.Exists, l).with_type(BOOL), ECond( EUnaryOp(UOp.Exists, r).with_type(BOOL), f( EBinOp( ESingleton(ll).with_type(e.e.type), op, ESingleton(rr).with_type(e.e.type)).with_type( e.e.type)), ll).with_type(e.type), rr).with_type(e.type) # from cozy.solver import valid # assert valid(EEq(e, x), model_callback=print) x._tag = True yield self.check(x, RUNTIME_POOL) # is-last(x, l) for (sz1, sz2) in pick_to_sum(2, size - 1): for e1 in cache.find(pool=RUNTIME_POOL, size=sz1): if not is_root(e1): continue for e2 in cache.find_collections(pool=STATE_POOL, size=sz2, of=e1.type): if not is_root(e2): continue for b in self.binders: if b.type != e1.type: continue m = EMakeMap2( e2, mk_lambda( e2.type.t, lambda x: EUnaryOp( UOp.Length, EFilter( e2, mk_lambda(e2.type.t, lambda y: EEq( x, y))).with_type(e2.type)). with_type(INT))).with_type(TMap( e2.type.t, INT)) # filt = EFilter(e2, ELambda(b, EEq(e1, b))).with_type(e2.type) # x = EEq( # EUnaryOp(UOp.Length, filt).with_type(INT), # ONE) x = EGt( EMapGet(EStateVar(m).with_type(m.type), e1).with_type(INT), ONE) # x._tag = True yield self.check(x, RUNTIME_POOL) # histogram # for e in cache.find_collections(pool=STATE_POOL, size=size-1): # m = EMakeMap2(e, # mk_lambda(e.type.t, lambda x: # EUnaryOp(UOp.Length, EFilter(e, # mk_lambda(e.type.t, lambda y: EEq(x, y))).with_type(e.type)).with_type(INT))).with_type(TMap(e.type.t, INT)) # m._tag = True # yield self.check(m, STATE_POOL) # Fixup EFilter(\x -> ECond...) for e in cache.find_collections(pool=RUNTIME_POOL, size=size - 1): if not is_root(e): continue if isinstance(e, EFilter): for (_, x, r, _) in enumerate_fragments(e.p.body): if isinstance(x, ECond): lhs = EFilter( e.e, ELambda(e.p.arg, EAll([x.cond, r(x.then_branch) ]))).with_type(e.type) rhs = EFilter( e.e, ELambda(e.p.arg, EAll([ENot(x.cond), r(x.else_branch) ]))).with_type(e.type) union = EBinOp(lhs, "+", rhs).with_type(e.type) # yield self.check(lhs.p.body, RUNTIME_POOL) # yield self.check(rhs.p.body, RUNTIME_POOL) yield self.check(lhs, RUNTIME_POOL) yield self.check(rhs, RUNTIME_POOL) yield self.check(union, RUNTIME_POOL) # Try instantiating bound expressions for pool in (STATE_POOL, RUNTIME_POOL): for (sz1, sz2) in pick_to_sum(2, size - 1): for e1 in cache.find(pool=pool, size=sz1): if not is_root(e1): continue for v in free_vars(e1): if pool == RUNTIME_POOL: e1 = subst( strip_EStateVar(e1), { sv.id: EStateVar(sv).with_type(sv.type) for sv in self.state_vars if sv != v }) for e2 in cache.find(pool=pool, type=v.type, size=sz2): yield self.check(subst(e1, {v.id: e2}), pool) for (sz1, sz2) in pick_to_sum(2, size - 1): for e in cache.find(pool=RUNTIME_POOL, size=sz1): if not is_root(e): continue for x, pool in map_accelerate(e, self.state_vars, self.binders, self.args, cache, sz2): yield self.check(x, pool) if isinstance(e, EFilter) and not any(v in self.binders for v in free_vars(e)): for x, pool in accelerate_filter(e.e, e.p, self.state_vars, self.binders, self.args, cache, sz2): yield self.check(x, pool) for bag in cache.find_collections(pool=RUNTIME_POOL, size=size - 1): if not is_root(bag): continue for a in self.args: for v in self.state_vars: if is_collection(v.type) and v.type == a.type: v = EStateVar(v).with_type(v.type) cond = EBinOp(a, BOp.In, v).with_type(BOOL) yield self.check( EFilter(bag, mk_lambda(bag.type.t, lambda _: cond)).with_type( bag.type), RUNTIME_POOL) yield self.check( EFilter( bag, mk_lambda(bag.type.t, lambda _: ENot(cond))).with_type( bag.type), RUNTIME_POOL) if isinstance(bag, EFilter): if any(v not in self.state_vars for v in free_vars(bag.e)): continue # separate filter conds if isinstance(bag.p.body, EBinOp) and bag.p.body.op == BOp.And: p1 = ELambda(bag.p.arg, bag.p.body.e1) p2 = ELambda(bag.p.arg, bag.p.body.e2) f1 = EFilter(bag.e, p1).with_type(bag.type) f2 = EFilter(bag.e, p2).with_type(bag.type) f3 = EFilter(f1, p2).with_type(bag.type) f4 = EFilter(f2, p1).with_type(bag.type) yield self.check(f1, RUNTIME_POOL) yield self.check(f2, RUNTIME_POOL) yield self.check(f3, RUNTIME_POOL) yield self.check(f4, RUNTIME_POOL) # construct map lookups binder = bag.p.arg inf = infer_map_lookup(bag.p.body, binder, set(self.state_vars)) if inf: key_proj, key_lookup, remaining_filter = inf bag_binder = find_one( self.binders, lambda b: b.type == key_proj.type and b != binder) if bag_binder: m = strip_EStateVar( EMakeMap2( EMap(bag.e, ELambda(binder, key_proj)).with_type( type(bag.type)(key_proj.type)), ELambda( bag_binder, EFilter( bag.e, ELambda(binder, EEq(key_proj, bag_binder))).with_type( bag.type))).with_type( TMap( key_proj.type, bag.type))) assert not any(v in self.args for v in free_vars(m)) yield self.check(m, STATE_POOL) m = EStateVar(m).with_type(m.type) mg = EMapGet(m, key_lookup).with_type(bag.type) yield self.check(mg, RUNTIME_POOL) yield self.check( EFilter(mg, ELambda( binder, remaining_filter)).with_type(mg.type), RUNTIME_POOL) # for e in cache.find(size=size-1): # # F(xs +/- ys) ---> F(xs), F(ys) # for z in break_plus_minus(e): # if z != e: # # print("broke {} --> {}".format(pprint(e), pprint(z))) # yield z # # try reordering operations # for (_, e1, f) in enumerate_fragments(e): # if e1.type == e.type and e1 != e: # for (_, e2, g) in enumerate_fragments(e1): # if e2.type == e.type and e2 != e1: # # e == f(g(e2)) # yield g(f(e2)) yield from self.wrapped.build(cache, size)
def can_serve_as_key(e, binder, state): fvs = free_vars(e) return binder in fvs and all(v == binder or v in state for v in fvs)
def code(self) -> Spec: """Get the current code corresponding to this implementation. The code is returned as a Cozy specification object, but the returned object throws away any unused abstract state as well as all invariants and assumptions on methods. It implements the same data structure, but probably more efficiently. """ state_read_by_query = { query_name : free_vars(query) for query_name, query in self.query_impls.items() } # prevent read-after-write by lifting reads before writes. # list of SDecls temps = defaultdict(list) updates = dict(self.updates) _concretization_functions = [v for v, e in self._concretization_functions] for operator in self.op_specs: # Compute order constraints between statements: # v1 -> v2 means that the update code for v1 should (if possible) # appear before the update code for v2 # (i.e. the update code for v1 reads v2) def state_used_during_update(v1 : EVar) -> [EVar]: v1_update_code = self.updates[(v1, operator.name)] v1_queries = list(self.queries_used_by(v1_update_code)) res = OrderedSet() for q in v1_queries: res |= state_read_by_query[q] return res g = DirectedGraph( nodes=_concretization_functions, successors=state_used_during_update) # Find the minimum set of edges we need to break cycles (see # "feedback arc set problem") edges_to_break = g.minimum_feedback_arc_set() g.delete_edges(edges_to_break) _concretization_functions = list(g.toposort()) # Lift auxiliary declarations as needed things_updated = [] for v in _concretization_functions: things_updated.append(v) stm = updates[(v, operator.name)] def problematic(e): for x in all_exps(e): if isinstance(x, ECall) and x.func in [q.name for q in self.query_specs]: problems = set(things_updated) & state_read_by_query[x.func] if problems: return True return False stm = pull_temps(stm, decls_out=temps[operator.name], exp_is_bad=problematic) updates[(v, operator.name)] = stm # construct new op implementations new_ops = [] for op in self.op_specs: stms = [ updates[(v, op.name)] for v in _concretization_functions ] stms.extend(hup for ((t, op_name), hup) in self.handle_updates.items() if op.name == op_name) new_stms = seq(temps[op.name] + stms) new_ops.append(Op( op.name, op.args, [], new_stms, op.docstring)) # assemble final result return Spec( self.spec.name, self.spec.types, self.spec.extern_funcs, [(v.id, e.type) for (v, e) in self._concretization_functions], [], list(self.query_impls.values()) + new_ops, self.spec.header, self.spec.footer, self.spec.docstring)
def f(self, e): if enforce_estatevar_wf.value: from cozy.syntax_tools import free_vars, pprint if not all(not v.id.startswith("_") for v in free_vars(e)): raise IllegalStateVarBoundary(pprint(e)) old(self, e)
def cleanup(self): """ Remove unused state, queries, and updates. """ def deps(thing): if isinstance(thing, str): yield from free_vars(self.query_impls[thing]) elif isinstance(thing, EVar): for op in self.op_specs: yield self.updates[(thing, op.name)] elif isinstance(thing, Stm): yield from self.queries_used_by(thing) else: raise ValueError(repr(thing)) g = DirectedGraph( nodes=itertools.chain(self.query_impls.keys(), (v for v, _ in self._concretization_functions), self.updates.values()), successors=deps) roots = [q.name for q in self.query_specs if q.visibility == Visibility.Public] roots.extend(itertools.chain(*[self.queries_used_by(code) for ((ht, op_name), code) in self.handle_updates.items()])) queries_to_keep = set(q for q in g.reachable_nodes(roots) if isinstance(q, str)) # remove old specs for q in list(self.query_specs): if q.name not in queries_to_keep: self.query_specs.remove(q) # remove old implementations for qname in list(self.query_impls.keys()): if qname not in queries_to_keep: del self.query_impls[qname] # remove old state vars self._concretization_functions = [ v for v in self._concretization_functions if any(v[0] in free_vars(q) for q in self.query_impls.values()) ] # remove old method implementations for k in list(self.updates.keys()): v, op_name = k if v not in [var for (var, exp) in self._concretization_functions]: del self.updates[k]
def _try_optimize(e : Exp, context : Context, pool : Pool): if not accelerate.value: return if pool != RUNTIME_POOL: return state_vars = [v for v, p in context.vars() if p == STATE_POOL] args = [v for v, p in context.vars() if p == RUNTIME_POOL] # --------------------------------------------------------------------- # "Rewrite schemes": these trigger on many different AST shapes # They are listed first because they are more powerful than the # specific rewrite rules below. if not free_vars(e) and not free_funcs(e): try: yield _check(uneval(e.type, eval(e, {})), context, RUNTIME_POOL) except NotImplementedError: print("Unable to evaluate {!r}".format(e)) if all(v in state_vars for v in free_vars(e)): nsv = strip_EStateVar(e) sv = EStateVar(nsv).with_type(e.type) yield _check(sv, context, RUNTIME_POOL) for ee in fold_into_map(e, context): yield _check(ee, context, pool) # --------------------------------------------------------------------- # "Rewrites": these trigger on specific AST nodes if isinstance(e, EBinOp): if e.op == "-" and is_collection(e.type): ee = optimized_bag_difference(e.e1, e.e2) yield _check(ee, context, RUNTIME_POOL) if e.op == "===" and isinstance(e.e1.type, THandle): yield _check(EAll([ optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)), optimized_eq(optimized_val(e.e1), optimized_val(e.e2)).with_type(BOOL)]), context, RUNTIME_POOL) if e.op == BOp.In: ee = optimized_in(e.e1, e.e2) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, ECond): yield _check(optimized_cond(e.cond, e.then_branch, e.else_branch), context, RUNTIME_POOL) if isinstance(e, EGetField): for ee in optimized_get_field(e.e, e.field_name, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EListGet) and e.index == ZERO: for res in optimized_the(e.e, args): yield _check(res, context, RUNTIME_POOL) if isinstance(e, EListGet) and isinstance(e.e, ECond): yield optimized_cond(e.e.cond, EListGet(e.e.then_branch, e.index).with_type(e.type), EListGet(e.e.else_branch, e.index).with_type(e.type)) from cozy.structures.treemultiset import ETreeMultisetElems, ETreeMultisetPeek if isinstance(e, EListGet) and isinstance(e.e, ETreeMultisetElems): yield ETreeMultisetPeek(e.e.e, e.index).with_type(e.type) if isinstance(e, EMapGet): ee = inline_mapget(e, context) yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EUnaryOp): if e.op == UOp.Sum: for ee in optimized_sum(e.e, args): yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Length: ee = optimized_len(e.e) yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Empty: ee = optimized_empty(e.e) yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Exists: ee = optimized_exists(e.e) yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.Distinct: for ee in optimized_distinct(e.e, args): yield _check(ee, context, RUNTIME_POOL) if e.op == UOp.The: for ee in optimized_the(e.e, args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EArgMin) or isinstance(e, EArgMax): for ee in optimized_best(e.e, e.key_function, "<" if isinstance(e, EArgMin) else ">", args=args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EFilter): for ee in optimized_filter(e.e, e.predicate, args=args): yield _check(ee, context, RUNTIME_POOL) if isinstance(e, EMap): for ee in optimized_map(e.e, e.transform_function, args=args): yield _check(ee, context, RUNTIME_POOL) from cozy.syntax import ESorted from cozy.structures.treemultiset import EMakeMaxTreeMultiset, TMaxTreeMultiset, EMakeMinTreeMultiset, TMinTreeMultiset, ETreeMultisetElems target = e if isinstance(target, ESorted) and isinstance(target.e, EStateVar): e_max = EMakeMaxTreeMultiset(target.e.e).with_type(TMaxTreeMultiset(target.e.e.type.elem_type)) e_min = EMakeMinTreeMultiset(target.e.e).with_type(TMinTreeMultiset(target.e.e.type.elem_type)) ee = optimized_cond(target.asc, ETreeMultisetElems(EStateVar(e_min).with_type(e_min.type)).with_type(target.type), ETreeMultisetElems(EStateVar(e_max).with_type(e_max.type)).with_type(target.type)) yield _check(ee, context, RUNTIME_POOL)
def free_vars_and_funcs(e): for v in free_vars(e): yield v.id for f in free_funcs(e): yield f
def _simple_filter(xs : Exp, p : ELambda, args : {EVar}): """Assumes the body of p is already in negation normal form""" if p.body == ETRUE: yield xs return if p.body == EFALSE: yield EEmptyList().with_type(xs.type) return if isinstance(xs, EEmptyList): yield xs return yielded = False if isinstance(xs, ESingleton): yielded = True yield optimized_cond(p.apply_to(xs.e), xs, EEmptyList().with_type(xs.type)) if isinstance(p.body, EBinOp) and p.body.op == BOp.Or: for e1, e2 in itertools.permutations([p.body.e1, p.body.e2]): for r1 in _simple_filter(xs, ELambda(p.arg, e1), args): for r2 in _simple_filter(xs, ELambda(p.arg, EAll([e2, ENot(e1)])), args): yielded = True yield EBinOp(r1, "+", r2).with_type(xs.type) if isinstance(p.body, EBinOp) and p.body.op == BOp.And: for e1, e2 in itertools.permutations([p.body.e1, p.body.e2]): for r1 in _simple_filter(xs, ELambda(p.arg, e1), args): yielded = True yield from _simple_filter(r1, ELambda(p.arg, e2), args) if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(p)): yielded = True yield EStateVar(EFilter(xs.e, strip_EStateVar(p)).with_type(xs.type)).with_type(xs.type) if isinstance(xs, EMapGet) and isinstance(xs.map, EStateVar) and not any(v in args for v in free_vars(p)): for m in map_values_multi(xs.map.e, lambda ys: _simple_filter(ys, p, args)): yielded = True yield EMapGet(EStateVar(m).with_type(m.type), xs.key).with_type(xs.type) if isinstance(xs, EBinOp) and xs.op in ("+", "-"): for e1 in _simple_filter(xs.e1, p, args): for e2 in _simple_filter(xs.e2, p, args): yielded = True yield EBinOp(e1, xs.op, e2).with_type(xs.type) if isinstance(p.body, EBinOp) and p.body.op == "==": e1 = p.body.e1 e2 = p.body.e2 fvs2 = free_vars(e2) fvs1 = free_vars(e1) for (e1, fvs1), (e2, fvs2) in itertools.permutations([(e1, fvs1), (e2, fvs2)]): if p.arg in fvs1 and not any(a in fvs1 for a in args) and p.arg not in fvs2 and isinstance(xs, EStateVar): if e1 == p.arg: yield optimized_cond( optimized_in(e2, xs), ESingleton(e2).with_type(xs.type), EEmptyList().with_type(xs.type)) k = fresh_var(e1.type) e = EMapGet( EStateVar( EMakeMap2( EMap(xs.e, ELambda(p.arg, e1)), ELambda(k, EFilter(xs.e, ELambda(p.arg, EEq(e1, k)))))), e2) res = retypecheck(e) assert res yielded = True yield e if not yielded: yield EFilter(xs, p).with_type(xs.type)
def visit_Spec(self, spec, state_exps, sharing, abstract_state=()): self.state_exps = state_exps self.funcs = { f.name: f for f in spec.extern_funcs } self.queries = { q.name: q for q in spec.methods if isinstance(q, Query) } self.vars = set(e.id for e in all_exps(spec) if isinstance(e, EVar)) self.setup_types(spec, state_exps, sharing) if guava.value: self.write("import com.google.common.collect.TreeMultiset;\n") self.write("import com.google.common.collect.Iterators;\n") if spec.header: self.write(spec.header.strip() + "\n\n") if spec.docstring: self.write(spec.docstring + "\n") self.write("public class {} implements java.io.Serializable ".format(spec.name)) with self.block(): for name, t in spec.types: self.types[t] = name # member variables for name, t in spec.statevars: self.write("{}protected {};\n".format(INDENT, self.visit(t, name))) # constructor self.write( "{indent}public {name}() {{\n{indent2}clear();\n{indent}}}\n\n" .format(indent=INDENT, indent2=INDENT+INDENT, name=spec.name)) # explicit constructor if abstract_state: self.begin_statement() self.write("public ", spec.name, "(") self.visit_args(abstract_state) self.write(") ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] self.visit(simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value))) self.end_statement() # clear self.begin_statement() self.write("public void clear() ") with self.block(): for name, t in spec.statevars: initial_value = state_exps[name] fvs = free_vars(initial_value) initial_value = subst(initial_value, {v.id : evaluation.construct_value(v.type) for v in fvs}) setup = simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value)) self.visit(setup) self.end_statement() # methods for op in spec.methods: self.visit(op) # generate auxiliary types for t, name in self.types.items(): self.define_type(spec.name, t, name, sharing) self.write("\n") self.write(spec.footer) if not spec.footer.endswith("\n"): self.write("\n")
def pull_temps(s : Stm, decls_out : [SDecl], exp_is_bad) -> Stm: """Remove "bad" expressions from `s`. This procedure returns a statement new_s that replaces every expression in `s` where `exp_is_bad` returns True with a fresh variable. After running, `decls_out` contains definitions for the fresh variables so that the whole statement decls_out; new_s should return the same result as `s`. """ def pull(e : Exp) -> Exp: """Pull an expression into a temporary. Creates a fresh variable for `e`, writes a declaration into `decls_out`, and returns the fresh variable. """ if exp_is_bad(e): v = fresh_var(e.type) decls_out.append(SDecl(v, e)) return v return e if isinstance(s, SNoOp): return s if isinstance(s, SSeq): s1 = pull_temps(s.s1, decls_out, exp_is_bad) s2 = pull_temps(s.s2, decls_out, exp_is_bad) return SSeq(s1, s2) if isinstance(s, SDecl): return SDecl(s.var, pull(s.val)) if isinstance(s, SIf): cond = pull(s.cond) s1 = pull_temps(s.then_branch, decls_out, exp_is_bad) s2 = pull_temps(s.else_branch, decls_out, exp_is_bad) return SIf(cond, s1, s2) if isinstance(s, SForEach): bag = pull(s.iter) d_tmp = [] body = pull_temps(s.body, d_tmp, exp_is_bad) to_fix, ok = partition(d_tmp, lambda d: s.loop_var in free_vars(d.val)) decls_out.extend(ok) for d in to_fix: v = d.var mt = TMap(s.loop_var.type, v.type) m = EMakeMap2(bag, ELambda(s.loop_var, d.val)).with_type(mt) mv = fresh_var(m.type) md = SDecl(mv, m) decls_out.append(md) body = subst(body, { v.id : EMapGet(mv, s.loop_var).with_type(v.type) }) return SForEach(s.loop_var, bag, body) if isinstance(s, SAssign): return SAssign(s.lhs, pull(s.rhs)) if isinstance(s, SCall): return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args)) if isinstance(s, SMapDel): return SMapDel(s.map, pull(s.key)) if isinstance(s, SMapPut): return SMapPut(s.map, pull(s.key), pull(s.value)) if isinstance(s, SMapUpdate): key = pull(s.key) d_tmp = [] change = pull_temps(s.change, d_tmp, exp_is_bad) for d in d_tmp: if s.val_var in free_vars(d.val): decls_out.append(SDecl(d.var, subst(d.val, { s.val_var.id : EMapGet(s.map, key).with_type(s.val_var.type) }))) else: decls_out.append(d) return SMapUpdate(s.map, key, s.val_var, change) raise NotImplementedError(s)
def EDeepIn(e1, e2): from cozy.syntax_tools import free_vars, fresh_var arg = fresh_var(e1.type, omit=free_vars(e1)) return EUnaryOp(UOp.Any, EMap(e2, ELambda(arg, EDeepEq(arg, e1))).with_type(BOOL_BAG)).with_type(BOOL)