Exemplo n.º 1
0
    def visit_EUnaryOp(self, e):
        op = e.op
        if op == UOp.Distinct:
            return self.visit_iterable(e)
        elif op == UOp.The:
            return self.find_one(e.e)
        elif op == UOp.Sum:
            sum_var = fresh_var(e.type, "sum")
            loop_var = fresh_var(e.e.type.elem_type, "x")
            self.stms.append(simplify_and_optimize(seq([
                SDecl(sum_var, ENum(0).with_type(e.type)),
                SForEach(loop_var, e.e,
                    SAssign(sum_var, EBinOp(sum_var, "+", loop_var).with_type(INT)))])))
            return sum_var
        elif op == UOp.Length:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Sum, EMap(e.e, ELambda(arg, ONE)).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.All:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Empty, EFilter(e.e, ELambda(arg, ENot(arg))).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.Any:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Exists, EFilter(e.e, ELambda(arg, arg)).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.Empty:
            iterable = e.e
            v = fresh_var(BOOL, "v")
            label = fresh_name("label")
            x = fresh_var(iterable.type.elem_type, "x")
            decl = SDecl(v, ETRUE)
            find = SEscapableBlock(label,
                SForEach(x, iterable, seq([
                    SAssign(v, EFALSE),
                    SEscapeBlock(label)])))
            self.stms.append(simplify_and_optimize(seq([decl, find])))
            return v
        elif op == UOp.Exists:
            return self.visit(ENot(EUnaryOp(UOp.Empty, e.e).with_type(BOOL)))
        # elif op == UOp.AreUnique:
        #     s = fresh_var(TSet(e.e.type.elem_type), "unique_elems")
        #     u = fresh_var(BOOL, "is_unique")
        #     x = fresh_var(e.e.type.elem_type)
        #     label = fresh_name("label")
        #     self.visit(seq([
        #         SDecl(s, EEmptyList().with_type(s.type)),
        #         SDecl(u, ETRUE),
        #         SEscapableBlock(label,
        #             SForEach(x, e.e,
        #                 SIf(EEscape("{s}.find({x}) != {s}.end()", ("s", "x"), (s, x)).with_type(BOOL),
        #                     seq([SAssign(u, EFALSE), SEscapeBlock(label)]),
        #                     SEscape("{indent}{s}.insert({x});\n", ("s", "x"), (s, x)))))]))
        #     return u.id

        return self.visit_Exp(e)
Exemplo n.º 2
0
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm:
    def pull(e: Exp) -> Exp:
        if exp_is_bad(e):
            v = fresh_var(e.type)
            decls_out.append(SDecl(v.id, e))
            return v
        return e

    if isinstance(s, SNoOp):
        return s
    if isinstance(s, SSeq):
        s1 = pull_temps(s.s1, decls_out, exp_is_bad)
        s2 = pull_temps(s.s2, decls_out, exp_is_bad)
        return SSeq(s1, s2)
    if isinstance(s, SIf):
        cond = pull(s.cond)
        s1 = pull_temps(s.then_branch, decls_out, exp_is_bad)
        s2 = pull_temps(s.else_branch, decls_out, exp_is_bad)
        return SIf(cond, s1, s2)
    if isinstance(s, SForEach):
        bag = pull(s.iter)
        d_tmp = []
        body = pull_temps(s.body, d_tmp, exp_is_bad)
        to_fix, ok = partition(d_tmp, lambda d: s.id in free_vars(d.val))
        decls_out.extend(ok)
        for d in to_fix:
            v = EVar(d.id).with_type(d.val.type)
            mt = TMap(s.id.type, v.type)
            m = EMakeMap2(bag, ELambda(s.id, d.val)).with_type(mt)
            mv = fresh_var(m.type)
            md = SDecl(mv.id, m)
            decls_out.append(md)
            body = subst(body, {v.id: EMapGet(mv, s.id).with_type(v.type)})
        return SForEach(s.id, bag, body)
    if isinstance(s, SAssign):
        return SAssign(s.lhs, pull(s.rhs))
    if isinstance(s, SCall):
        return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args))
    if isinstance(s, SMapDel):
        return SMapDel(s.map, pull(s.key))
    if isinstance(s, SMapPut):
        return SMapPut(s.map, pull(s.key), pull(s.value))
    if isinstance(s, SMapUpdate):
        key = pull(s.key)
        d_tmp = []
        change = pull_temps(s.change, d_tmp, exp_is_bad)
        for d in d_tmp:
            if s.val_var in free_vars(d.val):
                decls_out.append(
                    SDecl(
                        d.id,
                        subst(
                            d.val, {
                                s.val_var.id:
                                EMapGet(s.map, key).with_type(s.val_var.type)
                            })))
            else:
                decls_out.append(d)
        return SMapUpdate(s.map, key, s.val_var, change)
    raise NotImplementedError(s)
Exemplo n.º 3
0
    def _setup_handle_updates(self):
        """
        This method creates update code for handle objects modified by each op.
        Must be called once after all user-specified queries have been added.
        """
        for op in self.op_specs:
            print("Setting up handle updates for {}...".format(op.name))
            handles = reachable_handles_at_method(self.spec, op)
            # print("-"*60)
            for t, bag in handles.items():
                # print("  {} : {}".format(pprint(t), pprint(bag)))
                h = fresh_var(t)
                lval = EGetField(h, "val").with_type(t.value_type)
                new_val = inc.mutate(lval, op.body)

                # get set of modified handles
                modified_handles = Query(
                    fresh_name("modified_handles"), Visibility.Internal, [],
                    op.assumptions,
                    EFilter(
                        EUnaryOp(UOp.Distinct, bag).with_type(bag.type),
                        ELambda(h, ENot(EEq(lval,
                                            new_val)))).with_type(bag.type),
                    "[{}] modified handles of type {}".format(
                        op.name, pprint(t)))
                query_vars = [
                    v for v in free_vars(modified_handles)
                    if v not in self.abstract_state
                ]
                modified_handles.args = [(arg.id, arg.type)
                                         for arg in query_vars]

                # modify each one
                subqueries = []
                state_update_stm = inc.mutate_in_place(
                    lval,
                    lval,
                    op.body,
                    abstract_state=self.abstract_state,
                    assumptions=list(op.assumptions) +
                    [EDeepIn(h, bag),
                     EIn(h, modified_handles.ret)],
                    invariants=self.abstract_invariants,
                    subgoals_out=subqueries)
                for sub_q in subqueries:
                    sub_q.docstring = "[{}] {}".format(op.name,
                                                       sub_q.docstring)
                    state_update_stm = self._add_subquery(
                        sub_q=sub_q, used_by=state_update_stm)
                if state_update_stm != SNoOp():
                    state_update_stm = SForEach(
                        h,
                        ECall(modified_handles.name,
                              query_vars).with_type(bag.type),
                        state_update_stm)
                    state_update_stm = self._add_subquery(
                        sub_q=modified_handles, used_by=state_update_stm)
                self.handle_updates[(t, op.name)] = state_update_stm
Exemplo n.º 4
0
def stream(iterable : Exp, loop_var : EVar, body : Stm) -> Stm:
    """Convert an iterable expression to a streaming operation.

    Input:
      iterable - an expression with an iterable type (Bag, Set, or List), not
        yet optimized
      loop_var - a variable to use as the loop variable
      body - a statement to run on that variable, not yet optimized

    Output:
      A statement equivalent to
        for (loop_var in iterable) { body }
      that eliminates as many intermediate collections and objects as possible.

    NOTE: The output of function will not be correct if the body modifies any
    free variable in the iterable expression or writes to any pointers that
    are read by the iterable expression.

    Generating code for the expression

        Map {func} (Filter {predicate} big_collection)

    might create two new collections as large as `big_collection`: one to hold
    the result of the filter and one to hold the result of the map.  If all the
    code needs to do is to iterate over the result, then there is no reason to
    make the two new collections.

    This function is mutually recursive with `simplify_and_optimize`, so any
    transformations performed by that method are also applied to the output of
    this one.
    """

    if isinstance(iterable, EEmptyList):
        return SNoOp()
    elif isinstance(iterable, ESingleton):
        setup, value = simplify_and_optimize_expression(iterable.e)
        # SScoped because if the iterable is e.g. [x] + [y], then the body
        # might be appear in the same block twice.  If the body declares any
        # variables, that will cause problems in languages like Java or C++.
        return seq([setup, SScoped(re_use(value, loop_var, simplify_and_optimize(body)))])
    elif isinstance(iterable, ECond):
        cond_setup, cond = simplify_and_optimize_expression(iterable.cond)
        return seq([
            cond_setup,
            SIf(cond,
                stream(iterable.then_branch, loop_var, body),
                stream(iterable.else_branch, loop_var, body))])
    elif isinstance(iterable, EUnaryOp) and iterable.op == UOp.Distinct:
        tmp = fresh_var(TSet(iterable.type.elem_type), "distinct_elems")
        return seq([
            SDecl(tmp, EEmptyList().with_type(tmp.type)),
            stream(iterable.e, loop_var, SIf(
                ENot(EBinOp(loop_var, BOp.In, tmp).with_type(BOOL)),
                seq([body, SCall(tmp, "add", [loop_var])]),
                SNoOp()))])
    elif isinstance(iterable, EBinOp) and iterable.op == "+":
        return seq([
            stream(iterable.e1, loop_var, body),
            stream(iterable.e2, loop_var, body)])
    elif isinstance(iterable, EBinOp) and iterable.op == "-":
        if is_hashable(iterable.type.elem_type):
            h_setup, h = histogram(iterable.e2)
            val_ref = fresh_var(INT, "count")
            return seq([
                simplify_and_optimize(h_setup),
                stream(
                    iterable.e1,
                    loop_var,
                    SIf(EGt(EMapGet(h, loop_var).with_type(INT), ZERO),
                        SMapUpdate(h, loop_var, val_ref, SAssign(val_ref, EBinOp(val_ref, "-", ONE).with_type(INT))),
                        body))])
        else:
            rhs = fresh_var(iterable.e2.type, "bag_subtraction_right")
            return seq([
                simplify_and_optimize(SDecl(rhs, iterable.e2)),
                stream(
                    iterable.e1,
                    loop_var,
                    SIf(EIn(loop_var, rhs),
                        SCall(rhs, "remove", (loop_var,)),
                        body))])
    elif isinstance(iterable, EFilter):
        return stream(
            EFlatMap(iterable.e, ELambda(iterable.predicate.arg,
                ECond(iterable.predicate.body,
                    ESingleton(iterable.predicate.arg).with_type(iterable.type),
                    EEmptyList().with_type(iterable.type)).with_type(iterable.type))).with_type(iterable.type),
            loop_var,
            body)
    elif isinstance(iterable, EMap):
        return stream(
            EFlatMap(iterable.e, ELambda(iterable.transform_function.arg,
                ESingleton(iterable.transform_function.body).with_type(iterable.type))).with_type(iterable.type),
            loop_var,
            body)
    elif isinstance(iterable, EFlatMap):
        inner_loop_var = fresh_var(
            iterable.transform_function.arg.type,
            iterable.transform_function.arg.id)
        return stream(
            iterable.e,
            inner_loop_var,
            stream(iterable.transform_function.apply_to(inner_loop_var), loop_var, body))
    elif isinstance(iterable, EListSlice):
        elem_type = iterable.type.elem_type
        l = fresh_var(iterable.e.type, "list")
        s = fresh_var(INT, "start")
        e = fresh_var(INT, "end")
        return simplify_and_optimize(seq([
            SDecl(l, iterable.e),
            SDecl(s, max_of(iterable.start, ZERO)),
            SDecl(e, min_of(iterable.end, ELen(l))),
            SWhile(ELt(s, e), seq([
                SDecl(loop_var, EListGet(l, s).with_type(elem_type)),
                body,
                SAssign(s, EBinOp(s, "+", ONE).with_type(INT))]))]))
    elif isinstance(iterable, ELet):
        v = fresh_var(
            iterable.body_function.arg.type,
            iterable.body_function.arg.id)
        return seq([
            simplify_and_optimize(SDecl(v, iterable.e)),
            stream(iterable.body_function.apply_to(v), loop_var, body)])
    elif isinstance(iterable, EMove):
        return stream(iterable.e, loop_var, body)
    else:
        assert is_collection(iterable.type), repr(iterable)
        setup, e = simplify_and_optimize_expression(iterable)
        return seq([setup, SForEach(loop_var, e, simplify_and_optimize(body))])
Exemplo n.º 5
0
def possibly_useful_nonrecursive(
    solver,
    e: Exp,
    context: Context,
    pool=RUNTIME_POOL,
    assumptions: Exp = ETRUE,
    ops: [Op] = ()) -> bool:
    """Heuristic filter to ignore expressions that are almost certainly useless."""

    state_vars = OrderedSet(v for v, p in context.vars() if p == STATE_POOL)
    args = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL)
    assumptions = EAll([assumptions, context.path_condition()])
    at_runtime = pool == RUNTIME_POOL

    h = extension_handler(type(e))
    if h is not None:
        res = h.possibly_useful(e, context, pool, assumptions, ops, solver)
        if not res:
            return res

    if isinstance(e, EStateVar) and not free_vars(e.e):
        return No("constant value in state position")
    if (isinstance(e, EDropFront)
            or isinstance(e, EDropBack)) and not at_runtime:
        return No("EDrop* in state position")
    if not allow_big_sets.value and isinstance(e, EFlatMap) and not at_runtime:
        return No("EFlatMap in state position")
    if not allow_int_arithmetic_state.value and not at_runtime and isinstance(
            e, EBinOp) and e.type == INT:
        return No("integer arithmetic in state position")
    if is_collection(e.type) and not is_scalar(e.type.elem_type):
        return No("collection of nonscalar: e {}\n elem_type: {}\n".format(
            e, e.type.elem_type))
    if isinstance(e.type, TMap) and not is_scalar(e.type.k):
        return No("bad key type {}".format(pprint(e.type.k)))
    if isinstance(e.type, TMap) and isinstance(e.type.v, TMap):
        return No("map to map")
    # This check is probably a bad idea: whether `the` is legal may depend on
    # the contex that the expression is embedded within, so we can't skip it
    # during synthesis just because it looks invalid now.
    # if isinstance(e, EUnaryOp) and e.op == UOp.The:
    #     len = EUnaryOp(UOp.Length, e.e).with_type(INT)
    #     if not valid(EImplies(assumptions, EBinOp(len, "<=", ENum(1).with_type(INT)).with_type(BOOL))):
    #         return No("illegal application of 'the': could have >1 elems")
    if not at_runtime and isinstance(
            e, EBinOp) and e.op == "-" and is_collection(e.type):
        return No("collection subtraction in state position")
    # if not at_runtime and isinstance(e, ESingleton):
    #     return No("singleton in state position")
    if not allow_nonzero_state_constants.value and not at_runtime and isinstance(
            e, ENum) and e.val != 0:
        return No("nonzero integer constant in state position")
    if not allow_binop_state.value and at_runtime and isinstance(
            e, EStateVar) and isinstance(e.e, EBinOp) and is_scalar(
                e.e.e1.type) and is_scalar(e.e.e2.type):
        return No(
            "constant-time binary operator {!r} in state position".format(
                e.e.op))
    if not allow_conditional_state.value and not at_runtime and isinstance(
            e, ECond):
        return No("conditional in state position")
    if isinstance(e, EMakeMap2) and isinstance(e.e, EEmptyList):
        return No("trivially empty map")
    if isinstance(e, EMakeMap2) and isinstance(e.e, ESingleton):
        return No("really tiny map")
    if not at_runtime and (isinstance(e, EArgMin) or isinstance(e, EArgMax)):
        # Cozy has no way to efficiently implement mins/maxes when more than
        # one element may leave the collection.
        from cozy.state_maintenance import mutate
        for op in ops:
            elems = e.e
            elems_prime = mutate(elems, op.body)
            formula = EAll([assumptions] + list(op.assumptions) + [
                EGt(
                    ELen(
                        EBinOp(elems, "-", elems_prime).with_type(elems.type)),
                    ONE)
            ])
            if solver.satisfiable(formula):
                return No(
                    "more than one element might be removed during {}".format(
                        op.name))
    if not allow_peels.value and not at_runtime and isinstance(e, EFilter):
        # catch "peels": removal of zero or one elements
        if solver.valid(
                EImplies(
                    assumptions,
                    ELe(
                        ELen(
                            EFilter(
                                e.e,
                                ELambda(e.predicate.arg, ENot(
                                    e.predicate.body))).with_type(e.type)),
                        ONE))):
            return No("filter is a peel")
    if not allow_big_maps.value and not at_runtime and isinstance(
            e, EMakeMap2) and is_collection(e.type.v):
        all_collections = [sv for sv in state_vars if is_collection(sv.type)]
        total_size = ENum(0).with_type(INT)
        for c in all_collections:
            total_size = EBinOp(total_size, "+",
                                EUnaryOp(UOp.Length,
                                         c).with_type(INT)).with_type(INT)
        my_size = EUnaryOp(
            UOp.Length,
            EFlatMap(
                EUnaryOp(UOp.Distinct, e.e).with_type(e.e.type),
                e.value_function).with_type(e.type.v)).with_type(INT)
        s = EImplies(assumptions,
                     EBinOp(total_size, ">=", my_size).with_type(BOOL))
        if not solver.valid(s):
            return No("non-polynomial-sized map")

    return True
Exemplo n.º 6
0
 def build_lambdas(bag, pool, body_size):
     v = fresh_var(bag.type.elem_type,
                   omit=set(v for v, p in context.vars()))
     inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool)
     for lam_body in self.enumerate(inner_context, body_size, pool):
         yield ELambda(v, lam_body)
Exemplo n.º 7
0
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm:
    """Remove "bad" expressions from `s`.

    This procedure returns a statement new_s that replaces every expression in
    `s` where `exp_is_bad` returns True with a fresh variable.  After running,
    `decls_out` contains definitions for the fresh variables so that the whole
    statement

        decls_out; new_s

    should return the same result as `s`.
    """
    def pull(e: Exp) -> Exp:
        """Pull an expression into a temporary.

        Creates a fresh variable for `e`, writes a declaration into `decls_out`,
        and returns the fresh variable.
        """
        if exp_is_bad(e):
            v = fresh_var(e.type)
            decls_out.append(SDecl(v, e))
            return v
        return e

    if isinstance(s, SNoOp):
        return s
    if isinstance(s, SSeq):
        s1 = pull_temps(s.s1, decls_out, exp_is_bad)
        s2 = pull_temps(s.s2, decls_out, exp_is_bad)
        return SSeq(s1, s2)
    if isinstance(s, SDecl):
        return SDecl(s.var, pull(s.val))
    if isinstance(s, SIf):
        cond = pull(s.cond)
        s1 = pull_temps(s.then_branch, decls_out, exp_is_bad)
        s2 = pull_temps(s.else_branch, decls_out, exp_is_bad)
        return SIf(cond, s1, s2)
    if isinstance(s, SForEach):
        bag = pull(s.iter)
        d_tmp = []
        body = pull_temps(s.body, d_tmp, exp_is_bad)
        to_fix, ok = partition(d_tmp, lambda d: s.loop_var in free_vars(d.val))
        decls_out.extend(ok)
        for d in to_fix:
            v = d.var
            mt = TMap(s.loop_var.type, v.type)
            m = EMakeMap2(bag, ELambda(s.loop_var, d.val)).with_type(mt)
            mv = fresh_var(m.type)
            md = SDecl(mv, m)
            decls_out.append(md)
            body = subst(body,
                         {v.id: EMapGet(mv, s.loop_var).with_type(v.type)})
        return SForEach(s.loop_var, bag, body)
    if isinstance(s, SAssign):
        return SAssign(s.lhs, pull(s.rhs))
    if isinstance(s, SCall):
        return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args))
    if isinstance(s, SMapDel):
        return SMapDel(s.map, pull(s.key))
    if isinstance(s, SMapPut):
        return SMapPut(s.map, pull(s.key), pull(s.value))
    if isinstance(s, SMapUpdate):
        key = pull(s.key)
        d_tmp = []
        change = pull_temps(s.change, d_tmp, exp_is_bad)
        for d in d_tmp:
            if s.val_var in free_vars(d.val):
                decls_out.append(
                    SDecl(
                        d.var,
                        subst(
                            d.val, {
                                s.val_var.id:
                                EMapGet(s.map, key).with_type(s.val_var.type)
                            })))
            else:
                decls_out.append(d)
        return SMapUpdate(s.map, key, s.val_var, change)
    raise NotImplementedError(s)