コード例 #1
0
def construct_value(t: Type) -> Exp:
    """
    Construct an arbitrary expression e of the given type.
    eval(construct_value(t), {}) == mkval(t)
    """
    if is_numeric(t):
        e = ENum(0)
    elif t == BOOL:
        e = F
    elif t == STRING:
        e = EStr("")
    elif is_collection(t):
        e = EEmptyList()
    elif isinstance(t, TTuple):
        e = ETuple(tuple(construct_value(tt) for tt in t.ts))
    elif isinstance(t, TRecord):
        e = EMakeRecord(tuple(
            (f, construct_value(tt)) for (f, tt) in t.fields))
    elif isinstance(t, TEnum):
        e = EEnumEntry(t.cases[0])
    elif isinstance(t, THandle):
        e = EHandle(construct_value(INT), construct_value(t.value_type))
    elif isinstance(t, TNative):
        e = ENative(construct_value(INT))
    elif isinstance(t, TMap):
        e = EMakeMap2(EEmptyList().with_type(TBag(t.k)),
                      ELambda(EVar("x").with_type(t.k), construct_value(t.v)))
    else:
        h = extension_handler(type(t))
        if h is not None:
            return h.default_value(t, construct_value)
        raise NotImplementedError(pprint(t))
    return e.with_type(t)
コード例 #2
0
def _uneval(t, value):
    if is_numeric(t):
        return ENum(value).with_type(t)
    elif t == BOOL:
        return EBool(value).with_type(t)
    elif is_collection(t):
        e = EEmptyList().with_type(t)
        for x in value:
            e = EBinOp(e, "+",
                       ESingleton(uneval(t.t, x)).with_type(t)).with_type(t)
        return e
    elif isinstance(t, TString):
        return EStr(value).with_type(t)
    elif isinstance(t, TTuple):
        return ETuple(tuple(uneval(tt, x)
                            for (tt, x) in zip(t.ts, value))).with_type(t)
    elif isinstance(t, TRecord):
        return EMakeRecord(
            tuple((f, uneval(tt, value[f]))
                  for (f, tt) in t.fields)).with_type(t)
    elif isinstance(t, TEnum):
        return EEnumEntry(value).with_type(t)
    elif isinstance(t, THandle):
        return EHandle(
            ENum(value.address).with_type(INT),
            uneval(t.value_type, value.value)).with_type(t)
    elif isinstance(t, TNative):
        return ENative(ENum(value[1]).with_type(INT)).with_type(t)
    else:
        raise NotImplementedError(pprint(t))
コード例 #3
0
ファイル: cxx.py プロジェクト: sanidhya/cozy
 def construct_concrete(self, t: Type, e: Exp, out: Exp):
     """
     Construct a value of type `t` from the expression `e` and store it in
     lvalue `out`.
     """
     # from cozy.syntax_tools import pprint
     # print("construct_concrete | {} <- {}".format(pprint(out), pprint(e)))
     if hasattr(t, "construct_concrete"):
         return t.construct_concrete(e, out)
     elif isinstance(
             t, library.TNativeList) or type(t) is TBag or type(t) is TList:
         assert out not in free_vars(e)
         x = self.fv(t.t, "x")
         return SSeq(self.initialize_native_list(out),
                     SForEach(x, e, SCall(out, "add", [x])))
     elif isinstance(t, library.TNativeSet) or type(t) is TSet:
         if isinstance(e, EUnaryOp) and e.op == UOp.Distinct:
             return self.construct_concrete(t, e.e, out)
         x = self.fv(t.t, "x")
         return SSeq(self.initialize_native_set(out),
                     SForEach(x, e, SCall(out, "add", [x])))
     elif isinstance(t, library.TNativeMap) or type(t) is TMap:
         return SSeq(self.initialize_native_map(out),
                     self.construct_map(t, e, out))
     elif isinstance(t, THandle):
         return SEscape("{indent}{lhs} = {rhs};\n", ["lhs", "rhs"],
                        [out, self.addr_of(e)])
     elif is_numeric(t) or type(t) in [
             TBool, TNative, TString, TEnum, TTuple, TRecord
     ]:
         return SEscape("{indent}{lhs} = {rhs};\n", ["lhs", "rhs"],
                        [out, e])
     raise NotImplementedError(t, e, out)
コード例 #4
0
ファイル: evaluation.py プロジェクト: uwplse/cozy
def construct_value(t : Type) -> Exp:
    """
    Construct an arbitrary expression e of the given type.
    eval(construct_value(t), {}) == mkval(t)
    """
    if is_numeric(t):
        e = ENum(0)
    elif t == BOOL:
        e = EFALSE
    elif t == STRING:
        e = EStr("")
    elif is_collection(t):
        e = EEmptyList()
    elif isinstance(t, TTuple):
        e = ETuple(tuple(construct_value(tt) for tt in t.ts))
    elif isinstance(t, TRecord):
        e = EMakeRecord(tuple((f, construct_value(tt)) for (f, tt) in t.fields))
    elif isinstance(t, TEnum):
        e = EEnumEntry(t.cases[0])
    elif isinstance(t, THandle):
        e = EHandle(construct_value(INT), construct_value(t.value_type))
    elif isinstance(t, TNative):
        e = ENative(construct_value(INT))
    elif isinstance(t, TMap):
        e = EMakeMap2(
            EEmptyList().with_type(TBag(t.k)),
            ELambda(EVar("x").with_type(t.k), construct_value(t.v)))
    else:
        h = extension_handler(type(t))
        if h is not None:
            return h.default_value(t, construct_value)
        raise NotImplementedError(pprint(t))
    return e.with_type(t)
コード例 #5
0
def is_simple(t):
    if is_numeric(t):
        return True
    if isinstance(t, TString) or isinstance(t, TEnum) or isinstance(t, TBool) or isinstance(t, TNative):
        return True
    if isinstance(t, TTuple) and all(is_simple(tt) for tt in t.ts):
        return True
    if isinstance(t, TRecord) and all(is_simple(tt) for f, tt in t.fields):
        return True
    return False
コード例 #6
0
ファイル: cost_model.py プロジェクト: MostAwesomeDude/cozy
def storage_size(e, freebies: [Exp] = []):
    h = extension_handler(type(e.type))
    if h is not None:
        return h.storage_size(e, storage_size=storage_size)

    if e in freebies:
        return ZERO
    elif e.type == BOOL:
        return ONE
    elif is_numeric(e.type) or isinstance(e.type, THandle):
        return FOUR
    elif isinstance(e.type, TEnum):
        return TWO
    elif isinstance(e.type, TNative):
        return FOUR
    elif isinstance(e.type, TString):
        return TWENTY
    elif isinstance(e.type, TTuple):
        return ESum([
            storage_size(ETupleGet(e, n).with_type(t))
            for (n, t) in enumerate(e.type.ts)
        ])
    elif isinstance(e.type, TRecord):
        return ESum([
            storage_size(EGetField(e, f).with_type(t))
            for (f, t) in e.type.fields
        ])
    elif is_collection(e.type):
        v = fresh_var(e.type.t, omit=free_vars(e))
        return ESum([
            FOUR,
            EUnaryOp(UOp.Sum,
                     EMap(e, ELambda(
                         v,
                         storage_size(v))).with_type(INT_BAG)).with_type(INT)
        ])
    elif isinstance(e.type, TMap):
        k = fresh_var(e.type.k, omit=free_vars(e))
        return ESum([
            FOUR,
            EUnaryOp(
                UOp.Sum,
                EMap(
                    EMapKeys(e).with_type(TBag(e.type.k)),
                    ELambda(
                        k,
                        ESum([
                            storage_size(k),
                            storage_size(EMapGet(e, k).with_type(e.type.v))
                        ]))).with_type(INT_BAG)).with_type(INT)
        ])
    else:
        raise NotImplementedError(e.type)
コード例 #7
0
def break_sum(e):
    if not is_numeric(e.type):
        return
    if isinstance(e, EBinOp):
        if e.op == "+":
            yield from break_sum(e.e1)
            yield from break_sum(e.e2)
        else:
            assert e.op == "-"
            yield from break_sum(e.e1)
            for pos, x in break_sum(e.e2):
                yield (not pos, x)
    elif isinstance(e, EUnaryOp) and e.op == UOp.Sum:
        for pos, b in break_bag(e.e):
            yield pos, EUnaryOp(UOp.Sum, b).with_type(e.type)
    elif isinstance(e, EUnaryOp) and e.op == "-":
        for pos, x in break_sum(e.e):
            yield (not pos, x)
    # elif isinstance(e, EStateVar):
    #     yield from break_sum(e.e)
    else:
        yield True, e
コード例 #8
0
ファイル: evaluation.py プロジェクト: uwplse/cozy
def _uneval(t, value):
    if is_numeric(t):
        return ENum(value).with_type(t)
    elif t == BOOL:
        return EBool(value).with_type(t)
    elif is_collection(t):
        e = EEmptyList().with_type(t)
        for x in value:
            e = EBinOp(e, "+", ESingleton(uneval(t.elem_type, x)).with_type(t)).with_type(t)
        return e
    elif isinstance(t, TString):
        return EStr(value).with_type(t)
    elif isinstance(t, TTuple):
        return ETuple(tuple(uneval(tt, x) for (tt, x) in zip(t.ts, value))).with_type(t)
    elif isinstance(t, TRecord):
        return EMakeRecord(tuple((f, uneval(tt, value[f])) for (f, tt) in t.fields)).with_type(t)
    elif isinstance(t, TEnum):
        return EEnumEntry(value).with_type(t)
    elif isinstance(t, THandle):
        return EHandle(ENum(value.address).with_type(INT), uneval(t.value_type, value.value)).with_type(t)
    elif isinstance(t, TNative):
        return ENative(ENum(value[1]).with_type(INT)).with_type(t)
    else:
        raise NotImplementedError(pprint(t))
コード例 #9
0
    def build(self, cache, size):
        # print("Cache:")
        # for (e, sz, pool) in cache:
        #     from cozy.syntax_tools import pprint
        #     print("    @size={}, pool={}\t:\t{}".format(sz, pool, pprint(e)))
        binders_by_type = group_by(self.binders, lambda b: b.type)

        for pool in ALL_POOLS:
            if size == 1:
                yield self.check(T, pool)
                yield self.check(F, pool)
                yield self.check(ZERO, pool)
                yield self.check(ONE, pool)
                for b in self.binders:
                    yield self.check(b, pool)
                if pool == STATE_POOL:
                    for v in self.state_vars:
                        yield self.check(v, pool)
                elif pool == RUNTIME_POOL:
                    for v in self.args:
                        yield self.check(v, pool)

            if not build_exprs.value:
                return

            for e in cache.find(pool=STATE_POOL, size=size - 1):
                if all(v in self.state_vars for v in free_vars(e)):
                    yield self.check(
                        EStateVar(e).with_type(e.type), RUNTIME_POOL)

            for e in cache.find(pool=pool, size=size - 1):
                t = TBag(e.type)
                yield self.check(EEmptyList().with_type(t), pool)
                yield self.check(ESingleton(e).with_type(t), pool)

            for e in cache.find(pool=pool, type=TRecord, size=size - 1):
                for (f, t) in e.type.fields:
                    yield self.check(EGetField(e, f).with_type(t), pool)
            for e in cache.find_collections(pool=pool, size=size - 1):
                if is_numeric(e.type.t):
                    yield self.check(
                        EUnaryOp(UOp.Sum, e).with_type(e.type.t), pool)
            for e in cache.find(pool=pool, type=THandle, size=size - 1):
                yield self.check(
                    EGetField(e, "val").with_type(e.type.value_type), pool)
            for e in cache.find(pool=pool, type=TTuple, size=size - 1):
                for n in range(len(e.type.ts)):
                    yield self.check(
                        ETupleGet(e, n).with_type(e.type.ts[n]), pool)
            for e in cache.find(pool=pool, type=BOOL, size=size - 1):
                yield self.check(EUnaryOp(UOp.Not, e).with_type(BOOL), pool)
            for e in cache.find(pool=pool, type=INT, size=size - 1):
                yield self.check(EUnaryOp("-", e).with_type(INT), pool)

            for m in cache.find(pool=pool, type=TMap, size=size - 1):
                yield self.check(EMapKeys(m).with_type(TBag(m.type.k)), pool)

            for (sz1, sz2) in pick_to_sum(2, size - 1):
                for a1 in cache.find(pool=pool, size=sz1):
                    if not is_numeric(a1.type):
                        continue
                    for a2 in cache.find(pool=pool, type=a1.type, size=sz2):
                        yield self.check(
                            EBinOp(a1, "+", a2).with_type(INT), pool)
                        yield self.check(
                            EBinOp(a1, "-", a2).with_type(INT), pool)
                        yield self.check(
                            EBinOp(a1, ">", a2).with_type(BOOL), pool)
                        yield self.check(
                            EBinOp(a1, "<", a2).with_type(BOOL), pool)
                        yield self.check(
                            EBinOp(a1, ">=", a2).with_type(BOOL), pool)
                        yield self.check(
                            EBinOp(a1, "<=", a2).with_type(BOOL), pool)
                for a1 in cache.find_collections(pool=pool, size=sz1):
                    for a2 in cache.find(pool=pool, type=a1.type, size=sz2):
                        yield self.check(
                            EBinOp(a1, "+", a2).with_type(a1.type), pool)
                        yield self.check(
                            EBinOp(a1, "-", a2).with_type(a1.type), pool)
                    for a2 in cache.find(pool=pool, type=a1.type.t, size=sz2):
                        yield self.check(
                            EBinOp(a2, BOp.In, a1).with_type(BOOL), pool)
                for a1 in cache.find(pool=pool, type=BOOL, size=sz1):
                    for a2 in cache.find(pool=pool, type=BOOL, size=sz2):
                        yield self.check(
                            EBinOp(a1, BOp.And, a2).with_type(BOOL), pool)
                        yield self.check(
                            EBinOp(a1, BOp.Or, a2).with_type(BOOL), pool)
                for a1 in cache.find(pool=pool, size=sz1):
                    if not isinstance(a1.type, TMap):
                        for a2 in cache.find(pool=pool, type=a1.type,
                                             size=sz2):
                            yield self.check(EEq(a1, a2), pool)
                            yield self.check(
                                EBinOp(a1, "!=", a2).with_type(BOOL), pool)
                for m in cache.find(pool=pool, type=TMap, size=sz1):
                    for k in cache.find(pool=pool, type=m.type.k, size=sz2):
                        yield self.check(
                            EMapGet(m, k).with_type(m.type.v), pool)
                        yield self.check(EHasKey(m, k).with_type(BOOL), pool)

            for (sz1, sz2, sz3) in pick_to_sum(3, size - 1):
                for cond in cache.find(pool=pool, type=BOOL, size=sz1):
                    for then_branch in cache.find(pool=pool, size=sz2):
                        for else_branch in cache.find(pool=pool,
                                                      size=sz3,
                                                      type=then_branch.type):
                            yield self.check(
                                ECond(cond, then_branch,
                                      else_branch).with_type(then_branch.type),
                                pool)

            for bag in cache.find_collections(pool=pool, size=size - 1):
                # len of bag
                count = EUnaryOp(UOp.Length, bag).with_type(INT)
                yield self.check(count, pool)
                # empty?
                yield self.check(
                    EUnaryOp(UOp.Empty, bag).with_type(BOOL), pool)
                # exists?
                yield self.check(
                    EUnaryOp(UOp.Exists, bag).with_type(BOOL), pool)
                # singleton?
                yield self.check(EEq(count, ONE), pool)

                yield self.check(
                    EUnaryOp(UOp.The, bag).with_type(bag.type.t), pool)
                yield self.check(
                    EUnaryOp(UOp.Distinct, bag).with_type(bag.type), pool)
                yield self.check(
                    EUnaryOp(UOp.AreUnique, bag).with_type(BOOL), pool)

                if bag.type.t == BOOL:
                    yield self.check(
                        EUnaryOp(UOp.Any, bag).with_type(BOOL), pool)
                    yield self.check(
                        EUnaryOp(UOp.All, bag).with_type(BOOL), pool)

            for (sz1, sz2) in pick_to_sum(2, size - 1):
                for bag in cache.find_collections(pool=pool, size=sz1):
                    for binder in binders_by_type[bag.type.t]:
                        for body in itertools.chain(
                                cache.find(pool=pool, size=sz2), (binder, )):
                            yield self.check(
                                EMap(bag,
                                     ELambda(binder,
                                             body)).with_type(TBag(body.type)),
                                pool)
                            if body.type == BOOL:
                                yield self.check(
                                    EFilter(bag,
                                            ELambda(binder,
                                                    body)).with_type(bag.type),
                                    pool)
                            if body.type == INT:
                                yield self.check(
                                    EArgMin(bag, ELambda(
                                        binder, body)).with_type(bag.type.t),
                                    pool)
                                yield self.check(
                                    EArgMax(bag, ELambda(
                                        binder, body)).with_type(bag.type.t),
                                    pool)
                            if pool == RUNTIME_POOL and isinstance(
                                    body.type, TBag):
                                yield self.check(
                                    EFlatMap(bag,
                                             ELambda(binder, body)).with_type(
                                                 TBag(body.type.t)), pool)

        for (sz1, sz2) in pick_to_sum(2, size - 1):
            for bag in cache.find_collections(pool=STATE_POOL, size=sz1):
                if not is_scalar(bag.type.t):
                    continue
                for b in binders_by_type[bag.type.t]:
                    for val in cache.find(pool=STATE_POOL, size=sz2):
                        t = TMap(bag.type.t, val.type)
                        m = EMakeMap2(bag, ELambda(b, val)).with_type(t)
                        yield self.check(m, STATE_POOL)
コード例 #10
0
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = [],
        invariants: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(
                syntax.EAll(itertools.chain(assumptions, invariants)),
                syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []
    new_value = strip_EStateVar(new_value)

    def make_subgoal(e, a=[], docstring=None):
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    elif is_numeric(t) and update_numbers_with_deltas.value:
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.TMap):
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
        new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

        # (1) exit set
        deleted_keys = syntax.EBinOp(old_keys, "-",
                                     new_keys).with_type(key_bag)
        s1 = syntax.SForEach(
            k,
            make_subgoal(deleted_keys,
                         docstring="keys removed from {}".format(
                             pprint(lval))), target_syntax.SMapDel(lval, k))

        # (2) enter/mod set
        new_or_modified = target_syntax.EFilter(
            new_keys,
            syntax.ELambda(
                k,
                syntax.EAny([
                    syntax.ENot(syntax.EIn(k, old_keys)),
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value, k)))
                ]))).with_type(key_bag)
        update_value = recurse(v,
                               value_at(old_value, k),
                               value_at(new_value, k),
                               ctx=ctx,
                               assumptions=assumptions + [
                                   syntax.EIn(k, new_or_modified),
                                   syntax.EEq(v, value_at(old_value, k))
                               ],
                               invariants=invariants)
        s2 = syntax.SForEach(
            k,
            make_subgoal(new_or_modified,
                         docstring="new or modified keys from {}".format(
                             pprint(lval))),
            target_syntax.SMapUpdate(lval, k, v, update_value))

        stm = syntax.SSeq(s1, s2)
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)
コード例 #11
0
    def enumerate_core(self, context: Context, size: int, pool: Pool) -> [Exp]:
        """
        Arguments:
            conext : a Context object describing the vars in scope
            size   : size to enumerate
            pool   : pool to enumerate

        Yields all expressions of the given size legal in the given context and
        pool.
        """

        if size < 0:
            return

        if size == 0:
            for (e, p) in LITERALS:
                if p == pool:
                    yield e
            for (v, p) in context.vars():
                if p == pool:
                    yield v
                for t in all_types(v):
                    yield construct_value(t)
            for (e, ctx, p) in self.hints:
                if p == pool and ctx.alpha_equivalent(context):
                    yield context.adapt(e, ctx)
                for t in all_types(e):
                    yield construct_value(t)
            return

        yield from self.heuristic_enumeration(context, size, pool)

        for e in collections(self.enumerate(context, size - 1, pool)):
            yield EEmptyList().with_type(e.type)
            if is_numeric(e.type.t):
                yield EUnaryOp(UOp.Sum, e).with_type(e.type.t)

        for e in self.enumerate(context, size - 1, pool):
            yield ESingleton(e).with_type(TBag(e.type))

        for e in self.enumerate(context, size - 1, pool):
            if isinstance(e.type, TRecord):
                for (f, t) in e.type.fields:
                    yield EGetField(e, f).with_type(t)

        for e in self.enumerate(context, size - 1, pool):
            if isinstance(e.type, THandle):
                yield EGetField(e, "val").with_type(e.type.value_type)

        for e in self.enumerate(context, size - 1, pool):
            if isinstance(e.type, TTuple):
                for n in range(len(e.type.ts)):
                    yield ETupleGet(e, n).with_type(e.type.ts[n])

        for e in of_type(self.enumerate(context, size - 1, pool), BOOL):
            yield EUnaryOp(UOp.Not, e).with_type(BOOL)

        for e in self.enumerate(context, size - 1, pool):
            if is_numeric(e.type):
                yield EUnaryOp("-", e).with_type(e.type)

        for m in self.enumerate(context, size - 1, pool):
            if isinstance(m.type, TMap):
                yield EMapKeys(m).with_type(TBag(m.type.k))

        for (sz1, sz2) in pick_to_sum(2, size - 1):
            for a1 in self.enumerate(context, sz1, pool):
                t = a1.type
                if not is_numeric(t):
                    continue
                for a2 in of_type(self.enumerate(context, sz2, pool), t):
                    yield EBinOp(a1, "+", a2).with_type(t)
                    yield EBinOp(a1, "-", a2).with_type(t)
                    yield EBinOp(a1, ">", a2).with_type(BOOL)
                    yield EBinOp(a1, "<", a2).with_type(BOOL)
                    yield EBinOp(a1, ">=", a2).with_type(BOOL)
                    yield EBinOp(a1, "<=", a2).with_type(BOOL)
            for a1 in collections(self.enumerate(context, sz1, pool)):
                for a2 in of_type(self.enumerate(context, sz2, pool), a1.type):
                    yield EBinOp(a1, "+", a2).with_type(a1.type)
                    yield EBinOp(a1, "-", a2).with_type(a1.type)
                for a2 in of_type(self.enumerate(context, sz2, pool),
                                  a1.type.t):
                    yield EBinOp(a2, BOp.In, a1).with_type(BOOL)
            for a1 in of_type(self.enumerate(context, sz1, pool), BOOL):
                for a2 in of_type(self.enumerate(context, sz2, pool), BOOL):
                    yield EBinOp(a1, BOp.And, a2).with_type(BOOL)
                    yield EBinOp(a1, BOp.Or, a2).with_type(BOOL)
            for a1 in self.enumerate(context, sz1, pool):
                if not isinstance(a1.type, TMap):
                    for a2 in of_type(self.enumerate(context, sz2, pool),
                                      a1.type):
                        yield EEq(a1, a2)
                        yield EBinOp(a1, "!=", a2).with_type(BOOL)
            for m in self.enumerate(context, sz1, pool):
                if isinstance(m.type, TMap):
                    for k in of_type(self.enumerate(context, sz2, pool),
                                     m.type.k):
                        yield EMapGet(m, k).with_type(m.type.v)
                        yield EHasKey(m, k).with_type(BOOL)
            for l in self.enumerate(context, sz1, pool):
                if not isinstance(l.type, TList):
                    continue
                for i in of_type(self.enumerate(context, sz2, pool), INT):
                    yield EListGet(l, i).with_type(l.type.t)

        for (sz1, sz2, sz3) in pick_to_sum(3, size - 1):
            for cond in of_type(self.enumerate(context, sz1, pool), BOOL):
                for then_branch in self.enumerate(context, sz2, pool):
                    for else_branch in of_type(
                            self.enumerate(context, sz2, pool),
                            then_branch.type):
                        yield ECond(cond, then_branch,
                                    else_branch).with_type(then_branch.type)

            for l in self.enumerate(context, sz1, pool):
                if not isinstance(l.type, TList):
                    continue
                for st in of_type(self.enumerate(context, sz2, pool), INT):
                    for ed in of_type(self.enumerate(context, sz3, pool), INT):
                        yield EListSlice(l, st, ed).with_type(l.type)

        for bag in collections(self.enumerate(context, size - 1, pool)):
            # len of bag
            count = EUnaryOp(UOp.Length, bag).with_type(INT)
            yield count
            # empty?
            yield EUnaryOp(UOp.Empty, bag).with_type(BOOL)
            # exists?
            yield EUnaryOp(UOp.Exists, bag).with_type(BOOL)
            # singleton?
            yield EEq(count, ONE)

            yield EUnaryOp(UOp.The, bag).with_type(bag.type.t)
            yield EUnaryOp(UOp.Distinct, bag).with_type(bag.type)
            yield EUnaryOp(UOp.AreUnique, bag).with_type(BOOL)

            if bag.type.t == BOOL:
                yield EUnaryOp(UOp.Any, bag).with_type(BOOL)
                yield EUnaryOp(UOp.All, bag).with_type(BOOL)

        def build_lambdas(bag, pool, body_size):
            v = fresh_var(bag.type.t, omit=set(v for v, p in context.vars()))
            inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool)
            for lam_body in self.enumerate(inner_context, body_size, pool):
                yield ELambda(v, lam_body)

        # Iteration
        for (sz1, sz2) in pick_to_sum(2, size - 1):
            for bag in collections(self.enumerate(context, sz1, pool)):
                for lam in build_lambdas(bag, pool, sz2):
                    body_type = lam.body.type
                    yield EMap(bag, lam).with_type(TBag(body_type))
                    if body_type == BOOL:
                        yield EFilter(bag, lam).with_type(bag.type)
                    if is_numeric(body_type):
                        yield EArgMin(bag, lam).with_type(bag.type.t)
                        yield EArgMax(bag, lam).with_type(bag.type.t)
                    if is_collection(body_type):
                        yield EFlatMap(bag, lam).with_type(TBag(body_type.t))

        # Enable use of a state-pool expression at runtime
        if pool == RUNTIME_POOL:
            for e in self.enumerate(context, size - 1, STATE_POOL):
                yield EStateVar(e).with_type(e.type)

        # Create maps
        if pool == STATE_POOL:
            for (sz1, sz2) in pick_to_sum(2, size - 1):
                for bag in collections(self.enumerate(context, sz1,
                                                      STATE_POOL)):
                    if not is_scalar(bag.type.t):
                        continue
                    for lam in build_lambdas(bag, STATE_POOL, sz2):
                        t = TMap(bag.type.t, lam.body.type)
                        m = EMakeMap2(bag, lam).with_type(t)
                        yield m
コード例 #12
0
ファイル: state_maintenance.py プロジェクト: uwplse/cozy
def sketch_update(
        lval        : syntax.Exp,
        old_value   : syntax.Exp,
        new_value   : syntax.Exp,
        ctx         : [syntax.EVar],
        assumptions : [syntax.Exp] = [],
        invariants  : [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(syntax.EImplies(
            syntax.EAll(itertools.chain(assumptions, invariants)),
            syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []
    new_value = strip_EStateVar(new_value)

    def make_subgoal(e, a=[], docstring=None):
        if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(syntax.EBinOp(old_value, "-", new_value).with_type(t), docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.elem_type)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))])
    elif is_numeric(t) and update_numbers_with_deltas.value:
        change = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval, syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions,
                invariants=invariants)
            for i in range(len(t.ts))])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions,
                invariants=invariants)
            for i in range(len(t.fields))])
    elif isinstance(t, syntax.TMap):
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
        new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

        # (1) exit set
        deleted_keys = syntax.EBinOp(old_keys, "-", new_keys).with_type(key_bag)
        s1 = syntax.SForEach(k, make_subgoal(deleted_keys, docstring="keys removed from {}".format(pprint(lval))),
            target_syntax.SMapDel(lval, k))

        # (2) enter/mod set
        new_or_modified = target_syntax.EFilter(new_keys,
            syntax.ELambda(k, syntax.EAny([syntax.ENot(syntax.EIn(k, old_keys)), syntax.ENot(syntax.EEq(value_at(old_value, k), value_at(new_value, k)))]))).with_type(key_bag)
        update_value = recurse(
            v,
            value_at(old_value, k),
            value_at(new_value, k),
            ctx = ctx,
            assumptions = assumptions + [syntax.EIn(k, new_or_modified), syntax.EEq(v, value_at(old_value, k))],
            invariants = invariants)
        s2 = syntax.SForEach(k, make_subgoal(new_or_modified, docstring="new or modified keys from {}".format(pprint(lval))),
            target_syntax.SMapUpdate(lval, k, v, update_value))

        stm = syntax.SSeq(s1, s2)
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(lval, make_subgoal(new_value, docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)
コード例 #13
0
    def _enumerate_core(self, context: Context, size: int,
                        pool: Pool) -> [Exp]:
        """Build new expressions of the given size.

        Arguments:
            context : a Context object describing the vars in scope
            size    : size of expressions to enumerate; each expression in
                      the output will have this size
            pool    : pool to enumerate

        This function is not cached.  Clients should call `enumerate` instead.

        This function tries to be a clean description of the Cozy grammar.  It
        does not concern itself with deduplication (which is handled
        efficiently by equivalence class deduplication).
        """

        if size < 0:
            return

        if size == 0:
            for e in LITERALS:
                yield e

            all_interesting_types = OrderedSet(self.hint_types)
            for v, _ in context.vars():
                all_interesting_types |= all_types(v.type)
            for t in all_interesting_types:
                l = construct_value(t)
                if l not in LITERALS:
                    yield l

            for (v, p) in context.vars():
                if p == pool:
                    yield v
            for (e, ctx, p) in self.hints:
                if p == pool:
                    fvs = free_vars(e)
                    if ctx.alpha_equivalent(context.generalize(fvs)):
                        yield context.adapt(e, ctx, e_fvs=fvs)
            return

        if not do_enumerate.value:
            return

        def build_lambdas(bag, pool, body_size):
            v = fresh_var(bag.type.elem_type,
                          omit=set(v for v, p in context.vars()))
            inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool)
            for lam_body in self.enumerate(inner_context, body_size, pool):
                yield ELambda(v, lam_body)

        # Load all smaller expressions in this context and pool.
        # cache[S] contains expressions of size S in this context and pool.
        cache = [list(self.enumerate(context, sz, pool)) for sz in range(size)]

        # Enable use of a state-pool expression at runtime
        if pool == RUNTIME_POOL:
            for e in self.enumerate(context.root(), size - 1, STATE_POOL):
                yield EStateVar(e).with_type(e.type)

        # Arity-1 expressions
        for e in cache[size - 1]:
            if is_collection(e.type):
                elem_type = e.type.elem_type

                # This method of generating EEmptyList() ensures that we visit
                # empty collections of all possible types.
                yield EEmptyList().with_type(e.type)

                if is_numeric(elem_type):
                    yield EUnaryOp(UOp.Sum, e).with_type(elem_type)

                yield EUnaryOp(UOp.Length, e).with_type(INT)
                yield EUnaryOp(UOp.Empty, e).with_type(BOOL)
                yield EUnaryOp(UOp.Exists, e).with_type(BOOL)
                yield EUnaryOp(UOp.The, e).with_type(elem_type)
                yield EUnaryOp(UOp.Distinct, e).with_type(e.type)
                yield EUnaryOp(UOp.AreUnique, e).with_type(BOOL)

                if elem_type == BOOL:
                    yield EUnaryOp(UOp.Any, e).with_type(BOOL)
                    yield EUnaryOp(UOp.All, e).with_type(BOOL)

            yield ESingleton(e).with_type(TBag(e.type))

            if isinstance(e.type, TRecord):
                for (f, t) in e.type.fields:
                    yield EGetField(e, f).with_type(t)

            if isinstance(e.type, THandle):
                yield EGetField(e, "val").with_type(e.type.value_type)

            if isinstance(e.type, TTuple):
                for n in range(len(e.type.ts)):
                    yield ETupleGet(e, n).with_type(e.type.ts[n])

            if e.type == BOOL:
                yield EUnaryOp(UOp.Not, e).with_type(BOOL)

            if is_numeric(e.type):
                yield EUnaryOp("-", e).with_type(e.type)

            if isinstance(e.type, TMap):
                yield EMapKeys(e).with_type(TBag(e.type.k))

        # Arity-2 expressions
        for (sz1, sz2) in pick_to_sum(2, size - 1):
            # sz1 + sz2 = size - 1
            for e1 in cache[sz1]:
                t = e1.type

                if is_numeric(t):
                    for a2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, "+", a2).with_type(t)
                        yield EBinOp(e1, "-", a2).with_type(t)

                if is_ordered(t):
                    for a2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, ">", a2).with_type(BOOL)
                        yield EBinOp(e1, "<", a2).with_type(BOOL)
                        yield EBinOp(e1, ">=", a2).with_type(BOOL)
                        yield EBinOp(e1, "<=", a2).with_type(BOOL)

                if t == BOOL:
                    for a2 in of_type(cache[sz2], BOOL):
                        yield EBinOp(e1, BOp.And, a2).with_type(BOOL)
                        yield EBinOp(e1, BOp.Or, a2).with_type(BOOL)
                        # Cozy supports the implication operator "=>", but this
                        # function does not enumerate it because
                        #  - (a => b) is equivalent to ((not a) or b)
                        #  - there isn't an implication operator in any of our
                        #    current target languages, so we would need to
                        #    desugar it to ((not a) or b) anyway.

                if not isinstance(t, TMap):
                    for a2 in of_type(cache[sz2], t):
                        yield EEq(e1, a2)
                        yield EBinOp(e1, "!=", a2).with_type(BOOL)

                if isinstance(t, TMap):
                    for k in of_type(cache[sz2], t.k):
                        yield EMapGet(e1, k).with_type(t.v)
                        yield EHasKey(e1, k).with_type(BOOL)

                if isinstance(t, TList):
                    for i in of_type(cache[sz2], INT):
                        yield EListGet(e1, i).with_type(e1.type.elem_type)

                if is_collection(t):
                    elem_type = t.elem_type
                    for e2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, "+", e2).with_type(t)
                        yield EBinOp(e1, "-", e2).with_type(t)
                    for e2 in of_type(cache[sz2], elem_type):
                        yield EBinOp(e2, BOp.In, e1).with_type(BOOL)
                    for f in build_lambdas(e1, pool, sz2):
                        body_type = f.body.type
                        yield EMap(e1, f).with_type(TBag(body_type))
                        if body_type == BOOL:
                            yield EFilter(e1, f).with_type(t)
                        if is_numeric(body_type):
                            yield EArgMin(e1, f).with_type(elem_type)
                            yield EArgMax(e1, f).with_type(elem_type)
                        if is_collection(body_type):
                            yield EFlatMap(e1, f).with_type(
                                TBag(body_type.elem_type))

                        if pool == STATE_POOL and is_hashable(elem_type):
                            yield EMakeMap2(e1, f).with_type(
                                TMap(elem_type, body_type))

                e1_singleton = ESingleton(e1).with_type(TBag(e1.type))
                for f in build_lambdas(e1_singleton, pool, sz2):
                    yield ELet(e1, f).with_type(f.body.type)

        # Arity-3 expressions
        for (sz1, sz2, sz3) in pick_to_sum(3, size - 1):
            # sz1 + sz2 + sz3 = size - 1
            for e1 in cache[sz1]:
                if e1.type == BOOL:
                    cond = e1
                    for then_branch in cache[sz2]:
                        for else_branch in of_type(cache[sz3],
                                                   then_branch.type):
                            yield ECond(cond, then_branch,
                                        else_branch).with_type(
                                            then_branch.type)
                if isinstance(e1.type, TList):
                    for start in of_type(cache[sz2], INT):
                        for end in of_type(cache[sz3], INT):
                            yield EListSlice(e1, start, end).with_type(e1.type)
                            # It is not necessary to create slice expressions of
                            # the form a[:i] or a[i:].  Those are desugared
                            # after parsing to a[0:i] and a[i:len(a)]
                            # respectively, and Cozy is perfectly capable of
                            # discovering these expanded forms as well.

        for h in all_extension_handlers():
            yield from h.enumerate(context, size, pool, self.enumerate,
                                   build_lambdas)
コード例 #14
0
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(syntax.EAll(assumptions),
                            syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []

    def make_subgoal(e, a=[], docstring=None):
        e = strip_EStateVar(e)
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    # elif isinstance(t, syntax.TList):
    #     raise NotImplementedError()
    elif is_numeric(t):
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx,
                    assumptions) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx,
                    assumptions) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.THandle):
        # handles are tricky, and are dealt with at a higher level
        stm = syntax.SNoOp()
    elif isinstance(t, syntax.TMap):
        value_at = lambda m, k: target_syntax.EMapGet(m, k).with_type(lval.type
                                                                      .v)
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        if True:
            old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
            new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

            # (1) exit set
            deleted_keys = target_syntax.EFilter(
                old_keys,
                target_syntax.ELambda(k, syntax.ENot(syntax.EIn(
                    k, new_keys)))).with_type(key_bag)
            s1 = syntax.SForEach(
                k,
                make_subgoal(deleted_keys,
                             docstring="keys removed from {}".format(
                                 pprint(lval))),
                target_syntax.SMapDel(lval, k))

            # (2) modify set
            common_keys = target_syntax.EFilter(
                old_keys,
                target_syntax.ELambda(k,
                                      syntax.EIn(k,
                                                 new_keys))).with_type(key_bag)
            update_value = recurse(v,
                                   value_at(old_value, k),
                                   value_at(new_value, k),
                                   ctx=ctx,
                                   assumptions=assumptions + [
                                       syntax.EIn(k, common_keys),
                                       syntax.ENot(
                                           syntax.EEq(value_at(old_value, k),
                                                      value_at(new_value, k)))
                                   ])
            altered_keys = target_syntax.EFilter(
                common_keys,
                target_syntax.ELambda(
                    k,
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value,
                                            k))))).with_type(key_bag)
            s2 = syntax.SForEach(
                k,
                make_subgoal(altered_keys,
                             docstring="altered keys in {}".format(
                                 pprint(lval))),
                target_syntax.SMapUpdate(lval, k, v, update_value))

            # (3) enter set
            fresh_keys = target_syntax.EFilter(
                new_keys,
                target_syntax.ELambda(k, syntax.ENot(syntax.EIn(
                    k, old_keys)))).with_type(key_bag)
            s3 = syntax.SForEach(
                k,
                make_subgoal(fresh_keys,
                             docstring="new keys in {}".format(pprint(lval))),
                target_syntax.SMapPut(
                    lval, k,
                    make_subgoal(value_at(new_value, k),
                                 a=[syntax.EIn(k, fresh_keys)],
                                 docstring="new value inserted at {}".format(
                                     pprint(target_syntax.EMapGet(lval, k))))))

            stm = syntax.seq([s1, s2, s3])

        else:
            # update_value = code to update for value v at key k (given that k is an altered key)
            update_value = recurse(v,
                                   value_at(old_value, k),
                                   value_at(new_value, k),
                                   ctx=ctx,
                                   assumptions=assumptions + [
                                       syntax.ENot(
                                           syntax.EEq(value_at(old_value, k),
                                                      value_at(new_value, k)))
                                   ])

            # altered_keys = [k | k <- distinct(lval.keys() + new_value.keys()), value_at(old_value, k) != value_at(new_value, k))]
            altered_keys = make_subgoal(
                target_syntax.EFilter(
                    syntax.EUnaryOp(
                        syntax.UOp.Distinct,
                        syntax.EBinOp(
                            target_syntax.EMapKeys(old_value).with_type(
                                key_bag), "+",
                            target_syntax.EMapKeys(new_value).with_type(
                                key_bag)).with_type(key_bag)).with_type(
                                    key_bag),
                    target_syntax.ELambda(
                        k,
                        syntax.ENot(
                            syntax.EEq(value_at(old_value, k),
                                       value_at(new_value,
                                                k))))).with_type(key_bag))
            stm = syntax.SForEach(
                k, altered_keys,
                target_syntax.SMapUpdate(lval, k, v, update_value))
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)