示例#1
0
def to_abstract(t):
    if isinstance(t, syntax.TBag):
        return syntax.TBag(t.t)
    if isinstance(t, syntax.TSet):
        return syntax.TSet(t.t)
    if isinstance(t, syntax.TList):
        return syntax.TList(t.t)
    return t
示例#2
0
 def visit_TApp(self, t):
     if t.t == "Set":
         return syntax.TSet(self.visit(t.args))
     elif t.t == "Bag":
         return syntax.TBag(self.visit(t.args))
     elif t.t == "List":
         return syntax.TList(self.visit(t.args))
     else:
         self.report_err(t, "unknown type {}".format(t.t))
         return t
示例#3
0
 def lub(self, src, t1, t2, explanation):
     if t1 == t2:
         return t1
     if is_numeric(t1) and is_numeric(t2):
         return self.numeric_lub(src, t1, t2)
     if isinstance(t1, syntax.TList) and isinstance(t2, syntax.TList):
         return syntax.TList(t1.t)
     if is_collection(t1) and is_collection(t2):
         return syntax.TBag(t1.t)
     self.report_err(
         src,
         "cannot unify types {} and {} ({})".format(pprint(t1), pprint(t2),
                                                    explanation))
     return DEFAULT_TYPE
示例#4
0
 def visit_EMap(self, e):
     self.visit(e.e)
     elem_type = self.get_collection_type(e.e)
     e.f.arg.type = elem_type
     self.visit(e.f)
     if isinstance(e.e.type, syntax.TSet):
         # Sets might not have distinct elements after the map transform.
         # Consider e.g. `map {\x -> 1} my_set`.
         e.type = syntax.TBag(e.f.body.type)
     elif is_collection(e.e.type):
         e.type = type(to_abstract(e.e.type))(e.f.body.type)
     elif e.e.type is DEFAULT_TYPE:
         e.type = DEFAULT_TYPE
     else:
         self.report_err(e, "cannot map over non-collection {}".format(pprint(e.e.type)))
         e.type = DEFAULT_TYPE
示例#5
0
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = [],
        invariants: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(
                syntax.EAll(itertools.chain(assumptions, invariants)),
                syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []
    new_value = strip_EStateVar(new_value)

    def make_subgoal(e, a=[], docstring=None):
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    elif is_numeric(t) and update_numbers_with_deltas.value:
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.TMap):
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
        new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

        # (1) exit set
        deleted_keys = syntax.EBinOp(old_keys, "-",
                                     new_keys).with_type(key_bag)
        s1 = syntax.SForEach(
            k,
            make_subgoal(deleted_keys,
                         docstring="keys removed from {}".format(
                             pprint(lval))), target_syntax.SMapDel(lval, k))

        # (2) enter/mod set
        new_or_modified = target_syntax.EFilter(
            new_keys,
            syntax.ELambda(
                k,
                syntax.EAny([
                    syntax.ENot(syntax.EIn(k, old_keys)),
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value, k)))
                ]))).with_type(key_bag)
        update_value = recurse(v,
                               value_at(old_value, k),
                               value_at(new_value, k),
                               ctx=ctx,
                               assumptions=assumptions + [
                                   syntax.EIn(k, new_or_modified),
                                   syntax.EEq(v, value_at(old_value, k))
                               ],
                               invariants=invariants)
        s2 = syntax.SForEach(
            k,
            make_subgoal(new_or_modified,
                         docstring="new or modified keys from {}".format(
                             pprint(lval))),
            target_syntax.SMapUpdate(lval, k, v, update_value))

        stm = syntax.SSeq(s1, s2)
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)
示例#6
0
 def visit_EMapKeys(self, e):
     self.visit(e.e)
     k, v = self.get_map_type(e.e)
     e.type = syntax.TBag(k)
示例#7
0
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(syntax.EAll(assumptions),
                            syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []

    def make_subgoal(e, a=[], docstring=None):
        e = strip_EStateVar(e)
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    # elif isinstance(t, syntax.TList):
    #     raise NotImplementedError()
    elif is_numeric(t):
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx,
                    assumptions) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx,
                    assumptions) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.THandle):
        # handles are tricky, and are dealt with at a higher level
        stm = syntax.SNoOp()
    elif isinstance(t, syntax.TMap):
        value_at = lambda m, k: target_syntax.EMapGet(m, k).with_type(lval.type
                                                                      .v)
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        if True:
            old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
            new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

            # (1) exit set
            deleted_keys = target_syntax.EFilter(
                old_keys,
                target_syntax.ELambda(k, syntax.ENot(syntax.EIn(
                    k, new_keys)))).with_type(key_bag)
            s1 = syntax.SForEach(
                k,
                make_subgoal(deleted_keys,
                             docstring="keys removed from {}".format(
                                 pprint(lval))),
                target_syntax.SMapDel(lval, k))

            # (2) modify set
            common_keys = target_syntax.EFilter(
                old_keys,
                target_syntax.ELambda(k,
                                      syntax.EIn(k,
                                                 new_keys))).with_type(key_bag)
            update_value = recurse(v,
                                   value_at(old_value, k),
                                   value_at(new_value, k),
                                   ctx=ctx,
                                   assumptions=assumptions + [
                                       syntax.EIn(k, common_keys),
                                       syntax.ENot(
                                           syntax.EEq(value_at(old_value, k),
                                                      value_at(new_value, k)))
                                   ])
            altered_keys = target_syntax.EFilter(
                common_keys,
                target_syntax.ELambda(
                    k,
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value,
                                            k))))).with_type(key_bag)
            s2 = syntax.SForEach(
                k,
                make_subgoal(altered_keys,
                             docstring="altered keys in {}".format(
                                 pprint(lval))),
                target_syntax.SMapUpdate(lval, k, v, update_value))

            # (3) enter set
            fresh_keys = target_syntax.EFilter(
                new_keys,
                target_syntax.ELambda(k, syntax.ENot(syntax.EIn(
                    k, old_keys)))).with_type(key_bag)
            s3 = syntax.SForEach(
                k,
                make_subgoal(fresh_keys,
                             docstring="new keys in {}".format(pprint(lval))),
                target_syntax.SMapPut(
                    lval, k,
                    make_subgoal(value_at(new_value, k),
                                 a=[syntax.EIn(k, fresh_keys)],
                                 docstring="new value inserted at {}".format(
                                     pprint(target_syntax.EMapGet(lval, k))))))

            stm = syntax.seq([s1, s2, s3])

        else:
            # update_value = code to update for value v at key k (given that k is an altered key)
            update_value = recurse(v,
                                   value_at(old_value, k),
                                   value_at(new_value, k),
                                   ctx=ctx,
                                   assumptions=assumptions + [
                                       syntax.ENot(
                                           syntax.EEq(value_at(old_value, k),
                                                      value_at(new_value, k)))
                                   ])

            # altered_keys = [k | k <- distinct(lval.keys() + new_value.keys()), value_at(old_value, k) != value_at(new_value, k))]
            altered_keys = make_subgoal(
                target_syntax.EFilter(
                    syntax.EUnaryOp(
                        syntax.UOp.Distinct,
                        syntax.EBinOp(
                            target_syntax.EMapKeys(old_value).with_type(
                                key_bag), "+",
                            target_syntax.EMapKeys(new_value).with_type(
                                key_bag)).with_type(key_bag)).with_type(
                                    key_bag),
                    target_syntax.ELambda(
                        k,
                        syntax.ENot(
                            syntax.EEq(value_at(old_value, k),
                                       value_at(new_value,
                                                k))))).with_type(key_bag))
            stm = syntax.SForEach(
                k, altered_keys,
                target_syntax.SMapUpdate(lval, k, v, update_value))
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)