def mutate(e: syntax.Exp, op: syntax.Stm) -> syntax.Exp:
    """Return the new value of `e` after executing `op`."""
    if isinstance(op, syntax.SNoOp):
        return e
    elif isinstance(op, syntax.SAssign):
        return _do_assignment(op.lhs, op.rhs, e)
    elif isinstance(op, syntax.SCall):
        if op.func == "add":
            return mutate(
                e,
                syntax.SCall(op.target, "add_all", (syntax.ESingleton(
                    op.args[0]).with_type(op.target.type), )))
        elif op.func == "add_all":
            return mutate(
                e,
                syntax.SAssign(
                    op.target,
                    syntax.EBinOp(op.target, "+",
                                  op.args[0]).with_type(op.target.type)))
        elif op.func == "remove":
            return mutate(
                e,
                syntax.SCall(op.target, "remove_all", (syntax.ESingleton(
                    op.args[0]).with_type(op.target.type), )))
        elif op.func == "remove_all":
            return mutate(
                e,
                syntax.SAssign(
                    op.target,
                    syntax.EBinOp(op.target, "-",
                                  op.args[0]).with_type(op.target.type)))
        else:
            raise Exception("Unknown func: {}".format(op.func))
    elif isinstance(op, syntax.SIf):
        then_branch = mutate(e, op.then_branch)
        else_branch = mutate(e, op.else_branch)
        if alpha_equivalent(then_branch, else_branch):
            return then_branch
        return syntax.ECond(op.cond, then_branch,
                            else_branch).with_type(e.type)
    elif isinstance(op, syntax.SSeq):
        if isinstance(op.s1, syntax.SSeq):
            return mutate(e, syntax.SSeq(op.s1.s1,
                                         syntax.SSeq(op.s1.s2, op.s2)))
        e2 = mutate(mutate(e, op.s2), op.s1)
        if isinstance(op.s1, syntax.SDecl):
            e2 = subst(e2, {op.s1.id: op.s1.val})
        return e2
    elif isinstance(op, syntax.SDecl):
        return e
    else:
        raise NotImplementedError(type(op))
Exemple #2
0
def _fix_map(m: target_syntax.EMap) -> syntax.Exp:
    return m
    from cozy.simplification import simplify
    m = simplify(m)
    if not isinstance(m, target_syntax.EMap):
        return m
    elem_type = m.e.type.t
    assert m.f.body.type == elem_type
    changed = target_syntax.EFilter(
        m.e,
        mk_lambda(
            elem_type, lambda x: syntax.ENot(
                syntax.EBinOp(x, "===", m.f.apply_to(x)).with_type(syntax.BOOL)
            ))).with_type(m.e.type)
    e = syntax.EBinOp(
        syntax.EBinOp(m.e, "-", changed).with_type(m.e.type), "+",
        target_syntax.EMap(changed,
                           m.f).with_type(m.e.type)).with_type(m.e.type)
    if not valid(syntax.EEq(m, e)):
        print("WARNING: rewrite failed")
        print("_fix_map({!r})".format(m))
        return m
    return e
Exemple #3
0
 def p_exp(p):
     """exp : NUM
            | FLOAT
            | WORD
            | WORD OP_OPEN_PAREN exp_list OP_CLOSE_PAREN
            | KW_TRUE
            | KW_FALSE
            | exp OP_PLUS  exp
            | exp OP_MINUS exp
            | exp OP_TIMES exp
            | exp OP_EQ exp
            | exp OP_NE exp
            | exp OP_LT exp
            | exp OP_LE exp
            | exp OP_GT exp
            | exp OP_GE exp
            | exp KW_AND exp
            | exp KW_OR exp
            | exp OP_IMPLIES exp
            | exp OP_QUESTION exp OP_COLON exp
            | exp OP_OPEN_BRACKET slice OP_CLOSE_BRACKET
            | KW_NOT exp
            | OP_MINUS exp
            | exp KW_IN exp
            | KW_UNIQUE exp
            | KW_DISTINCT exp
            | KW_EMPTY exp
            | KW_THE exp
            | KW_MIN exp
            | KW_MAX exp
            | KW_ARGMIN lambda exp
            | KW_ARGMAX lambda exp
            | KW_SUM exp
            | KW_LEN exp
            | KW_ANY exp
            | KW_ALL exp
            | KW_EXISTS exp
            | KW_REVERSED exp
            | exp OP_DOT NUM
            | exp OP_DOT WORD
            | OP_OPEN_PAREN exp_list OP_CLOSE_PAREN
            | OP_OPEN_BRACE record_fields OP_CLOSE_BRACE
            | OP_OPEN_BRACKET exp OP_CLOSE_BRACKET
            | OP_OPEN_BRACKET exp OP_VBAR comprehension_body OP_CLOSE_BRACKET"""
     if len(p) == 2:
         if type(p[1]) is syntax.ENum:
             p[0] = p[1]
         elif p[1] == "true":
             p[0] = syntax.EBool(True)
         elif p[1] == "false":
             p[0] = syntax.EBool(False)
         else:
             p[0] = syntax.EVar(p[1])
     elif len(p) == 3:
         if p[1] == "min":
             p[0] = syntax.EArgMin(
                 p[2], syntax.ELambda(syntax.EVar("x"), syntax.EVar("x")))
         elif p[1] == "max":
             p[0] = syntax.EArgMax(
                 p[2], syntax.ELambda(syntax.EVar("x"), syntax.EVar("x")))
         else:
             p[0] = syntax.EUnaryOp(p[1], p[2])
     elif len(p) == 4:
         if p[1] == "(":
             exps = p[2]
             if len(exps) == 0:
                 raise Exception("illegal ()")
             elif len(exps) == 1:
                 p[0] = exps[0]
             elif len(exps) > 1:
                 p[0] = syntax.ETuple(tuple(exps))
         elif p[1] == "[":
             p[0] = syntax.ESingleton(p[2])
         elif p[1] == "{":
             p[0] = syntax.EMakeRecord(p[2])
         elif p[2] == ".":
             if isinstance(p[3], syntax.ENum):
                 p[0] = syntax.ETupleGet(p[1], p[3].val)
             else:
                 p[0] = syntax.EGetField(p[1], p[3])
         elif p[1] == "argmin":
             p[0] = syntax.EArgMin(p[3], p[2])
         elif p[1] == "argmax":
             p[0] = syntax.EArgMax(p[3], p[2])
         else:
             p[0] = syntax.EBinOp(p[1], p[2], p[3])
     else:
         if p[2] == "?":
             p[0] = syntax.ECond(p[1], p[3], p[5])
         elif p[2] == "[":
             if isinstance(p[3], syntax.Exp):
                 p[0] = syntax.EListGet(p[1], p[3])
             elif isinstance(p[3], tuple):
                 start = p[3][0]
                 end = p[3][1]
                 if start is None:
                     start = syntax.ZERO
                 if end is None:
                     end = syntax.ELen(p[1])
                 p[0] = syntax.EListSlice(p[1], start, end)
         elif p[1] == "[":
             p[0] = syntax.EListComprehension(p[2], p[4])
         elif p[2] == "(":
             p[0] = syntax.ECall(p[1], p[3])
         else:
             assert False, "unknown case: {}".format(repr(p[1:]))
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = [],
        invariants: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(
                syntax.EAll(itertools.chain(assumptions, invariants)),
                syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []
    new_value = strip_EStateVar(new_value)

    def make_subgoal(e, a=[], docstring=None):
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    elif is_numeric(t) and update_numbers_with_deltas.value:
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.TMap):
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
        new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

        # (1) exit set
        deleted_keys = syntax.EBinOp(old_keys, "-",
                                     new_keys).with_type(key_bag)
        s1 = syntax.SForEach(
            k,
            make_subgoal(deleted_keys,
                         docstring="keys removed from {}".format(
                             pprint(lval))), target_syntax.SMapDel(lval, k))

        # (2) enter/mod set
        new_or_modified = target_syntax.EFilter(
            new_keys,
            syntax.ELambda(
                k,
                syntax.EAny([
                    syntax.ENot(syntax.EIn(k, old_keys)),
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value, k)))
                ]))).with_type(key_bag)
        update_value = recurse(v,
                               value_at(old_value, k),
                               value_at(new_value, k),
                               ctx=ctx,
                               assumptions=assumptions + [
                                   syntax.EIn(k, new_or_modified),
                                   syntax.EEq(v, value_at(old_value, k))
                               ],
                               invariants=invariants)
        s2 = syntax.SForEach(
            k,
            make_subgoal(new_or_modified,
                         docstring="new or modified keys from {}".format(
                             pprint(lval))),
            target_syntax.SMapUpdate(lval, k, v, update_value))

        stm = syntax.SSeq(s1, s2)
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)
Exemple #5
0
def _delta_form(res: {str: syntax.Exp}, op: syntax.Stm) -> {str: syntax.Exp}:
    if isinstance(op, syntax.SNoOp):
        pass
    elif isinstance(op, syntax.SCall):
        update = _rewriter(op.target)
        if op.func == "add":
            update(
                res, lambda old: syntax.EBinOp(
                    old, "+",
                    syntax.ESingleton(op.args[0]).with_type(old.type)).
                with_type(old.type))
        elif op.func == "add_back":
            update(
                res, lambda old: syntax.EBinOp(
                    old, "+",
                    syntax.ESingleton(op.args[0]).with_type(old.type)).
                with_type(old.type))
        elif op.func == "add_front":
            update(
                res, lambda old: syntax.EBinOp(
                    syntax.ESingleton(op.args[0]).with_type(old.type), "+", old
                ).with_type(old.type))
        elif op.func == "remove_back":
            update(
                res,
                lambda old: target_syntax.EDropBack(old).with_type(old.type))
        elif op.func == "remove_front":
            update(
                res,
                lambda old: target_syntax.EDropFront(old).with_type(old.type))
        elif op.func == "remove":
            update(
                res, lambda old: syntax.EBinOp(
                    old, "-",
                    syntax.ESingleton(op.args[0]).with_type(old.type)).
                with_type(old.type))
        elif op.func == "remove_all":
            update(
                res, lambda old: syntax.EBinOp(old, "-", op.args[0]).with_type(
                    old.type))
        else:
            raise Exception("Unknown func: {}".format(op.func))
    elif isinstance(op, syntax.SAssign):
        update = _rewriter(op.lhs)
        update(res, lambda old: op.rhs)
    elif isinstance(op, syntax.SIf):
        res_then = res.copy()
        _delta_form(res_then, op.then_branch)
        res_else = res.copy()
        _delta_form(res_else, op.else_branch)

        for key in res:
            then_val = res_then[key]
            else_val = res_else[key]

            if then_val == else_val:
                res[key] = then_val
            else:
                # Substatements differ; need to defer to ECond evaluation.
                res[key] = syntax.ECond(op.cond, then_val,
                                        else_val).with_type(then_val.type)
    elif isinstance(op, syntax.SSeq):
        _delta_form(res, op.s1)
        _delta_form(res, subst(op.s2, res))
    else:
        raise NotImplementedError(type(op))

    return res
Exemple #6
0
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(syntax.EAll(assumptions),
                            syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []

    def make_subgoal(e, a=[], docstring=None):
        e = strip_EStateVar(e)
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    # elif isinstance(t, syntax.TList):
    #     raise NotImplementedError()
    elif is_numeric(t):
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx,
                    assumptions) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx,
                    assumptions) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.THandle):
        # handles are tricky, and are dealt with at a higher level
        stm = syntax.SNoOp()
    elif isinstance(t, syntax.TMap):
        value_at = lambda m, k: target_syntax.EMapGet(m, k).with_type(lval.type
                                                                      .v)
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        if True:
            old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
            new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

            # (1) exit set
            deleted_keys = target_syntax.EFilter(
                old_keys,
                target_syntax.ELambda(k, syntax.ENot(syntax.EIn(
                    k, new_keys)))).with_type(key_bag)
            s1 = syntax.SForEach(
                k,
                make_subgoal(deleted_keys,
                             docstring="keys removed from {}".format(
                                 pprint(lval))),
                target_syntax.SMapDel(lval, k))

            # (2) modify set
            common_keys = target_syntax.EFilter(
                old_keys,
                target_syntax.ELambda(k,
                                      syntax.EIn(k,
                                                 new_keys))).with_type(key_bag)
            update_value = recurse(v,
                                   value_at(old_value, k),
                                   value_at(new_value, k),
                                   ctx=ctx,
                                   assumptions=assumptions + [
                                       syntax.EIn(k, common_keys),
                                       syntax.ENot(
                                           syntax.EEq(value_at(old_value, k),
                                                      value_at(new_value, k)))
                                   ])
            altered_keys = target_syntax.EFilter(
                common_keys,
                target_syntax.ELambda(
                    k,
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value,
                                            k))))).with_type(key_bag)
            s2 = syntax.SForEach(
                k,
                make_subgoal(altered_keys,
                             docstring="altered keys in {}".format(
                                 pprint(lval))),
                target_syntax.SMapUpdate(lval, k, v, update_value))

            # (3) enter set
            fresh_keys = target_syntax.EFilter(
                new_keys,
                target_syntax.ELambda(k, syntax.ENot(syntax.EIn(
                    k, old_keys)))).with_type(key_bag)
            s3 = syntax.SForEach(
                k,
                make_subgoal(fresh_keys,
                             docstring="new keys in {}".format(pprint(lval))),
                target_syntax.SMapPut(
                    lval, k,
                    make_subgoal(value_at(new_value, k),
                                 a=[syntax.EIn(k, fresh_keys)],
                                 docstring="new value inserted at {}".format(
                                     pprint(target_syntax.EMapGet(lval, k))))))

            stm = syntax.seq([s1, s2, s3])

        else:
            # update_value = code to update for value v at key k (given that k is an altered key)
            update_value = recurse(v,
                                   value_at(old_value, k),
                                   value_at(new_value, k),
                                   ctx=ctx,
                                   assumptions=assumptions + [
                                       syntax.ENot(
                                           syntax.EEq(value_at(old_value, k),
                                                      value_at(new_value, k)))
                                   ])

            # altered_keys = [k | k <- distinct(lval.keys() + new_value.keys()), value_at(old_value, k) != value_at(new_value, k))]
            altered_keys = make_subgoal(
                target_syntax.EFilter(
                    syntax.EUnaryOp(
                        syntax.UOp.Distinct,
                        syntax.EBinOp(
                            target_syntax.EMapKeys(old_value).with_type(
                                key_bag), "+",
                            target_syntax.EMapKeys(new_value).with_type(
                                key_bag)).with_type(key_bag)).with_type(
                                    key_bag),
                    target_syntax.ELambda(
                        k,
                        syntax.ENot(
                            syntax.EEq(value_at(old_value, k),
                                       value_at(new_value,
                                                k))))).with_type(key_bag))
            stm = syntax.SForEach(
                k, altered_keys,
                target_syntax.SMapUpdate(lval, k, v, update_value))
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)