Example #1
0
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm:
    def pull(e: Exp) -> Exp:
        if exp_is_bad(e):
            v = fresh_var(e.type)
            decls_out.append(SDecl(v.id, e))
            return v
        return e

    if isinstance(s, SNoOp):
        return s
    if isinstance(s, SSeq):
        s1 = pull_temps(s.s1, decls_out, exp_is_bad)
        s2 = pull_temps(s.s2, decls_out, exp_is_bad)
        return SSeq(s1, s2)
    if isinstance(s, SIf):
        cond = pull(s.cond)
        s1 = pull_temps(s.then_branch, decls_out, exp_is_bad)
        s2 = pull_temps(s.else_branch, decls_out, exp_is_bad)
        return SIf(cond, s1, s2)
    if isinstance(s, SForEach):
        bag = pull(s.iter)
        d_tmp = []
        body = pull_temps(s.body, d_tmp, exp_is_bad)
        to_fix, ok = partition(d_tmp, lambda d: s.id in free_vars(d.val))
        decls_out.extend(ok)
        for d in to_fix:
            v = EVar(d.id).with_type(d.val.type)
            mt = TMap(s.id.type, v.type)
            m = EMakeMap2(bag, ELambda(s.id, d.val)).with_type(mt)
            mv = fresh_var(m.type)
            md = SDecl(mv.id, m)
            decls_out.append(md)
            body = subst(body, {v.id: EMapGet(mv, s.id).with_type(v.type)})
        return SForEach(s.id, bag, body)
    if isinstance(s, SAssign):
        return SAssign(s.lhs, pull(s.rhs))
    if isinstance(s, SCall):
        return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args))
    if isinstance(s, SMapDel):
        return SMapDel(s.map, pull(s.key))
    if isinstance(s, SMapPut):
        return SMapPut(s.map, pull(s.key), pull(s.value))
    if isinstance(s, SMapUpdate):
        key = pull(s.key)
        d_tmp = []
        change = pull_temps(s.change, d_tmp, exp_is_bad)
        for d in d_tmp:
            if s.val_var in free_vars(d.val):
                decls_out.append(
                    SDecl(
                        d.id,
                        subst(
                            d.val, {
                                s.val_var.id:
                                EMapGet(s.map, key).with_type(s.val_var.type)
                            })))
            else:
                decls_out.append(d)
        return SMapUpdate(s.map, key, s.val_var, change)
    raise NotImplementedError(s)
Example #2
0
def simplify_and_optimize_lvalue(e : Exp) -> (Stm, Exp):
    """Helper for simplify_and_optimize.

    Input:
      e - an L-value expression to optimize

    Output:
      A pair (s, e') such that executing s and then evaluating e' is the same
      as evaluating the input expression.

    Unlike `simplify_and_optimize_expression`, this function preserves the
    meaning of `e` as an L-value.  For instance, this function will not replace
    `e` with a fresh variable.
    """
    assert is_lvalue(e), "not an L-value: {}".format(pprint(e))
    if isinstance(e, EVar):
        return (SNoOp(), e)
    if isinstance(e, EGetField):
        if isinstance(e.e.type, THandle):
            setup, handle = simplify_and_optimize_expression(e.e)
            return (setup, EGetField(handle, e.field_name).with_type(e.type))
        else:
            setup, lvalue = simplify_and_optimize_lvalue(e.e)
            return (setup, EGetField(lvalue, e.field_name).with_type(e.type))
    if isinstance(e, EMapGet):
        mapsetup, map = simplify_and_optimize_lvalue(e.map)
        keysetup, key = simplify_and_optimize_expression(e.key)
        return (seq([mapsetup, keysetup]), EMapGet(map, key).with_type(e.type))
    if isinstance(e, EListGet):
        listsetup, list_lvalue = simplify_and_optimize_lvalue(e.e)
        indexsetup, index_exp = simplify_and_optimize_expression(e.index)
        return (seq([listsetup, indexsetup]), EListGet(list_lvalue, index_exp).with_type(e.type))
    raise NotImplementedError(repr(e))
Example #3
0
 def visit_SMapUpdate(self, update):
     if isinstance(update.change, SNoOp):
         return ""
     # TODO: liveness analysis to avoid this map lookup in some cases
     self.declare(update.val_var, EMapGet(update.map, update.key).with_type(update.map.type.v))
     map = self.visit(update.map) # TODO: deduplicate
     key = self.visit(update.key) # TODO: deduplicate
     self.visit(update.change)
     self.begin_statement()
     self.write("{map}.put({key}, {val});\n".format(map=map, key=key, val=update.val_var.id))
     self.end_statement()
Example #4
0
def stream(iterable : Exp, loop_var : EVar, body : Stm) -> Stm:
    """Convert an iterable expression to a streaming operation.

    Input:
      iterable - an expression with an iterable type (Bag, Set, or List), not
        yet optimized
      loop_var - a variable to use as the loop variable
      body - a statement to run on that variable, not yet optimized

    Output:
      A statement equivalent to
        for (loop_var in iterable) { body }
      that eliminates as many intermediate collections and objects as possible.

    NOTE: The output of function will not be correct if the body modifies any
    free variable in the iterable expression or writes to any pointers that
    are read by the iterable expression.

    Generating code for the expression

        Map {func} (Filter {predicate} big_collection)

    might create two new collections as large as `big_collection`: one to hold
    the result of the filter and one to hold the result of the map.  If all the
    code needs to do is to iterate over the result, then there is no reason to
    make the two new collections.

    This function is mutually recursive with `simplify_and_optimize`, so any
    transformations performed by that method are also applied to the output of
    this one.
    """

    if isinstance(iterable, EEmptyList):
        return SNoOp()
    elif isinstance(iterable, ESingleton):
        setup, value = simplify_and_optimize_expression(iterable.e)
        # SScoped because if the iterable is e.g. [x] + [y], then the body
        # might be appear in the same block twice.  If the body declares any
        # variables, that will cause problems in languages like Java or C++.
        return seq([setup, SScoped(re_use(value, loop_var, simplify_and_optimize(body)))])
    elif isinstance(iterable, ECond):
        cond_setup, cond = simplify_and_optimize_expression(iterable.cond)
        return seq([
            cond_setup,
            SIf(cond,
                stream(iterable.then_branch, loop_var, body),
                stream(iterable.else_branch, loop_var, body))])
    elif isinstance(iterable, EUnaryOp) and iterable.op == UOp.Distinct:
        tmp = fresh_var(TSet(iterable.type.elem_type), "distinct_elems")
        return seq([
            SDecl(tmp, EEmptyList().with_type(tmp.type)),
            stream(iterable.e, loop_var, SIf(
                ENot(EBinOp(loop_var, BOp.In, tmp).with_type(BOOL)),
                seq([body, SCall(tmp, "add", [loop_var])]),
                SNoOp()))])
    elif isinstance(iterable, EBinOp) and iterable.op == "+":
        return seq([
            stream(iterable.e1, loop_var, body),
            stream(iterable.e2, loop_var, body)])
    elif isinstance(iterable, EBinOp) and iterable.op == "-":
        if is_hashable(iterable.type.elem_type):
            h_setup, h = histogram(iterable.e2)
            val_ref = fresh_var(INT, "count")
            return seq([
                simplify_and_optimize(h_setup),
                stream(
                    iterable.e1,
                    loop_var,
                    SIf(EGt(EMapGet(h, loop_var).with_type(INT), ZERO),
                        SMapUpdate(h, loop_var, val_ref, SAssign(val_ref, EBinOp(val_ref, "-", ONE).with_type(INT))),
                        body))])
        else:
            rhs = fresh_var(iterable.e2.type, "bag_subtraction_right")
            return seq([
                simplify_and_optimize(SDecl(rhs, iterable.e2)),
                stream(
                    iterable.e1,
                    loop_var,
                    SIf(EIn(loop_var, rhs),
                        SCall(rhs, "remove", (loop_var,)),
                        body))])
    elif isinstance(iterable, EFilter):
        return stream(
            EFlatMap(iterable.e, ELambda(iterable.predicate.arg,
                ECond(iterable.predicate.body,
                    ESingleton(iterable.predicate.arg).with_type(iterable.type),
                    EEmptyList().with_type(iterable.type)).with_type(iterable.type))).with_type(iterable.type),
            loop_var,
            body)
    elif isinstance(iterable, EMap):
        return stream(
            EFlatMap(iterable.e, ELambda(iterable.transform_function.arg,
                ESingleton(iterable.transform_function.body).with_type(iterable.type))).with_type(iterable.type),
            loop_var,
            body)
    elif isinstance(iterable, EFlatMap):
        inner_loop_var = fresh_var(
            iterable.transform_function.arg.type,
            iterable.transform_function.arg.id)
        return stream(
            iterable.e,
            inner_loop_var,
            stream(iterable.transform_function.apply_to(inner_loop_var), loop_var, body))
    elif isinstance(iterable, EListSlice):
        elem_type = iterable.type.elem_type
        l = fresh_var(iterable.e.type, "list")
        s = fresh_var(INT, "start")
        e = fresh_var(INT, "end")
        return simplify_and_optimize(seq([
            SDecl(l, iterable.e),
            SDecl(s, max_of(iterable.start, ZERO)),
            SDecl(e, min_of(iterable.end, ELen(l))),
            SWhile(ELt(s, e), seq([
                SDecl(loop_var, EListGet(l, s).with_type(elem_type)),
                body,
                SAssign(s, EBinOp(s, "+", ONE).with_type(INT))]))]))
    elif isinstance(iterable, ELet):
        v = fresh_var(
            iterable.body_function.arg.type,
            iterable.body_function.arg.id)
        return seq([
            simplify_and_optimize(SDecl(v, iterable.e)),
            stream(iterable.body_function.apply_to(v), loop_var, body)])
    elif isinstance(iterable, EMove):
        return stream(iterable.e, loop_var, body)
    else:
        assert is_collection(iterable.type), repr(iterable)
        setup, e = simplify_and_optimize_expression(iterable)
        return seq([setup, SForEach(loop_var, e, simplify_and_optimize(body))])
Example #5
0
    def _enumerate_core(self, context: Context, size: int,
                        pool: Pool) -> [Exp]:
        """Build new expressions of the given size.

        Arguments:
            context : a Context object describing the vars in scope
            size    : size of expressions to enumerate; each expression in
                      the output will have this size
            pool    : pool to enumerate

        This function is not cached.  Clients should call `enumerate` instead.

        This function tries to be a clean description of the Cozy grammar.  It
        does not concern itself with deduplication (which is handled
        efficiently by equivalence class deduplication).
        """

        if size < 0:
            return

        if size == 0:
            for e in LITERALS:
                yield e

            all_interesting_types = OrderedSet(self.hint_types)
            for v, _ in context.vars():
                all_interesting_types |= all_types(v.type)
            for t in all_interesting_types:
                l = construct_value(t)
                if l not in LITERALS:
                    yield l

            for (v, p) in context.vars():
                if p == pool:
                    yield v
            for (e, ctx, p) in self.hints:
                if p == pool:
                    fvs = free_vars(e)
                    if ctx.alpha_equivalent(context.generalize(fvs)):
                        yield context.adapt(e, ctx, e_fvs=fvs)
            return

        if not do_enumerate.value:
            return

        def build_lambdas(bag, pool, body_size):
            v = fresh_var(bag.type.elem_type,
                          omit=set(v for v, p in context.vars()))
            inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool)
            for lam_body in self.enumerate(inner_context, body_size, pool):
                yield ELambda(v, lam_body)

        # Load all smaller expressions in this context and pool.
        # cache[S] contains expressions of size S in this context and pool.
        cache = [list(self.enumerate(context, sz, pool)) for sz in range(size)]

        # Enable use of a state-pool expression at runtime
        if pool == RUNTIME_POOL:
            for e in self.enumerate(context.root(), size - 1, STATE_POOL):
                yield EStateVar(e).with_type(e.type)

        # Arity-1 expressions
        for e in cache[size - 1]:
            if is_collection(e.type):
                elem_type = e.type.elem_type

                # This method of generating EEmptyList() ensures that we visit
                # empty collections of all possible types.
                yield EEmptyList().with_type(e.type)

                if is_numeric(elem_type):
                    yield EUnaryOp(UOp.Sum, e).with_type(elem_type)

                yield EUnaryOp(UOp.Length, e).with_type(INT)
                yield EUnaryOp(UOp.Empty, e).with_type(BOOL)
                yield EUnaryOp(UOp.Exists, e).with_type(BOOL)
                yield EUnaryOp(UOp.The, e).with_type(elem_type)
                yield EUnaryOp(UOp.Distinct, e).with_type(e.type)
                yield EUnaryOp(UOp.AreUnique, e).with_type(BOOL)

                if elem_type == BOOL:
                    yield EUnaryOp(UOp.Any, e).with_type(BOOL)
                    yield EUnaryOp(UOp.All, e).with_type(BOOL)

            yield ESingleton(e).with_type(TBag(e.type))

            if isinstance(e.type, TRecord):
                for (f, t) in e.type.fields:
                    yield EGetField(e, f).with_type(t)

            if isinstance(e.type, THandle):
                yield EGetField(e, "val").with_type(e.type.value_type)

            if isinstance(e.type, TTuple):
                for n in range(len(e.type.ts)):
                    yield ETupleGet(e, n).with_type(e.type.ts[n])

            if e.type == BOOL:
                yield EUnaryOp(UOp.Not, e).with_type(BOOL)

            if is_numeric(e.type):
                yield EUnaryOp("-", e).with_type(e.type)

            if isinstance(e.type, TMap):
                yield EMapKeys(e).with_type(TBag(e.type.k))

        # Arity-2 expressions
        for (sz1, sz2) in pick_to_sum(2, size - 1):
            # sz1 + sz2 = size - 1
            for e1 in cache[sz1]:
                t = e1.type

                if is_numeric(t):
                    for a2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, "+", a2).with_type(t)
                        yield EBinOp(e1, "-", a2).with_type(t)

                if is_ordered(t):
                    for a2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, ">", a2).with_type(BOOL)
                        yield EBinOp(e1, "<", a2).with_type(BOOL)
                        yield EBinOp(e1, ">=", a2).with_type(BOOL)
                        yield EBinOp(e1, "<=", a2).with_type(BOOL)

                if t == BOOL:
                    for a2 in of_type(cache[sz2], BOOL):
                        yield EBinOp(e1, BOp.And, a2).with_type(BOOL)
                        yield EBinOp(e1, BOp.Or, a2).with_type(BOOL)
                        # Cozy supports the implication operator "=>", but this
                        # function does not enumerate it because
                        #  - (a => b) is equivalent to ((not a) or b)
                        #  - there isn't an implication operator in any of our
                        #    current target languages, so we would need to
                        #    desugar it to ((not a) or b) anyway.

                if not isinstance(t, TMap):
                    for a2 in of_type(cache[sz2], t):
                        yield EEq(e1, a2)
                        yield EBinOp(e1, "!=", a2).with_type(BOOL)

                if isinstance(t, TMap):
                    for k in of_type(cache[sz2], t.k):
                        yield EMapGet(e1, k).with_type(t.v)
                        yield EHasKey(e1, k).with_type(BOOL)

                if isinstance(t, TList):
                    for i in of_type(cache[sz2], INT):
                        yield EListGet(e1, i).with_type(e1.type.elem_type)

                if is_collection(t):
                    elem_type = t.elem_type
                    for e2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, "+", e2).with_type(t)
                        yield EBinOp(e1, "-", e2).with_type(t)
                    for e2 in of_type(cache[sz2], elem_type):
                        yield EBinOp(e2, BOp.In, e1).with_type(BOOL)
                    for f in build_lambdas(e1, pool, sz2):
                        body_type = f.body.type
                        yield EMap(e1, f).with_type(TBag(body_type))
                        if body_type == BOOL:
                            yield EFilter(e1, f).with_type(t)
                        if is_numeric(body_type):
                            yield EArgMin(e1, f).with_type(elem_type)
                            yield EArgMax(e1, f).with_type(elem_type)
                        if is_collection(body_type):
                            yield EFlatMap(e1, f).with_type(
                                TBag(body_type.elem_type))

                        if pool == STATE_POOL and is_hashable(elem_type):
                            yield EMakeMap2(e1, f).with_type(
                                TMap(elem_type, body_type))

                e1_singleton = ESingleton(e1).with_type(TBag(e1.type))
                for f in build_lambdas(e1_singleton, pool, sz2):
                    yield ELet(e1, f).with_type(f.body.type)

        # Arity-3 expressions
        for (sz1, sz2, sz3) in pick_to_sum(3, size - 1):
            # sz1 + sz2 + sz3 = size - 1
            for e1 in cache[sz1]:
                if e1.type == BOOL:
                    cond = e1
                    for then_branch in cache[sz2]:
                        for else_branch in of_type(cache[sz3],
                                                   then_branch.type):
                            yield ECond(cond, then_branch,
                                        else_branch).with_type(
                                            then_branch.type)
                if isinstance(e1.type, TList):
                    for start in of_type(cache[sz2], INT):
                        for end in of_type(cache[sz3], INT):
                            yield EListSlice(e1, start, end).with_type(e1.type)
                            # It is not necessary to create slice expressions of
                            # the form a[:i] or a[i:].  Those are desugared
                            # after parsing to a[0:i] and a[i:len(a)]
                            # respectively, and Cozy is perfectly capable of
                            # discovering these expanded forms as well.

        for h in all_extension_handlers():
            yield from h.enumerate(context, size, pool, self.enumerate,
                                   build_lambdas)
Example #6
0
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm:
    """Remove "bad" expressions from `s`.

    This procedure returns a statement new_s that replaces every expression in
    `s` where `exp_is_bad` returns True with a fresh variable.  After running,
    `decls_out` contains definitions for the fresh variables so that the whole
    statement

        decls_out; new_s

    should return the same result as `s`.
    """
    def pull(e: Exp) -> Exp:
        """Pull an expression into a temporary.

        Creates a fresh variable for `e`, writes a declaration into `decls_out`,
        and returns the fresh variable.
        """
        if exp_is_bad(e):
            v = fresh_var(e.type)
            decls_out.append(SDecl(v, e))
            return v
        return e

    if isinstance(s, SNoOp):
        return s
    if isinstance(s, SSeq):
        s1 = pull_temps(s.s1, decls_out, exp_is_bad)
        s2 = pull_temps(s.s2, decls_out, exp_is_bad)
        return SSeq(s1, s2)
    if isinstance(s, SDecl):
        return SDecl(s.var, pull(s.val))
    if isinstance(s, SIf):
        cond = pull(s.cond)
        s1 = pull_temps(s.then_branch, decls_out, exp_is_bad)
        s2 = pull_temps(s.else_branch, decls_out, exp_is_bad)
        return SIf(cond, s1, s2)
    if isinstance(s, SForEach):
        bag = pull(s.iter)
        d_tmp = []
        body = pull_temps(s.body, d_tmp, exp_is_bad)
        to_fix, ok = partition(d_tmp, lambda d: s.loop_var in free_vars(d.val))
        decls_out.extend(ok)
        for d in to_fix:
            v = d.var
            mt = TMap(s.loop_var.type, v.type)
            m = EMakeMap2(bag, ELambda(s.loop_var, d.val)).with_type(mt)
            mv = fresh_var(m.type)
            md = SDecl(mv, m)
            decls_out.append(md)
            body = subst(body,
                         {v.id: EMapGet(mv, s.loop_var).with_type(v.type)})
        return SForEach(s.loop_var, bag, body)
    if isinstance(s, SAssign):
        return SAssign(s.lhs, pull(s.rhs))
    if isinstance(s, SCall):
        return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args))
    if isinstance(s, SMapDel):
        return SMapDel(s.map, pull(s.key))
    if isinstance(s, SMapPut):
        return SMapPut(s.map, pull(s.key), pull(s.value))
    if isinstance(s, SMapUpdate):
        key = pull(s.key)
        d_tmp = []
        change = pull_temps(s.change, d_tmp, exp_is_bad)
        for d in d_tmp:
            if s.val_var in free_vars(d.val):
                decls_out.append(
                    SDecl(
                        d.var,
                        subst(
                            d.val, {
                                s.val_var.id:
                                EMapGet(s.map, key).with_type(s.val_var.type)
                            })))
            else:
                decls_out.append(d)
        return SMapUpdate(s.map, key, s.val_var, change)
    raise NotImplementedError(s)