def always(self, op, other, cards : Exp, **kwargs) -> bool: """ Partial order on costs. """ if isinstance(self.formula, ENum) and isinstance(other.formula, ENum): return eval(EBinOp(self.formula, op, other.formula).with_type(BOOL), env={}) f = EImplies(cards, EBinOp(self.formula, op, other.formula).with_type(BOOL)) if integer_cardinalities.value: try: return valid(f, logic="QF_LIA", timeout=1, **kwargs) except SolverReportedUnknown: # If we accidentally made an unsolveable integer arithmetic formula, # then try again with real numbers. This will admit some models that # are not possible (since bags must have integer cardinalities), but # returning false is always a safe move here, so it's fine. print("Warning: not able to solve {}".format(pprint(f))) f = subst(f, { v.id : EVar(v.id).with_type(REAL) for v in free_vars(cards) }) # This timeout is dangerous! Sufficiently complex specifications # will cause this to timeout _every_time_, meaning we never make # progress. # However, this timeout helps ensure liveness: the Python process # never gets deadlocked waiting for Z3. In the Distant Future it # would be nice to move away from Z3Py and invoke Z3 as a subprocess # instead. That would allow the Python process to break out if it is # asked to stop while Z3 is running. It would also give us added # protection against Z3 segfaults, which have been observed in the # wild from time to time. timeout = 60 try: return valid(f, logic="QF_NRA", timeout=timeout, **kwargs) except SolverReportedUnknown: print("Giving up!") return False
def test_handle_writes(self): t = THandle("T", INT) x = EVar("x").with_type(t) y = EVar("y").with_type(t) z = EVar("z").with_type(t) e1 = EGetField(x, "val").with_type(t.value_type) e2 = inc.mutate(e1, SAssign(EGetField(y, "val").with_type(t.value_type), ZERO)) assert not valid(EEq(e1, e2)) assert valid(EImplies(ENot(EEq(x, y)), EEq(e1, e2)))
def test_handle_writes(self): t = THandle("elem_type", INT) x = EVar("x").with_type(t) y = EVar("y").with_type(t) z = EVar("z").with_type(t) e1 = EGetField(x, "val").with_type(t.value_type) e2 = inc.mutate(e1, SAssign(EGetField(y, "val").with_type(t.value_type), ZERO)) assert not valid(EEq(e1, e2)) assert valid(EImplies(ENot(EEq(x, y)), EEq(e1, e2)))
def test_map_eq(self): k = TNative("V") v = TBag(THandle("H", k)) t = TMap(k, v) m1 = EVar("m1").with_type(t) m2 = EVar("m1").with_type(t) e = EImplies(EEq(m1, m2), EEq(EMapKeys(m1), EMapKeys(m2))) assert retypecheck(e) assert valid(e, collection_depth=3) k = EVar("k").with_type(t.k) e = EImplies(EEq(m1, m2), EEq(EMapGet(m1, k), EMapGet(m2, k))) assert retypecheck(e) assert valid(e, collection_depth=3)
def check_the_wf(spec : Spec): res = [] for (a, e, r, bound) in enumerate_fragments(spec): if isinstance(e, EUnaryOp) and e.op == UOp.The: if not valid(cse(EImplies(EAll(a), EAny([EIsSingleton(e.e), EEmpty(e.e)])))): res.append("at {}: `the` is illegal since its argument may not be singleton".format(pprint(e))) return res
def test_conditional(self): x = EVar("x").with_type(INT) b = EVar("b").with_type(BOOL) s = SIf(b, SAssign(x, ONE), SAssign(x, ZERO)) assert valid(EEq( inc.mutate(x, s), ECond(b, ONE, ZERO).with_type(INT)))
def check_minmax_wf(spec : Spec): res = [] for ctx in enumerate_fragments(spec): e = ctx.e if isinstance(e, EArgMin) or isinstance(e, EArgMax): a = ctx.facts if not valid(EImplies(EAll(a), EUnaryOp(UOp.Exists, e.e).with_type(BOOL))): res.append("at {}: result is ambiguous since {} could be empty".format(pprint(e), pprint(e.e))) return res
def check_the_wf(spec : Spec): res = [] for ctx in enumerate_fragments(spec): e = ctx.e if isinstance(e, EUnaryOp) and e.op == UOp.The: a = ctx.facts if not valid(EImplies(EAll(a), EAny([EIsSingleton(e.e), EEmpty(e.e)]))): res.append("at {}: `the` is illegal since its argument may not be singleton".format(pprint(e))) return res
def queries_equivalent(q1: Query, q2: Query): if q1.ret.type != q2.ret.type: return False q1args = dict(q1.args) q2args = dict(q2.args) if q1args != q2args: return False q1a = EAll(q1.assumptions) q2a = EAll(q2.assumptions) return valid(EImplies(EAny([q1a, q2a]), EEq(q1.ret, q2.ret)))
def test_no_argument_conflict_lambda(self): x = EVar("x").with_type(TInt()) y = EVar("y").with_type(TInt()) f = ELambda(x, EBinOp(y, "+", ENum(1).with_type(INT))) assert retypecheck(f) g = subst(f, {y.id: x}) a = EVar("a").with_type(TInt()) b = EVar("b").with_type(TInt()) assert valid(equal(g.apply_to(a), g.apply_to(b)))
def test_no_argument_conflict_lambda(self): x = EVar("x").with_type(TInt()) y = EVar("y").with_type(TInt()) f = ELambda(x, EBinOp(y, "+", ENum(1).with_type(INT))) assert retypecheck(f) g = subst(f, { y.id : x }) a = EVar("a").with_type(TInt()) b = EVar("b").with_type(TInt()) assert valid(EEq(g.apply_to(a), g.apply_to(b)))
def test_map_discovery(self): xs = EVar("xs").with_type(INT_BAG) y = EVar("y").with_type(INT) spec = EFilter(EStateVar(xs), mk_lambda(INT, lambda x: EEq(x, y))) assert retypecheck(spec) assert check_discovery( spec=spec, expected=lambda e: isinstance(e, EMapGet) and isinstance( e.map, EStateVar) and valid(EEq(e, spec)), args=[y], state_vars=[xs])
def assert_same(self, e1, e2): def dbg(model): print("model: {!r}".format(model)) r1 = eval(e1, model) r2 = eval(e2, model) print("e1: {}".format(pprint(e1))) print(" ---> {!r}".format(r1)) print("e2: {}".format(pprint(e2))) print(" ---> {!r}".format(r2)) assert valid(EBinOp(e1, "===", e2).with_type(BOOL), model_callback=dbg)
def test_bag_plus_minus(self): t = THandle("H", INT) x = EVar("x").with_type(t) xs = EVar("xs").with_type(TBag(t)) spec = EBinOp(EBinOp(xs, "+", ESingleton(x)), "-", ESingleton(x)) expected = xs assert retypecheck(spec) assert valid(EEq(spec, expected)) ex = satisfy(ENot(EBinOp(spec, "===", expected).with_type(BOOL))) assert ex is not None assert check_discovery(spec=spec, expected=expected, args=[x, xs], examples=[ex])
def test_map_discovery2(self): xs = EVar("xs").with_type(INT_BAG) y = EVar("y").with_type(INT) spec = EIn(y, EStateVar(xs)) assert retypecheck(spec) assert check_discovery( spec=spec, expected=lambda e: (isinstance(e, EMapGet) or isinstance(e, EHasKey)) and isinstance( e.map, EStateVar) and valid(EEq(e, spec)), args=[y], state_vars=[xs])
def test_mutate_sequence_order2(self): e = EVar("xs").with_type(INT_BAG) x = EVar("x").with_type(INT) y = EVar("y").with_type(INT) s = SSeq( SCall(e, "remove", (y,)), SCall(e, "add", (x,))) assert valid(EDeepEq( inc.mutate(e, s), EBinOp(EBinOp(e, "-", ESingleton(y).with_type(INT_BAG)).with_type(INT_BAG), "+", ESingleton(x).with_type(INT_BAG)).with_type(INT_BAG)))
def test_enumerate_fragments_strange_binder_behavior(self): xs = EVar("xs").with_type(TBag(INT)) x = EVar("x").with_type(INT) xs_eq_zero = EFilter(xs, ELambda(x, equal(x, ZERO))) e = EFilter(xs_eq_zero, ELambda(x, equal( EFilter(xs, ELambda(x, T)), EEmptyList().with_type(xs.type)))) assert retypecheck(e) for (a, e, r, bound) in enumerate_fragments(e): if e == T: assert not valid(implies(EAll(a), equal(x, ZERO)), validate_model=True), "assumptions at {}: {}".format(pprint(e), "; ".join(pprint(aa) for aa in a))
def _do_cse_check(self, e): for x in sorted(set(all_exps(e)), key=Exp.size): if isinstance(x, ELambda): continue print("checking {}...".format(pprint(x))) y = cse(x) if not valid(EBinOp(x, "===", y).with_type(BOOL)): print("Bad behavior!") print(pprint(x)) print(pprint(y)) return False return True
def impls(self, e: Exp, assumptions: Exp): ty = e.type if type(ty) is TMap: k = fresh_var(ty.k) for v in self.impls( EMapGet(e, k).with_type(e.type.v), assumptions): if is_enumerable(ty.k): yield TVectorMap(ty.k, v) else: yield TNativeMap(ty.k, v) elif type(ty) is TSet or (type(ty) is TBag and valid( EImplies(assumptions, EUnaryOp(UOp.AreUnique, e).with_type(BOOL)), model_callback=print)): if isinstance(ty.t, THandle): yield TIntrusiveLinkedList(ty.t) x = fresh_var(ty.t) for t in self.impls(x, EAll((assumptions, EIn(x, e)))): yield TNativeSet(t) elif type(ty) is TBag: x = fresh_var(ty.t) for t in self.impls(x, EAll((assumptions, EIn(x, e)))): yield TNativeList(t) elif type(ty) is TList: if isinstance(ty.t, THandle) and valid(EImplies( assumptions, EUnaryOp(UOp.AreUnique, e).with_type(BOOL)), model_callback=print): yield TIntrusiveLinkedList(ty.t) yield TNativeList(ty.t) elif type(ty) is TTuple: for refinements in cross_product([ self.impls( ETupleGet(e, i).with_type(ty.ts[i]), assumptions) for i in range(len(ty.ts)) ]): yield TTuple(refinements) else: yield ty
def can_elim_vars(spec : Exp, assumptions : Exp, vs : [EVar]): """Does any execution of `spec` actually depend on any of `vs`? It is possible for a variable to appear in an expression like `spec` without affecting its value. This function uses the solver to determine whether any of the given variables can affect the output of `spec`. """ spec = strip_EStateVar(spec) sub = { v.id : fresh_var(v.type) for v in vs } return valid(EImplies( EAll([assumptions, subst(assumptions, sub)]), EEq(spec, subst(spec, sub))))
def can_elim_vars(spec: Exp, assumptions: Exp, vs: [EVar]): """Does any execution of `spec` actually depend on any of `vs`? It is possible for a variable to appear in an expression like `spec` without affecting its value. This function uses the solver to determine whether any of the given variables can affect the output of `spec`. """ spec = strip_EStateVar(spec) sub = {v.id: fresh_var(v.type) for v in vs} return valid( EImplies(EAll([assumptions, subst(assumptions, sub)]), EEq(spec, subst(spec, sub))))
def check_ops_preserve_invariants(spec : Spec): if not invariant_preservation_check.value: return [] res = [] for m in spec.methods: if not isinstance(m, Op): continue for a in spec.assumptions: print("Checking that {} preserves {}...".format(m.name, pprint(a))) a_post_delta = mutate(a, m.body) assumptions = list(m.assumptions) + list(spec.assumptions) if not valid(EImplies(EAll(assumptions), a_post_delta)): res.append("{.name!r} may not preserve invariant {}".format(m, pprint(a))) return res
def _add_subquery(self, sub_q : Query, used_by : Stm) -> Stm: with task("adding query", query=sub_q.name): sub_q = shallow_copy(sub_q) with task("checking whether we need more handle assumptions"): new_a = implicit_handle_assumptions_for_method( reachable_handles_at_method(self.spec, sub_q), sub_q) if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))): event("we do!") sub_q.assumptions = list(itertools.chain(sub_q.assumptions, new_a)) with task("simplifying"): orig_a = sub_q.assumptions orig_a_size = sum(a.size() for a in sub_q.assumptions) orig_ret_size = sub_q.ret.size() sub_q.assumptions = tuple(simplify_or_ignore(a) for a in sub_q.assumptions) sub_q.ret = simplify(sub_q.ret) a_size = sum(a.size() for a in sub_q.assumptions) ret_size = sub_q.ret.size() event("|assumptions|: {} -> {}".format(orig_a_size, a_size)) event("|ret|: {} -> {}".format(orig_ret_size, ret_size)) if a_size > orig_a_size: print("NO, BAD SIMPLIFICATION") print("original") for a in orig_a: print(" - {}".format(pprint(a))) print("simplified") for a in sub_q.assumptions: print(" - {}".format(pprint(a))) assert False state_vars = self.abstract_state funcs = self.extern_funcs qq = find_one(self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs)) if qq is not None: event("subgoal {} is equivalent to {}".format(sub_q.name, qq.name)) arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args] class Repl(BottomUpRewriter): def visit_ECall(self, e): args = tuple(self.visit(a) for a in e.args) if e.func == sub_q.name: args = tuple(args[idx] for idx in arg_reorder) return ECall(qq.name, args).with_type(e.type) else: return ECall(e.func, args).with_type(e.type) used_by = Repl().visit(used_by) else: self.add_query(sub_q) return used_by
def check_ops_preserve_invariants(spec : Spec): if not invariant_preservation_check.value: return [] res = [] for m in spec.methods: if not isinstance(m, Op): continue for a in spec.assumptions: print("Checking that {} preserves {}...".format(m.name, pprint(a))) a_post_delta = mutate(a, m.body) if not alpha_equivalent(a, a_post_delta): assumptions = list(m.assumptions) + list(spec.assumptions) if not valid(EImplies(EAll(assumptions), a_post_delta)): res.append("{.name!r} may not preserve invariant {}".format(m, pprint(a))) return res
def check_calls_wf(spec : Spec): res = [] queries = { m.name : m for m in spec.methods if isinstance(m, Query) } for ctx in enumerate_fragments(spec): e = ctx.e if isinstance(e, ECall): q = queries.get(e.func) if q is None: continue print("Checking call {}...".format(pprint(e))) a = EAll(ctx.facts) for precond in q.assumptions: precond = mutate(subst(precond, { v : val for (v, t), val in zip(q.args, e.args) }), ctx.mutations) if not valid(inline_calls(spec, EImplies(a, precond))): res.append("at {}: call may not satisfy precondition {}".format(pprint(e), pprint(precond))) return res
def simplify(e, validate=None, debug=False): if validate is None: validate = checked_simplify.value try: visitor = _V(debug) orig = e e = visitor.visit(e) # e = cse(e) if validate and not valid(EBinOp(orig, "===", e).with_type(BOOL)): import sys print("simplify did something stupid!\nto reproduce:\nsimplify({e!r}, validate=True, debug=True)".format(e=orig), file=sys.stderr) return orig return e except: print("SIMPL FAILED") print(repr(e)) raise
def check_ops_preserve_invariants(spec : Spec): if not invariant_preservation_check.value: return [] res = [] for m in spec.methods: if not isinstance(m, Op): continue remap = delta_form(spec.statevars, m) # print(m.name) # for id, e in remap.items(): # print(" {id} ---> {e}".format(id=id, e=pprint(e))) for a in spec.assumptions: a_post_delta = subst(a, remap) assumptions = list(m.assumptions) + list(spec.assumptions) if not valid(cse(EImplies(EAll(assumptions), a_post_delta))): res.append("{.name!r} may not preserve invariant {}".format(m, pprint(a))) return res
def simplify(e, validate=None, debug=False): if validate is None: validate = checked_simplify.value try: visitor = _SimplificationVisitor(debug) orig = e e = visitor.visit(e) # assert orig.type == e.type, "simplification changed the expression's type: {} --> {}".format(pprint(orig.type), pprint(e.type)) # e = cse(e) if validate and not valid(EBinOp(orig, "===", e).with_type(BOOL)): import sys print("simplify did something stupid!\nto reproduce:\nsimplify({e!r}, validate=True, debug=True)".format(e=orig), file=sys.stderr) return orig return e except: print("SIMPL FAILED") print(repr(e)) raise
def cardinality_le(c1 : Exp, c2 : Exp, assumptions : Exp = T, as_f : bool = False, solver : IncrementalSolver = None) -> bool: """ Is |c1| <= |c2|? Yes, iff there are no v such that v occurs more times in c2 than in c1. """ if True: f = EBinOp(ELen(c1), "<=", ELen(c2)).with_type(BOOL) else: assert c1.type == c2.type # Oh heck. # This isn't actually very smart if: # x = [y] # a = Filter (!= y) b # This method can't prove that |x| <= |a|, even though |a| is likely huge v = fresh_var(c1.type.t) f = EBinOp(ECountIn(v, c1), "<=", ECountIn(v, c2)).with_type(BOOL) if as_f: return f res = solver.valid(EImplies(assumptions, f)) if solver else valid(EImplies(assumptions, f)) # assert res == valid(EImplies(assumptions, f)) return res
def set_impl(self, q: Query, rep: [(EVar, Exp)], ret: Exp): to_remove = set() from cozy.solver import valid for (v, e) in rep: aeq = find_one( vv for (vv, ee) in self.concrete_state if e.type == ee.type and valid(EImplies(EAll(self.spec.assumptions), EEq(e, ee)))) # aeq = find_one(vv for (vv, ee) in self.concrete_state if e.type == ee.type and alpha_equivalent(e, ee)) if aeq is not None: print("########### state var {} is equivalent to {}".format( v.id, aeq.id)) ret = subst(ret, {v.id: aeq}) to_remove.add(v) rep = [x for x in rep if x[0] not in to_remove] self.concrete_state.extend(rep) self.query_impls[q.name] = rewrite_ret(q, lambda prev: ret, keep_assumptions=False) op_deltas = { op.name: inc.delta_form(self.spec.statevars, op) for op in self.op_specs } for op in self.op_specs: # print("###### INCREMENTALIZING: {}".format(op.name)) delta = op_deltas[op.name] for new_member, projection in rep: (state_update_stm, subqueries) = inc.sketch_update(new_member, projection, subst(projection, delta), self.abstract_state, list(op.assumptions)) for sub_q in subqueries: sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring) state_update_stm = self._add_subquery( sub_q=sub_q, used_by=state_update_stm) self.updates[(new_member, op.name)] = state_update_stm
def _fix_map(m: target_syntax.EMap) -> syntax.Exp: return m from cozy.simplification import simplify m = simplify(m) if not isinstance(m, target_syntax.EMap): return m elem_type = m.e.type.t assert m.f.body.type == elem_type changed = target_syntax.EFilter( m.e, mk_lambda( elem_type, lambda x: syntax.ENot( syntax.EBinOp(x, "===", m.f.apply_to(x)).with_type(syntax.BOOL) ))).with_type(m.e.type) e = syntax.EBinOp( syntax.EBinOp(m.e, "-", changed).with_type(m.e.type), "+", target_syntax.EMap(changed, m.f).with_type(m.e.type)).with_type(m.e.type) if not valid(syntax.EEq(m, e)): print("WARNING: rewrite failed") print("_fix_map({!r})".format(m)) return m return e
def sketch_update( lval: syntax.Exp, old_value: syntax.Exp, new_value: syntax.Exp, ctx: [syntax.EVar], assumptions: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]): """ Write code to update `lval` when it changes from `old_value` to `new_value`. Variables in `ctx` are assumed to be part of the data structure abstract state, and `assumptions` will be appended to all generated subgoals. This function returns a statement (code to update `lval`) and a list of subgoals (new queries that appear in the code). """ if valid( syntax.EImplies(syntax.EAll(assumptions), syntax.EEq(old_value, new_value))): return (syntax.SNoOp(), []) subgoals = [] def make_subgoal(e, a=[], docstring=None): e = strip_EStateVar(e) if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)): return e query_name = fresh_name("query") query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring) query_vars = [v for v in free_vars(query) if v not in ctx] query.args = [(arg.id, arg.type) for arg in query_vars] subgoals.append(query) return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type) def recurse(*args, **kwargs): (code, sgs) = sketch_update(*args, **kwargs) subgoals.extend(sgs) return code t = lval.type if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet): to_add = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="additions to {}".format(pprint(lval))) to_del = make_subgoal( syntax.EBinOp(old_value, "-", new_value).with_type(t), docstring="deletions from {}".format(pprint(lval))) v = fresh_var(t.t) stm = syntax.seq([ syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])), syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v])) ]) # elif isinstance(t, syntax.TList): # raise NotImplementedError() elif is_numeric(t): change = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="delta for {}".format(pprint(lval))) stm = syntax.SAssign(lval, syntax.EBinOp(lval, "+", change).with_type(t)) elif isinstance(t, syntax.TTuple): get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions) for i in range(len(t.ts)) ]) elif isinstance(t, syntax.TRecord): get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type( t.fields[i][1]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions) for i in range(len(t.fields)) ]) elif isinstance(t, syntax.THandle): # handles are tricky, and are dealt with at a higher level stm = syntax.SNoOp() elif isinstance(t, syntax.TMap): value_at = lambda m, k: target_syntax.EMapGet(m, k).with_type(lval.type .v) k = fresh_var(lval.type.k) v = fresh_var(lval.type.v) key_bag = syntax.TBag(lval.type.k) if True: old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag) new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag) # (1) exit set deleted_keys = target_syntax.EFilter( old_keys, target_syntax.ELambda(k, syntax.ENot(syntax.EIn( k, new_keys)))).with_type(key_bag) s1 = syntax.SForEach( k, make_subgoal(deleted_keys, docstring="keys removed from {}".format( pprint(lval))), target_syntax.SMapDel(lval, k)) # (2) modify set common_keys = target_syntax.EFilter( old_keys, target_syntax.ELambda(k, syntax.EIn(k, new_keys))).with_type(key_bag) update_value = recurse(v, value_at(old_value, k), value_at(new_value, k), ctx=ctx, assumptions=assumptions + [ syntax.EIn(k, common_keys), syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))) ]) altered_keys = target_syntax.EFilter( common_keys, target_syntax.ELambda( k, syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))))).with_type(key_bag) s2 = syntax.SForEach( k, make_subgoal(altered_keys, docstring="altered keys in {}".format( pprint(lval))), target_syntax.SMapUpdate(lval, k, v, update_value)) # (3) enter set fresh_keys = target_syntax.EFilter( new_keys, target_syntax.ELambda(k, syntax.ENot(syntax.EIn( k, old_keys)))).with_type(key_bag) s3 = syntax.SForEach( k, make_subgoal(fresh_keys, docstring="new keys in {}".format(pprint(lval))), target_syntax.SMapPut( lval, k, make_subgoal(value_at(new_value, k), a=[syntax.EIn(k, fresh_keys)], docstring="new value inserted at {}".format( pprint(target_syntax.EMapGet(lval, k)))))) stm = syntax.seq([s1, s2, s3]) else: # update_value = code to update for value v at key k (given that k is an altered key) update_value = recurse(v, value_at(old_value, k), value_at(new_value, k), ctx=ctx, assumptions=assumptions + [ syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))) ]) # altered_keys = [k | k <- distinct(lval.keys() + new_value.keys()), value_at(old_value, k) != value_at(new_value, k))] altered_keys = make_subgoal( target_syntax.EFilter( syntax.EUnaryOp( syntax.UOp.Distinct, syntax.EBinOp( target_syntax.EMapKeys(old_value).with_type( key_bag), "+", target_syntax.EMapKeys(new_value).with_type( key_bag)).with_type(key_bag)).with_type( key_bag), target_syntax.ELambda( k, syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))))).with_type(key_bag)) stm = syntax.SForEach( k, altered_keys, target_syntax.SMapUpdate(lval, k, v, update_value)) else: # Fallback rule: just compute a new value from scratch stm = syntax.SAssign( lval, make_subgoal(new_value, docstring="new value for {}".format(pprint(lval)))) return (stm, subgoals)
def sketch_update( lval: syntax.Exp, old_value: syntax.Exp, new_value: syntax.Exp, ctx: [syntax.EVar], assumptions: [syntax.Exp] = [], invariants: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]): """ Write code to update `lval` when it changes from `old_value` to `new_value`. Variables in `ctx` are assumed to be part of the data structure abstract state, and `assumptions` will be appended to all generated subgoals. This function returns a statement (code to update `lval`) and a list of subgoals (new queries that appear in the code). """ if valid( syntax.EImplies( syntax.EAll(itertools.chain(assumptions, invariants)), syntax.EEq(old_value, new_value))): return (syntax.SNoOp(), []) subgoals = [] new_value = strip_EStateVar(new_value) def make_subgoal(e, a=[], docstring=None): if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)): return e query_name = fresh_name("query") query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring) query_vars = [v for v in free_vars(query) if v not in ctx] query.args = [(arg.id, arg.type) for arg in query_vars] subgoals.append(query) return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type) def recurse(*args, **kwargs): (code, sgs) = sketch_update(*args, **kwargs) subgoals.extend(sgs) return code t = lval.type if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet): to_add = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="additions to {}".format(pprint(lval))) to_del = make_subgoal( syntax.EBinOp(old_value, "-", new_value).with_type(t), docstring="deletions from {}".format(pprint(lval))) v = fresh_var(t.t) stm = syntax.seq([ syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])), syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v])) ]) elif is_numeric(t) and update_numbers_with_deltas.value: change = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="delta for {}".format(pprint(lval))) stm = syntax.SAssign(lval, syntax.EBinOp(lval, "+", change).with_type(t)) elif isinstance(t, syntax.TTuple): get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions, invariants=invariants) for i in range(len(t.ts)) ]) elif isinstance(t, syntax.TRecord): get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type( t.fields[i][1]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions, invariants=invariants) for i in range(len(t.fields)) ]) elif isinstance(t, syntax.TMap): k = fresh_var(lval.type.k) v = fresh_var(lval.type.v) key_bag = syntax.TBag(lval.type.k) old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag) new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag) # (1) exit set deleted_keys = syntax.EBinOp(old_keys, "-", new_keys).with_type(key_bag) s1 = syntax.SForEach( k, make_subgoal(deleted_keys, docstring="keys removed from {}".format( pprint(lval))), target_syntax.SMapDel(lval, k)) # (2) enter/mod set new_or_modified = target_syntax.EFilter( new_keys, syntax.ELambda( k, syntax.EAny([ syntax.ENot(syntax.EIn(k, old_keys)), syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))) ]))).with_type(key_bag) update_value = recurse(v, value_at(old_value, k), value_at(new_value, k), ctx=ctx, assumptions=assumptions + [ syntax.EIn(k, new_or_modified), syntax.EEq(v, value_at(old_value, k)) ], invariants=invariants) s2 = syntax.SForEach( k, make_subgoal(new_or_modified, docstring="new or modified keys from {}".format( pprint(lval))), target_syntax.SMapUpdate(lval, k, v, update_value)) stm = syntax.SSeq(s1, s2) else: # Fallback rule: just compute a new value from scratch stm = syntax.SAssign( lval, make_subgoal(new_value, docstring="new value for {}".format(pprint(lval)))) return (stm, subgoals)
def sketch_update( lval : syntax.Exp, old_value : syntax.Exp, new_value : syntax.Exp, ctx : [syntax.EVar], assumptions : [syntax.Exp] = [], invariants : [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]): """ Write code to update `lval` when it changes from `old_value` to `new_value`. Variables in `ctx` are assumed to be part of the data structure abstract state, and `assumptions` will be appended to all generated subgoals. This function returns a statement (code to update `lval`) and a list of subgoals (new queries that appear in the code). """ if valid(syntax.EImplies( syntax.EAll(itertools.chain(assumptions, invariants)), syntax.EEq(old_value, new_value))): return (syntax.SNoOp(), []) subgoals = [] new_value = strip_EStateVar(new_value) def make_subgoal(e, a=[], docstring=None): if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)): return e query_name = fresh_name("query") query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring) query_vars = [v for v in free_vars(query) if v not in ctx] query.args = [(arg.id, arg.type) for arg in query_vars] subgoals.append(query) return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type) def recurse(*args, **kwargs): (code, sgs) = sketch_update(*args, **kwargs) subgoals.extend(sgs) return code t = lval.type if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet): to_add = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="additions to {}".format(pprint(lval))) to_del = make_subgoal(syntax.EBinOp(old_value, "-", new_value).with_type(t), docstring="deletions from {}".format(pprint(lval))) v = fresh_var(t.elem_type) stm = syntax.seq([ syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])), syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))]) elif is_numeric(t) and update_numbers_with_deltas.value: change = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="delta for {}".format(pprint(lval))) stm = syntax.SAssign(lval, syntax.EBinOp(lval, "+", change).with_type(t)) elif isinstance(t, syntax.TTuple): get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions, invariants=invariants) for i in range(len(t.ts))]) elif isinstance(t, syntax.TRecord): get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(t.fields[i][1]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions, invariants=invariants) for i in range(len(t.fields))]) elif isinstance(t, syntax.TMap): k = fresh_var(lval.type.k) v = fresh_var(lval.type.v) key_bag = syntax.TBag(lval.type.k) old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag) new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag) # (1) exit set deleted_keys = syntax.EBinOp(old_keys, "-", new_keys).with_type(key_bag) s1 = syntax.SForEach(k, make_subgoal(deleted_keys, docstring="keys removed from {}".format(pprint(lval))), target_syntax.SMapDel(lval, k)) # (2) enter/mod set new_or_modified = target_syntax.EFilter(new_keys, syntax.ELambda(k, syntax.EAny([syntax.ENot(syntax.EIn(k, old_keys)), syntax.ENot(syntax.EEq(value_at(old_value, k), value_at(new_value, k)))]))).with_type(key_bag) update_value = recurse( v, value_at(old_value, k), value_at(new_value, k), ctx = ctx, assumptions = assumptions + [syntax.EIn(k, new_or_modified), syntax.EEq(v, value_at(old_value, k))], invariants = invariants) s2 = syntax.SForEach(k, make_subgoal(new_or_modified, docstring="new or modified keys from {}".format(pprint(lval))), target_syntax.SMapUpdate(lval, k, v, update_value)) stm = syntax.SSeq(s1, s2) else: # Fallback rule: just compute a new value from scratch stm = syntax.SAssign(lval, make_subgoal(new_value, docstring="new value for {}".format(pprint(lval)))) return (stm, subgoals)
def _add_subquery(self, sub_q: Query, used_by: Stm) -> Stm: """Add a query that helps maintain some other state. Parameters: sub_q - the specification of the helper query used_by - the statement that calls `sub_q` If a query already exists that is equivalent to `sub_q`, this method returns `used_by` rewritten to use the existing query and does not add the query to the implementation. Otherwise it returns `used_by` unchanged. """ with task("adding query", query=sub_q.name): sub_q = shallow_copy(sub_q) with task("checking whether we need more handle assumptions"): new_a = implicit_handle_assumptions( reachable_handles_at_method(self.spec, sub_q)) if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))): event("we do!") sub_q.assumptions = list( itertools.chain(sub_q.assumptions, new_a)) with task("repairing state var boundaries"): extra_available_state = [ e for v, e in self._concretization_functions ] sub_q.ret = repair_well_formedness( strip_EStateVar(sub_q.ret), self.context_for_method(sub_q), extra_available_state) with task("simplifying"): orig_a = sub_q.assumptions orig_a_size = sum(a.size() for a in sub_q.assumptions) orig_ret_size = sub_q.ret.size() sub_q.assumptions = tuple( simplify_or_ignore(a) for a in sub_q.assumptions) sub_q.ret = simplify(sub_q.ret) a_size = sum(a.size() for a in sub_q.assumptions) ret_size = sub_q.ret.size() event("|assumptions|: {} -> {}".format(orig_a_size, a_size)) event("|ret|: {} -> {}".format(orig_ret_size, ret_size)) if a_size > orig_a_size: print("NO, BAD SIMPLIFICATION") print("original") for a in orig_a: print(" - {}".format(pprint(a))) print("simplified") for a in sub_q.assumptions: print(" - {}".format(pprint(a))) assert False state_vars = self.abstract_state funcs = self.extern_funcs qq = find_one( self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs, assumptions=EAll(self.abstract_invariants))) if qq is not None: event("subgoal {} is equivalent to {}".format( sub_q.name, qq.name)) arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args] class Repl(BottomUpRewriter): def visit_ECall(self, e): args = tuple(self.visit(a) for a in e.args) if e.func == sub_q.name: args = tuple(args[idx] for idx in arg_reorder) return ECall(qq.name, args).with_type(e.type) else: return ECall(e.func, args).with_type(e.type) used_by = Repl().visit(used_by) else: self.add_query(sub_q) return used_by
def test_map_discovery(self): xs = EVar("xs").with_type(INT_BAG) y = EVar("y").with_type(INT) spec = EFilter(EStateVar(xs), mk_lambda(INT, lambda x: EEq(x, y))) assert retypecheck(spec) assert check_discovery(spec=spec, expected=lambda e: isinstance(e, EMapGet) and isinstance(e.map, EStateVar) and valid(EEq(e, spec)), args=[y], state_vars=[xs])
def _add_subquery(self, sub_q : Query, used_by : Stm) -> Stm: """Add a query that helps maintain some other state. Parameters: sub_q - the specification of the helper query used_by - the statement that calls `sub_q` If a query already exists that is equivalent to `sub_q`, this method returns `used_by` rewritten to use the existing query and does not add the query to the implementation. Otherwise it returns `used_by` unchanged. """ with task("adding query", query=sub_q.name): sub_q = shallow_copy(sub_q) with task("checking whether we need more handle assumptions"): new_a = implicit_handle_assumptions( reachable_handles_at_method(self.spec, sub_q)) if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))): event("we do!") sub_q.assumptions = list(itertools.chain(sub_q.assumptions, new_a)) with task("repairing state var boundaries"): extra_available_state = [e for v, e in self._concretization_functions] sub_q.ret = repair_well_formedness( strip_EStateVar(sub_q.ret), self.context_for_method(sub_q), extra_available_state) with task("simplifying"): orig_a = sub_q.assumptions orig_a_size = sum(a.size() for a in sub_q.assumptions) orig_ret_size = sub_q.ret.size() sub_q.assumptions = tuple(simplify_or_ignore(a) for a in sub_q.assumptions) sub_q.ret = simplify(sub_q.ret) a_size = sum(a.size() for a in sub_q.assumptions) ret_size = sub_q.ret.size() event("|assumptions|: {} -> {}".format(orig_a_size, a_size)) event("|ret|: {} -> {}".format(orig_ret_size, ret_size)) if a_size > orig_a_size: print("NO, BAD SIMPLIFICATION") print("original") for a in orig_a: print(" - {}".format(pprint(a))) print("simplified") for a in sub_q.assumptions: print(" - {}".format(pprint(a))) assert False state_vars = self.abstract_state funcs = self.extern_funcs qq = find_one(self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs, assumptions=EAll(self.abstract_invariants))) if qq is not None: event("subgoal {} is equivalent to {}".format(sub_q.name, qq.name)) arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args] class Repl(BottomUpRewriter): def visit_ECall(self, e): args = tuple(self.visit(a) for a in e.args) if e.func == sub_q.name: args = tuple(args[idx] for idx in arg_reorder) return ECall(qq.name, args).with_type(e.type) else: return ECall(e.func, args).with_type(e.type) used_by = Repl().visit(used_by) else: self.add_query(sub_q) return used_by
def test_map_discovery2(self): xs = EVar("xs").with_type(INT_BAG) y = EVar("y").with_type(INT) spec = EIn(y, EStateVar(xs)) assert retypecheck(spec) assert check_discovery(spec=spec, expected=lambda e: (isinstance(e, EMapGet) or isinstance(e, EHasKey)) and isinstance(e.map, EStateVar) and valid(EEq(e, spec)), args=[y], state_vars=[xs])
def can_elim_vars(spec: Exp, assumptions: Exp, vs: [EVar]): spec = strip_EStateVar(spec) sub = {v.id: fresh_var(v.type) for v in vs} return valid( EImplies(EAll([assumptions, subst(assumptions, sub)]), EEq(spec, subst(spec, sub))))