def uses_intrusive_data(e: target_syntax.Exp, handle: target_syntax.Exp) -> target_syntax.Exp: if isinstance(e, target_syntax.EMakeMap): if isinstance(e.e.type, target_syntax.TBag) and e.e.type.t == handle.type: k = e.key.apply_to(handle) kk = syntax_tools.fresh_var(k.type, "k") return uses_intrusive_data( e.value.apply_to( target_syntax.EFilter( e.e, target_syntax.ELambda(handle, syntax_tools.equal(k, kk)))), handle) return target_syntax.F elif isinstance(e, target_syntax.EMakeMap2): if e.e.type.t == handle.type: k = syntax_tools.fresh_var(e.type.k) return target_syntax.EImplies( target_syntax.EBinOp(k, target_syntax.BOp.In, e.e), uses_intrusive_data(e.value.apply_to(k), handle)) return target_syntax.F elif isinstance(e, target_syntax.EFilter): return target_syntax.EAll( [uses_intrusive_data(e.e, handle), e.p.apply_to(handle)]) elif isinstance(e, target_syntax.EEmptyList): return target_syntax.F elif isinstance(e, target_syntax.EMap): return uses_intrusive_data(e.e, handle) elif isinstance(e, target_syntax.EUnaryOp): return uses_intrusive_data(e.e, handle) elif isinstance(e, target_syntax.EBinOp): return uses_intrusive_data(e.e1, handle) or uses_intrusive_data( e.e2, handle) elif isinstance(e, target_syntax.ECond): return target_syntax.ECond(e.cond, uses_intrusive_data(e.then_branch, handle), uses_intrusive_data(e.else_branch, handle)).with_type( target_syntax.BOOL) elif isinstance(e, target_syntax.ESingleton): if e.type.t == handle.type: return target_syntax.EEq(e.e, handle) return target_syntax.F elif isinstance(e, target_syntax.ETuple): return target_syntax.EAny( uses_intrusive_data(ee, handle) for ee in e.es) elif isinstance(e, target_syntax.EVar): if isinstance(e.type, target_syntax.TBag) and e.type.t == handle.type: return target_syntax.EBinOp(handle, target_syntax.BOp.In, e).with_type(target_syntax.BOOL) return target_syntax.F elif type(e) in [ target_syntax.ENum, target_syntax.EBool, target_syntax.EEnumEntry ]: return target_syntax.F else: raise NotImplementedError(e)
def _fix_map(m: target_syntax.EMap) -> syntax.Exp: return m from cozy.simplification import simplify m = simplify(m) if not isinstance(m, target_syntax.EMap): return m elem_type = m.e.type.t assert m.f.body.type == elem_type changed = target_syntax.EFilter( m.e, mk_lambda( elem_type, lambda x: syntax.ENot( syntax.EBinOp(x, "===", m.f.apply_to(x)).with_type(syntax.BOOL) ))).with_type(m.e.type) e = syntax.EBinOp( syntax.EBinOp(m.e, "-", changed).with_type(m.e.type), "+", target_syntax.EMap(changed, m.f).with_type(m.e.type)).with_type(m.e.type) if not valid(syntax.EEq(m, e)): print("WARNING: rewrite failed") print("_fix_map({!r})".format(m)) return m return e
def sketch_update( lval: syntax.Exp, old_value: syntax.Exp, new_value: syntax.Exp, ctx: [syntax.EVar], assumptions: [syntax.Exp] = [], invariants: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]): """ Write code to update `lval` when it changes from `old_value` to `new_value`. Variables in `ctx` are assumed to be part of the data structure abstract state, and `assumptions` will be appended to all generated subgoals. This function returns a statement (code to update `lval`) and a list of subgoals (new queries that appear in the code). """ if valid( syntax.EImplies( syntax.EAll(itertools.chain(assumptions, invariants)), syntax.EEq(old_value, new_value))): return (syntax.SNoOp(), []) subgoals = [] new_value = strip_EStateVar(new_value) def make_subgoal(e, a=[], docstring=None): if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)): return e query_name = fresh_name("query") query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring) query_vars = [v for v in free_vars(query) if v not in ctx] query.args = [(arg.id, arg.type) for arg in query_vars] subgoals.append(query) return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type) def recurse(*args, **kwargs): (code, sgs) = sketch_update(*args, **kwargs) subgoals.extend(sgs) return code t = lval.type if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet): to_add = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="additions to {}".format(pprint(lval))) to_del = make_subgoal( syntax.EBinOp(old_value, "-", new_value).with_type(t), docstring="deletions from {}".format(pprint(lval))) v = fresh_var(t.t) stm = syntax.seq([ syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])), syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v])) ]) elif is_numeric(t) and update_numbers_with_deltas.value: change = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="delta for {}".format(pprint(lval))) stm = syntax.SAssign(lval, syntax.EBinOp(lval, "+", change).with_type(t)) elif isinstance(t, syntax.TTuple): get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions, invariants=invariants) for i in range(len(t.ts)) ]) elif isinstance(t, syntax.TRecord): get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type( t.fields[i][1]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions, invariants=invariants) for i in range(len(t.fields)) ]) elif isinstance(t, syntax.TMap): k = fresh_var(lval.type.k) v = fresh_var(lval.type.v) key_bag = syntax.TBag(lval.type.k) old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag) new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag) # (1) exit set deleted_keys = syntax.EBinOp(old_keys, "-", new_keys).with_type(key_bag) s1 = syntax.SForEach( k, make_subgoal(deleted_keys, docstring="keys removed from {}".format( pprint(lval))), target_syntax.SMapDel(lval, k)) # (2) enter/mod set new_or_modified = target_syntax.EFilter( new_keys, syntax.ELambda( k, syntax.EAny([ syntax.ENot(syntax.EIn(k, old_keys)), syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))) ]))).with_type(key_bag) update_value = recurse(v, value_at(old_value, k), value_at(new_value, k), ctx=ctx, assumptions=assumptions + [ syntax.EIn(k, new_or_modified), syntax.EEq(v, value_at(old_value, k)) ], invariants=invariants) s2 = syntax.SForEach( k, make_subgoal(new_or_modified, docstring="new or modified keys from {}".format( pprint(lval))), target_syntax.SMapUpdate(lval, k, v, update_value)) stm = syntax.SSeq(s1, s2) else: # Fallback rule: just compute a new value from scratch stm = syntax.SAssign( lval, make_subgoal(new_value, docstring="new value for {}".format(pprint(lval)))) return (stm, subgoals)
def sketch_update( lval: syntax.Exp, old_value: syntax.Exp, new_value: syntax.Exp, ctx: [syntax.EVar], assumptions: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]): """ Write code to update `lval` when it changes from `old_value` to `new_value`. Variables in `ctx` are assumed to be part of the data structure abstract state, and `assumptions` will be appended to all generated subgoals. This function returns a statement (code to update `lval`) and a list of subgoals (new queries that appear in the code). """ if valid( syntax.EImplies(syntax.EAll(assumptions), syntax.EEq(old_value, new_value))): return (syntax.SNoOp(), []) subgoals = [] def make_subgoal(e, a=[], docstring=None): e = strip_EStateVar(e) if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)): return e query_name = fresh_name("query") query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring) query_vars = [v for v in free_vars(query) if v not in ctx] query.args = [(arg.id, arg.type) for arg in query_vars] subgoals.append(query) return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type) def recurse(*args, **kwargs): (code, sgs) = sketch_update(*args, **kwargs) subgoals.extend(sgs) return code t = lval.type if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet): to_add = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="additions to {}".format(pprint(lval))) to_del = make_subgoal( syntax.EBinOp(old_value, "-", new_value).with_type(t), docstring="deletions from {}".format(pprint(lval))) v = fresh_var(t.t) stm = syntax.seq([ syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])), syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v])) ]) # elif isinstance(t, syntax.TList): # raise NotImplementedError() elif is_numeric(t): change = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="delta for {}".format(pprint(lval))) stm = syntax.SAssign(lval, syntax.EBinOp(lval, "+", change).with_type(t)) elif isinstance(t, syntax.TTuple): get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions) for i in range(len(t.ts)) ]) elif isinstance(t, syntax.TRecord): get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type( t.fields[i][1]) stm = syntax.seq([ recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions) for i in range(len(t.fields)) ]) elif isinstance(t, syntax.THandle): # handles are tricky, and are dealt with at a higher level stm = syntax.SNoOp() elif isinstance(t, syntax.TMap): value_at = lambda m, k: target_syntax.EMapGet(m, k).with_type(lval.type .v) k = fresh_var(lval.type.k) v = fresh_var(lval.type.v) key_bag = syntax.TBag(lval.type.k) if True: old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag) new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag) # (1) exit set deleted_keys = target_syntax.EFilter( old_keys, target_syntax.ELambda(k, syntax.ENot(syntax.EIn( k, new_keys)))).with_type(key_bag) s1 = syntax.SForEach( k, make_subgoal(deleted_keys, docstring="keys removed from {}".format( pprint(lval))), target_syntax.SMapDel(lval, k)) # (2) modify set common_keys = target_syntax.EFilter( old_keys, target_syntax.ELambda(k, syntax.EIn(k, new_keys))).with_type(key_bag) update_value = recurse(v, value_at(old_value, k), value_at(new_value, k), ctx=ctx, assumptions=assumptions + [ syntax.EIn(k, common_keys), syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))) ]) altered_keys = target_syntax.EFilter( common_keys, target_syntax.ELambda( k, syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))))).with_type(key_bag) s2 = syntax.SForEach( k, make_subgoal(altered_keys, docstring="altered keys in {}".format( pprint(lval))), target_syntax.SMapUpdate(lval, k, v, update_value)) # (3) enter set fresh_keys = target_syntax.EFilter( new_keys, target_syntax.ELambda(k, syntax.ENot(syntax.EIn( k, old_keys)))).with_type(key_bag) s3 = syntax.SForEach( k, make_subgoal(fresh_keys, docstring="new keys in {}".format(pprint(lval))), target_syntax.SMapPut( lval, k, make_subgoal(value_at(new_value, k), a=[syntax.EIn(k, fresh_keys)], docstring="new value inserted at {}".format( pprint(target_syntax.EMapGet(lval, k)))))) stm = syntax.seq([s1, s2, s3]) else: # update_value = code to update for value v at key k (given that k is an altered key) update_value = recurse(v, value_at(old_value, k), value_at(new_value, k), ctx=ctx, assumptions=assumptions + [ syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))) ]) # altered_keys = [k | k <- distinct(lval.keys() + new_value.keys()), value_at(old_value, k) != value_at(new_value, k))] altered_keys = make_subgoal( target_syntax.EFilter( syntax.EUnaryOp( syntax.UOp.Distinct, syntax.EBinOp( target_syntax.EMapKeys(old_value).with_type( key_bag), "+", target_syntax.EMapKeys(new_value).with_type( key_bag)).with_type(key_bag)).with_type( key_bag), target_syntax.ELambda( k, syntax.ENot( syntax.EEq(value_at(old_value, k), value_at(new_value, k))))).with_type(key_bag)) stm = syntax.SForEach( k, altered_keys, target_syntax.SMapUpdate(lval, k, v, update_value)) else: # Fallback rule: just compute a new value from scratch stm = syntax.SAssign( lval, make_subgoal(new_value, docstring="new value for {}".format(pprint(lval)))) return (stm, subgoals)