Exemple #1
0
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm:
    def pull(e: Exp) -> Exp:
        if exp_is_bad(e):
            v = fresh_var(e.type)
            decls_out.append(SDecl(v.id, e))
            return v
        return e

    if isinstance(s, SNoOp):
        return s
    if isinstance(s, SSeq):
        s1 = pull_temps(s.s1, decls_out, exp_is_bad)
        s2 = pull_temps(s.s2, decls_out, exp_is_bad)
        return SSeq(s1, s2)
    if isinstance(s, SIf):
        cond = pull(s.cond)
        s1 = pull_temps(s.then_branch, decls_out, exp_is_bad)
        s2 = pull_temps(s.else_branch, decls_out, exp_is_bad)
        return SIf(cond, s1, s2)
    if isinstance(s, SForEach):
        bag = pull(s.iter)
        d_tmp = []
        body = pull_temps(s.body, d_tmp, exp_is_bad)
        to_fix, ok = partition(d_tmp, lambda d: s.id in free_vars(d.val))
        decls_out.extend(ok)
        for d in to_fix:
            v = EVar(d.id).with_type(d.val.type)
            mt = TMap(s.id.type, v.type)
            m = EMakeMap2(bag, ELambda(s.id, d.val)).with_type(mt)
            mv = fresh_var(m.type)
            md = SDecl(mv.id, m)
            decls_out.append(md)
            body = subst(body, {v.id: EMapGet(mv, s.id).with_type(v.type)})
        return SForEach(s.id, bag, body)
    if isinstance(s, SAssign):
        return SAssign(s.lhs, pull(s.rhs))
    if isinstance(s, SCall):
        return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args))
    if isinstance(s, SMapDel):
        return SMapDel(s.map, pull(s.key))
    if isinstance(s, SMapPut):
        return SMapPut(s.map, pull(s.key), pull(s.value))
    if isinstance(s, SMapUpdate):
        key = pull(s.key)
        d_tmp = []
        change = pull_temps(s.change, d_tmp, exp_is_bad)
        for d in d_tmp:
            if s.val_var in free_vars(d.val):
                decls_out.append(
                    SDecl(
                        d.id,
                        subst(
                            d.val, {
                                s.val_var.id:
                                EMapGet(s.map, key).with_type(s.val_var.type)
                            })))
            else:
                decls_out.append(d)
        return SMapUpdate(s.map, key, s.val_var, change)
    raise NotImplementedError(s)
Exemple #2
0
 def visit_EArrayIndexOf(self, e):
     if isinstance(e.a, EVar): pass
     elif isinstance(e.a, ETupleGet) and isinstance(e.a.e, EVar): pass
     else:
         raise NotImplementedError(
             "finding index of non-var array"
         )  # TODO: make this fast when this is false
     it = self.fv(
         TNative("{}::const_iterator".format(
             self.visit(e.a.type, "").strip())), "cursor")
     res = self.fv(INT, "index")
     self.visit(
         seq([
             SDecl(
                 it,
                 EEscape("std::find({a}.begin(), {a}.end(), {x})",
                         ("a", "x"), (e.a, e.x)).with_type(it.type)),
             SDecl(
                 res,
                 ECond(
                     EEq(
                         it,
                         EEscape("{a}.end()", ("a", ),
                                 (e.a, )).with_type(it.type)),
                     ENum(-1).with_type(INT),
                     EEscape("({it} - {a}.begin())", (
                         "it",
                         "a",
                     ), (
                         it,
                         e.a,
                     )).with_type(INT)).with_type(INT))
         ]))
     return res.id
Exemple #3
0
    def visit_EUnaryOp(self, e):
        op = e.op
        if op == UOp.Distinct:
            return self.visit_iterable(e)
        elif op == UOp.The:
            return self.find_one(e.e)
        elif op == UOp.Sum:
            sum_var = fresh_var(e.type, "sum")
            loop_var = fresh_var(e.e.type.elem_type, "x")
            self.stms.append(simplify_and_optimize(seq([
                SDecl(sum_var, ENum(0).with_type(e.type)),
                SForEach(loop_var, e.e,
                    SAssign(sum_var, EBinOp(sum_var, "+", loop_var).with_type(INT)))])))
            return sum_var
        elif op == UOp.Length:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Sum, EMap(e.e, ELambda(arg, ONE)).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.All:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Empty, EFilter(e.e, ELambda(arg, ENot(arg))).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.Any:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Exists, EFilter(e.e, ELambda(arg, arg)).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.Empty:
            iterable = e.e
            v = fresh_var(BOOL, "v")
            label = fresh_name("label")
            x = fresh_var(iterable.type.elem_type, "x")
            decl = SDecl(v, ETRUE)
            find = SEscapableBlock(label,
                SForEach(x, iterable, seq([
                    SAssign(v, EFALSE),
                    SEscapeBlock(label)])))
            self.stms.append(simplify_and_optimize(seq([decl, find])))
            return v
        elif op == UOp.Exists:
            return self.visit(ENot(EUnaryOp(UOp.Empty, e.e).with_type(BOOL)))
        # elif op == UOp.AreUnique:
        #     s = fresh_var(TSet(e.e.type.elem_type), "unique_elems")
        #     u = fresh_var(BOOL, "is_unique")
        #     x = fresh_var(e.e.type.elem_type)
        #     label = fresh_name("label")
        #     self.visit(seq([
        #         SDecl(s, EEmptyList().with_type(s.type)),
        #         SDecl(u, ETRUE),
        #         SEscapableBlock(label,
        #             SForEach(x, e.e,
        #                 SIf(EEscape("{s}.find({x}) != {s}.end()", ("s", "x"), (s, x)).with_type(BOOL),
        #                     seq([SAssign(u, EFALSE), SEscapeBlock(label)]),
        #                     SEscape("{indent}{s}.insert({x});\n", ("s", "x"), (s, x)))))]))
        #     return u.id

        return self.visit_Exp(e)
Exemple #4
0
 def visit_SForEach(self, for_each):
     loop_var = for_each.loop_var
     iterable = for_each.iter
     body = for_each.body
     if isinstance(iterable, EMapKeys):
         map = self.visit(iterable.e)
         pair = self.fv(TNative("const auto &"))
         self.begin_statement()
         self.write("for (", self.visit(pair.type, pair.id), " : ", map,
                    ") ")
         with self.block():
             self.visit(
                 SSeq(
                     SDecl(
                         loop_var,
                         EEscape("{p}.first", ("p", ),
                                 (pair, )).with_type(loop_var.type)), body))
         self.end_statement()
         return
     iterable = self.visit(iterable)
     self.begin_statement()
     self.write("for (", self.visit(loop_var.type, loop_var.id), " : ",
                iterable, ") ")
     with self.block():
         self.visit(body)
     self.end_statement()
Exemple #5
0
 def visit_EMakeMap2(self, e):
     res = fresh_var(e.type, "map")
     self.stms.append(SDecl(res, EEmptyMap().with_type(e.type)))
     k = e.value_function.arg
     v = fresh_var(e.type.v, "v")
     self.stms.append(simplify_and_optimize(SForEach(k, e.e,
         SMapUpdate(res, k, v,
             SAssign(v, e.value_function.body)))))
     return EMove(res).with_type(res.type)
Exemple #6
0
 def compute_hash(self, fields: [Exp]) -> Stm:
     hc = self.fv(INT, "hash_code")
     h = self.fv(INT, "hash_code")
     s = SSeq(SDecl(hc,
                    ENum(0).with_type(hc.type)),
              SDecl(h,
                    ENum(0).with_type(h.type)))
     for f in fields:
         s = seq([
             s,
             self.compute_hash_1(h, f),
             SAssign(
                 hc,
                 EEscape("({hc} * 31) ^ ({h})", ("hc", "h"),
                         (hc, h)).with_type(INT))
         ])
     s = SSeq(s, SEscape("{indent}return {e};\n", ("e", ), (hc, )))
     return s
Exemple #7
0
    def pull(e: Exp) -> Exp:
        """Pull an expression into a temporary.

        Creates a fresh variable for `e`, writes a declaration into `decls_out`,
        and returns the fresh variable.
        """
        if exp_is_bad(e):
            v = fresh_var(e.type)
            decls_out.append(SDecl(v, e))
            return v
        return e
Exemple #8
0
 def find_one(self, iterable):
     v = fresh_var(iterable.type.elem_type, "v")
     label = fresh_name("label")
     x = fresh_var(iterable.type.elem_type, "x")
     decl = SDecl(v, evaluation.construct_value(v.type))
     find = SEscapableBlock(label,
         SForEach(x, iterable, seq([
             SAssign(v, x),
             SEscapeBlock(label)])))
     self.stms.append(simplify_and_optimize(seq([decl, find])))
     return v
Exemple #9
0
 def min_or_max(self, op, e, f):
     if isinstance(e, EBinOp) and e.op == "+" and isinstance(e.e1, ESingleton) and isinstance(e.e2, ESingleton):
         # argmin_f ([a] + [b]) ---> f(a) < f(b) ? a : b
         return self.visit(ECond(
             EBinOp(f.apply_to(e.e1.e), op, f.apply_to(e.e2.e)).with_type(BOOL),
             e.e1.e,
             e.e2.e).with_type(e.e1.e.type))
     out = fresh_var(e.type.elem_type, "min" if op == "<" else "max")
     first = fresh_var(BOOL, "first")
     x = fresh_var(e.type.elem_type, "x")
     decl1 = SDecl(out, evaluation.construct_value(out.type))
     decl2 = SDecl(first, ETRUE)
     find = SForEach(x, e,
         SIf(EBinOp(
                 first,
                 BOp.Or,
                 EBinOp(f.apply_to(x), op, f.apply_to(out)).with_type(BOOL)).with_type(BOOL),
             seq([SAssign(first, EFALSE), SAssign(out, x)]),
             SNoOp()))
     self.stms.append(simplify_and_optimize(seq([decl1, decl2, find])))
     return out
Exemple #10
0
def histogram(e : Exp) -> (Stm, EVar):
    """Compute a histogram of the elements in the iterable `e`.

    Returns an unoptimized statement that declares and constructs a histogram
    map and the fresh variable that got declared.
    """
    elem_type = e.type.elem_type
    h = fresh_var(TMap(elem_type, INT), "histogram")
    x = fresh_var(elem_type, "x")
    count = fresh_var(INT, "count")
    stm = seq([
        SDecl(h, EEmptyMap().with_type(h.type)),
        SForEach(x, e,
            SMapUpdate(h, x, count,
                SAssign(count, EBinOp(count, "+", ONE).with_type(INT))))])
    return (stm, h)
Exemple #11
0
 def visit_EBinOp(self, e):
     if e.op in ("+", "-") and is_collection(e.type):
         return self.visit_iterable(e)
     elif e.op == BOp.In and not isinstance(e.e2.type, TSet):
         t = BOOL
         res = fresh_var(t, "found")
         x = fresh_var(e.e1.type, "x")
         label = fresh_name("label")
         self.stms.append(simplify_and_optimize(seq([
             SDecl(res, EFALSE),
             SEscapableBlock(label,
                 SForEach(x, e.e2, SIf(
                     EBinOp(x, "==", e.e1).with_type(BOOL),
                     seq([SAssign(res, ETRUE), SEscapeBlock(label)]),
                     SNoOp())))])))
         return res
     return self.visit_Exp(e)
Exemple #12
0
 def visit_EMapGet(self, e):
     if self.use_trove(e.map.type):
         if self.trovename(e.map.type.v) == "Object" and not isinstance(evaluation.construct_value(e.map.type.v), ENull):
             # Le sigh...
             emap = self.visit(e.map)
             ekey = self.visit(e.key)
             v = self.fv(self.box_if_boolean(e.map.type.v), hint="v")
             self.visit(SDecl(v, EEscape("{emap}.get({ekey})".format(emap=emap, ekey=ekey), [], []).with_type(e.type)))
             return self.visit(ECond(EEq(v, ENull().with_type(v.type)), evaluation.construct_value(e.map.type.v), v).with_type(e.type))
         else:
             # For Trove, defaults are set at construction time
             emap = self.visit(e.map)
             ekey = self.visit(e.key)
             return "{emap}.get({ekey})".format(emap=emap, ekey=ekey)
     else:
         emap = self.visit(e.map)
         ekey = self.visit(e.key)
         edefault = self.visit(evaluation.construct_value(e.type))
         return "{emap}.getOrDefault({ekey}, {edefault})".format(emap=emap, ekey=ekey, edefault=edefault)
Exemple #13
0
 def visit_ELet(self, e):
     value_exp = self.visit(e.e)
     fv = fresh_var(value_exp.type, e.body_function.arg.id)
     self.stms.append(SDecl(fv, value_exp))
     return self.visit(subst(e.body_function.body, { e.body_function.arg.id : fv }))
Exemple #14
0
def simplify_and_optimize(s : Stm) -> Stm:
    """Simplify and optimize a statement.

    Input:
      s - a statement to optimize

    Output:
      A statement that is functionally equivalent to the input.

    This function makes two big transformations:
      - "compile" many kinds of expressions (listed below) to simpler forms so
        that downstream code generation has less work to do
      - avoid creating short-lived intermediate objects (see `stream`)

    Expression types eliminated by this procedure:
      - EMap, EFilter, EFlatMap
      - EArg{Min,Max}
      - unary ops:
        Distinct,
        AreUnique,
        Length, Sum, All, Any,
        Exists, Empty,
        The
      - binary ops:
        In (where the collection is a Bag or List)
        "-" on collections
        "+" on collections
      - EMakeMap2
      - ELet
      - EListSlice
      - EStm
    """
    assert isinstance(s, Stm)

    if isinstance(s, SNoOp):
        return s
    if isinstance(s, SSeq):
        # TODO: while the first basic statement in s1 is an SDecl, we should
        # apply `re_use` to perhaps eliminate or inline the declaration.
        return seq([simplify_and_optimize(s.s1), simplify_and_optimize(s.s2)])
    if isinstance(s, SAssign):
        setup, e = simplify_and_optimize_expression(s.rhs)
        return seq([setup, SAssign(s.lhs, e)])
    if isinstance(s, SReturn):
        setup, e = simplify_and_optimize_expression(s.e)
        return seq([setup, SReturn(e)])
    if isinstance(s, SDecl):
        setup, e = simplify_and_optimize_expression(s.val)
        return seq([setup, SDecl(s.var, e)])
    if isinstance(s, SForEach):
        return stream(s.iter, s.loop_var, s.body)
    if isinstance(s, SEscape):
        return s
    if isinstance(s, SIf):
        setup, test = simplify_and_optimize_expression(s.cond)
        if test == ETRUE:
            return simplify_and_optimize(s.then_branch)
        if test == EFALSE:
            return simplify_and_optimize(s.else_branch)
        return seq([setup, SIf(test, simplify_and_optimize(s.then_branch), simplify_and_optimize(s.else_branch))])
    if isinstance(s, SWhile):
        setup, cond = simplify_and_optimize_expression(s.e)
        if setup != SNoOp():
            # This is a problem because we don't want to duplicate the setup
            # condition.
            # TODO: introduce an SEscapableBlock/SEscapeBlock to do it right
            raise ValueError("oops! setup for condition {} is very long:\n{}".format(pprint(s.e), pprint(setup)))
        return SWhile(cond, simplify_and_optimize(s.body))
    if isinstance(s, SScoped):
        return SScoped(simplify_and_optimize(s.s))
    if isinstance(s, SMapUpdate):
        # TODO: optimize s.map & s.key
        # TODO: s.map must be optimized as an lvalue
        mapsetup, map = simplify_and_optimize_lvalue(s.map)
        keysetup, key = simplify_and_optimize_expression(s.key)
        return seq([
            mapsetup,
            keysetup,
            SMapUpdate(map, key, s.val_var, simplify_and_optimize(s.change))])
    if isinstance(s, SMapDel):
        mapsetup, map = simplify_and_optimize_lvalue(s.map)
        keysetup, key = simplify_and_optimize_expression(s.key)
        return seq([
            mapsetup,
            keysetup,
            SMapDel(map, key)])
    if isinstance(s, SCall):
        setups, args = zip(*(simplify_and_optimize_expression(a) for a in s.args))
        return seq(list(setups) + [SCall(s.target, s.func, tuple(args))])
    if isinstance(s, SEscapableBlock):
        return SEscapableBlock(s.label, simplify_and_optimize(s.body))
    if isinstance(s, SEscapeBlock):
        return s
    if isinstance(s, SArrayAlloc):
        setup, cap = simplify_and_optimize_expression(s.capacity)
        return seq([setup, SArrayAlloc(s.a, cap)])
    if isinstance(s, SArrayReAlloc):
        setup, cap = simplify_and_optimize_expression(s.new_capacity)
        return seq([setup, SArrayReAlloc(s.a, cap)])
    if isinstance(s, SEnsureCapacity):
        setup, cap = simplify_and_optimize_expression(s.capacity)
        return seq([setup, SEnsureCapacity(s.a, cap)])
    if isinstance(s, SSwap):
        # TODO: if we want to optimize the operands we will need a special
        # procedure that optimizes lvalues while preserving meaning... same
        # goes for SAssign case above.
        return s
    if isinstance(s, SSwitch):
        setup, e = simplify_and_optimize_expression(s.e)
        new_cases = [(case, simplify_and_optimize(stm)) for (case, stm) in s.cases]
Exemple #15
0
 def visit_SSwap(self, s):
     tmp = self.fv(s.lval1.type, "swap_tmp")
     return self.visit(seq([
         SDecl(tmp, s.lval1),
         SAssign(s.lval1, s.lval2),
         SAssign(s.lval2, tmp)]))
Exemple #16
0
def re_use(value : Exp, v : EVar, s : Stm) -> Stm:
    if efficiently_reuseable(value) or count_occurrences_of_free_var(s, v) <= 1:
        return subst(s, {v.id : value})
    return seq([SDecl(v, value), s])
Exemple #17
0
def stream(iterable : Exp, loop_var : EVar, body : Stm) -> Stm:
    """Convert an iterable expression to a streaming operation.

    Input:
      iterable - an expression with an iterable type (Bag, Set, or List), not
        yet optimized
      loop_var - a variable to use as the loop variable
      body - a statement to run on that variable, not yet optimized

    Output:
      A statement equivalent to
        for (loop_var in iterable) { body }
      that eliminates as many intermediate collections and objects as possible.

    NOTE: The output of function will not be correct if the body modifies any
    free variable in the iterable expression or writes to any pointers that
    are read by the iterable expression.

    Generating code for the expression

        Map {func} (Filter {predicate} big_collection)

    might create two new collections as large as `big_collection`: one to hold
    the result of the filter and one to hold the result of the map.  If all the
    code needs to do is to iterate over the result, then there is no reason to
    make the two new collections.

    This function is mutually recursive with `simplify_and_optimize`, so any
    transformations performed by that method are also applied to the output of
    this one.
    """

    if isinstance(iterable, EEmptyList):
        return SNoOp()
    elif isinstance(iterable, ESingleton):
        setup, value = simplify_and_optimize_expression(iterable.e)
        # SScoped because if the iterable is e.g. [x] + [y], then the body
        # might be appear in the same block twice.  If the body declares any
        # variables, that will cause problems in languages like Java or C++.
        return seq([setup, SScoped(re_use(value, loop_var, simplify_and_optimize(body)))])
    elif isinstance(iterable, ECond):
        cond_setup, cond = simplify_and_optimize_expression(iterable.cond)
        return seq([
            cond_setup,
            SIf(cond,
                stream(iterable.then_branch, loop_var, body),
                stream(iterable.else_branch, loop_var, body))])
    elif isinstance(iterable, EUnaryOp) and iterable.op == UOp.Distinct:
        tmp = fresh_var(TSet(iterable.type.elem_type), "distinct_elems")
        return seq([
            SDecl(tmp, EEmptyList().with_type(tmp.type)),
            stream(iterable.e, loop_var, SIf(
                ENot(EBinOp(loop_var, BOp.In, tmp).with_type(BOOL)),
                seq([body, SCall(tmp, "add", [loop_var])]),
                SNoOp()))])
    elif isinstance(iterable, EBinOp) and iterable.op == "+":
        return seq([
            stream(iterable.e1, loop_var, body),
            stream(iterable.e2, loop_var, body)])
    elif isinstance(iterable, EBinOp) and iterable.op == "-":
        if is_hashable(iterable.type.elem_type):
            h_setup, h = histogram(iterable.e2)
            val_ref = fresh_var(INT, "count")
            return seq([
                simplify_and_optimize(h_setup),
                stream(
                    iterable.e1,
                    loop_var,
                    SIf(EGt(EMapGet(h, loop_var).with_type(INT), ZERO),
                        SMapUpdate(h, loop_var, val_ref, SAssign(val_ref, EBinOp(val_ref, "-", ONE).with_type(INT))),
                        body))])
        else:
            rhs = fresh_var(iterable.e2.type, "bag_subtraction_right")
            return seq([
                simplify_and_optimize(SDecl(rhs, iterable.e2)),
                stream(
                    iterable.e1,
                    loop_var,
                    SIf(EIn(loop_var, rhs),
                        SCall(rhs, "remove", (loop_var,)),
                        body))])
    elif isinstance(iterable, EFilter):
        return stream(
            EFlatMap(iterable.e, ELambda(iterable.predicate.arg,
                ECond(iterable.predicate.body,
                    ESingleton(iterable.predicate.arg).with_type(iterable.type),
                    EEmptyList().with_type(iterable.type)).with_type(iterable.type))).with_type(iterable.type),
            loop_var,
            body)
    elif isinstance(iterable, EMap):
        return stream(
            EFlatMap(iterable.e, ELambda(iterable.transform_function.arg,
                ESingleton(iterable.transform_function.body).with_type(iterable.type))).with_type(iterable.type),
            loop_var,
            body)
    elif isinstance(iterable, EFlatMap):
        inner_loop_var = fresh_var(
            iterable.transform_function.arg.type,
            iterable.transform_function.arg.id)
        return stream(
            iterable.e,
            inner_loop_var,
            stream(iterable.transform_function.apply_to(inner_loop_var), loop_var, body))
    elif isinstance(iterable, EListSlice):
        elem_type = iterable.type.elem_type
        l = fresh_var(iterable.e.type, "list")
        s = fresh_var(INT, "start")
        e = fresh_var(INT, "end")
        return simplify_and_optimize(seq([
            SDecl(l, iterable.e),
            SDecl(s, max_of(iterable.start, ZERO)),
            SDecl(e, min_of(iterable.end, ELen(l))),
            SWhile(ELt(s, e), seq([
                SDecl(loop_var, EListGet(l, s).with_type(elem_type)),
                body,
                SAssign(s, EBinOp(s, "+", ONE).with_type(INT))]))]))
    elif isinstance(iterable, ELet):
        v = fresh_var(
            iterable.body_function.arg.type,
            iterable.body_function.arg.id)
        return seq([
            simplify_and_optimize(SDecl(v, iterable.e)),
            stream(iterable.body_function.apply_to(v), loop_var, body)])
    elif isinstance(iterable, EMove):
        return stream(iterable.e, loop_var, body)
    else:
        assert is_collection(iterable.type), repr(iterable)
        setup, e = simplify_and_optimize_expression(iterable)
        return seq([setup, SForEach(loop_var, e, simplify_and_optimize(body))])
Exemple #18
0
 def visit_ECond(self, e):
     v = fresh_var(e.type, "conditional_result")
     self.stms.append(simplify_and_optimize(seq([
         SDecl(v, evaluation.construct_value(e.type)),
         SIf(e.cond, SAssign(v, e.then_branch), SAssign(v, e.else_branch))])))
     return v
Exemple #19
0
 def visit_iterable(self, e):
     res = fresh_var(e.type)
     self.stms.append(SDecl(res, EEmptyList().with_type(e.type)))
     x = fresh_var(e.type.elem_type)
     self.stms.append(simplify_and_optimize(SForEach(x, e, SCall(res, "add", (x,)))))
     return EMove(res).with_type(res.type)
Exemple #20
0
 def pull(e: Exp) -> Exp:
     if exp_is_bad(e):
         v = fresh_var(e.type)
         decls_out.append(SDecl(v.id, e))
         return v
     return e
Exemple #21
0
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm:
    """Remove "bad" expressions from `s`.

    This procedure returns a statement new_s that replaces every expression in
    `s` where `exp_is_bad` returns True with a fresh variable.  After running,
    `decls_out` contains definitions for the fresh variables so that the whole
    statement

        decls_out; new_s

    should return the same result as `s`.
    """
    def pull(e: Exp) -> Exp:
        """Pull an expression into a temporary.

        Creates a fresh variable for `e`, writes a declaration into `decls_out`,
        and returns the fresh variable.
        """
        if exp_is_bad(e):
            v = fresh_var(e.type)
            decls_out.append(SDecl(v, e))
            return v
        return e

    if isinstance(s, SNoOp):
        return s
    if isinstance(s, SSeq):
        s1 = pull_temps(s.s1, decls_out, exp_is_bad)
        s2 = pull_temps(s.s2, decls_out, exp_is_bad)
        return SSeq(s1, s2)
    if isinstance(s, SDecl):
        return SDecl(s.var, pull(s.val))
    if isinstance(s, SIf):
        cond = pull(s.cond)
        s1 = pull_temps(s.then_branch, decls_out, exp_is_bad)
        s2 = pull_temps(s.else_branch, decls_out, exp_is_bad)
        return SIf(cond, s1, s2)
    if isinstance(s, SForEach):
        bag = pull(s.iter)
        d_tmp = []
        body = pull_temps(s.body, d_tmp, exp_is_bad)
        to_fix, ok = partition(d_tmp, lambda d: s.loop_var in free_vars(d.val))
        decls_out.extend(ok)
        for d in to_fix:
            v = d.var
            mt = TMap(s.loop_var.type, v.type)
            m = EMakeMap2(bag, ELambda(s.loop_var, d.val)).with_type(mt)
            mv = fresh_var(m.type)
            md = SDecl(mv, m)
            decls_out.append(md)
            body = subst(body,
                         {v.id: EMapGet(mv, s.loop_var).with_type(v.type)})
        return SForEach(s.loop_var, bag, body)
    if isinstance(s, SAssign):
        return SAssign(s.lhs, pull(s.rhs))
    if isinstance(s, SCall):
        return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args))
    if isinstance(s, SMapDel):
        return SMapDel(s.map, pull(s.key))
    if isinstance(s, SMapPut):
        return SMapPut(s.map, pull(s.key), pull(s.value))
    if isinstance(s, SMapUpdate):
        key = pull(s.key)
        d_tmp = []
        change = pull_temps(s.change, d_tmp, exp_is_bad)
        for d in d_tmp:
            if s.val_var in free_vars(d.val):
                decls_out.append(
                    SDecl(
                        d.var,
                        subst(
                            d.val, {
                                s.val_var.id:
                                EMapGet(s.map, key).with_type(s.val_var.type)
                            })))
            else:
                decls_out.append(d)
        return SMapUpdate(s.map, key, s.val_var, change)
    raise NotImplementedError(s)