Example #1
0
def build_soft_fusion_kernel(loops, loop_chain_index):
    """
    Build AST and :class:`Kernel` for a sequence of loops suitable to soft fusion.
    """

    kernels = [l.kernel for l in loops]
    asts = [k._ast for k in kernels]
    base_ast, fuse_asts = dcopy(asts[0]), asts[1:]

    base_fundecl = Find(ast.FunDecl).visit(base_ast)[ast.FunDecl][0]
    base_fundecl.body[:] = [ast.Block(base_fundecl.body, open_scope=True)]
    for unique_id, _fuse_ast in enumerate(fuse_asts, 1):
        fuse_ast = dcopy(_fuse_ast)
        fuse_fundecl = Find(ast.FunDecl).visit(fuse_ast)[ast.FunDecl][0]
        # 1) Extend function name
        base_fundecl.name = "%s_%s" % (base_fundecl.name, fuse_fundecl.name)
        # 2) Concatenate the arguments in the signature
        base_fundecl.args.extend(fuse_fundecl.args)
        # 3) Uniquify symbols identifiers
        fuse_symbols = SymbolReferences().visit(fuse_ast)
        for decl in fuse_fundecl.args:
            for symbol, _ in fuse_symbols[decl.sym.symbol]:
                symbol.symbol = "%s_%d" % (symbol.symbol, unique_id)
        # 4) Concatenate bodies
        base_fundecl.body.extend([ast.FlatBlock("\n\n// Fused kernel: \n\n")] +
                                 [ast.Block(fuse_fundecl.body, open_scope=True)])

    # Eliminate redundancies in the /fused/ kernel signature
    Filter().kernel_args(loops, base_fundecl)

    return Kernel(kernels, base_ast, loop_chain_index)
Example #2
0
 def find_save(target_expr, expr_info):
     save_factor = [l.size for l in expr_info.out_linear_loops] or [1]
     save_factor = reduce(operator.mul, save_factor)
     # The save factor should be multiplied by the number of terms
     # that will /not/ be pre-evaluated. To obtain this number, we
     # can exploit the linearity of the expression in the terms
     # depending on the linear loops.
     syms = Find(Symbol).visit(target_expr)[Symbol]
     inner = lambda s: any(r == expr_info.linear_dims[-1]
                           for r in s.rank)
     nterms = len(set(s.symbol for s in syms if inner(s)))
     save = nterms * save_factor
     return save_factor, save
Example #3
0
    def _multiple_ast_to_c(self, kernels):
        """Glue together different ASTs (or strings) such that: ::

            * clashes due to identical function names are avoided;
            * duplicate functions (same name, same body) are avoided.
        """
        code = ""
        identifier = lambda k: k.cache_key[1:]
        unsorted_kernels = sorted(kernels, key=identifier)
        for i, (_, kernel_group) in enumerate(
                groupby(unsorted_kernels, identifier)):
            duplicates = list(kernel_group)
            main = duplicates[0]
            if main._ast:
                main_ast = dcopy(main._ast)
                found = Find((ast.FunDecl, ast.FunCall)).visit(main_ast)
                for fundecl in found[ast.FunDecl]:
                    new_name = "%s_%d" % (fundecl.name, i)
                    # Need to change the name of any inner functions too
                    for funcall in found[ast.FunCall]:
                        if fundecl.name == funcall.funcall.symbol:
                            funcall.funcall.symbol = new_name
                    fundecl.name = new_name
                function_name = "%s_%d" % (main._name, i)
                code += sequential.Kernel._ast_to_c(main, main_ast, main._opts)
            else:
                # AST not available so can't change the name, hopefully there
                # will not be compile time clashes.
                function_name = main._name
                code += main._code
            # Finally track the function name within this /fusion.Kernel/
            for k in duplicates:
                try:
                    k._function_names[self.cache_key] = function_name
                except AttributeError:
                    k._function_names = {
                        k.cache_key: k.name,
                        self.cache_key: function_name
                    }
            code += "\n"

        # Tiled kernels are C++, and C++ compilers don't recognize /restrict/
        code = """
#define restrict __restrict

%s
""" % code

        return code
Example #4
0
 def __contains__(self, val):
     from coffee.visitors import Find
     if isinstance(val, Node):
         val, search = str(val), type(Node)
     elif isinstance(val, str):
         val, search = val, Symbol
     else:
         return False
     for i in self:
         if isinstance(i, Node):
             items = Find(search).visit(i)
             if any(val == str(i) for i in items[search]):
                 return True
         elif isinstance(i, str) and val == i:
             return True
     return False
Example #5
0
def generate_integral_code(ir, prefix, parameters):
    "Generate code for integral from intermediate representation."

    info("Generating code from tsfc representation")

    # Generate generic ffc code snippets
    code = initialize_integral_code(ir, prefix, parameters)

    # Go unoptimized if TSFC mode has not been set yet
    integral_data, form_data, prefix, parameters = ir["compile_integral"]
    parameters = parameters.copy()
    parameters.setdefault("mode", "vanilla")

    # Generate tabulate_tensor body
    ast = compile_integral(integral_data,
                           form_data,
                           None,
                           parameters,
                           interface=ufc_interface)

    # COFFEE vectorize
    knl = ASTKernel(ast)
    knl.plan_cpu(dict(optlevel='Ov'))

    tsfc_code = "".join(b.gencode() for b in ast.body)
    tsfc_code = tsfc_code.replace("#pragma coffee",
                                  "//#pragma coffee")  # FIXME
    code["tabulate_tensor"] = tsfc_code

    includes = set()
    includes.update(ir.get("additional_includes_set", ()))
    includes.update(ast.headers)
    includes.add("#include <cstring>")  # memset
    if any(
            node.funcall.symbol.startswith("boost::math::")
            for node in Find(coffee.FunCall).visit(ast)[coffee.FunCall]):
        includes.add("#include <boost/math/special_functions.hpp>")
    code["additional_includes_set"] = includes

    return code
Example #6
0
    def _transpose_layout(self, decls):
        dim = self.loop.dim
        symbols = Find(Symbol).visit(self.loop)[Symbol]
        symbols = [
            s for s in symbols if any(r == dim for r in s.rank) and s.dim > 1
        ]

        # Cannot handle arrays with more than 2 dimensions
        if any(s.dim > 2 for s in symbols):
            return

        mapper = OrderedDict()
        for s in symbols:
            mapper.setdefault(decls[s.symbol], list()).append(s)

        for decl, syms in mapper.items():
            # Adjust the declaration
            transposed_values = decl.init.values.transpose()
            decl.init.values = transposed_values
            decl.sym.rank = transposed_values.shape

            # Adjust the instances
            for s in syms:
                s.rank = tuple(reversed(s.rank))
Example #7
0
def build_hard_fusion_kernel(base_loop, fuse_loop, fusion_map, loop_chain_index):
    """
    Build AST and :class:`Kernel` for two loops suitable to hard fusion.

    The AST consists of three functions: fusion, base, fuse. base and fuse
    are respectively the ``base_loop`` and the ``fuse_loop`` kernels, whereas
    fusion is the orchestrator that invokes, for each ``base_loop`` iteration,
    base and, if still to be executed, fuse.

    The orchestrator has the following structure: ::

        fusion (buffer, ..., executed):
            base (buffer, ...)
            for i = 0 to arity:
                if not executed[i]:
                    additional pointer staging required by kernel2
                    fuse (sub_buffer, ...)
                    insertion into buffer

    The executed array tracks whether the i-th iteration (out of /arity/)
    adjacent to the main kernel1 iteration has been executed.
    """

    finder = Find((ast.FunDecl, ast.PreprocessNode))

    base = base_loop.kernel
    base_ast = dcopy(base._ast)
    base_info = finder.visit(base_ast)
    base_headers = base_info[ast.PreprocessNode]
    base_fundecl = base_info[ast.FunDecl]
    assert len(base_fundecl) == 1
    base_fundecl = base_fundecl[0]

    fuse = fuse_loop.kernel
    fuse_ast = dcopy(fuse._ast)
    fuse_info = finder.visit(fuse_ast)
    fuse_headers = fuse_info[ast.PreprocessNode]
    fuse_fundecl = fuse_info[ast.FunDecl]
    assert len(fuse_fundecl) == 1
    fuse_fundecl = fuse_fundecl[0]

    # Create /fusion/ arguments and signature
    body = ast.Block([])
    fusion_name = '%s_%s' % (base_fundecl.name, fuse_fundecl.name)
    fusion_args = dcopy(base_fundecl.args + fuse_fundecl.args)
    fusion_fundecl = ast.FunDecl(base_fundecl.ret, fusion_name, fusion_args, body)

    # Make sure kernel and variable names are unique
    base_fundecl.name = "%s_base" % base_fundecl.name
    fuse_fundecl.name = "%s_fuse" % fuse_fundecl.name
    for i, decl in enumerate(fusion_args):
        decl.sym.symbol += '_%d' % i

    # Filter out duplicate arguments, and append extra arguments to the fundecl
    binding = WeakFilter().kernel_args([base_loop, fuse_loop], fusion_fundecl)
    fusion_args += [ast.Decl('int*', 'executed'),
                    ast.Decl('int*', 'fused_iters'),
                    ast.Decl('int', 'i')]

    # Which args are actually used in /fuse/, but not in /base/ ? The gather for
    # such arguments is moved to /fusion/, to avoid usless memory LOADs
    base_dats = set(a.data for a in base_loop.args)
    fuse_dats = set(a.data for a in fuse_loop.args)
    unshared = OrderedDict()
    for arg, decl in binding.items():
        if arg.data in fuse_dats - base_dats:
            unshared.setdefault(decl, arg)

    # Track position of Args that need a postponed gather
    # Can't track Args themselves as they change across different parloops
    fargs = {fusion_args.index(i): ('postponed', False) for i in unshared.keys()}
    fargs.update({len(set(binding.values())): ('onlymap', True)})

    # Add maps for arguments that need a postponed gather
    for decl, arg in unshared.items():
        decl_pos = fusion_args.index(decl)
        fusion_args[decl_pos].sym.symbol = arg.c_arg_name()
        if arg._is_indirect:
            fusion_args[decl_pos].sym.rank = ()
            fusion_args.insert(decl_pos + 1, ast.Decl('int*', arg.c_map_name(0, 0)))

    # Append the invocation of /base/; then, proceed with the invocation
    # of the /fuse/ kernels
    base_funcall_syms = [binding[a].sym.symbol for a in base_loop.args]
    body.children.append(ast.FunCall(base_fundecl.name, *base_funcall_syms))

    for idx in range(fusion_map.arity):

        fused_iter = ast.Assign('i', ast.Symbol('fused_iters', (idx,)))
        fuse_funcall = ast.FunCall(fuse_fundecl.name)
        if_cond = ast.Not(ast.Symbol('executed', ('i',)))
        if_update = ast.Assign(ast.Symbol('executed', ('i',)), 1)
        if_body = ast.Block([fuse_funcall, if_update], open_scope=True)
        if_exec = ast.If(if_cond, [if_body])
        body.children.extend([ast.FlatBlock('\n'), fused_iter, if_exec])

        # Modify the /fuse/ kernel
        # This is to take into account that many arguments are shared with
        # /base/, so they will only staged once for /base/. This requires
        # tweaking the way the arguments are declared and accessed in /fuse/.
        # For example, the shared incremented array (called /buffer/ in
        # the pseudocode in the comment above) now needs to take offsets
        # to be sure the locations that /base/ is supposed to increment are
        # actually accessed. The same concept apply to indirect arguments.
        init = lambda v: '{%s}' % ', '.join([str(j) for j in v])
        for i, fuse_loop_arg in enumerate(fuse_loop.args):
            fuse_kernel_arg = binding[fuse_loop_arg]

            buffer_name = '%s_vec' % fuse_kernel_arg.sym.symbol
            fuse_funcall_sym = ast.Symbol(buffer_name)

            # What kind of temporaries do we need ?
            if fuse_loop_arg.access == INC:
                op, lvalue, rvalue = ast.Incr, fuse_kernel_arg.sym.symbol, buffer_name
                stager = lambda b, l: b.children.extend(l)
                indexer = lambda indices: [(k, j) for j, k in enumerate(indices)]
                pointers = []
            elif fuse_loop_arg.access == READ:
                op, lvalue, rvalue = ast.Assign, buffer_name, fuse_kernel_arg.sym.symbol
                stager = lambda b, l: [b.children.insert(0, j) for j in reversed(l)]
                indexer = lambda indices: [(j, k) for j, k in enumerate(indices)]
                pointers = list(fuse_kernel_arg.pointers)

            # Now gonna handle arguments depending on their type and rank ...

            if fuse_loop_arg._is_global:
                # ... Handle global arguments. These can be dropped in the
                # kernel without any particular fiddling
                fuse_funcall_sym = ast.Symbol(fuse_kernel_arg.sym.symbol)

            elif fuse_kernel_arg in unshared:
                # ... Handle arguments that appear only in /fuse/
                staging = unshared[fuse_kernel_arg].c_vec_init(False).split('\n')
                rvalues = [ast.FlatBlock(j.split('=')[1]) for j in staging]
                lvalues = [ast.Symbol(buffer_name, (j,)) for j in range(len(staging))]
                staging = [ast.Assign(j, k) for j, k in zip(lvalues, rvalues)]

                # Set up the temporary
                buffer_symbol = ast.Symbol(buffer_name, (len(staging),))
                buffer_decl = ast.Decl(fuse_kernel_arg.typ, buffer_symbol,
                                       qualifiers=fuse_kernel_arg.qual,
                                       pointers=list(pointers))

                # Update the if-then AST body
                stager(if_exec.children[0], staging)
                if_exec.children[0].children.insert(0, buffer_decl)

            elif fuse_loop_arg._is_mat:
                # ... Handle Mats
                staging = []
                for b in fused_inc_arg._block_shape:
                    for rc in b:
                        lvalue = ast.Symbol(lvalue, (idx, idx),
                                            ((rc[0], 'j'), (rc[1], 'k')))
                        rvalue = ast.Symbol(rvalue, ('j', 'k'))
                        staging = ItSpace(mode=0).to_for([(0, rc[0]), (0, rc[1])],
                                                         ('j', 'k'),
                                                         [op(lvalue, rvalue)])[:1]

                # Set up the temporary
                buffer_symbol = ast.Symbol(buffer_name, (fuse_kernel_arg.sym.rank,))
                buffer_init = ast.ArrayInit(init([init([0.0])]))
                buffer_decl = ast.Decl(fuse_kernel_arg.typ, buffer_symbol, buffer_init,
                                       qualifiers=fuse_kernel_arg.qual, pointers=pointers)

                # Update the if-then AST body
                stager(if_exec.children[0], staging)
                if_exec.children[0].children.insert(0, buffer_decl)

            elif fuse_loop_arg._is_indirect:
                cdim = fuse_loop_arg.data.cdim

                if cdim == 1 and fuse_kernel_arg.sym.rank:
                    # [Special case]
                    # ... Handle rank 1 indirect arguments that appear in both
                    # /base/ and /fuse/: just point into the right location
                    rank = (idx,) if fusion_map.arity > 1 else ()
                    fuse_funcall_sym = ast.Symbol(fuse_kernel_arg.sym.symbol, rank)

                else:
                    # ... Handle indirect arguments. At the C level, these arguments
                    # are of pointer type, so simple pointer arithmetic is used
                    # to ensure the kernel accesses are to the correct locations
                    fuse_arity = fuse_loop_arg.map.arity
                    base_arity = fuse_arity*fusion_map.arity
                    size = fuse_arity*cdim

                    # Set the proper storage layout before invoking /fuse/
                    ofs_vals = [[base_arity*j + k for k in range(fuse_arity)]
                                for j in range(cdim)]
                    ofs_vals = [[fuse_arity*j + k for k in flatten(ofs_vals)]
                                for j in range(fusion_map.arity)]
                    ofs_vals = list(flatten(ofs_vals))
                    indices = [ofs_vals[idx*size + j] for j in range(size)]

                    staging = [op(ast.Symbol(lvalue, (j,)), ast.Symbol(rvalue, (k,)))
                               for j, k in indexer(indices)]

                    # Set up the temporary
                    buffer_symbol = ast.Symbol(buffer_name, (size,))
                    if fuse_loop_arg.access == INC:
                        buffer_init = ast.ArrayInit(init([0.0]))
                    else:
                        buffer_init = ast.EmptyStatement()
                        pointers.pop()
                    buffer_decl = ast.Decl(fuse_kernel_arg.typ, buffer_symbol, buffer_init,
                                           qualifiers=fuse_kernel_arg.qual,
                                           pointers=pointers)

                    # Update the if-then AST body
                    stager(if_exec.children[0], staging)
                    if_exec.children[0].children.insert(0, buffer_decl)

            else:
                # Nothing special to do for direct arguments
                pass

            # Finally update the /fuse/ funcall
            fuse_funcall.children.append(fuse_funcall_sym)

    fused_headers = set([str(h) for h in base_headers + fuse_headers])
    fused_ast = ast.Root([ast.PreprocessNode(h) for h in fused_headers] +
                         [base_fundecl, fuse_fundecl, fusion_fundecl])

    return Kernel([base, fuse], fused_ast, loop_chain_index), fargs
Example #8
0
    def _dissect(self, heuristics):
        """Analyze the set of expressions in the LoopOptimizer and infer an
        optimal rewrite mode for each of them.

        If an expression is embedded in a non-perfect loop nest, then injection
        may be performed. Injection consists of unrolling any loops outside of
        the expression iteration space into the expression itself.
        For example: ::

            for i
              for r
                a += B[r]*C[i][r]
              for j
                for k
                  A[j][k] += ...f(a)... // the expression at hand

        gets transformed into:

            for i
              for j
                for k
                  A[j][k] += ...f(B[0]*C[i][0] + B[1]*C[i][1] + ...)...

        Injection could be necessary to maximize the impact of rewrite mode=3,
        which tries to pre-evaluate subexpressions whose values are known at
        code generation time. Injection is essential to factorize such subexprs.

        :arg heuristic: any value in ['greedy', 'aggressive']. With 'greedy', a greedy
            approach is used to decide which of the expressions for which injection
            looks beneficial should be dissected (e.g., injection increases the memory
            footprint, and some memory constraints must always be preserved).
            With 'aggressive', the whole space of possibilities is analyzed.
        """
        # The memory threshold. The total size of temporaries will not have to
        # be greated than this value. If we predict that injection will lead
        # to too much temporary space, we have to partially drop it
        threshold = system.architecture['cache_size'] * 1.2

        expr_graph = ExpressionGraph(header)

        # 1) Find out and unroll injectable loops. For unrolling we create new
        # expressions; that is, for now, we do not modify the AST in place.
        analyzed, injectable = [], {}
        for stmt, expr_info in self.exprs.items():
            # Get all loop nests, then discard the one enclosing the expression
            nests = [n for n in visit(expr_info.loops_parents[0])['fors']]
            injectable_nests = [n for n in nests if list(zip(*n))[0] != expr_info.loops]

            for nest in injectable_nests:
                to_unroll = [(l, p) for l, p in nest if l not in expr_info.loops]
                unroll_cost = reduce(operator.mul, (l.size for l, p in to_unroll))

                nest_writers = Find(Writer).visit(to_unroll[0][0])
                for op, i_stmts in nest_writers.items():
                    # Check safety of unrolling
                    if op in [Assign, IMul, IDiv]:
                        continue
                    assert op in [Incr, Decr]

                    for i_stmt in i_stmts:
                        i_sym, i_expr = i_stmt.children

                        # Avoid injecting twice the same loop
                        if i_stmt in analyzed + [l.incr for l, p in to_unroll]:
                            continue
                        analyzed.append(i_stmt)

                        # Create unrolled, injectable expressions
                        for l, p in reversed(to_unroll):
                            i_expr = [dcopy(i_expr) for i in range(l.size)]
                            for i, e in enumerate(i_expr):
                                e_syms = Find(Symbol).visit(e)[Symbol]
                                for s in e_syms:
                                    s.rank = tuple([r if r != l.dim else i for r in s.rank])
                            i_expr = ast_make_expr(Sum, i_expr)

                        # Track the unrolled, injectable expressions and their cost
                        if i_sym.symbol in injectable:
                            old_i_expr, old_cost = injectable[i_sym.symbol]
                            new_i_expr = ast_make_expr(Sum, [i_expr, old_i_expr])
                            new_cost = unroll_cost + old_cost
                            injectable[i_sym.symbol] = (new_i_expr, new_cost)
                        else:
                            injectable[i_sym.symbol] = (i_expr, unroll_cost)

        # 2) Will rewrite mode=3 be cheaper than rewrite mode=2?
        def find_save(target_expr, expr_info):
            save_factor = [l.size for l in expr_info.out_linear_loops] or [1]
            save_factor = reduce(operator.mul, save_factor)
            # The save factor should be multiplied by the number of terms
            # that will /not/ be pre-evaluated. To obtain this number, we
            # can exploit the linearity of the expression in the terms
            # depending on the linear loops.
            syms = Find(Symbol).visit(target_expr)[Symbol]
            inner = lambda s: any(r == expr_info.linear_dims[-1] for r in s.rank)
            nterms = len(set(s.symbol for s in syms if inner(s)))
            save = nterms * save_factor
            return save_factor, save

        should_unroll = True
        storage = 0
        i_syms, injected = injectable.keys(), defaultdict(list)
        for stmt, expr_info in self.exprs.items():
            sym, expr = stmt.children

            # Divide /expr/ into subexpressions, each subexpression affected
            # differently by injection
            if i_syms:
                dissected = find_expression(expr, Prod, expr_info.linear_dims, i_syms)
                leftover = find_expression(expr, dims=expr_info.linear_dims, out_syms=i_syms)
                leftover = {(): list(flatten(leftover.values()))}
                dissected = dict(dissected.items() + leftover.items())
            else:
                dissected = {(): [expr]}
            if any(i not in flatten(dissected.keys()) for i in i_syms):
                should_unroll = False
                continue

            # Apply the profitability model
            analysis = OrderedDict()
            for i_syms, target_exprs in dissected.items():
                for target_expr in target_exprs:

                    # *** Save ***
                    save_factor, save = find_save(target_expr, expr_info)

                    # *** Cost ***
                    # The number of operations increases by a factor which
                    # corresponds to the number of possible /combinations with
                    # repetitions/ in the injected-values set. We consider
                    # combinations and not dispositions to take into account the
                    # (future) effect of factorization.
                    retval = ProjectExpansion.default_retval()
                    projection = ProjectExpansion(i_syms).visit(target_expr, ret=retval)
                    projection = [i for i in projection if i]
                    increase_factor = 0
                    for i in projection:
                        partial = 1
                        for j in expr_graph.shares(i):
                            # _n=number of unique elements, _k=group size
                            _n = injectable[j[0]][1]
                            _k = len(j)
                            partial *= fact(_n + _k - 1)//(fact(_k)*fact(_n - 1))
                        increase_factor += partial
                    increase_factor = increase_factor or 1
                    if increase_factor > save_factor:
                        # We immediately give up if this holds since it ensures
                        # that /cost > save/ (but not that cost <= save)
                        should_unroll = False
                        continue
                    # The increase factor should be multiplied by the number of
                    # terms that will be pre-evaluated. To obtain this number,
                    # we need to project the output of factorization.
                    fake_stmt = stmt.__class__(stmt.children[0], dcopy(target_expr))
                    fake_parent = expr_info.parent.children
                    fake_parent[fake_parent.index(stmt)] = fake_stmt
                    ew = ExpressionRewriter(fake_stmt, expr_info)
                    ew.expand(mode='all').factorize(mode='all').factorize(mode='linear')
                    nterms = ew.licm(mode='aggressive', look_ahead=True)
                    nterms = len(uniquify(nterms[expr_info.dims])) or 1
                    fake_parent[fake_parent.index(fake_stmt)] = stmt
                    cost = nterms * increase_factor

                    # Pre-evaluation will also increase the working set size by
                    # /cost/ * /sizeof(term)/.
                    size = [l.size for l in expr_info.linear_loops]
                    size = reduce(operator.mul, size, 1)
                    storage_increase = cost * size * system.architecture[expr_info.type]

                    # Track the injectable sub-expression and its cost/save. The
                    # final decision of whether to actually perform injection or not
                    # is postponed until all dissected expressions have been analyzed
                    analysis[target_expr] = (cost, save, storage_increase)

            # So what should we inject afterall ? Time to *use* the cost model
            if heuristics == 'greedy':
                for target_expr, (cost, save, storage_increase) in analysis.items():
                    if cost > save or storage_increase + storage > threshold:
                        should_unroll = False
                    else:
                        # Update the available storage
                        storage += storage_increase
                        # At this point, we can happily inject
                        to_replace = {k: v[0] for k, v in injectable.items()}
                        ast_replace(target_expr, to_replace, copy=True)
                        injected[stmt].append(target_expr)
            elif heuristics == 'aggressive':
                # A) Remove expression that we already know should never be injected
                not_injected = []
                for target_expr, (cost, save, storage_increase) in analysis.items():
                    if cost > save:
                        should_unroll = False
                        analysis.pop(target_expr)
                        not_injected.append(target_expr)
                # B) Find all possible bipartitions: each bipartition represents
                # the set of expressions that will be pre-evaluated and the set
                # of expressions that could also be pre-evaluated, but might not
                # (e.g. because of memory constraints)
                target_exprs = analysis.keys()
                bipartitions = []
                for i in range(len(target_exprs)+1):
                    for e1 in combinations(target_exprs, i):
                        bipartitions.append((e1, tuple(e2 for e2 in target_exprs
                                                       if e2 not in e1)))
                # C) Eliminate those bipartitions that would lead to exceeding
                # the memory threshold
                bipartitions = [(e1, e2) for e1, e2 in bipartitions if
                                sum(analysis[i][2] for i in e1) <= threshold]
                # D) Find out what is best to pre-evaluate (and therefore
                # what should be injected)
                totals = OrderedDict()
                for e1, e2 in bipartitions:
                    # Is there any value in actually not pre-evaluating the
                    # expressions in /e2/ ?
                    fake_expr = ast_make_expr(Sum, list(e2) + not_injected)
                    _, save = find_save(fake_expr, expr_info) if fake_expr else (0, 0)
                    cost = sum(analysis[i][0] for i in e1)
                    totals[(e1, e2)] = save + cost
                best = min(totals, key=totals.get)
                # At this point, we can happily inject
                to_replace = {k: v[0] for k, v in injectable.items()}
                for target_expr in best[0]:
                    ast_replace(target_expr, to_replace, copy=True)
                    injected[stmt].append(target_expr)
                if best[1]:
                    # At least one non-injected expressions, let's be sure we
                    # don't unroll everything
                    should_unroll = False

        # 3) Purge the AST from now useless symbols/expressions
        if should_unroll:
            decls = visit(self.header, info_items=['decls'])['decls']
            for stmt, expr_info in self.exprs.items():
                nests = [n for n in visit(expr_info.loops_parents[0])['fors']]
                injectable_nests = [n for n in nests if list(zip(*n))[0] != expr_info.loops]
                for nest in injectable_nests:
                    unrolled = [(l, p) for l, p in nest if l not in expr_info.loops]
                    for l, p in unrolled:
                        p.children.remove(l)
                        for i_sym in injectable.keys():
                            decl = decls.get(i_sym)
                            if decl and decl in p.children:
                                p.children.remove(decl)

        # 4) Split the expressions if injection has been performed
        for stmt, expr_info in self.exprs.items():
            expr_info.mode = 4
            inj_exprs = injected.get(stmt)
            if not inj_exprs:
                continue
            fissioner = ExpressionFissioner(match=inj_exprs, loops='all', perfect=True)
            new_exprs = fissioner.fission(stmt, self.exprs.pop(stmt))
            self.exprs.update(new_exprs)
            for stmt, expr_info in new_exprs.items():
                expr_info.mode = 3 if stmt in fissioner.matched else 4
Example #9
0
    def plan_cpu(self, opts):
        """Optimize this :class:`ASTKernel` for CPU execution.

        :param opts: a dictionary of optimizations to be applied. For a description
            of the recognized optimizations, please refer to the ``coffee.set_opt_level``
            documentation. If equal to ``None``, the default optimizations in
            ``coffee.options['optimizations']`` are applied; these are either the
            optimizations set when COFFEE was initialized or those changed through
            a call to ``set_opt_level``. In this way, a default set of optimizations
            is applied to all kernels, but users are also allowed to select
            specific transformations for individual kernels.
        """

        start_time = time.time()

        kernels = Find(FunDecl, stop_when_found=True).visit(self.ast)[FunDecl]

        if opts is None:
            opts = coffee.OptimizationLevel.retrieve(
                coffee.options['optimizations'])
        else:
            opts = coffee.OptimizationLevel.retrieve(opts.get('optlevel', {}))

        flops_pre = EstimateFlops().visit(self.ast)

        for kernel in kernels:
            rewrite = opts.get('rewrite')
            vectorize = opts.get('vectorize', (None, None))
            align_pad = opts.get('align_pad')
            split = opts.get('split')
            dead_ops_elimination = opts.get('dead_ops_elimination')

            info = visit(kernel, info_items=['decls', 'exprs'])
            # Collect expressions and related metadata
            nests = defaultdict(OrderedDict)
            for stmt, expr_info in info['exprs'].items():
                parent, nest = expr_info
                if not nest:
                    continue
                if kernel.template:
                    typ = "double"
                else:
                    typ = check_type(stmt, info['decls'])
                metaexpr = MetaExpr(typ, parent, nest)
                nests[nest[0]].update({stmt: metaexpr})
            loop_opts = [
                CPULoopOptimizer(loop, header, exprs)
                for (loop, header), exprs in nests.items()
            ]

            # Combining certain optimizations is forbidden.
            if dead_ops_elimination and split:
                warn("Split forbidden with dead-ops elimination")
                return
            if dead_ops_elimination and vectorize[0]:
                warn("Vect forbidden with dead-ops elimination")
                return
            if rewrite == 'auto' and len(info['exprs']) > 1:
                warn("Rewrite auto forbidden with multiple exprs")
                rewrite = 4

            # Main Ootimization pipeline
            for loop_opt in loop_opts:

                # 0) Expression Rewriting
                if rewrite:
                    loop_opt.rewrite(rewrite)

                # 1) Dead-operations elimination
                if dead_ops_elimination:
                    loop_opt.eliminate_zeros()

                # 2) Code specialization
                if split:
                    loop_opt.split(split)
                if coffee.initialized and flatten(loop_opt.expr_linear_loops):
                    vect = LoopVectorizer(loop_opt, kernel)
                    if align_pad:
                        # Padding and data alignment
                        vect.autovectorize()
                    if vectorize[0] and vectorize[0] != VectStrategy.AUTO:
                        # Specialize vectorization for the memory access pattern
                        # of the expression
                        vect.specialize(*vectorize)

            # Ensure kernel is always marked static inline
            # Remove either or both of static and inline (so that we get the order right)
            kernel.pred = [
                q for q in kernel.pred if q not in ['static', 'inline']
            ]
            kernel.pred.insert(0, 'inline')
            kernel.pred.insert(0, 'static')

            # Post processing of the AST ensures higher-quality code
            postprocess(kernel)

        flops_post = EstimateFlops().visit(self.ast)

        tot_time = time.time() - start_time

        output = "COFFEE finished in %g seconds (flops: %d -> %d)" % \
            (tot_time, flops_pre, flops_post)
        log(output, PERF_OK if flops_post <= flops_pre else PERF_WARN)
Example #10
0
    def plan_gpu(self):
        """Transform the kernel suitably for GPU execution.

        Loops decorated with a ``pragma coffee itspace`` are hoisted out of
        the kernel. The list of arguments in the function signature is
        enriched by adding iteration variables of hoisted loops. The size of any
        kernel's non-constant tensor is modified accordingly.

        For example, consider the following function: ::

            void foo (int A[3]) {
              int B[3] = {...};
              #pragma coffee itspace
              for (int i = 0; i < 3; i++)
                A[i] = B[i];
            }

        plan_gpu modifies its AST such that the resulting output code is ::

            void foo(int A[1], int i) {
              A[0] = B[i];
            }
        """

        # The optimization passes are performed individually (i.e., "locally") for
        # each function (or "kernel") found in the provided AST
        kernels = Find(FunDecl, stop_when_found=True).visit(self.ast)[FunDecl]

        for kernel in kernels:
            info = visit(kernel, info_items=['decls', 'exprs'])
            # Collect expressions and related metadata
            nests = defaultdict(OrderedDict)
            for stmt, expr_info in info['exprs'].items():
                parent, nest = expr_info
                if not nest:
                    continue
                if kernel.template:
                    typ = "double"
                else:
                    typ = check_type(stmt, info['decls'])
                metaexpr = MetaExpr(typ, parent, nest)
                nests[nest[0]].update({stmt: metaexpr})
            loop_opts = [
                GPULoopOptimizer(loop, header, exprs)
                for (loop, header), exprs in nests.items()
            ]

            for loop_opt in loop_opts:
                itspace_vrs, accessed_vrs = loop_opt.extract()

                for v in accessed_vrs:
                    # Change declaration of non-constant iteration space-dependent
                    # parameters by shrinking the size of the iteration space
                    # dimension to 1
                    decl = set(
                        [d for d in kernel.args if d.sym.symbol == v.symbol])
                    dsym = decl.pop().sym if len(decl) > 0 else None
                    if dsym and dsym.rank:
                        dsym.rank = tuple([
                            1 if i in itspace_vrs else j
                            for i, j in zip(v.rank, dsym.rank)
                        ])

                    # Remove indices of all iteration space-dependent and
                    # kernel-dependent variables that are accessed in an itspace
                    v.rank = tuple([
                        0 if i in itspace_vrs and dsym else i for i in v.rank
                    ])

                # Add iteration space arguments
                kernel.args.extend(
                    [Decl("int", Symbol("%s" % i)) for i in itspace_vrs])

            # Clean up the kernel removing variable qualifiers like 'static'
            for decl in decls.values():
                d, place = decl
                d.qual = [q for q in d.qual if q not in ['static', 'const']]

            kernel.pred = [
                q for q in kernel.pred if q not in ['static', 'inline']
            ]
Example #11
0
    def _align_data(self, p_dim, decls):
        """Apply data alignment. This boils down to:

            * Decorate declarations with qualifiers for data alignment
            * Round up the bounds (i.e. /start/ and /end/ points) of loops such
            that all memory accesses get aligned to the vector length. Several
            checks ensure the correctness of the transformation.
        """
        vector_length = system.isa["dp_reg"]
        align = system.compiler['align'](system.isa['alignment'])

        # Array alignment
        for decl in decls.values():
            if decl.sym.rank and decl.scope == LOCAL:
                decl.attr.append(align)

        # Loop bounds adjustment
        for l in inner_loops(self.header):
            should_round = True

            for stmt in l.body:
                sym, expr = stmt.lvalue, stmt.rvalue
                decl = decls[sym.symbol]

                # Condition A: the lvalue can be a scalar only if /stmt/ is not an
                # augmented assignment, otherwise the extra iterations would alter
                # its value
                if not sym.rank and isinstance(stmt, AugmentedAssign):
                    should_round = False
                    break

                # Condition B: the fastest varying dimension of the lvalue must be /l/
                if sym.rank and not sym.rank[p_dim] == l.dim:
                    should_round = False
                    break

                # Condition C: the lvalue must have been padded
                if sym.rank and decl.size[p_dim] != vect_roundup(
                        decl.size[p_dim]):
                    should_round = False
                    break

                symbols = [sym] + Find(Symbol).visit(expr)[Symbol]
                symbols = [
                    s for s in symbols
                    if s.rank and any(r == l.dim for r in s.rank)
                ]

                # Condition D: the access pattern must be accessible
                if any(not s.is_unit_period for s in symbols):
                    # Cannot infer the access pattern so must break
                    should_round = False
                    break

                # Condition E: extra iterations induced by bounds and offset rounding
                # must not alter the computation
                for s in symbols:
                    decl = decls[s.symbol]
                    index = s.rank.index(l.dim)
                    stride = s.strides[index]
                    extra = list(
                        range(stride + l.size, stride + vect_roundup(l.size)))
                    # Do any of the extra iterations alter the computation ?
                    if any(i > decl.size[index] for i in extra):
                        # ... outside of the legal region, abort
                        should_round = False
                        break
                    if all(i >= decl.core[index] for i in extra):
                        # ... in the padded region, pass
                        continue
                    nz = list(self.nz_syms.get(s.symbol))
                    if not nz:
                        # ... lacks the non zero-valued entries mapping, abort
                        should_round = False
                        break
                    # ... get the non zero-valued entries in the right dimension
                    nz_index = []
                    for i in nz:
                        can_skip = False
                        for j, r in enumerate(s.rank[:index]):
                            if not is_const_dim(r):
                                continue
                            if not (i[j].ofs <= r < i[j].ofs + i[j].size):
                                # ... actually on a different outer dimension, safe
                                # to avoid this check
                                can_skip = True
                        if not can_skip:
                            nz_index.append(i[index])
                    if any(ofs <= i < ofs + size for size, ofs in nz_index):
                        # ... writing to a non-zero region, abort
                        should_round = False
                        break

            if should_round:
                l.end = vect_roundup(l.end)
                if all(i % vector_length == 0 for i in [l.start, l.size]):
                    l.pragma.add(system.compiler["align_forloop"])
                    l.pragma.add(system.compiler['force_simdization'])
Example #12
0
    def _pad(self, p_dim, decls, fors, symbols_dep, symbols_mode, symbol_refs):
        """Apply padding."""
        to_invert = Find(Invert).visit(self.header)[Invert]

        # Loop increments different than 1 are unsupported
        if any([l.increment != 1 for l, _ in flatten(fors)]):
            return None

        DSpace = namedtuple('DSpace', ['region', 'nest', 'symbols'])
        ISpace = namedtuple('ISpace', ['region', 'nest', 'bag'])

        buf_decl = None
        for decl_name, decl in decls.items():
            if not decl.size or decl.is_pointer_type:
                continue

            p_rank = decl.size[:p_dim] + (vect_roundup(decl.size[p_dim]), )
            if decl.size[p_dim] == 1 or p_rank == decl.size:
                continue

            if decl.scope == LOCAL:
                decl.pad(p_rank)
                continue

            # At this point we are sure /decl/ is a FunDecl argument

            # A) Can a buffer actually be allocated ?
            symbols = [
                s for s, _ in symbol_refs[decl_name] if s is not decl.sym
            ]
            if not all(s.dim == decl.sym.dim and s.is_const_offset
                       for s in symbols):
                continue
            periods = flatten([s.periods for s in symbols])
            if not all(p == 1 for p in periods):
                continue

            # ... must be either READ or WRITE mode
            modes = [symbols_mode[s][0] for s in symbols]
            if not modes or any(m != modes[0] for m in modes):
                continue
            mode = modes[0]
            if mode not in [READ, WRITE]:
                continue

            # ... accesses to entries in /decl/ must be explicit in all loop nests
            deps = OrderedDict(
                (s, [l for l in symbols_dep[s] if l.dim in s.rank])
                for s in symbols)
            if not all(s.dim == len(n) for s, n in deps.items()):
                continue

            # ... organize symbols based on their dataspace
            dspace_mapper = OrderedDict()
            for s, n in deps.items():
                n.sort(key=lambda l: s.rank.index(l.dim))
                region = tuple(
                    Region(l.size, l.start + i) for i, l in zip(s.strides, n))
                dspace = DSpace(region=region, nest=n, symbols=[])
                dspace_mapper.setdefault(dspace.region, dspace)
                dspace.symbols.append(s)

            # ... is there any overlap in the memory accesses? Memory accesses must:
            # - either completely overlap (they will be mapped to the same buffer)
            # - OR be disjoint
            will_break = False
            for regions1, regions2 in product(dspace_mapper.keys(),
                                              dspace_mapper.keys()):
                for r1, r2 in zip(regions1, regions2):
                    if ItSpace(mode=1).intersect([r1, r2]) not in [(0, 0), r1]:
                        will_break = True
            if will_break:
                continue

            # ... initialize buffer-related data
            buf_name = '_' + decl_name
            buf_nz = self.nz_syms.setdefault(buf_name, [])

            # ... determine the non zero-valued region in the buffer
            for n, region in enumerate(dspace_mapper.keys()):
                p_region = (Region(region[p_dim].size, 0), )
                buf_nz.append((Region(1, n), ) + region[:p_dim] + p_region)

            # ... replace symbols in the AST with proper buffer instances
            itspace_mapper = OrderedDict()
            for n, dspace in enumerate(dspace_mapper.values()):
                itspace = ISpace(region=tuple(
                    (l.size, l.start) for l in dspace.nest),
                                 nest=dspace.nest,
                                 bag=OrderedDict())
                itspace = itspace_mapper.setdefault(itspace.region, itspace)
                for s in dspace.symbols:
                    original = Symbol(s.symbol, s.rank, s.offset)
                    s.symbol = buf_name
                    s.rank = (n, ) + s.rank
                    s.offset = ((1, 0), ) + s.offset[:p_dim] + ((1, 0), )
                    if s.urepr not in [i.urepr for i in itspace.bag.values()]:
                        itspace.bag[original] = Symbol(s.symbol, s.rank,
                                                       s.offset)

            # ... insert the buffer into the AST
            buf_dim = n + 1
            buf_rank = (buf_dim, ) + decl.size
            init = ArrayInit(
                np.ndarray(shape=(1, ) * len(buf_rank), buffer=np.array(0.0)))
            buf_decl = Decl(decl.typ,
                            Symbol(buf_name, buf_rank),
                            init,
                            scope=BUFFER)
            buf_decl.pad((buf_dim, ) + p_rank)
            self.header.children.insert(0, buf_decl)

            # C) Create a loop nest for copying data into/from the buffer
            for itspace in itspace_mapper.values():

                if mode == READ:
                    stmts = [Assign(b, s) for s, b in itspace.bag.items()]
                    copy_back = ItSpace(mode=2).to_for(itspace.nest,
                                                       stmts=stmts)
                    insert_at_elem(self.header.children,
                                   buf_decl,
                                   copy_back[0],
                                   ofs=1)

                elif mode == WRITE:
                    # If extra information (a pragma) is present, telling that
                    # the argument does not need to be incremented because it does
                    # not contain any meaningful values, then we can safely write
                    # to it. This optimization may avoid useless increments
                    can_write = WRITE in decl.pragma and len(
                        itspace_mapper) == 1
                    op = Assign if can_write else Incr
                    stmts = [op(s, b) for s, b in itspace.bag.items()]
                    copy_back = ItSpace(mode=2).to_for(itspace.nest,
                                                       stmts=stmts)
                    if to_invert:
                        insert_at_elem(self.header.children, to_invert[0],
                                       copy_back[0])
                    else:
                        self.header.children.append(copy_back[0])

            # D) Update the global data structures
            decls[buf_name] = buf_decl

        return buf_decl
Example #13
0
    def _dissect(self, heuristics):
        """Analyze the set of expressions in the LoopOptimizer and infer an
        optimal rewrite mode for each of them.

        If an expression is embedded in a non-perfect loop nest, then injection
        may be performed. Injection consists of unrolling any loops outside of
        the expression iteration space into the expression itself.
        For example: ::

            for i
              for r
                a += B[r]*C[i][r]
              for j
                for k
                  A[j][k] += ...f(a)... // the expression at hand

        gets transformed into:

            for i
              for j
                for k
                  A[j][k] += ...f(B[0]*C[i][0] + B[1]*C[i][1] + ...)...

        Injection could be necessary to maximize the impact of rewrite mode=3,
        which tries to pre-evaluate subexpressions whose values are known at
        code generation time. Injection is essential to factorize such subexprs.

        :arg heuristic: any value in ['greedy', 'aggressive']. With 'greedy', a greedy
            approach is used to decide which of the expressions for which injection
            looks beneficial should be dissected (e.g., injection increases the memory
            footprint, and some memory constraints must always be preserved).
            With 'aggressive', the whole space of possibilities is analyzed.
        """
        # The memory threshold. The total size of temporaries will not have to
        # be greated than this value. If we predict that injection will lead
        # to too much temporary space, we have to partially drop it
        threshold = system.architecture['cache_size'] * 1.2

        expr_graph = ExpressionGraph(header)

        # 1) Find out and unroll injectable loops. For unrolling we create new
        # expressions; that is, for now, we do not modify the AST in place.
        analyzed, injectable = [], {}
        for stmt, expr_info in self.exprs.items():
            # Get all loop nests, then discard the one enclosing the expression
            nests = [n for n in visit(expr_info.loops_parents[0])['fors']]
            injectable_nests = [
                n for n in nests if list(zip(*n))[0] != expr_info.loops
            ]

            for nest in injectable_nests:
                to_unroll = [(l, p) for l, p in nest
                             if l not in expr_info.loops]
                unroll_cost = reduce(operator.mul,
                                     (l.size for l, p in to_unroll))

                nest_writers = Find(Writer).visit(to_unroll[0][0])
                for op, i_stmts in nest_writers.items():
                    # Check safety of unrolling
                    if op in [Assign, IMul, IDiv]:
                        continue
                    assert op in [Incr, Decr]

                    for i_stmt in i_stmts:
                        i_sym, i_expr = i_stmt.children

                        # Avoid injecting twice the same loop
                        if i_stmt in analyzed + [l.incr for l, p in to_unroll]:
                            continue
                        analyzed.append(i_stmt)

                        # Create unrolled, injectable expressions
                        for l, p in reversed(to_unroll):
                            i_expr = [dcopy(i_expr) for i in range(l.size)]
                            for i, e in enumerate(i_expr):
                                e_syms = Find(Symbol).visit(e)[Symbol]
                                for s in e_syms:
                                    s.rank = tuple([
                                        r if r != l.dim else i for r in s.rank
                                    ])
                            i_expr = ast_make_expr(Sum, i_expr)

                        # Track the unrolled, injectable expressions and their cost
                        if i_sym.symbol in injectable:
                            old_i_expr, old_cost = injectable[i_sym.symbol]
                            new_i_expr = ast_make_expr(Sum,
                                                       [i_expr, old_i_expr])
                            new_cost = unroll_cost + old_cost
                            injectable[i_sym.symbol] = (new_i_expr, new_cost)
                        else:
                            injectable[i_sym.symbol] = (i_expr, unroll_cost)

        # 2) Will rewrite mode=3 be cheaper than rewrite mode=2?
        def find_save(target_expr, expr_info):
            save_factor = [l.size for l in expr_info.out_linear_loops] or [1]
            save_factor = reduce(operator.mul, save_factor)
            # The save factor should be multiplied by the number of terms
            # that will /not/ be pre-evaluated. To obtain this number, we
            # can exploit the linearity of the expression in the terms
            # depending on the linear loops.
            syms = Find(Symbol).visit(target_expr)[Symbol]
            inner = lambda s: any(r == expr_info.linear_dims[-1]
                                  for r in s.rank)
            nterms = len(set(s.symbol for s in syms if inner(s)))
            save = nterms * save_factor
            return save_factor, save

        should_unroll = True
        storage = 0
        i_syms, injected = injectable.keys(), defaultdict(list)
        for stmt, expr_info in self.exprs.items():
            sym, expr = stmt.children

            # Divide /expr/ into subexpressions, each subexpression affected
            # differently by injection
            if i_syms:
                dissected = find_expression(expr, Prod, expr_info.linear_dims,
                                            i_syms)
                leftover = find_expression(expr,
                                           dims=expr_info.linear_dims,
                                           out_syms=i_syms)
                leftover = {(): list(flatten(leftover.values()))}
                dissected = dict(dissected.items() + leftover.items())
            else:
                dissected = {(): [expr]}
            if any(i not in flatten(dissected.keys()) for i in i_syms):
                should_unroll = False
                continue

            # Apply the profitability model
            analysis = OrderedDict()
            for i_syms, target_exprs in dissected.items():
                for target_expr in target_exprs:

                    # *** Save ***
                    save_factor, save = find_save(target_expr, expr_info)

                    # *** Cost ***
                    # The number of operations increases by a factor which
                    # corresponds to the number of possible /combinations with
                    # repetitions/ in the injected-values set. We consider
                    # combinations and not dispositions to take into account the
                    # (future) effect of factorization.
                    retval = ProjectExpansion.default_retval()
                    projection = ProjectExpansion(i_syms).visit(target_expr,
                                                                ret=retval)
                    projection = [i for i in projection if i]
                    increase_factor = 0
                    for i in projection:
                        partial = 1
                        for j in expr_graph.shares(i):
                            # _n=number of unique elements, _k=group size
                            _n = injectable[j[0]][1]
                            _k = len(j)
                            partial *= fact(_n + _k - 1) // (fact(_k) *
                                                             fact(_n - 1))
                        increase_factor += partial
                    increase_factor = increase_factor or 1
                    if increase_factor > save_factor:
                        # We immediately give up if this holds since it ensures
                        # that /cost > save/ (but not that cost <= save)
                        should_unroll = False
                        continue
                    # The increase factor should be multiplied by the number of
                    # terms that will be pre-evaluated. To obtain this number,
                    # we need to project the output of factorization.
                    fake_stmt = stmt.__class__(stmt.children[0],
                                               dcopy(target_expr))
                    fake_parent = expr_info.parent.children
                    fake_parent[fake_parent.index(stmt)] = fake_stmt
                    ew = ExpressionRewriter(fake_stmt, expr_info)
                    ew.expand(mode='all').factorize(mode='all').factorize(
                        mode='linear')
                    nterms = ew.licm(mode='aggressive', look_ahead=True)
                    nterms = len(uniquify(nterms[expr_info.dims])) or 1
                    fake_parent[fake_parent.index(fake_stmt)] = stmt
                    cost = nterms * increase_factor

                    # Pre-evaluation will also increase the working set size by
                    # /cost/ * /sizeof(term)/.
                    size = [l.size for l in expr_info.linear_loops]
                    size = reduce(operator.mul, size, 1)
                    storage_increase = cost * size * system.architecture[
                        expr_info.type]

                    # Track the injectable sub-expression and its cost/save. The
                    # final decision of whether to actually perform injection or not
                    # is postponed until all dissected expressions have been analyzed
                    analysis[target_expr] = (cost, save, storage_increase)

            # So what should we inject afterall ? Time to *use* the cost model
            if heuristics == 'greedy':
                for target_expr, (cost, save,
                                  storage_increase) in analysis.items():
                    if cost > save or storage_increase + storage > threshold:
                        should_unroll = False
                    else:
                        # Update the available storage
                        storage += storage_increase
                        # At this point, we can happily inject
                        to_replace = {k: v[0] for k, v in injectable.items()}
                        ast_replace(target_expr, to_replace, copy=True)
                        injected[stmt].append(target_expr)
            elif heuristics == 'aggressive':
                # A) Remove expression that we already know should never be injected
                not_injected = []
                for target_expr, (cost, save,
                                  storage_increase) in analysis.items():
                    if cost > save:
                        should_unroll = False
                        analysis.pop(target_expr)
                        not_injected.append(target_expr)
                # B) Find all possible bipartitions: each bipartition represents
                # the set of expressions that will be pre-evaluated and the set
                # of expressions that could also be pre-evaluated, but might not
                # (e.g. because of memory constraints)
                target_exprs = analysis.keys()
                bipartitions = []
                for i in range(len(target_exprs) + 1):
                    for e1 in combinations(target_exprs, i):
                        bipartitions.append(
                            (e1,
                             tuple(e2 for e2 in target_exprs if e2 not in e1)))
                # C) Eliminate those bipartitions that would lead to exceeding
                # the memory threshold
                bipartitions = [(e1, e2) for e1, e2 in bipartitions
                                if sum(analysis[i][2]
                                       for i in e1) <= threshold]
                # D) Find out what is best to pre-evaluate (and therefore
                # what should be injected)
                totals = OrderedDict()
                for e1, e2 in bipartitions:
                    # Is there any value in actually not pre-evaluating the
                    # expressions in /e2/ ?
                    fake_expr = ast_make_expr(Sum, list(e2) + not_injected)
                    _, save = find_save(fake_expr,
                                        expr_info) if fake_expr else (0, 0)
                    cost = sum(analysis[i][0] for i in e1)
                    totals[(e1, e2)] = save + cost
                best = min(totals, key=totals.get)
                # At this point, we can happily inject
                to_replace = {k: v[0] for k, v in injectable.items()}
                for target_expr in best[0]:
                    ast_replace(target_expr, to_replace, copy=True)
                    injected[stmt].append(target_expr)
                if best[1]:
                    # At least one non-injected expressions, let's be sure we
                    # don't unroll everything
                    should_unroll = False

        # 3) Purge the AST from now useless symbols/expressions
        if should_unroll:
            decls = visit(self.header, info_items=['decls'])['decls']
            for stmt, expr_info in self.exprs.items():
                nests = [n for n in visit(expr_info.loops_parents[0])['fors']]
                injectable_nests = [
                    n for n in nests if list(zip(*n))[0] != expr_info.loops
                ]
                for nest in injectable_nests:
                    unrolled = [(l, p) for l, p in nest
                                if l not in expr_info.loops]
                    for l, p in unrolled:
                        p.children.remove(l)
                        for i_sym in injectable.keys():
                            decl = decls.get(i_sym)
                            if decl and decl in p.children:
                                p.children.remove(decl)

        # 4) Split the expressions if injection has been performed
        for stmt, expr_info in self.exprs.items():
            expr_info.mode = 4
            inj_exprs = injected.get(stmt)
            if not inj_exprs:
                continue
            fissioner = ExpressionFissioner(match=inj_exprs,
                                            loops='all',
                                            perfect=True)
            new_exprs = fissioner.fission(stmt, self.exprs.pop(stmt))
            self.exprs.update(new_exprs)
            for stmt, expr_info in new_exprs.items():
                expr_info.mode = 3 if stmt in fissioner.matched else 4