예제 #1
0
파일: optimization.py 프로젝트: uwplse/cozy
def simplify_and_optimize_lvalue(e : Exp) -> (Stm, Exp):
    """Helper for simplify_and_optimize.

    Input:
      e - an L-value expression to optimize

    Output:
      A pair (s, e') such that executing s and then evaluating e' is the same
      as evaluating the input expression.

    Unlike `simplify_and_optimize_expression`, this function preserves the
    meaning of `e` as an L-value.  For instance, this function will not replace
    `e` with a fresh variable.
    """
    assert is_lvalue(e), "not an L-value: {}".format(pprint(e))
    if isinstance(e, EVar):
        return (SNoOp(), e)
    if isinstance(e, EGetField):
        if isinstance(e.e.type, THandle):
            setup, handle = simplify_and_optimize_expression(e.e)
            return (setup, EGetField(handle, e.field_name).with_type(e.type))
        else:
            setup, lvalue = simplify_and_optimize_lvalue(e.e)
            return (setup, EGetField(lvalue, e.field_name).with_type(e.type))
    if isinstance(e, EMapGet):
        mapsetup, map = simplify_and_optimize_lvalue(e.map)
        keysetup, key = simplify_and_optimize_expression(e.key)
        return (seq([mapsetup, keysetup]), EMapGet(map, key).with_type(e.type))
    if isinstance(e, EListGet):
        listsetup, list_lvalue = simplify_and_optimize_lvalue(e.e)
        indexsetup, index_exp = simplify_and_optimize_expression(e.index)
        return (seq([listsetup, indexsetup]), EListGet(list_lvalue, index_exp).with_type(e.type))
    raise NotImplementedError(repr(e))
예제 #2
0
 def find_one(self, iterable):
     v = fresh_var(iterable.type.elem_type, "v")
     label = fresh_name("label")
     x = fresh_var(iterable.type.elem_type, "x")
     decl = SDecl(v, evaluation.construct_value(v.type))
     find = SEscapableBlock(label,
         SForEach(x, iterable, seq([
             SAssign(v, x),
             SEscapeBlock(label)])))
     self.stms.append(simplify_and_optimize(seq([decl, find])))
     return v
예제 #3
0
    def visit_EUnaryOp(self, e):
        op = e.op
        if op == UOp.Distinct:
            return self.visit_iterable(e)
        elif op == UOp.The:
            return self.find_one(e.e)
        elif op == UOp.Sum:
            sum_var = fresh_var(e.type, "sum")
            loop_var = fresh_var(e.e.type.elem_type, "x")
            self.stms.append(simplify_and_optimize(seq([
                SDecl(sum_var, ENum(0).with_type(e.type)),
                SForEach(loop_var, e.e,
                    SAssign(sum_var, EBinOp(sum_var, "+", loop_var).with_type(INT)))])))
            return sum_var
        elif op == UOp.Length:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Sum, EMap(e.e, ELambda(arg, ONE)).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.All:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Empty, EFilter(e.e, ELambda(arg, ENot(arg))).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.Any:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Exists, EFilter(e.e, ELambda(arg, arg)).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.Empty:
            iterable = e.e
            v = fresh_var(BOOL, "v")
            label = fresh_name("label")
            x = fresh_var(iterable.type.elem_type, "x")
            decl = SDecl(v, ETRUE)
            find = SEscapableBlock(label,
                SForEach(x, iterable, seq([
                    SAssign(v, EFALSE),
                    SEscapeBlock(label)])))
            self.stms.append(simplify_and_optimize(seq([decl, find])))
            return v
        elif op == UOp.Exists:
            return self.visit(ENot(EUnaryOp(UOp.Empty, e.e).with_type(BOOL)))
        # elif op == UOp.AreUnique:
        #     s = fresh_var(TSet(e.e.type.elem_type), "unique_elems")
        #     u = fresh_var(BOOL, "is_unique")
        #     x = fresh_var(e.e.type.elem_type)
        #     label = fresh_name("label")
        #     self.visit(seq([
        #         SDecl(s, EEmptyList().with_type(s.type)),
        #         SDecl(u, ETRUE),
        #         SEscapableBlock(label,
        #             SForEach(x, e.e,
        #                 SIf(EEscape("{s}.find({x}) != {s}.end()", ("s", "x"), (s, x)).with_type(BOOL),
        #                     seq([SAssign(u, EFALSE), SEscapeBlock(label)]),
        #                     SEscape("{indent}{s}.insert({x});\n", ("s", "x"), (s, x)))))]))
        #     return u.id

        return self.visit_Exp(e)
예제 #4
0
 def visit_EArrayIndexOf(self, e):
     if isinstance(e.a, EVar): pass
     elif isinstance(e.a, ETupleGet) and isinstance(e.a.e, EVar): pass
     else:
         raise NotImplementedError(
             "finding index of non-var array"
         )  # TODO: make this fast when this is false
     it = self.fv(
         TNative("{}::const_iterator".format(
             self.visit(e.a.type, "").strip())), "cursor")
     res = self.fv(INT, "index")
     self.visit(
         seq([
             SDecl(
                 it,
                 EEscape("std::find({a}.begin(), {a}.end(), {x})",
                         ("a", "x"), (e.a, e.x)).with_type(it.type)),
             SDecl(
                 res,
                 ECond(
                     EEq(
                         it,
                         EEscape("{a}.end()", ("a", ),
                                 (e.a, )).with_type(it.type)),
                     ENum(-1).with_type(INT),
                     EEscape("({it} - {a}.begin())", (
                         "it",
                         "a",
                     ), (
                         it,
                         e.a,
                     )).with_type(INT)).with_type(INT))
         ]))
     return res.id
예제 #5
0
 def visit_EBinOp(self, e):
     if e.op in ("+", "-") and is_collection(e.type):
         return self.visit_iterable(e)
     elif e.op == BOp.In and not isinstance(e.e2.type, TSet):
         t = BOOL
         res = fresh_var(t, "found")
         x = fresh_var(e.e1.type, "x")
         label = fresh_name("label")
         self.stms.append(simplify_and_optimize(seq([
             SDecl(res, EFALSE),
             SEscapableBlock(label,
                 SForEach(x, e.e2, SIf(
                     EBinOp(x, "==", e.e1).with_type(BOOL),
                     seq([SAssign(res, ETRUE), SEscapeBlock(label)]),
                     SNoOp())))])))
         return res
     return self.visit_Exp(e)
예제 #6
0
파일: cxx.py 프로젝트: uwplse/cozy
 def compute_hash(self, fields : [Exp]) -> Stm:
     hc = self.fv(INT, "hash_code")
     h = self.fv(INT, "hash_code")
     s = SSeq(SDecl(hc, ENum(0).with_type(hc.type)),
              SDecl(h, ENum(0).with_type(h.type)))
     for f in fields:
         s = seq([s,
                  self.compute_hash_1(h, f),
                  SAssign(hc,
                     EEscape("({hc} * 31) ^ ({h})", ("hc", "h"), (hc, h)).with_type(INT))])
     s = SSeq(s, SEscape("{indent}return {e};\n", ("e",), (hc,)))
     return s
예제 #7
0
 def min_or_max(self, op, e, f):
     if isinstance(e, EBinOp) and e.op == "+" and isinstance(e.e1, ESingleton) and isinstance(e.e2, ESingleton):
         # argmin_f ([a] + [b]) ---> f(a) < f(b) ? a : b
         return self.visit(ECond(
             EBinOp(f.apply_to(e.e1.e), op, f.apply_to(e.e2.e)).with_type(BOOL),
             e.e1.e,
             e.e2.e).with_type(e.e1.e.type))
     out = fresh_var(e.type.elem_type, "min" if op == "<" else "max")
     first = fresh_var(BOOL, "first")
     x = fresh_var(e.type.elem_type, "x")
     decl1 = SDecl(out, evaluation.construct_value(out.type))
     decl2 = SDecl(first, ETRUE)
     find = SForEach(x, e,
         SIf(EBinOp(
                 first,
                 BOp.Or,
                 EBinOp(f.apply_to(x), op, f.apply_to(out)).with_type(BOOL)).with_type(BOOL),
             seq([SAssign(first, EFALSE), SAssign(out, x)]),
             SNoOp()))
     self.stms.append(simplify_and_optimize(seq([decl1, decl2, find])))
     return out
예제 #8
0
def simplify_and_optimize_expression(e : Exp) -> (Stm, Exp):
    """Helper for simplify_and_optimize.

    Input:
      e - an expression to optimize

    Output:
      A pair (s, e') such that executing s and then evaluating e' is the same
      as evaluating the input expression.
    """
    optimizer = ExpressionOptimizer()
    e_prime = optimizer.visit(e)
    return (seq(optimizer.stms), e_prime)
예제 #9
0
파일: cxx.py 프로젝트: uwplse/cozy
 def visit_EArrayIndexOf(self, e):
     if isinstance(e.a, EVar): pass
     elif isinstance(e.a, ETupleGet) and isinstance(e.a.e, EVar): pass
     else: raise NotImplementedError("finding index of non-var array") # TODO: make this fast when this is false
     it = self.fv(TNative("{}::const_iterator".format(self.visit(e.a.type, "").strip())), "cursor")
     res = self.fv(INT, "index")
     self.visit(seq([
         SDecl(it, EEscape("std::find({a}.begin(), {a}.end(), {x})", ("a", "x"), (e.a, e.x)).with_type(it.type)),
         SDecl(res, ECond(
             EEq(it, EEscape("{a}.end()", ("a",), (e.a,)).with_type(it.type)),
             ENum(-1).with_type(INT),
             EEscape("({it} - {a}.begin())", ("it", "a",), (it, e.a,)).with_type(INT)).with_type(INT))]))
     return res.id
예제 #10
0
def histogram(e : Exp) -> (Stm, EVar):
    """Compute a histogram of the elements in the iterable `e`.

    Returns an unoptimized statement that declares and constructs a histogram
    map and the fresh variable that got declared.
    """
    elem_type = e.type.elem_type
    h = fresh_var(TMap(elem_type, INT), "histogram")
    x = fresh_var(elem_type, "x")
    count = fresh_var(INT, "count")
    stm = seq([
        SDecl(h, EEmptyMap().with_type(h.type)),
        SForEach(x, e,
            SMapUpdate(h, x, count,
                SAssign(count, EBinOp(count, "+", ONE).with_type(INT))))])
    return (stm, h)
예제 #11
0
 def compute_hash(self, fields: [Exp]) -> Stm:
     hc = self.fv(INT, "hash_code")
     h = self.fv(INT, "hash_code")
     s = SSeq(SDecl(hc,
                    ENum(0).with_type(hc.type)),
              SDecl(h,
                    ENum(0).with_type(h.type)))
     for f in fields:
         s = seq([
             s,
             self.compute_hash_1(h, f),
             SAssign(
                 hc,
                 EEscape("({hc} * 31) ^ ({h})", ("hc", "h"),
                         (hc, h)).with_type(INT))
         ])
     s = SSeq(s, SEscape("{indent}return {e};\n", ("e", ), (hc, )))
     return s
예제 #12
0
def sketch_update(
        lval        : syntax.Exp,
        old_value   : syntax.Exp,
        new_value   : syntax.Exp,
        ctx         : [syntax.EVar],
        assumptions : [syntax.Exp] = [],
        invariants  : [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(syntax.EImplies(
            syntax.EAll(itertools.chain(assumptions, invariants)),
            syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []
    new_value = strip_EStateVar(new_value)

    def make_subgoal(e, a=[], docstring=None):
        if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(syntax.EBinOp(old_value, "-", new_value).with_type(t), docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.elem_type)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))])
    elif is_numeric(t) and update_numbers_with_deltas.value:
        change = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval, syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions,
                invariants=invariants)
            for i in range(len(t.ts))])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions,
                invariants=invariants)
            for i in range(len(t.fields))])
    elif isinstance(t, syntax.TMap):
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
        new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

        # (1) exit set
        deleted_keys = syntax.EBinOp(old_keys, "-", new_keys).with_type(key_bag)
        s1 = syntax.SForEach(k, make_subgoal(deleted_keys, docstring="keys removed from {}".format(pprint(lval))),
            target_syntax.SMapDel(lval, k))

        # (2) enter/mod set
        new_or_modified = target_syntax.EFilter(new_keys,
            syntax.ELambda(k, syntax.EAny([syntax.ENot(syntax.EIn(k, old_keys)), syntax.ENot(syntax.EEq(value_at(old_value, k), value_at(new_value, k)))]))).with_type(key_bag)
        update_value = recurse(
            v,
            value_at(old_value, k),
            value_at(new_value, k),
            ctx = ctx,
            assumptions = assumptions + [syntax.EIn(k, new_or_modified), syntax.EEq(v, value_at(old_value, k))],
            invariants = invariants)
        s2 = syntax.SForEach(k, make_subgoal(new_or_modified, docstring="new or modified keys from {}".format(pprint(lval))),
            target_syntax.SMapUpdate(lval, k, v, update_value))

        stm = syntax.SSeq(s1, s2)
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(lval, make_subgoal(new_value, docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)
예제 #13
0
파일: java.py 프로젝트: wcphkust/cozy
 def visit_SSwap(self, s):
     tmp = self.fv(s.lval1.type, "swap_tmp")
     return self.visit(seq([
         SDecl(tmp, s.lval1),
         SAssign(s.lval1, s.lval2),
         SAssign(s.lval2, tmp)]))
예제 #14
0
파일: impls.py 프로젝트: uwplse/cozy
    def code(self) -> Spec:
        """Get the current code corresponding to this implementation.

        The code is returned as a Cozy specification object, but the returned
        object throws away any unused abstract state as well as all invariants
        and assumptions on methods. It implements the same data structure, but
        probably more efficiently.
        """

        state_read_by_query = {
            query_name : free_vars(query)
            for query_name, query in self.query_impls.items() }

        # prevent read-after-write by lifting reads before writes.

        # list of SDecls
        temps = defaultdict(list)
        updates = dict(self.updates)

        _concretization_functions = [v for v, e in self._concretization_functions]

        for operator in self.op_specs:

            # Compute order constraints between statements:
            #   v1 -> v2 means that the update code for v1 should (if possible)
            #   appear before the update code for v2
            #   (i.e. the update code for v1 reads v2)
            def state_used_during_update(v1 : EVar) -> [EVar]:
                v1_update_code = self.updates[(v1, operator.name)]
                v1_queries = list(self.queries_used_by(v1_update_code))
                res = OrderedSet()
                for q in v1_queries:
                    res |= state_read_by_query[q]
                return res
            g = DirectedGraph(
                nodes=_concretization_functions,
                successors=state_used_during_update)

            # Find the minimum set of edges we need to break cycles (see
            # "feedback arc set problem")
            edges_to_break = g.minimum_feedback_arc_set()
            g.delete_edges(edges_to_break)
            _concretization_functions = list(g.toposort())

            # Lift auxiliary declarations as needed
            things_updated = []
            for v in _concretization_functions:
                things_updated.append(v)
                stm = updates[(v, operator.name)]
                def problematic(e):
                    for x in all_exps(e):
                        if isinstance(x, ECall) and x.func in [q.name for q in self.query_specs]:
                            problems = set(things_updated) & state_read_by_query[x.func]
                            if problems:
                                return True
                    return False
                stm = pull_temps(stm,
                    decls_out=temps[operator.name],
                    exp_is_bad=problematic)
                updates[(v, operator.name)] = stm

        # construct new op implementations
        new_ops = []
        for op in self.op_specs:

            stms = [ updates[(v, op.name)] for v in _concretization_functions ]
            stms.extend(hup for ((t, op_name), hup) in self.handle_updates.items() if op.name == op_name)
            new_stms = seq(temps[op.name] + stms)
            new_ops.append(Op(
                op.name,
                op.args,
                [],
                new_stms,
                op.docstring))

        # assemble final result
        return Spec(
            self.spec.name,
            self.spec.types,
            self.spec.extern_funcs,
            [(v.id, e.type) for (v, e) in self._concretization_functions],
            [],
            list(self.query_impls.values()) + new_ops,
            self.spec.header,
            self.spec.footer,
            self.spec.docstring)
예제 #15
0
def re_use(value : Exp, v : EVar, s : Stm) -> Stm:
    if efficiently_reuseable(value) or count_occurrences_of_free_var(s, v) <= 1:
        return subst(s, {v.id : value})
    return seq([SDecl(v, value), s])
예제 #16
0
def stream(iterable : Exp, loop_var : EVar, body : Stm) -> Stm:
    """Convert an iterable expression to a streaming operation.

    Input:
      iterable - an expression with an iterable type (Bag, Set, or List), not
        yet optimized
      loop_var - a variable to use as the loop variable
      body - a statement to run on that variable, not yet optimized

    Output:
      A statement equivalent to
        for (loop_var in iterable) { body }
      that eliminates as many intermediate collections and objects as possible.

    NOTE: The output of function will not be correct if the body modifies any
    free variable in the iterable expression or writes to any pointers that
    are read by the iterable expression.

    Generating code for the expression

        Map {func} (Filter {predicate} big_collection)

    might create two new collections as large as `big_collection`: one to hold
    the result of the filter and one to hold the result of the map.  If all the
    code needs to do is to iterate over the result, then there is no reason to
    make the two new collections.

    This function is mutually recursive with `simplify_and_optimize`, so any
    transformations performed by that method are also applied to the output of
    this one.
    """

    if isinstance(iterable, EEmptyList):
        return SNoOp()
    elif isinstance(iterable, ESingleton):
        setup, value = simplify_and_optimize_expression(iterable.e)
        # SScoped because if the iterable is e.g. [x] + [y], then the body
        # might be appear in the same block twice.  If the body declares any
        # variables, that will cause problems in languages like Java or C++.
        return seq([setup, SScoped(re_use(value, loop_var, simplify_and_optimize(body)))])
    elif isinstance(iterable, ECond):
        cond_setup, cond = simplify_and_optimize_expression(iterable.cond)
        return seq([
            cond_setup,
            SIf(cond,
                stream(iterable.then_branch, loop_var, body),
                stream(iterable.else_branch, loop_var, body))])
    elif isinstance(iterable, EUnaryOp) and iterable.op == UOp.Distinct:
        tmp = fresh_var(TSet(iterable.type.elem_type), "distinct_elems")
        return seq([
            SDecl(tmp, EEmptyList().with_type(tmp.type)),
            stream(iterable.e, loop_var, SIf(
                ENot(EBinOp(loop_var, BOp.In, tmp).with_type(BOOL)),
                seq([body, SCall(tmp, "add", [loop_var])]),
                SNoOp()))])
    elif isinstance(iterable, EBinOp) and iterable.op == "+":
        return seq([
            stream(iterable.e1, loop_var, body),
            stream(iterable.e2, loop_var, body)])
    elif isinstance(iterable, EBinOp) and iterable.op == "-":
        if is_hashable(iterable.type.elem_type):
            h_setup, h = histogram(iterable.e2)
            val_ref = fresh_var(INT, "count")
            return seq([
                simplify_and_optimize(h_setup),
                stream(
                    iterable.e1,
                    loop_var,
                    SIf(EGt(EMapGet(h, loop_var).with_type(INT), ZERO),
                        SMapUpdate(h, loop_var, val_ref, SAssign(val_ref, EBinOp(val_ref, "-", ONE).with_type(INT))),
                        body))])
        else:
            rhs = fresh_var(iterable.e2.type, "bag_subtraction_right")
            return seq([
                simplify_and_optimize(SDecl(rhs, iterable.e2)),
                stream(
                    iterable.e1,
                    loop_var,
                    SIf(EIn(loop_var, rhs),
                        SCall(rhs, "remove", (loop_var,)),
                        body))])
    elif isinstance(iterable, EFilter):
        return stream(
            EFlatMap(iterable.e, ELambda(iterable.predicate.arg,
                ECond(iterable.predicate.body,
                    ESingleton(iterable.predicate.arg).with_type(iterable.type),
                    EEmptyList().with_type(iterable.type)).with_type(iterable.type))).with_type(iterable.type),
            loop_var,
            body)
    elif isinstance(iterable, EMap):
        return stream(
            EFlatMap(iterable.e, ELambda(iterable.transform_function.arg,
                ESingleton(iterable.transform_function.body).with_type(iterable.type))).with_type(iterable.type),
            loop_var,
            body)
    elif isinstance(iterable, EFlatMap):
        inner_loop_var = fresh_var(
            iterable.transform_function.arg.type,
            iterable.transform_function.arg.id)
        return stream(
            iterable.e,
            inner_loop_var,
            stream(iterable.transform_function.apply_to(inner_loop_var), loop_var, body))
    elif isinstance(iterable, EListSlice):
        elem_type = iterable.type.elem_type
        l = fresh_var(iterable.e.type, "list")
        s = fresh_var(INT, "start")
        e = fresh_var(INT, "end")
        return simplify_and_optimize(seq([
            SDecl(l, iterable.e),
            SDecl(s, max_of(iterable.start, ZERO)),
            SDecl(e, min_of(iterable.end, ELen(l))),
            SWhile(ELt(s, e), seq([
                SDecl(loop_var, EListGet(l, s).with_type(elem_type)),
                body,
                SAssign(s, EBinOp(s, "+", ONE).with_type(INT))]))]))
    elif isinstance(iterable, ELet):
        v = fresh_var(
            iterable.body_function.arg.type,
            iterable.body_function.arg.id)
        return seq([
            simplify_and_optimize(SDecl(v, iterable.e)),
            stream(iterable.body_function.apply_to(v), loop_var, body)])
    elif isinstance(iterable, EMove):
        return stream(iterable.e, loop_var, body)
    else:
        assert is_collection(iterable.type), repr(iterable)
        setup, e = simplify_and_optimize_expression(iterable)
        return seq([setup, SForEach(loop_var, e, simplify_and_optimize(body))])
예제 #17
0
 def visit_ECond(self, e):
     v = fresh_var(e.type, "conditional_result")
     self.stms.append(simplify_and_optimize(seq([
         SDecl(v, evaluation.construct_value(e.type)),
         SIf(e.cond, SAssign(v, e.then_branch), SAssign(v, e.else_branch))])))
     return v
예제 #18
0
    def code(self) -> Spec:
        """Get the current code corresponding to this implementation.

        The code is returned as a Cozy specification object, but the returned
        object throws away any unused abstract state as well as all invariants
        and assumptions on methods. It implements the same data structure, but
        probably more efficiently.
        """

        state_read_by_query = {
            query_name: free_vars(query)
            for query_name, query in self.query_impls.items()
        }

        # prevent read-after-write by lifting reads before writes.

        # list of SDecls
        temps = defaultdict(list)
        updates = dict(self.updates)

        _concretization_functions = [
            v for v, e in self._concretization_functions
        ]

        for operator in self.op_specs:

            # Compute order constraints between statements:
            #   v1 -> v2 means that the update code for v1 should (if possible)
            #   appear before the update code for v2
            #   (i.e. the update code for v1 reads v2)
            def state_used_during_update(v1: EVar) -> [EVar]:
                v1_update_code = self.updates[(v1, operator.name)]
                v1_queries = list(self.queries_used_by(v1_update_code))
                res = OrderedSet()
                for q in v1_queries:
                    res |= state_read_by_query[q]
                return res

            g = DirectedGraph(nodes=_concretization_functions,
                              successors=state_used_during_update)

            # Find the minimum set of edges we need to break cycles (see
            # "feedback arc set problem")
            edges_to_break = g.minimum_feedback_arc_set()
            g.delete_edges(edges_to_break)
            _concretization_functions = list(g.toposort())

            # Lift auxiliary declarations as needed
            things_updated = []
            for v in _concretization_functions:
                things_updated.append(v)
                stm = updates[(v, operator.name)]

                def problematic(e):
                    for x in all_exps(e):
                        if isinstance(x, ECall) and x.func in [
                                q.name for q in self.query_specs
                        ]:
                            problems = set(
                                things_updated) & state_read_by_query[x.func]
                            if problems:
                                return True
                    return False

                stm = pull_temps(stm,
                                 decls_out=temps[operator.name],
                                 exp_is_bad=problematic)
                updates[(v, operator.name)] = stm

        # construct new op implementations
        new_ops = []
        for op in self.op_specs:

            stms = [updates[(v, op.name)] for v in _concretization_functions]
            stms.extend(hup
                        for ((t, op_name), hup) in self.handle_updates.items()
                        if op.name == op_name)
            new_stms = seq(temps[op.name] + stms)
            new_ops.append(Op(op.name, op.args, [], new_stms, op.docstring))

        # assemble final result
        return Spec(self.spec.name, self.spec.types, self.spec.extern_funcs,
                    [(v.id, e.type)
                     for (v, e) in self._concretization_functions], [],
                    list(self.query_impls.values()) + new_ops,
                    self.spec.header, self.spec.footer, self.spec.docstring)
예제 #19
0
파일: optimization.py 프로젝트: uwplse/cozy
def simplify_and_optimize(s : Stm) -> Stm:
    """Simplify and optimize a statement.

    Input:
      s - a statement to optimize

    Output:
      A statement that is functionally equivalent to the input.

    This function makes two big transformations:
      - "compile" many kinds of expressions (listed below) to simpler forms so
        that downstream code generation has less work to do
      - avoid creating short-lived intermediate objects (see `stream`)

    Expression types eliminated by this procedure:
      - EMap, EFilter, EFlatMap
      - EArg{Min,Max}
      - unary ops:
        Distinct,
        AreUnique,
        Length, Sum, All, Any,
        Exists, Empty,
        The
      - binary ops:
        In (where the collection is a Bag or List)
        "-" on collections
        "+" on collections
      - EMakeMap2
      - ELet
      - EListSlice
      - EStm
    """
    assert isinstance(s, Stm)

    if isinstance(s, SNoOp):
        return s
    if isinstance(s, SSeq):
        # TODO: while the first basic statement in s1 is an SDecl, we should
        # apply `re_use` to perhaps eliminate or inline the declaration.
        return seq([simplify_and_optimize(s.s1), simplify_and_optimize(s.s2)])
    if isinstance(s, SAssign):
        setup, e = simplify_and_optimize_expression(s.rhs)
        return seq([setup, SAssign(s.lhs, e)])
    if isinstance(s, SReturn):
        setup, e = simplify_and_optimize_expression(s.e)
        return seq([setup, SReturn(e)])
    if isinstance(s, SDecl):
        setup, e = simplify_and_optimize_expression(s.val)
        return seq([setup, SDecl(s.var, e)])
    if isinstance(s, SForEach):
        return stream(s.iter, s.loop_var, s.body)
    if isinstance(s, SEscape):
        return s
    if isinstance(s, SIf):
        setup, test = simplify_and_optimize_expression(s.cond)
        if test == ETRUE:
            return simplify_and_optimize(s.then_branch)
        if test == EFALSE:
            return simplify_and_optimize(s.else_branch)
        return seq([setup, SIf(test, simplify_and_optimize(s.then_branch), simplify_and_optimize(s.else_branch))])
    if isinstance(s, SWhile):
        setup, cond = simplify_and_optimize_expression(s.e)
        if setup != SNoOp():
            # This is a problem because we don't want to duplicate the setup
            # condition.
            # TODO: introduce an SEscapableBlock/SEscapeBlock to do it right
            raise ValueError("oops! setup for condition {} is very long:\n{}".format(pprint(s.e), pprint(setup)))
        return SWhile(cond, simplify_and_optimize(s.body))
    if isinstance(s, SScoped):
        return SScoped(simplify_and_optimize(s.s))
    if isinstance(s, SMapUpdate):
        # TODO: optimize s.map & s.key
        # TODO: s.map must be optimized as an lvalue
        mapsetup, map = simplify_and_optimize_lvalue(s.map)
        keysetup, key = simplify_and_optimize_expression(s.key)
        return seq([
            mapsetup,
            keysetup,
            SMapUpdate(map, key, s.val_var, simplify_and_optimize(s.change))])
    if isinstance(s, SMapDel):
        mapsetup, map = simplify_and_optimize_lvalue(s.map)
        keysetup, key = simplify_and_optimize_expression(s.key)
        return seq([
            mapsetup,
            keysetup,
            SMapDel(map, key)])
    if isinstance(s, SCall):
        setups, args = zip(*(simplify_and_optimize_expression(a) for a in s.args))
        return seq(list(setups) + [SCall(s.target, s.func, tuple(args))])
    if isinstance(s, SEscapableBlock):
        return SEscapableBlock(s.label, simplify_and_optimize(s.body))
    if isinstance(s, SEscapeBlock):
        return s
    if isinstance(s, SArrayAlloc):
        setup, cap = simplify_and_optimize_expression(s.capacity)
        return seq([setup, SArrayAlloc(s.a, cap)])
    if isinstance(s, SArrayReAlloc):
        setup, cap = simplify_and_optimize_expression(s.new_capacity)
        return seq([setup, SArrayReAlloc(s.a, cap)])
    if isinstance(s, SEnsureCapacity):
        setup, cap = simplify_and_optimize_expression(s.capacity)
        return seq([setup, SEnsureCapacity(s.a, cap)])
    if isinstance(s, SSwap):
        # TODO: if we want to optimize the operands we will need a special
        # procedure that optimizes lvalues while preserving meaning... same
        # goes for SAssign case above.
        return s
    if isinstance(s, SSwitch):
        setup, e = simplify_and_optimize_expression(s.e)
        new_cases = [(case, simplify_and_optimize(stm)) for (case, stm) in s.cases]
        new_default = simplify_and_optimize(s.default)
        return seq([setup, SSwitch(e, new_cases, new_default)])
    if isinstance(s, SErase) or isinstance(s, SInsert):
        return s
    raise NotImplementedError(repr(s))
예제 #20
0
def simplify_and_optimize(s : Stm) -> Stm:
    """Simplify and optimize a statement.

    Input:
      s - a statement to optimize

    Output:
      A statement that is functionally equivalent to the input.

    This function makes two big transformations:
      - "compile" many kinds of expressions (listed below) to simpler forms so
        that downstream code generation has less work to do
      - avoid creating short-lived intermediate objects (see `stream`)

    Expression types eliminated by this procedure:
      - EMap, EFilter, EFlatMap
      - EArg{Min,Max}
      - unary ops:
        Distinct,
        AreUnique,
        Length, Sum, All, Any,
        Exists, Empty,
        The
      - binary ops:
        In (where the collection is a Bag or List)
        "-" on collections
        "+" on collections
      - EMakeMap2
      - ELet
      - EListSlice
      - EStm
    """
    assert isinstance(s, Stm)

    if isinstance(s, SNoOp):
        return s
    if isinstance(s, SSeq):
        # TODO: while the first basic statement in s1 is an SDecl, we should
        # apply `re_use` to perhaps eliminate or inline the declaration.
        return seq([simplify_and_optimize(s.s1), simplify_and_optimize(s.s2)])
    if isinstance(s, SAssign):
        setup, e = simplify_and_optimize_expression(s.rhs)
        return seq([setup, SAssign(s.lhs, e)])
    if isinstance(s, SReturn):
        setup, e = simplify_and_optimize_expression(s.e)
        return seq([setup, SReturn(e)])
    if isinstance(s, SDecl):
        setup, e = simplify_and_optimize_expression(s.val)
        return seq([setup, SDecl(s.var, e)])
    if isinstance(s, SForEach):
        return stream(s.iter, s.loop_var, s.body)
    if isinstance(s, SEscape):
        return s
    if isinstance(s, SIf):
        setup, test = simplify_and_optimize_expression(s.cond)
        if test == ETRUE:
            return simplify_and_optimize(s.then_branch)
        if test == EFALSE:
            return simplify_and_optimize(s.else_branch)
        return seq([setup, SIf(test, simplify_and_optimize(s.then_branch), simplify_and_optimize(s.else_branch))])
    if isinstance(s, SWhile):
        setup, cond = simplify_and_optimize_expression(s.e)
        if setup != SNoOp():
            # This is a problem because we don't want to duplicate the setup
            # condition.
            # TODO: introduce an SEscapableBlock/SEscapeBlock to do it right
            raise ValueError("oops! setup for condition {} is very long:\n{}".format(pprint(s.e), pprint(setup)))
        return SWhile(cond, simplify_and_optimize(s.body))
    if isinstance(s, SScoped):
        return SScoped(simplify_and_optimize(s.s))
    if isinstance(s, SMapUpdate):
        # TODO: optimize s.map & s.key
        # TODO: s.map must be optimized as an lvalue
        mapsetup, map = simplify_and_optimize_lvalue(s.map)
        keysetup, key = simplify_and_optimize_expression(s.key)
        return seq([
            mapsetup,
            keysetup,
            SMapUpdate(map, key, s.val_var, simplify_and_optimize(s.change))])
    if isinstance(s, SMapDel):
        mapsetup, map = simplify_and_optimize_lvalue(s.map)
        keysetup, key = simplify_and_optimize_expression(s.key)
        return seq([
            mapsetup,
            keysetup,
            SMapDel(map, key)])
    if isinstance(s, SCall):
        setups, args = zip(*(simplify_and_optimize_expression(a) for a in s.args))
        return seq(list(setups) + [SCall(s.target, s.func, tuple(args))])
    if isinstance(s, SEscapableBlock):
        return SEscapableBlock(s.label, simplify_and_optimize(s.body))
    if isinstance(s, SEscapeBlock):
        return s
    if isinstance(s, SArrayAlloc):
        setup, cap = simplify_and_optimize_expression(s.capacity)
        return seq([setup, SArrayAlloc(s.a, cap)])
    if isinstance(s, SArrayReAlloc):
        setup, cap = simplify_and_optimize_expression(s.new_capacity)
        return seq([setup, SArrayReAlloc(s.a, cap)])
    if isinstance(s, SEnsureCapacity):
        setup, cap = simplify_and_optimize_expression(s.capacity)
        return seq([setup, SEnsureCapacity(s.a, cap)])
    if isinstance(s, SSwap):
        # TODO: if we want to optimize the operands we will need a special
        # procedure that optimizes lvalues while preserving meaning... same
        # goes for SAssign case above.
        return s
    if isinstance(s, SSwitch):
        setup, e = simplify_and_optimize_expression(s.e)
        new_cases = [(case, simplify_and_optimize(stm)) for (case, stm) in s.cases]
예제 #21
0
    if isinstance(s, SArrayReAlloc):
        setup, cap = simplify_and_optimize_expression(s.new_capacity)
        return seq([setup, SArrayReAlloc(s.a, cap)])
    if isinstance(s, SEnsureCapacity):
        setup, cap = simplify_and_optimize_expression(s.capacity)
        return seq([setup, SEnsureCapacity(s.a, cap)])
    if isinstance(s, SSwap):
        # TODO: if we want to optimize the operands we will need a special
        # procedure that optimizes lvalues while preserving meaning... same
        # goes for SAssign case above.
        return s
    if isinstance(s, SSwitch):
        setup, e = simplify_and_optimize_expression(s.e)
        new_cases = [(case, simplify_and_optimize(stm)) for (case, stm) in s.cases]
        new_default = simplify_and_optimize(s.default)
        return seq([setup, SSwitch(e, new_cases, new_default)])
    if isinstance(s, SErase) or isinstance(s, SInsert):
        return s
    raise NotImplementedError(repr(s))

class ExpressionOptimizer(BottomUpRewriter):
    def __init__(self):
        super().__init__()
        self.stms = []

    def visit_iterable(self, e):
        res = fresh_var(e.type)
        self.stms.append(SDecl(res, EEmptyList().with_type(e.type)))
        x = fresh_var(e.type.elem_type)
        self.stms.append(simplify_and_optimize(SForEach(x, e, SCall(res, "add", (x,)))))
        return EMove(res).with_type(res.type)
예제 #22
0
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = [],
        invariants: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(
                syntax.EAll(itertools.chain(assumptions, invariants)),
                syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []
    new_value = strip_EStateVar(new_value)

    def make_subgoal(e, a=[], docstring=None):
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    elif is_numeric(t) and update_numbers_with_deltas.value:
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.TMap):
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
        new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

        # (1) exit set
        deleted_keys = syntax.EBinOp(old_keys, "-",
                                     new_keys).with_type(key_bag)
        s1 = syntax.SForEach(
            k,
            make_subgoal(deleted_keys,
                         docstring="keys removed from {}".format(
                             pprint(lval))), target_syntax.SMapDel(lval, k))

        # (2) enter/mod set
        new_or_modified = target_syntax.EFilter(
            new_keys,
            syntax.ELambda(
                k,
                syntax.EAny([
                    syntax.ENot(syntax.EIn(k, old_keys)),
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value, k)))
                ]))).with_type(key_bag)
        update_value = recurse(v,
                               value_at(old_value, k),
                               value_at(new_value, k),
                               ctx=ctx,
                               assumptions=assumptions + [
                                   syntax.EIn(k, new_or_modified),
                                   syntax.EEq(v, value_at(old_value, k))
                               ],
                               invariants=invariants)
        s2 = syntax.SForEach(
            k,
            make_subgoal(new_or_modified,
                         docstring="new or modified keys from {}".format(
                             pprint(lval))),
            target_syntax.SMapUpdate(lval, k, v, update_value))

        stm = syntax.SSeq(s1, s2)
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)
예제 #23
0
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(syntax.EAll(assumptions),
                            syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []

    def make_subgoal(e, a=[], docstring=None):
        e = strip_EStateVar(e)
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    # elif isinstance(t, syntax.TList):
    #     raise NotImplementedError()
    elif is_numeric(t):
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx,
                    assumptions) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx,
                    assumptions) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.THandle):
        # handles are tricky, and are dealt with at a higher level
        stm = syntax.SNoOp()
    elif isinstance(t, syntax.TMap):
        value_at = lambda m, k: target_syntax.EMapGet(m, k).with_type(lval.type
                                                                      .v)
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        if True:
            old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
            new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

            # (1) exit set
            deleted_keys = target_syntax.EFilter(
                old_keys,
                target_syntax.ELambda(k, syntax.ENot(syntax.EIn(
                    k, new_keys)))).with_type(key_bag)
            s1 = syntax.SForEach(
                k,
                make_subgoal(deleted_keys,
                             docstring="keys removed from {}".format(
                                 pprint(lval))),
                target_syntax.SMapDel(lval, k))

            # (2) modify set
            common_keys = target_syntax.EFilter(
                old_keys,
                target_syntax.ELambda(k,
                                      syntax.EIn(k,
                                                 new_keys))).with_type(key_bag)
            update_value = recurse(v,
                                   value_at(old_value, k),
                                   value_at(new_value, k),
                                   ctx=ctx,
                                   assumptions=assumptions + [
                                       syntax.EIn(k, common_keys),
                                       syntax.ENot(
                                           syntax.EEq(value_at(old_value, k),
                                                      value_at(new_value, k)))
                                   ])
            altered_keys = target_syntax.EFilter(
                common_keys,
                target_syntax.ELambda(
                    k,
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value,
                                            k))))).with_type(key_bag)
            s2 = syntax.SForEach(
                k,
                make_subgoal(altered_keys,
                             docstring="altered keys in {}".format(
                                 pprint(lval))),
                target_syntax.SMapUpdate(lval, k, v, update_value))

            # (3) enter set
            fresh_keys = target_syntax.EFilter(
                new_keys,
                target_syntax.ELambda(k, syntax.ENot(syntax.EIn(
                    k, old_keys)))).with_type(key_bag)
            s3 = syntax.SForEach(
                k,
                make_subgoal(fresh_keys,
                             docstring="new keys in {}".format(pprint(lval))),
                target_syntax.SMapPut(
                    lval, k,
                    make_subgoal(value_at(new_value, k),
                                 a=[syntax.EIn(k, fresh_keys)],
                                 docstring="new value inserted at {}".format(
                                     pprint(target_syntax.EMapGet(lval, k))))))

            stm = syntax.seq([s1, s2, s3])

        else:
            # update_value = code to update for value v at key k (given that k is an altered key)
            update_value = recurse(v,
                                   value_at(old_value, k),
                                   value_at(new_value, k),
                                   ctx=ctx,
                                   assumptions=assumptions + [
                                       syntax.ENot(
                                           syntax.EEq(value_at(old_value, k),
                                                      value_at(new_value, k)))
                                   ])

            # altered_keys = [k | k <- distinct(lval.keys() + new_value.keys()), value_at(old_value, k) != value_at(new_value, k))]
            altered_keys = make_subgoal(
                target_syntax.EFilter(
                    syntax.EUnaryOp(
                        syntax.UOp.Distinct,
                        syntax.EBinOp(
                            target_syntax.EMapKeys(old_value).with_type(
                                key_bag), "+",
                            target_syntax.EMapKeys(new_value).with_type(
                                key_bag)).with_type(key_bag)).with_type(
                                    key_bag),
                    target_syntax.ELambda(
                        k,
                        syntax.ENot(
                            syntax.EEq(value_at(old_value, k),
                                       value_at(new_value,
                                                k))))).with_type(key_bag))
            stm = syntax.SForEach(
                k, altered_keys,
                target_syntax.SMapUpdate(lval, k, v, update_value))
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)