Пример #1
0
def value_at(m, k):
    """Make an AST node for m[k]."""
    if isinstance(m, target_syntax.EMakeMap2):
        return syntax.ECond(syntax.EIn(k, m.e), m.value.apply_to(k),
                            construct_value(m.type.v)).with_type(m.type.v)
    if isinstance(m, syntax.ECond):
        return syntax.ECond(m.cond, value_at(m.then_branch, k),
                            value_at(m.else_branch, k)).with_type(m.type.v)
    return target_syntax.EMapGet(m, k).with_type(m.type.v)
Пример #2
0
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = [],
        invariants: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(
                syntax.EAll(itertools.chain(assumptions, invariants)),
                syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []
    new_value = strip_EStateVar(new_value)

    def make_subgoal(e, a=[], docstring=None):
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    elif is_numeric(t) and update_numbers_with_deltas.value:
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.TMap):
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
        new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

        # (1) exit set
        deleted_keys = syntax.EBinOp(old_keys, "-",
                                     new_keys).with_type(key_bag)
        s1 = syntax.SForEach(
            k,
            make_subgoal(deleted_keys,
                         docstring="keys removed from {}".format(
                             pprint(lval))), target_syntax.SMapDel(lval, k))

        # (2) enter/mod set
        new_or_modified = target_syntax.EFilter(
            new_keys,
            syntax.ELambda(
                k,
                syntax.EAny([
                    syntax.ENot(syntax.EIn(k, old_keys)),
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value, k)))
                ]))).with_type(key_bag)
        update_value = recurse(v,
                               value_at(old_value, k),
                               value_at(new_value, k),
                               ctx=ctx,
                               assumptions=assumptions + [
                                   syntax.EIn(k, new_or_modified),
                                   syntax.EEq(v, value_at(old_value, k))
                               ],
                               invariants=invariants)
        s2 = syntax.SForEach(
            k,
            make_subgoal(new_or_modified,
                         docstring="new or modified keys from {}".format(
                             pprint(lval))),
            target_syntax.SMapUpdate(lval, k, v, update_value))

        stm = syntax.SSeq(s1, s2)
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)
Пример #3
0
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(syntax.EAll(assumptions),
                            syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []

    def make_subgoal(e, a=[], docstring=None):
        e = strip_EStateVar(e)
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    # elif isinstance(t, syntax.TList):
    #     raise NotImplementedError()
    elif is_numeric(t):
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx,
                    assumptions) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx,
                    assumptions) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.THandle):
        # handles are tricky, and are dealt with at a higher level
        stm = syntax.SNoOp()
    elif isinstance(t, syntax.TMap):
        value_at = lambda m, k: target_syntax.EMapGet(m, k).with_type(lval.type
                                                                      .v)
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        if True:
            old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
            new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

            # (1) exit set
            deleted_keys = target_syntax.EFilter(
                old_keys,
                target_syntax.ELambda(k, syntax.ENot(syntax.EIn(
                    k, new_keys)))).with_type(key_bag)
            s1 = syntax.SForEach(
                k,
                make_subgoal(deleted_keys,
                             docstring="keys removed from {}".format(
                                 pprint(lval))),
                target_syntax.SMapDel(lval, k))

            # (2) modify set
            common_keys = target_syntax.EFilter(
                old_keys,
                target_syntax.ELambda(k,
                                      syntax.EIn(k,
                                                 new_keys))).with_type(key_bag)
            update_value = recurse(v,
                                   value_at(old_value, k),
                                   value_at(new_value, k),
                                   ctx=ctx,
                                   assumptions=assumptions + [
                                       syntax.EIn(k, common_keys),
                                       syntax.ENot(
                                           syntax.EEq(value_at(old_value, k),
                                                      value_at(new_value, k)))
                                   ])
            altered_keys = target_syntax.EFilter(
                common_keys,
                target_syntax.ELambda(
                    k,
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value,
                                            k))))).with_type(key_bag)
            s2 = syntax.SForEach(
                k,
                make_subgoal(altered_keys,
                             docstring="altered keys in {}".format(
                                 pprint(lval))),
                target_syntax.SMapUpdate(lval, k, v, update_value))

            # (3) enter set
            fresh_keys = target_syntax.EFilter(
                new_keys,
                target_syntax.ELambda(k, syntax.ENot(syntax.EIn(
                    k, old_keys)))).with_type(key_bag)
            s3 = syntax.SForEach(
                k,
                make_subgoal(fresh_keys,
                             docstring="new keys in {}".format(pprint(lval))),
                target_syntax.SMapPut(
                    lval, k,
                    make_subgoal(value_at(new_value, k),
                                 a=[syntax.EIn(k, fresh_keys)],
                                 docstring="new value inserted at {}".format(
                                     pprint(target_syntax.EMapGet(lval, k))))))

            stm = syntax.seq([s1, s2, s3])

        else:
            # update_value = code to update for value v at key k (given that k is an altered key)
            update_value = recurse(v,
                                   value_at(old_value, k),
                                   value_at(new_value, k),
                                   ctx=ctx,
                                   assumptions=assumptions + [
                                       syntax.ENot(
                                           syntax.EEq(value_at(old_value, k),
                                                      value_at(new_value, k)))
                                   ])

            # altered_keys = [k | k <- distinct(lval.keys() + new_value.keys()), value_at(old_value, k) != value_at(new_value, k))]
            altered_keys = make_subgoal(
                target_syntax.EFilter(
                    syntax.EUnaryOp(
                        syntax.UOp.Distinct,
                        syntax.EBinOp(
                            target_syntax.EMapKeys(old_value).with_type(
                                key_bag), "+",
                            target_syntax.EMapKeys(new_value).with_type(
                                key_bag)).with_type(key_bag)).with_type(
                                    key_bag),
                    target_syntax.ELambda(
                        k,
                        syntax.ENot(
                            syntax.EEq(value_at(old_value, k),
                                       value_at(new_value,
                                                k))))).with_type(key_bag))
            stm = syntax.SForEach(
                k, altered_keys,
                target_syntax.SMapUpdate(lval, k, v, update_value))
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)