Exemplo n.º 1
0
def iet_insert_decls(iet, external):
    """
    Transform the input IET inserting the necessary symbol declarations.
    Declarations are placed as close as possible to the first symbol occurrence.

    Parameters
    ----------
    iet : Node
        The input Iteration/Expression tree.
    external : tuple, optional
        The symbols defined in some outer Callable, which therefore must not
        be re-defined.
    """
    iet = as_tuple(iet)

    # Classify and then schedule declarations to stack/heap
    allocator = Allocator()
    for k, v in MapSections().visit(iet).items():
        if k.is_Expression:
            if k.is_scalar_assign:
                # On the stack
                site = v if v else iet
                allocator.push_scalar_on_stack(site[-1], k)
                continue
            objs = [k.write]
        elif k.is_Call:
            objs = k.arguments

        for i in objs:
            try:
                if i.is_LocalObject:
                    # On the stack
                    site = v if v else iet
                    allocator.push_object_on_stack(site[-1], i)
                elif i.is_Array:
                    if i in as_tuple(external):
                        # The Array is defined in some other IET
                        continue
                    elif i._mem_stack:
                        # On the stack
                        key = lambda i: not i.is_Parallel
                        site = filter_iterations(v, key=key) or iet
                        allocator.push_object_on_stack(site[-1], i)
                    else:
                        # On the heap
                        allocator.push_array_on_heap(i)
            except AttributeError:
                # E.g., a generic SymPy expression
                pass

    # Introduce declarations on the stack
    mapper = dict(allocator.onstack)
    iet = Transformer(mapper, nested=True).visit(iet)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet
Exemplo n.º 2
0
def find_offloadable_trees(nodes):
    """
    Return the trees within ``nodes`` that can be computed by YASK.

    A tree is "offloadable to YASK" if it is embedded in a time stepping loop
    *and* all of the grids accessed by the enclosed equations are homogeneous
    (i.e., same dimensions and data type).
    """
    offloadable = []
    for tree in retrieve_iteration_tree(nodes):
        parallel = filter_iterations(tree, lambda i: i.is_Parallel)
        if not parallel:
            # Cannot offload non-parallel loops
            continue
        if not (IsPerfectIteration().visit(tree)
                and all(i.is_Expression for i in tree[-1].nodes)):
            # Don't know how to offload this Iteration/Expression to YASK
            continue
        functions = flatten(i.functions for i in tree[-1].nodes)
        keys = set((i.grid, i.dtype) for i in functions if i.is_TimeFunction)
        if len(keys) == 0:
            continue
        elif len(keys) > 1:
            exit("Cannot handle Operators w/ heterogeneous grids")
        grid, dtype = keys.pop()
        # Is this a "complete" tree iterating over the entire grid?
        dims = [i.dim for i in tree]
        if all(i in dims for i in grid.dimensions) and\
                any(j in dims for j in [grid.time_dim, grid.stepping_dim]):
            offloadable.append((tree, grid, dtype))
    return offloadable
Exemplo n.º 3
0
def find_offloadable_trees(nodes):
    """
    Return the trees within ``nodes`` that can be computed by YASK.

    A tree is "offloadable to YASK" if it is embedded in a time stepping loop
    *and* all of the grids accessed by the enclosed equations are homogeneous
    (i.e., same dimensions, shape, data type).
    """
    offloadable = []
    for tree in retrieve_iteration_tree(nodes):
        parallel = filter_iterations(tree, lambda i: i.is_Parallel)
        if not parallel:
            # Cannot offload non-parallel loops
            continue
        if not (IsPerfectIteration().visit(tree)
                and all(i.is_Expression for i in tree[-1].nodes)):
            # Don't know how to offload this Iteration/Expression to YASK
            continue
        functions = flatten(i.functions for i in tree[-1].nodes)
        keys = set((i.indices, i.shape, i.dtype) for i in functions
                   if i.is_TimeFunction)
        if len(keys) == 0:
            continue
        elif len(keys) > 1:
            exit("Cannot handle Operators w/ heterogeneous grids")
        dimensions, shape, dtype = keys.pop()
        if len(dimensions) == len(tree) and\
                all(i.dim == j for i, j in zip(tree, dimensions)):
            # Detected a "full" Iteration/Expression tree (over both
            # time and space dimensions)
            offloadable.append((tree, dimensions, shape, dtype))
    return offloadable
Exemplo n.º 4
0
    def _ompize(self, nodes, state):
        """
        Add OpenMP pragmas to the Iteration/Expression tree to emit parallel code
        """
        # Group by outer loop so that we can embed within the same parallel region
        was_tagged = False
        groups = OrderedDict()
        for tree in retrieve_iteration_tree(nodes):
            # Determine the number of consecutive parallelizable Iterations
            key = lambda i: i.is_Parallel and\
                not (i.is_Elementizable or i.is_Vectorizable)
            candidates = filter_iterations(tree, key=key, stop='asap')
            if not candidates:
                was_tagged = False
                continue
            # Consecutive tagged Iteration go in the same group
            is_tagged = any(i.tag is not None for i in tree)
            key = len(groups) - (is_tagged & was_tagged)
            handle = groups.setdefault(key, OrderedDict())
            handle[candidates[0]] = candidates
            was_tagged = is_tagged

        # Handle parallelizable loops
        mapper = OrderedDict()
        for group in groups.values():
            private = []
            for root, tree in group.items():
                # Heuristic: if at least two parallel loops are available and the
                # physical core count is greater than self.thresholds['collapse'],
                # then omp-collapse the loops
                nparallel = len(tree)
                if psutil.cpu_count(logical=False) < self.thresholds['collapse'] or\
                        nparallel < 2:
                    parallel = omplang['for']
                else:
                    parallel = omplang['collapse'](nparallel)

                mapper[root] = root._rebuild(pragmas=root.pragmas +
                                             (parallel, ))

                # Track the thread-private and thread-shared variables
                private.extend([
                    i for i in FindSymbols('symbolics').visit(root)
                    if i.is_Array and i._mem_stack
                ])

            # Build the parallel region
            private = sorted(set([i.name for i in private]))
            private = ('private(%s)' % ','.join(private)) if private else ''
            rebuilt = [v for k, v in mapper.items() if k in group]
            par_region = Block(header=omplang['par-region'](private),
                               body=rebuilt)
            for k, v in list(mapper.items()):
                if isinstance(v, Iteration):
                    mapper[k] = None if v.is_Remainder else par_region

        processed = Transformer(mapper).visit(nodes)

        return processed, {}
Exemplo n.º 5
0
def iet_insert_C_decls(iet, func_table):
    """
    Given an Iteration/Expression tree ``iet``, build a new tree with the
    necessary symbol declarations. Declarations are placed as close as
    possible to the first symbol use.

    :param iet: The input Iteration/Expression tree.
    :param func_table: A mapper from callable names to :class:`Callable`s
                       called from within ``iet``.
    """
    # Resolve function calls first
    scopes = []
    me = MapExpressions()
    for k, v in me.visit(iet).items():
        if k.is_Call:
            func = func_table[k.name]
            if func.local:
                scopes.extend(me.visit(func.root, queue=list(v)).items())
        else:
            scopes.append((k, v))

    # Determine all required declarations
    allocator = Allocator()
    mapper = OrderedDict()
    for k, v in scopes:
        if k.is_scalar:
            # Inline declaration
            mapper[k] = LocalExpression(**k.args)
        elif k.write is None or k.write._mem_external:
            # Nothing to do, e.g., variable passed as kernel argument
            continue
        elif k.write._mem_stack:
            # On the stack
            key = lambda i: not i.is_Parallel
            site = filter_iterations(v, key=key, stop='asap') or [iet]
            allocator.push_stack(site[-1], k.write)
        else:
            # On the heap, as a tensor that must be globally accessible
            allocator.push_heap(k.write)

    # Introduce declarations on the stack
    for k, v in allocator.onstack:
        mapper[k] = tuple(Element(i) for i in v)
    iet = NestedTransformer(mapper).visit(iet)
    for k, v in list(func_table.items()):
        if v.local:
            func_table[k] = MetaCall(
                Transformer(mapper).visit(v.root), v.local)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet
Exemplo n.º 6
0
    def test_make_efuncs(self, exprs, nfuncs, ntimeiters, nests):
        """Test construction of ElementalFunctions."""
        exprs = list(as_tuple(exprs))

        grid = Grid(shape=(10, 10))
        t = grid.stepping_dim  # noqa
        x, y = grid.dimensions  # noqa

        u = Function(name='u', grid=grid)  # noqa
        v = TimeFunction(name='v', grid=grid)  # noqa

        # List comprehension would need explicit locals/globals mappings to eval
        for i, e in enumerate(list(exprs)):
            exprs[i] = eval(e)

        op = Operator(exprs)

        # We create one ElementalFunction for each Iteration nest over space dimensions
        efuncs = []
        for n, tree in enumerate(retrieve_iteration_tree(op)):
            root = filter_iterations(tree,
                                     key=lambda i: i.dim.is_Space,
                                     stop='asap')
            efuncs.append(make_efunc('f%d' % n, root))

        assert len(efuncs) == len(nfuncs) == len(ntimeiters) == len(nests)

        for efunc, nf, nt, nest in zip(efuncs, nfuncs, ntimeiters, nests):
            # Check the `efunc` parameters
            assert all(i in efunc.parameters
                       for i in (x.symbolic_min, x.symbolic_max))
            assert all(i in efunc.parameters
                       for i in (y.symbolic_min, y.symbolic_max))
            functions = FindSymbols().visit(efunc)
            assert len(functions) == nf
            assert all(i in efunc.parameters for i in functions)
            timeiters = [
                i for i in FindSymbols('free-symbols').visit(efunc)
                if i.is_Dimension and i.is_Time
            ]
            assert len(timeiters) == nt
            assert all(i in efunc.parameters for i in timeiters)
            assert len(efunc.parameters) == 4 + len(functions) + len(timeiters)

            # Check there's exactly one ArrayCast for each Function
            assert len(FindNodes(ArrayCast).visit(efunc)) == nf

            # Check the loop nest structure
            trees = retrieve_iteration_tree(efunc)
            assert len(trees) == 1
            tree = trees[0]
            assert all(i.dim.name == j for i, j in zip(tree, nest))

            assert efunc.make_call()
Exemplo n.º 7
0
    def make_parallel(self, iet):
        """
        Transform ``iet`` by decorating its parallel :class:`Iteration`s with
        suitable ``#pragma omp ...`` for thread-level parallelism.
        """
        # Group sequences of loops that should go within the same parallel region
        was_tagged = False
        groups = OrderedDict()
        for tree in retrieve_iteration_tree(iet):
            # Determine the number of consecutive parallelizable Iterations
            candidates = filter_iterations(tree, key=self.key, stop='asap')
            if not candidates:
                was_tagged = False
                continue
            # Consecutive tagged Iteration go in the same group
            is_tagged = any(i.tag is not None for i in tree)
            key = len(groups) - (is_tagged & was_tagged)
            handle = groups.setdefault(key, OrderedDict())
            handle[candidates[0]] = candidates
            was_tagged = is_tagged

        mapper = OrderedDict()
        for group in groups.values():
            private = []
            for root, candidates in group.items():
                mapper.update(self._make_parallel_tree(root, candidates))

                # Track the thread-private and thread-shared variables
                private.extend([
                    i for i in FindSymbols('symbolics').visit(root)
                    if i.is_Array and i._mem_stack
                ])

            # Build the parallel region
            private = sorted(set([i.name for i in private]))
            private = ('private(%s)' % ','.join(private)) if private else ''
            rebuilt = [v for k, v in mapper.items() if k in group]
            par_region = Block(header=self.lang['par-region'](private),
                               body=rebuilt)
            for k, v in list(mapper.items()):
                if isinstance(v, Iteration):
                    mapper[k] = None if v.is_Remainder else par_region
        processed = Transformer(mapper).visit(iet)

        # Hack/workaround to the fact that the OpenMP pragmas are not true
        # IET nodes, so the `nthreads` variables won't be detected as a
        # Callable parameter unless inserted in a mock Expression
        if mapper:
            nt = NThreads()
            eq = LocalExpression(DummyEq(Symbol(name='nt', dtype=np.int32),
                                         nt))
            return List(body=[eq, processed]), {'input': [nt]}
        else:
            return List(body=processed), {}
Exemplo n.º 8
0
    def _insert_declarations(self, nodes):
        """Populate the Operator's body with the necessary variable declarations."""

        # Resolve function calls first
        scopes = []
        me = MapExpressions()
        for k, v in me.visit(nodes).items():
            if k.is_Call:
                func = self.func_table[k.name]
                if func.local:
                    scopes.extend(me.visit(func.root, queue=list(v)).items())
            else:
                scopes.append((k, v))

        # Determine all required declarations
        allocator = Allocator()
        mapper = OrderedDict()
        for k, v in scopes:
            if k.is_scalar:
                # Inline declaration
                mapper[k] = LocalExpression(**k.args)
            elif k.write._mem_external:
                # Nothing to do, variable passed as kernel argument
                continue
            elif k.write._mem_stack:
                # On the stack, as established by the DLE
                key = lambda i: not i.is_Parallel
                site = filter_iterations(v, key=key, stop='asap') or [nodes]
                allocator.push_stack(site[-1], k.write)
            else:
                # On the heap, as a tensor that must be globally accessible
                allocator.push_heap(k.write)

        # Introduce declarations on the stack
        for k, v in allocator.onstack:
            mapper[k] = tuple(Element(i) for i in v)
        nodes = NestedTransformer(mapper).visit(nodes)
        for k, v in list(self.func_table.items()):
            if v.local:
                self.func_table[k] = FunMeta(
                    Transformer(mapper).visit(v.root), v.local)

        # Introduce declarations on the heap (if any)
        if allocator.onheap:
            decls, allocs, frees = zip(*allocator.onheap)
            nodes = List(header=decls + allocs, body=nodes, footer=frees)

        return nodes
Exemplo n.º 9
0
    def test_make_efuncs(self, exprs, nfuncs, ntimeiters, nests):
        """Test construction of ElementalFunctions."""
        exprs = list(as_tuple(exprs))

        grid = Grid(shape=(10, 10))
        t = grid.stepping_dim  # noqa
        x, y = grid.dimensions  # noqa

        u = Function(name='u', grid=grid)  # noqa
        v = TimeFunction(name='v', grid=grid)  # noqa

        # List comprehension would need explicit locals/globals mappings to eval
        for i, e in enumerate(list(exprs)):
            exprs[i] = eval(e)

        op = Operator(exprs)

        # We create one ElementalFunction for each Iteration nest over space dimensions
        efuncs = []
        for n, tree in enumerate(retrieve_iteration_tree(op)):
            root = filter_iterations(tree, key=lambda i: i.dim.is_Space)[0]
            efuncs.append(make_efunc('f%d' % n, root))

        assert len(efuncs) == len(nfuncs) == len(ntimeiters) == len(nests)

        for efunc, nf, nt, nest in zip(efuncs, nfuncs, ntimeiters, nests):
            # Check the `efunc` parameters
            assert all(i in efunc.parameters for i in (x.symbolic_min, x.symbolic_max))
            assert all(i in efunc.parameters for i in (y.symbolic_min, y.symbolic_max))
            functions = FindSymbols().visit(efunc)
            assert len(functions) == nf
            assert all(i in efunc.parameters for i in functions)
            timeiters = [i for i in FindSymbols('free-symbols').visit(efunc)
                         if isinstance(i, Dimension) and i.is_Time]
            assert len(timeiters) == nt
            assert all(i in efunc.parameters for i in timeiters)
            assert len(efunc.parameters) == 4 + len(functions) + len(timeiters)

            # Check the loop nest structure
            trees = retrieve_iteration_tree(efunc)
            assert len(trees) == 1
            tree = trees[0]
            assert all(i.dim.name == j for i, j in zip(tree, nest))

            assert efunc.make_call()
Exemplo n.º 10
0
    def make_parallel(self, iet):
        """
        Transform ``iet`` by decorating its parallel :class:`Iteration`s with
        suitable ``#pragma omp ...`` triggering thread-level parallelism.
        """
        # Group sequences of loops that should go within the same parallel region
        was_tagged = False
        groups = OrderedDict()
        for tree in retrieve_iteration_tree(iet):
            # Determine the number of consecutive parallelizable Iterations
            candidates = filter_iterations(tree, key=self.key, stop='asap')
            if not candidates:
                was_tagged = False
                continue
            # Consecutive tagged Iteration go in the same group
            is_tagged = any(i.tag is not None for i in tree)
            key = len(groups) - (is_tagged & was_tagged)
            handle = groups.setdefault(key, OrderedDict())
            handle[candidates[0]] = candidates
            was_tagged = is_tagged

        mapper = OrderedDict()
        for group in groups.values():
            private = []
            for root, candidates in group.items():
                mapper.update(self._make_parallel_tree(root, candidates))

                # Track the thread-private and thread-shared variables
                private.extend([i for i in FindSymbols('symbolics').visit(root)
                                if i.is_Array and i._mem_stack])

            # Build the parallel region
            private = sorted(set([i.name for i in private]))
            private = ('private(%s)' % ','.join(private)) if private else ''
            rebuilt = [v for k, v in mapper.items() if k in group]
            par_region = Block(header=self.lang['par-region'](private), body=rebuilt)
            for k, v in list(mapper.items()):
                if isinstance(v, Iteration):
                    mapper[k] = None if v.is_Remainder else par_region

        return Transformer(mapper).visit(iet)
Exemplo n.º 11
0
    def _hoist_prodders(self, iet):
        """
        Move Prodders within the outer levels of an Iteration tree.
        """
        mapper = {}
        for tree in retrieve_iteration_tree(iet):
            for prodder in FindNodes(Prodder).visit(tree.root):
                if prodder._periodic:
                    try:
                        key = lambda i: isinstance(i.dim, BlockDimension)
                        candidate = filter_iterations(tree, key)[-1]
                    except IndexError:
                        # Fallback: use the outermost Iteration
                        candidate = tree.root
                    mapper[candidate] = candidate._rebuild(nodes=((prodder._rebuild(),) +
                                                                  candidate.nodes))
                    mapper[prodder] = None

        iet = Transformer(mapper, nested=True).visit(iet)

        return iet, {}
Exemplo n.º 12
0
    def _hoist_prodders(self, iet):
        """
        Move Prodders within the outer levels of an Iteration tree.
        """
        mapper = {}
        for tree in retrieve_iteration_tree(iet):
            for prodder in FindNodes(Prodder).visit(tree.root):
                if prodder._periodic:
                    try:
                        key = lambda i: isinstance(i.dim, BlockDimension)
                        candidate = filter_iterations(tree, key)[-1]
                    except IndexError:
                        # Fallback: use the outermost Iteration
                        candidate = tree.root
                    mapper[candidate] = candidate._rebuild(
                        nodes=(candidate.nodes + (prodder._rebuild(), )))
                    mapper[prodder] = None

        iet = Transformer(mapper, nested=True).visit(iet)

        return iet, {}
Exemplo n.º 13
0
    def make_parallel(self, iet):
        """Transform ``iet`` by introducing shared-memory parallelism."""
        mapper = OrderedDict()
        for tree in retrieve_iteration_tree(iet):
            # Get the first omp-parallelizable Iteration in `tree`
            candidates = filter_iterations(tree, key=self.key, stop='asap')
            if not candidates:
                continue
            root = candidates[0]

            # Build the `omp-for` tree
            partree = self._make_parallel_tree(root, candidates)

            # Find out the thread-private and thread-shared variables
            private = [
                i for i in FindSymbols().visit(partree)
                if i.is_Array and i._mem_stack
            ]

            # Build the `omp-parallel` region
            private = sorted(set([i.name for i in private]))
            private = ('private(%s)' % ','.join(private)) if private else ''
            partree = Block(header=self.lang['par-region'](self.nthreads.name,
                                                           private),
                            body=partree)

            # Do not enter the parallel region if the step increment might be 0; this
            # would raise a `Floating point exception (core dumped)` in some OpenMP
            # implementation. Note that using an OpenMP `if` clause won't work
            if isinstance(root.step, Symbol):
                cond = Conditional(CondEq(root.step, 0),
                                   Element(c.Statement('return')))
                partree = List(body=[cond, partree])

            mapper[root] = partree
        iet = Transformer(mapper).visit(iet)

        return iet, {'input': [self.nthreads] if mapper else []}
Exemplo n.º 14
0
    def __init__(self, expressions, **kwargs):
        super(OperatorDebug, self).__init__(expressions, **kwargs)
        self._includes.append('stdio.h')

        # Minimize the trip count of the sequential loops
        iterations = set(flatten(retrieve_iteration_tree(self.body)))
        mapper = {
            i: i._rebuild(limits=(max(i.offsets) + 2))
            for i in iterations if i.is_Sequential
        }
        self.body = Transformer(mapper).visit(self.body)

        # Mark entry/exit points of each non-sequential Iteration tree in the body
        iterations = [
            filter_iterations(i, lambda i: not i.is_Sequential, 'any')
            for i in retrieve_iteration_tree(self.body)
        ]
        iterations = [i[0] for i in iterations if i]
        mapper = {
            t: List(header=printmark('In nest %d' % i), body=t)
            for i, t in enumerate(iterations)
        }
        self.body = Transformer(mapper).visit(self.body)
Exemplo n.º 15
0
def iet_insert_C_decls(iet, func_table=None):
    """
    Given an Iteration/Expression tree ``iet``, build a new tree with the
    necessary symbol declarations. Declarations are placed as close as
    possible to the first symbol use.

    :param iet: The input Iteration/Expression tree.
    :param func_table: (Optional) a mapper from callable names within ``iet``
                       to :class:`Callable`s.
    """
    func_table = func_table or {}
    allocator = Allocator()
    mapper = OrderedDict()

    # Detect all IET nodes accessing symbols that need to be declared
    scopes = []
    me = MapExpressions()
    for k, v in me.visit(iet).items():
        if k.is_Call:
            func = func_table.get(k.name)
            if func is not None and func.local:
                scopes.extend(me.visit(func.root, queue=list(v)).items())
        scopes.append((k, v))

    # Classify, and then schedule declarations to stack/heap
    for k, v in scopes:
        if k.is_Expression:
            if k.is_scalar:
                # Inline declaration
                mapper[k] = LocalExpression(**k.args)
                continue
            objs = [k.write]
        elif k.is_Call:
            objs = k.params
        else:
            raise NotImplementedError("Cannot schedule declarations for IET "
                                      "node of type `%s`" % type(k))
        for i in objs:
            try:
                if i.is_LocalObject:
                    # On the stack
                    site = v[-1] if v else iet
                    allocator.push_stack(site, i)
                elif i.is_Array:
                    if i._mem_external:
                        # Nothing to do; e.g., a user-provided Function
                        continue
                    elif i._mem_stack:
                        # On the stack
                        key = lambda i: not i.is_Parallel
                        site = filter_iterations(v, key=key,
                                                 stop='asap') or [iet]
                        allocator.push_stack(site[-1], i)
                    else:
                        # On the heap, as a tensor that must be globally accessible
                        allocator.push_heap(i)
            except AttributeError:
                # E.g., a generic SymPy expression
                pass

    # Introduce declarations on the stack
    for k, v in allocator.onstack:
        mapper[k] = tuple(Element(i) for i in v)
    iet = Transformer(mapper, nested=True).visit(iet)
    for k, v in list(func_table.items()):
        if v.local:
            func_table[k] = MetaCall(
                Transformer(mapper).visit(v.root), v.local)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet
Exemplo n.º 16
0
def iet_insert_C_decls(iet, func_table=None):
    """
    Given an Iteration/Expression tree ``iet``, build a new tree with the
    necessary symbol declarations. Declarations are placed as close as
    possible to the first symbol use.

    :param iet: The input Iteration/Expression tree.
    :param func_table: (Optional) a mapper from callable names within ``iet``
                       to :class:`Callable`s.
    """
    func_table = func_table or {}
    allocator = Allocator()
    mapper = OrderedDict()

    # First, schedule declarations for Expressions
    scopes = []
    me = MapExpressions()
    for k, v in me.visit(iet).items():
        if k.is_Call:
            func = func_table.get(k.name)
            if func is not None and func.local:
                scopes.extend(me.visit(func.root, queue=list(v)).items())
        else:
            scopes.append((k, v))
    for k, v in scopes:
        if k.is_scalar:
            # Inline declaration
            mapper[k] = LocalExpression(**k.args)
        elif k.write is None or k.write._mem_external:
            # Nothing to do, e.g., variable passed as kernel argument
            continue
        elif k.write._mem_stack:
            # On the stack
            key = lambda i: not i.is_Parallel
            site = filter_iterations(v, key=key, stop='asap') or [iet]
            allocator.push_stack(site[-1], k.write)
        else:
            # On the heap, as a tensor that must be globally accessible
            allocator.push_heap(k.write)

    # Then, schedule declarations callables arguments passed by reference/pointer
    # (as modified internally by the callable)
    scopes = [(k, v) for k, v in me.visit(iet).items() if k.is_Call]
    for k, v in scopes:
        site = v[-1] if v else iet
        for i in k.params:
            try:
                if i.is_LocalObject:
                    # On the stack
                    allocator.push_stack(site, i)
                elif i.is_Array:
                    if i._mem_stack:
                        # On the stack
                        allocator.push_stack(site, i)
                    elif i._mem_heap:
                        # On the heap
                        allocator.push_heap(i)
            except AttributeError:
                # E.g., a generic SymPy expression
                pass

    # Introduce declarations on the stack
    for k, v in allocator.onstack:
        mapper[k] = tuple(Element(i) for i in v)
    iet = NestedTransformer(mapper).visit(iet)
    for k, v in list(func_table.items()):
        if v.local:
            func_table[k] = MetaCall(Transformer(mapper).visit(v.root), v.local)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet
Exemplo n.º 17
0
    def _loop_blocking(self, iet):
        """
        Apply loop blocking to PARALLEL Iteration trees.
        """
        blockinner = bool(self.params.get('blockinner'))
        blockalways = bool(self.params.get('blockalways'))

        # Make sure loop blocking will span as many Iterations as possible
        iet = fold_blockable_tree(iet, blockinner)

        mapper = {}
        efuncs = []
        block_dims = []
        for tree in retrieve_iteration_tree(iet):
            # Is the Iteration tree blockable ?
            iterations = filter_iterations(tree, lambda i: i.is_Parallel)
            if not blockinner:
                iterations = iterations[:-1]
            if len(iterations) <= 1:
                continue
            root = iterations[0]
            if not (tree.root.is_Sequential or iet.is_Callable) and not blockalways:
                # Heuristic: avoid polluting the generated code with blocked
                # nests (thus increasing JIT compilation time and affecting
                # readability) if the blockable tree isn't embedded in a
                # sequential loop (e.g., a timestepping loop)
                continue

            # Apply loop blocking to `tree`
            interb = []
            intrab = []
            for i in iterations:
                d = BlockDimension(i.dim, name="%s%d_blk" % (i.dim.name, len(mapper)))
                block_dims.append(d)
                # Build Iteration over blocks
                properties = (PARALLEL,) + ((AFFINE,) if i.is_Affine else ())
                interb.append(Iteration([], d, d.symbolic_max, properties=properties))
                # Build Iteration within a block
                intrab.append(i._rebuild([], limits=(d, d+d.step-1, 1), offsets=(0, 0)))

            # Construct the blocked tree
            blocked = compose_nodes(interb + intrab + [iterations[-1].nodes])
            blocked = unfold_blocked_tree(blocked)

            # Promote to a separate Callable
            dynamic_parameters = flatten((bi.dim, bi.dim.symbolic_size) for bi in interb)
            efunc = make_efunc("bf%d" % len(mapper), blocked, dynamic_parameters)
            efuncs.append(efunc)

            # Compute the iteration ranges
            ranges = []
            for i, bi in zip(iterations, interb):
                maxb = i.symbolic_max - (i.symbolic_size % bi.dim.step)
                ranges.append(((i.symbolic_min, maxb, bi.dim.step),
                               (maxb + 1, i.symbolic_max, i.symbolic_max - maxb)))

            # Build Calls to the `efunc`
            body = []
            for p in product(*ranges):
                dynamic_args_mapper = {}
                for bi, (m, M, b) in zip(interb, p):
                    dynamic_args_mapper[bi.dim] = (m, M)
                    dynamic_args_mapper[bi.dim.step] = (b,)
                call = efunc.make_call(dynamic_args_mapper)
                body.append(List(body=call))

            mapper[root] = List(body=body)

        iet = Transformer(mapper).visit(iet)

        return iet, {'dimensions': block_dims, 'efuncs': efuncs,
                     'args': [i.step for i in block_dims]}
Exemplo n.º 18
0
    def make_blocking(self, iet):
        """
        Apply loop blocking to PARALLEL Iteration trees.
        """
        # Make sure loop blocking will span as many Iterations as possible
        iet = fold_blockable_tree(iet, self.blockinner)

        mapper = {}
        efuncs = []
        block_dims = []
        for tree in retrieve_iteration_tree(iet):
            # Is the Iteration tree blockable ?
            iterations = filter_iterations(tree, lambda i: i.is_Parallel and i.is_Affine)
            if not self.blockinner:
                iterations = iterations[:-1]
            if len(iterations) <= 1:
                continue
            root = iterations[0]
            if not self.blockalways:
                # Heuristically bypass loop blocking if we think `tree`
                # won't be computationally expensive. This will help with code
                # size/readbility, JIT time, and auto-tuning time
                if not (tree.root.is_Sequential or iet.is_Callable):
                    # E.g., not inside a time-stepping Iteration
                    continue
                if any(i.dim.is_Sub and i.dim.local for i in tree):
                    # At least an outer Iteration is over a local SubDimension,
                    # which suggests the computational cost of this Iteration
                    # nest will be negligible w.r.t. the "core" Iteration nest
                    # (making use of non-local (Sub)Dimensions only)
                    continue
            if not IsPerfectIteration().visit(root):
                # Don't know how to block non-perfect nests
                continue

            # Apply hierarchical loop blocking to `tree`
            level_0 = []  # Outermost level of blocking
            level_i = [[] for i in range(1, self.nlevels)]  # Inner levels of blocking
            intra = []  # Within the smallest block
            for i in iterations:
                template = "%s%d_blk%s" % (i.dim.name, self.nblocked, '%d')
                properties = (PARALLEL,) + ((AFFINE,) if i.is_Affine else ())

                # Build Iteration across `level_0` blocks
                d = BlockDimension(i.dim, name=template % 0)
                level_0.append(Iteration([], d, d.symbolic_max, properties=properties))

                # Build Iteration across all `level_i` blocks, `i` in (1, self.nlevels]
                for n, li in enumerate(level_i, 1):
                    di = BlockDimension(d, name=template % n)
                    li.append(Iteration([], di, limits=(d, d+d.step-1, di.step),
                                        properties=properties))
                    d = di

                # Build Iteration within the smallest block
                intra.append(i._rebuild([], limits=(d, d+d.step-1, 1), offsets=(0, 0)))
            level_i = flatten(level_i)

            # Track all constructed BlockDimensions
            block_dims.extend(i.dim for i in level_0 + level_i)

            # Construct the blocked tree
            blocked = compose_nodes(level_0 + level_i + intra + [iterations[-1].nodes])
            blocked = unfold_blocked_tree(blocked)

            # Promote to a separate Callable
            dynamic_parameters = flatten((l0.dim, l0.step) for l0 in level_0)
            dynamic_parameters.extend([li.step for li in level_i])
            efunc = make_efunc("bf%d" % self.nblocked, blocked, dynamic_parameters)
            efuncs.append(efunc)

            # Compute the iteration ranges
            ranges = []
            for i, l0 in zip(iterations, level_0):
                maxb = i.symbolic_max - (i.symbolic_size % l0.step)
                ranges.append(((i.symbolic_min, maxb, l0.step),
                               (maxb + 1, i.symbolic_max, i.symbolic_max - maxb)))

            # Build Calls to the `efunc`
            body = []
            for p in product(*ranges):
                dynamic_args_mapper = {}
                for l0, (m, M, b) in zip(level_0, p):
                    dynamic_args_mapper[l0.dim] = (m, M)
                    dynamic_args_mapper[l0.step] = (b,)
                    for li in level_i:
                        if li.dim.root is l0.dim.root:
                            value = li.step if b is l0.step else b
                            dynamic_args_mapper[li.step] = (value,)
                call = efunc.make_call(dynamic_args_mapper)
                body.append(List(body=call))

            mapper[root] = List(body=body)

            # Next blockable nest, use different (unique) variable/function names
            self.nblocked += 1

        iet = Transformer(mapper).visit(iet)

        # Force-unfold if some folded Iterations haven't been blocked in the end
        iet = unfold_blocked_tree(iet)

        return iet, {'dimensions': block_dims,
                     'efuncs': efuncs,
                     'args': [i.step for i in block_dims]}
Exemplo n.º 19
0
    def _create_elemental_functions(self, nodes, state):
        """
        Extract :class:`Iteration` sub-trees and move them into :class:`Callable`s.

        Currently, only tagged, elementizable Iteration objects are targeted.
        """
        noinline = self._compiler_decoration('noinline',
                                             c.Comment('noinline?'))

        functions = OrderedDict()
        mapper = {}
        for tree in retrieve_iteration_tree(nodes, mode='superset'):
            # Search an elementizable sub-tree (if any)
            tagged = filter_iterations(tree, lambda i: i.tag is not None,
                                       'asap')
            if not tagged:
                continue
            root = tagged[0]
            if not root.is_Elementizable:
                continue
            target = tree[tree.index(root):]

            # Elemental function arguments
            args = []  # Found so far (scalars, tensors)
            defined_args = {}  # Map of argument values defined by loop bounds

            # Build a new Iteration/Expression tree with free bounds
            free = []
            for i in target:
                name, bounds = i.dim.name, i.bounds_symbolic
                # Iteration bounds
                start = Scalar(name='%s_start' % name, dtype=np.int32)
                finish = Scalar(name='%s_finish' % name, dtype=np.int32)
                defined_args[start.name] = bounds[0]
                defined_args[finish.name] = bounds[1]

                # Iteration unbounded indices
                ufunc = [
                    Scalar(name='%s_ub%d' % (name, j), dtype=np.int32)
                    for j in range(len(i.uindices))
                ]
                defined_args.update(
                    {uf.name: j.start
                     for uf, j in zip(ufunc, i.uindices)})
                limits = [
                    Scalar(name=start.name, dtype=np.int32),
                    Scalar(name=finish.name, dtype=np.int32), 1
                ]
                uindices = [
                    UnboundedIndex(j.index, i.dim + as_symbol(k))
                    for j, k in zip(i.uindices, ufunc)
                ]
                free.append(
                    i._rebuild(limits=limits, offsets=None, uindices=uindices))

            # Construct elemental function body, and inspect it
            free = NestedTransformer(dict((zip(target, free)))).visit(root)

            # Insert array casts for all non-defined
            f_symbols = FindSymbols('symbolics').visit(free)
            defines = [s.name for s in FindSymbols('defines').visit(free)]
            casts = [
                ArrayCast(f) for f in f_symbols
                if f.is_Tensor and f.name not in defines
            ]
            free = (List(body=casts), free)

            for i in derive_parameters(free):
                if i.name in defined_args:
                    args.append((defined_args[i.name], i))
                elif i.is_Dimension:
                    d = Scalar(name=i.name, dtype=i.dtype)
                    args.append((d, d))
                else:
                    args.append((i, i))

            call, params = zip(*args)
            name = "f_%d" % root.tag

            # Produce the new Call
            mapper[root] = List(header=noinline, body=Call(name, call))

            # Produce the new Callable
            functions.setdefault(
                name,
                Callable(name, free, 'void', flatten(params), ('static', )))

        # Transform the main tree
        processed = Transformer(mapper).visit(nodes)

        return processed, {'elemental_functions': functions.values()}
Exemplo n.º 20
0
Arquivo: basic.py Projeto: nw0/devito
    def _create_elemental_functions(self, nodes, state):
        """
        Extract :class:`Iteration` sub-trees and move them into :class:`Callable`s.

        Currently, only tagged, elementizable Iteration objects are targeted.
        """
        noinline = self._compiler_decoration('noinline',
                                             c.Comment('noinline?'))

        functions = OrderedDict()
        mapper = {}
        for tree in retrieve_iteration_tree(nodes, mode='superset'):
            # Search an elementizable sub-tree (if any)
            tagged = filter_iterations(tree, lambda i: i.tag is not None,
                                       'asap')
            if not tagged:
                continue
            root = tagged[0]
            if not root.is_Elementizable:
                continue
            target = tree[tree.index(root):]

            # Elemental function arguments
            args = []  # Found so far (scalars, tensors)
            maybe_required = set()  # Scalars that *may* have to be passed in
            not_required = set()  # Elemental function locally declared scalars

            # Build a new Iteration/Expression tree with free bounds
            free = []
            for i in target:
                name, bounds = i.dim.name, i.bounds_symbolic
                # Iteration bounds
                start = Scalar(name='%s_start' % name, dtype=np.int32)
                finish = Scalar(name='%s_finish' % name, dtype=np.int32)
                args.extend(zip([ccode(j) for j in bounds], (start, finish)))
                # Iteration unbounded indices
                ufunc = [
                    Scalar(name='%s_ub%d' % (name, j), dtype=np.int32)
                    for j in range(len(i.uindices))
                ]
                args.extend(zip([ccode(j.start) for j in i.uindices], ufunc))
                limits = [Symbol(start.name), Symbol(finish.name), 1]
                uindices = [
                    UnboundedIndex(j.index, i.dim + as_symbol(k))
                    for j, k in zip(i.uindices, ufunc)
                ]
                free.append(
                    i._rebuild(limits=limits, offsets=None, uindices=uindices))
                not_required.update({i.dim}, set(j.index for j in i.uindices))

            # Construct elemental function body, and inspect it
            free = NestedTransformer(dict((zip(target, free)))).visit(root)
            expressions = FindNodes(Expression).visit(free)
            fsymbols = FindSymbols('symbolics').visit(free)

            # Add all definitely-required arguments
            not_required.update({i.output for i in expressions if i.is_scalar})
            for i in fsymbols:
                if i in not_required:
                    continue
                elif i.is_Array:
                    args.append(
                        ("(%s*)%s" % (c.dtype_to_ctype(i.dtype), i.name), i))
                elif i.is_TensorFunction:
                    args.append(("%s_vec" % i.name, i))
                elif i.is_Scalar:
                    args.append((i.name, i))

            # Add all maybe-required arguments that turn out to be required
            maybe_required.update(
                set(FindSymbols(mode='free-symbols').visit(free)))
            for i in fsymbols:
                not_required.update({as_symbol(i), i.indexify()})
                for j in i.symbolic_shape:
                    maybe_required.update(j.free_symbols)
            required = filter_sorted(maybe_required - not_required,
                                     key=attrgetter('name'))
            args.extend([(i.name, Scalar(name=i.name, dtype=i.dtype))
                         for i in required])

            call, params = zip(*args)
            handle = flatten([p.rtargs for p in params])
            name = "f_%d" % root.tag

            # Produce the new Call
            mapper[root] = List(header=noinline, body=Call(name, call))

            # Produce the new Callable
            functions.setdefault(
                name, Callable(name, free, 'void', handle, ('static', )))

        # Transform the main tree
        processed = Transformer(mapper).visit(nodes)

        return processed, {'elemental_functions': functions.values()}
Exemplo n.º 21
0
def iet_insert_C_decls(iet, external=None):
    """
    Given an IET, build a new tree with the necessary symbol declarations.
    Declarations are placed as close as possible to the first symbol occurrence.

    Parameters
    ----------
    iet : Node
        The input Iteration/Expression tree.
    external : tuple, optional
        The symbols defined in some outer Callable, which therefore must not
        be re-defined.
    """
    external = external or []

    # Classify and then schedule declarations to stack/heap
    allocator = Allocator()
    mapper = OrderedDict()
    for k, v in MapExpressions().visit(iet).items():
        if k.is_Expression:
            if k.is_scalar_assign:
                # Inline declaration
                mapper[k] = LocalExpression(**k.args)
                continue
            objs = [k.write]
        elif k.is_Call:
            objs = k.params

        for i in objs:
            try:
                if i.is_LocalObject:
                    # On the stack
                    site = v[-1] if v else iet
                    allocator.push_stack(site, i)
                elif i.is_Array:
                    if i in external:
                        # The Array is to be defined in some foreign IET
                        continue
                    elif i._mem_stack:
                        # On the stack
                        key = lambda i: not i.is_Parallel
                        site = filter_iterations(v, key=key,
                                                 stop='asap') or [iet]
                        allocator.push_stack(site[-1], i)
                    else:
                        # On the heap, as a tensor that must be globally accessible
                        allocator.push_heap(i)
            except AttributeError:
                # E.g., a generic SymPy expression
                pass

    # Introduce declarations on the stack
    for k, v in allocator.onstack:
        mapper[k] = tuple(Element(i) for i in v)
    iet = Transformer(mapper, nested=True).visit(iet)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet
Exemplo n.º 22
0
def iet_insert_decls(iet, external):
    """
    Transform the input IET inserting the necessary symbol declarations.
    Declarations are placed as close as possible to the first symbol occurrence.

    Parameters
    ----------
    iet : Node
        The input Iteration/Expression tree.
    external : tuple, optional
        The symbols defined in some outer Callable, which therefore must not
        be re-defined.
    """
    iet = as_tuple(iet)

    # Classify and then schedule declarations to stack/heap
    allocator = Allocator()
    mapper = OrderedDict()
    for k, v in MapSections().visit(iet).items():
        if k.is_Expression:
            if k.is_scalar_assign:
                # Inline declaration
                mapper[k] = LocalExpression(**k.args)
                continue
            objs = [k.write]
        elif k.is_Call:
            objs = k.arguments

        for i in objs:
            try:
                if i.is_LocalObject:
                    # On the stack
                    site = v if v else iet
                    allocator.push_stack(site[-1], i)
                elif i.is_Array:
                    if i in as_tuple(external):
                        # The Array is defined in some other IET
                        continue
                    elif i._mem_stack:
                        # On the stack
                        key = lambda i: not i.is_Parallel
                        site = filter_iterations(v, key=key) or iet
                        allocator.push_stack(site[-1], i)
                    else:
                        # On the heap, as a tensor that must be globally accessible
                        allocator.push_heap(i)
            except AttributeError:
                # E.g., a generic SymPy expression
                pass

    # Introduce declarations on the stack
    for k, v in allocator.onstack:
        mapper[k] = tuple(Element(i) for i in v)
    iet = Transformer(mapper, nested=True).visit(iet)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet
Exemplo n.º 23
0
    def _loop_blocking(self, iet):
        """
        Apply loop blocking to PARALLEL Iteration trees.
        """
        blockinner = bool(self.params.get('blockinner'))
        blockalways = bool(self.params.get('blockalways'))

        # Make sure loop blocking will span as many Iterations as possible
        iet = fold_blockable_tree(iet, blockinner)

        mapper = {}
        efuncs = []
        block_dims = []
        for tree in retrieve_iteration_tree(iet):
            # Is the Iteration tree blockable ?
            iterations = filter_iterations(tree, lambda i: i.is_Parallel)
            if not blockinner:
                iterations = iterations[:-1]
            if len(iterations) <= 1:
                continue
            root = iterations[0]
            if not blockalways:
                # Heuristically bypass loop blocking if we think `tree`
                # won't be computationally expensive. This will help with code
                # size/redability, JIT time, and auto-tuning time
                if not (tree.root.is_Sequential or iet.is_Callable):
                    # E.g., not inside a time-stepping Iteration
                    continue
                if any(i.dim.is_Sub and i.dim.local for i in tree):
                    # At least an outer Iteration is over a local SubDimension,
                    # which suggests the computational cost of this Iteration
                    # nest will be negligible w.r.t. the "core" Iteration nest
                    # (making use of non-local (Sub)Dimensions only)
                    continue
            if not IsPerfectIteration().visit(root):
                # Don't know how to block non-perfect nests
                continue

            # Apply loop blocking to `tree`
            interb = []
            intrab = []
            for i in iterations:
                d = BlockDimension(i.dim, name="%s%d_blk" % (i.dim.name, len(mapper)))
                block_dims.append(d)
                # Build Iteration over blocks
                properties = (PARALLEL,) + ((AFFINE,) if i.is_Affine else ())
                interb.append(Iteration([], d, d.symbolic_max, properties=properties))
                # Build Iteration within a block
                intrab.append(i._rebuild([], limits=(d, d+d.step-1, 1), offsets=(0, 0)))

            # Construct the blocked tree
            blocked = compose_nodes(interb + intrab + [iterations[-1].nodes])
            blocked = unfold_blocked_tree(blocked)

            # Promote to a separate Callable
            dynamic_parameters = flatten((bi.dim, bi.dim.symbolic_size) for bi in interb)
            efunc = make_efunc("bf%d" % len(mapper), blocked, dynamic_parameters)
            efuncs.append(efunc)

            # Compute the iteration ranges
            ranges = []
            for i, bi in zip(iterations, interb):
                maxb = i.symbolic_max - (i.symbolic_size % bi.dim.step)
                ranges.append(((i.symbolic_min, maxb, bi.dim.step),
                               (maxb + 1, i.symbolic_max, i.symbolic_max - maxb)))

            # Build Calls to the `efunc`
            body = []
            for p in product(*ranges):
                dynamic_args_mapper = {}
                for bi, (m, M, b) in zip(interb, p):
                    dynamic_args_mapper[bi.dim] = (m, M)
                    dynamic_args_mapper[bi.dim.step] = (b,)
                call = efunc.make_call(dynamic_args_mapper)
                body.append(List(body=call))

            mapper[root] = List(body=body)

        iet = Transformer(mapper).visit(iet)

        return iet, {'dimensions': block_dims, 'efuncs': efuncs,
                     'args': [i.step for i in block_dims]}
Exemplo n.º 24
0
    def _loop_blocking(self, iet):
        """
        Apply loop blocking to PARALLEL Iteration trees.
        """
        blockinner = bool(self.params.get('blockinner'))
        blockalways = bool(self.params.get('blockalways'))

        # Make sure loop blocking will span as many Iterations as possible
        iet = fold_blockable_tree(iet, blockinner)

        mapper = {}
        efuncs = []
        block_dims = []
        for tree in retrieve_iteration_tree(iet):
            # Is the Iteration tree blockable ?
            iterations = filter_iterations(tree, lambda i: i.is_Parallel)
            if not blockinner:
                iterations = iterations[:-1]
            if len(iterations) <= 1:
                continue
            root = iterations[0]
            if not (tree.root.is_Sequential
                    or iet.is_Callable) and not blockalways:
                # Heuristic: avoid polluting the generated code with blocked
                # nests (thus increasing JIT compilation time and affecting
                # readability) if the blockable tree isn't embedded in a
                # sequential loop (e.g., a timestepping loop)
                continue

            # Apply loop blocking to `tree`
            interb = []
            intrab = []
            for i in iterations:
                d = BlockDimension(i.dim,
                                   name="%s%d_blk" % (i.dim.name, len(mapper)))
                block_dims.append(d)
                # Build Iteration over blocks
                interb.append(
                    Iteration([], d, d.symbolic_max, properties=PARALLEL))
                # Build Iteration within a block
                intrab.append(
                    i._rebuild([],
                               limits=(d, d + d.step - 1, 1),
                               offsets=(0, 0)))

            # Construct the blocked tree
            blocked = compose_nodes(interb + intrab + [iterations[-1].nodes])
            blocked = unfold_blocked_tree(blocked)

            # Promote to a separate Callable
            dynamic_parameters = flatten(
                (bi.dim, bi.dim.symbolic_size) for bi in interb)
            efunc = make_efunc("bf%d" % len(mapper), blocked,
                               dynamic_parameters)
            efuncs.append(efunc)

            # Compute the iteration ranges
            ranges = []
            for i, bi in zip(iterations, interb):
                maxb = i.symbolic_max - (i.symbolic_size % bi.dim.step)
                ranges.append(
                    ((i.symbolic_min, maxb, bi.dim.step),
                     (maxb + 1, i.symbolic_max, i.symbolic_max - maxb)))

            # Build Calls to the `efunc`
            body = []
            for p in product(*ranges):
                dynamic_args_mapper = {}
                for bi, (m, M, b) in zip(interb, p):
                    dynamic_args_mapper[bi.dim] = (m, M)
                    dynamic_args_mapper[bi.dim.step] = (b, )
                call = efunc.make_call(dynamic_args_mapper)
                body.append(List(body=call))

            mapper[root] = List(body=body)

        iet = Transformer(mapper).visit(iet)

        return iet, {
            'dimensions': block_dims,
            'efuncs': efuncs,
            'args': [i.step for i in block_dims]
        }
Exemplo n.º 25
0
    def _create_efuncs(self, nodes, state):
        """
        Extract Iteration sub-trees and turn them into Calls+Callables.

        Currently, only tagged, elementizable Iteration objects are targeted.
        """
        noinline = self._compiler_decoration('noinline',
                                             c.Comment('noinline?'))

        efuncs = OrderedDict()
        mapper = {}
        for tree in retrieve_iteration_tree(nodes, mode='superset'):
            # Search an elementizable sub-tree (if any)
            tagged = filter_iterations(tree, lambda i: i.tag is not None,
                                       'asap')
            if not tagged:
                continue
            root = tagged[0]
            if not root.is_Elementizable:
                continue
            target = tree[tree.index(root):]

            # Build a new Iteration/Expression tree with free bounds
            free = []
            defined_args = {}  # Map of argument values defined by loop bounds
            for i in target:
                name, bounds = i.dim.name, i.symbolic_bounds
                # Iteration bounds
                _min = Scalar(name='%sf_m' % name,
                              dtype=np.int32,
                              is_const=True)
                _max = Scalar(name='%sf_M' % name,
                              dtype=np.int32,
                              is_const=True)
                defined_args[_min.name] = bounds[0]
                defined_args[_max.name] = bounds[1]

                # Iteration unbounded indices
                ufunc = [
                    Scalar(name='%s_ub%d' % (name, j), dtype=np.int32)
                    for j in range(len(i.uindices))
                ]
                defined_args.update({
                    uf.name: j.symbolic_min
                    for uf, j in zip(ufunc, i.uindices)
                })
                uindices = [
                    IncrDimension(j.parent, i.dim + as_symbol(k), 1, j.name)
                    for j, k in zip(i.uindices, ufunc)
                ]
                free.append(
                    i._rebuild(limits=(_min, _max, 1),
                               offsets=None,
                               uindices=uindices))

            # Construct elemental function body
            free = Transformer(dict((zip(target, free))),
                               nested=True).visit(root)
            items = FindSymbols().visit(free)

            # Insert array casts
            casts = [ArrayCast(i) for i in items if i.is_Tensor]
            free = List(body=casts + [free])

            # Insert declarations
            external = [i for i in items if i.is_Array]
            free = iet_insert_C_decls(free, external)

            # Create the Callable
            name = "f_%d" % root.tag
            params = derive_parameters(free)
            efuncs.setdefault(name,
                              Callable(name, free, 'void', params, 'static'))

            # Create the Call
            args = [defined_args.get(i.name, i) for i in params]
            mapper[root] = List(header=noinline, body=Call(name, args))

        # Transform the main tree
        processed = Transformer(mapper).visit(nodes)

        return processed, {'efuncs': efuncs.values()}