Beispiel #1
0
 def __call__(self, loop_chain):
     loop_chain = self._schedule(loop_chain)
     # Track the individual kernels, and the args of each kernel
     all_itspaces = tuple(loop.it_space for loop in loop_chain)
     all_args = []
     for i, (loop, gtl_maps) in enumerate(zip(loop_chain, self._executor.gtl_maps)):
         all_args.append([TilingArg(arg, i, None if self._opt_glb_maps else gtl_maps)
                          for arg in loop.args])
     all_args = tuple(all_args)
     # Data for the actual ParLoop
     it_space = TilingIterationSpace(all_itspaces)
     args = self._filter(loop_chain)
     reduced_globals = [loop._reduced_globals for loop in loop_chain]
     read_args = set(flatten([loop.reads for loop in loop_chain]))
     written_args = set(flatten([loop.writes for loop in loop_chain]))
     inc_args = set(flatten([loop.incs for loop in loop_chain]))
     kwargs = {
         'all_kernels': self._kernel._kernels,
         'all_itspaces': all_itspaces,
         'all_args': all_args,
         'read_args': read_args,
         'written_args': written_args,
         'reduced_globals': reduced_globals,
         'inc_args': inc_args,
         'insp_name': self._insp_name,
         'use_glb_maps': self._opt_glb_maps,
         'use_prefetch': self._opt_prefetch,
         'inspection': self._inspection,
         'executor': self._executor
     }
     return [TilingParLoop(self._kernel, it_space, *args, **kwargs)]
Beispiel #2
0
 def __call__(self, loop_chain):
     loop_chain = self._schedule(loop_chain)
     # Track the individual kernels, and the args of each kernel
     all_itspaces = tuple(loop.it_space for loop in loop_chain)
     all_args = []
     for i, (loop, gtl_maps) in enumerate(zip(loop_chain, self._executor.gtl_maps)):
         all_args.append([TilingArg(arg, i, None if self._opt_glb_maps else gtl_maps)
                          for arg in loop.args])
     all_args = tuple(all_args)
     # Data for the actual ParLoop
     it_space = TilingIterationSpace(all_itspaces)
     args = self._filter(loop_chain)
     reduced_globals = [loop._reduced_globals for loop in loop_chain]
     read_args = set(flatten([loop.reads for loop in loop_chain]))
     written_args = set(flatten([loop.writes for loop in loop_chain]))
     inc_args = set(flatten([loop.incs for loop in loop_chain]))
     kwargs = {
         'all_kernels': self._kernel._kernels,
         'all_itspaces': all_itspaces,
         'all_args': all_args,
         'read_args': read_args,
         'written_args': written_args,
         'reduced_globals': reduced_globals,
         'inc_args': inc_args,
         'insp_name': self._insp_name,
         'use_glb_maps': self._opt_glb_maps,
         'use_prefetch': self._opt_prefetch,
         'inspection': self._inspection,
         'executor': self._executor
     }
     return [TilingParLoop(self._kernel, it_space, *args, **kwargs)]
Beispiel #3
0
    def _needs_reassembly(self):
        """Does this :class:`Matrix` need reassembly.

        The :class:`Matrix` needs reassembling if the subdomains over
        which boundary conditions were applied the last time it was
        assembled are different from the subdomains of the current set
        of boundary conditions.
        """
        old_subdomains = set(flatten(as_tuple(bc.sub_domain) for bc in self._bcs_at_point_of_assembly))
        new_subdomains = set(flatten(as_tuple(bc.sub_domain) for bc in self.bcs))
        return old_subdomains != new_subdomains
Beispiel #4
0
    def _needs_reassembly(self):
        """Does this :class:`Matrix` need reassembly.

        The :class:`Matrix` needs reassembling if the subdomains over
        which boundary conditions were applied the last time it was
        assembled are different from the subdomains of the current set
        of boundary conditions.
        """
        old_subdomains = set(flatten(as_tuple(bc.sub_domain)
                             for bc in self._bcs_at_point_of_assembly))
        new_subdomains = set(flatten(as_tuple(bc.sub_domain)
                             for bc in self.bcs))
        return old_subdomains != new_subdomains
Beispiel #5
0
    def __init__(self, spaces, name=None):
        """
        :param spaces: a list (or tuple) of :class:`FunctionSpace`\s

        The function space may be created as ::

            V = MixedFunctionSpace(spaces)

        ``spaces`` may consist of multiple occurances of the same space: ::

            P1  = FunctionSpace(mesh, "CG", 1)
            P2v = VectorFunctionSpace(mesh, "Lagrange", 2)

            ME  = MixedFunctionSpace([P2v, P1, P1, P1])
        """

        if self._initialized:
            return
        self._spaces = [IndexedFunctionSpace(s, i, self)
                        for i, s in enumerate(flatten(spaces))]
        self._mesh = self._spaces[0].mesh()
        self._ufl_element = ufl.MixedElement(*[fs.ufl_element() for fs in self._spaces])
        self.name = name or '_'.join(str(s.name) for s in self._spaces)
        self.rank = 1
        self._index = None
        self._initialized = True
Beispiel #6
0
    def __init__(self, spaces, name=None):
        """
        :param spaces: a list (or tuple) of :class:`FunctionSpace`\s

        The function space may be created as ::

            V = MixedFunctionSpace(spaces)

        ``spaces`` may consist of multiple occurances of the same space: ::

            P1  = FunctionSpace(mesh, "CG", 1)
            P2v = VectorFunctionSpace(mesh, "Lagrange", 2)

            ME  = MixedFunctionSpace([P2v, P1, P1, P1])
        """

        if self._initialized:
            return
        self._spaces = [IndexedFunctionSpace(s, i, self)
                        for i, s in enumerate(flatten(spaces))]
        self._mesh = self._spaces[0].mesh()
        self._ufl_element = ufl.MixedElement(*[fs.ufl_element() for fs in self._spaces])
        self.name = name or '_'.join(str(s.name) for s in self._spaces)
        self.rank = 1
        self._index = None
        self._initialized = True
        dm = PETSc.DMShell().create()
        with function.Function(self).dat.vec_ro as v:
            dm.setGlobalVector(v.duplicate())
        dm.setAttr('__fs__', weakref.ref(self))
        dm.setCreateFieldDecomposition(self.create_field_decomp)
        dm.setCreateSubDM(self.create_subdm)
        self._dm = dm
        self._ises = self.dof_dset.field_ises
        self._subspaces = []
Beispiel #7
0
    def __new__(cls, spaces, name=None):
        """
        :param spaces: a list (or tuple) of :class:`FunctionSpace`\s

        The function space may be created as ::

            V = MixedFunctionSpace(spaces)

        ``spaces`` may consist of multiple occurances of the same space: ::

            P1  = FunctionSpace(mesh, "CG", 1)
            P2v = VectorFunctionSpace(mesh, "Lagrange", 2)

            ME  = MixedFunctionSpace([P2v, P1, P1, P1])
        """

        # Check that function spaces are on the same mesh
        meshes = [space.mesh() for space in spaces]
        for i in xrange(1, len(meshes)):
            if meshes[i] is not meshes[0]:
                raise ValueError(
                    "All function spaces must be defined on the same mesh!")

        # Select mesh
        mesh = meshes[0]

        # Get topological spaces
        spaces = flatten(spaces)
        if mesh is mesh.topology:
            spaces = tuple(spaces)
        else:
            spaces = tuple(space.topological for space in spaces)

        # Ask object from cache
        self = ObjectCached.__new__(cls, mesh, spaces, name)
        if not self._initialized:
            self._spaces = [
                IndexedFunctionSpace(s, i, self) for i, s in enumerate(spaces)
            ]
            self._mesh = mesh.topology
            self._ufl_element = ufl.MixedElement(
                *[fs.ufl_element() for fs in spaces])
            self.name = name or '_'.join(str(s.name) for s in spaces)
            self._initialized = True
            dm = PETSc.DMShell().create()
            with self.make_dat().vec_ro as v:
                dm.setGlobalVector(v.duplicate())
            dm.setAttr('__fs__', weakref.ref(self))
            dm.setCreateFieldDecomposition(self.create_field_decomp)
            dm.setCreateSubDM(self.create_subdm)
            self._dm = dm
            self._ises = self.dof_dset.field_ises
            self._subspaces = []

        if mesh is not mesh.topology:
            self = WithGeometry(self, mesh)
        return self
Beispiel #8
0
    def __init__(self, kernels, fused_ast=None, loop_chain_index=None):
        """Initialize a :class:`fusion.Kernel` object.

        :arg kernels: an iterator of some :class:`Kernel` objects. The objects
            can be of class `fusion.Kernel` or of any superclass.
        :arg fused_ast: the abstract syntax tree of the fused kernel. If not
            provided, objects in ``kernels`` are considered "isolated C functions".
        :arg loop_chain_index: index (i.e., position) of the kernel in a loop chain.
            Meaningful only if ``fused_ast`` is specified.
        """
        # Protect against re-initialization when retrieved from cache
        if self._initialized:
            return
        Kernel._globalcount += 1

        # We need to distinguish between the kernel name and the function name(s).
        # Since /fusion.Kernel/ are, in general, collections of functions, the same
        # function (which is itself associated a Kernel) can appear in different
        # /fusion.Kernel/ objects, but possibly under a different name (to avoid
        # name clashes)
        self._name = "_".join([k.name for k in kernels])
        self._function_names = {self.cache_key: self._name}

        self._cpp = any(k._cpp for k in kernels)
        self._opts = dict(flatten([k._opts.items() for k in kernels]))
        self._include_dirs = list(
            set(flatten([k._include_dirs for k in kernels])))
        self._ldargs = list(set(flatten([k._ldargs for k in kernels])))
        self._headers = list(set(flatten([k._headers for k in kernels])))
        self._user_code = "\n".join(list(set([k._user_code for k in kernels])))
        self._attached_info = {'fundecl': None, 'attached': False}

        # What sort of Kernel do I have?
        if fused_ast:
            # A single AST (as a result of soft or hard fusion)
            self._ast = fused_ast
            self._code = self._ast_to_c(fused_ast)
        else:
            # Multiple functions (AST or strings, as a result of tiling)
            self._ast = None
            self._code = self._multiple_ast_to_c(kernels)
        self._kernels = kernels

        self._initialized = True
Beispiel #9
0
    def __init__(self, kernels, fused_ast=None, loop_chain_index=None):
        """Initialize a :class:`fusion.Kernel` object.

        :arg kernels: an iterator of some :class:`Kernel` objects. The objects
            can be of class `fusion.Kernel` or of any superclass.
        :arg fused_ast: the abstract syntax tree of the fused kernel. If not
            provided, objects in ``kernels`` are considered "isolated C functions".
        :arg loop_chain_index: index (i.e., position) of the kernel in a loop chain.
            Meaningful only if ``fused_ast`` is specified.
        """
        # Protect against re-initialization when retrieved from cache
        if self._initialized:
            return
        Kernel._globalcount += 1

        # We need to distinguish between the kernel name and the function name(s).
        # Since /fusion.Kernel/ are, in general, collections of functions, the same
        # function (which is itself associated a Kernel) can appear in different
        # /fusion.Kernel/ objects, but possibly under a different name (to avoid
        # name clashes)
        self._name = "_".join([k.name for k in kernels])
        self._function_names = {self.cache_key: self._name}

        self._cpp = any(k._cpp for k in kernels)
        self._opts = dict(flatten([k._opts.items() for k in kernels]))
        self._include_dirs = list(set(flatten([k._include_dirs for k in kernels])))
        self._ldargs = list(set(flatten([k._ldargs for k in kernels])))
        self._headers = list(set(flatten([k._headers for k in kernels])))
        self._user_code = "\n".join(list(set([k._user_code for k in kernels])))
        self._attached_info = {'fundecl': None, 'attached': False}

        # What sort of Kernel do I have?
        if fused_ast:
            # A single AST (as a result of soft or hard fusion)
            self._ast = fused_ast
            self._code = self._ast_to_c(fused_ast)
        else:
            # Multiple functions (AST or strings, as a result of tiling)
            self._ast = None
            self._code = self._multiple_ast_to_c(kernels)
        self._kernels = kernels

        self._initialized = True
Beispiel #10
0
    def __new__(cls, spaces, name=None):
        """
        :param spaces: a list (or tuple) of :class:`FunctionSpace`\s

        The function space may be created as ::

            V = MixedFunctionSpace(spaces)

        ``spaces`` may consist of multiple occurances of the same space: ::

            P1  = FunctionSpace(mesh, "CG", 1)
            P2v = VectorFunctionSpace(mesh, "Lagrange", 2)

            ME  = MixedFunctionSpace([P2v, P1, P1, P1])
        """

        # Check that function spaces are on the same mesh
        meshes = [space.mesh() for space in spaces]
        for i in xrange(1, len(meshes)):
            if meshes[i] is not meshes[0]:
                raise ValueError("All function spaces must be defined on the same mesh!")

        # Select mesh
        mesh = meshes[0]

        # Get topological spaces
        spaces = flatten(spaces)
        if mesh is mesh.topology:
            spaces = tuple(spaces)
        else:
            spaces = tuple(space.topological for space in spaces)

        # Ask object from cache
        self = ObjectCached.__new__(cls, mesh, spaces, name)
        if not self._initialized:
            self._spaces = [IndexedFunctionSpace(s, i, self)
                            for i, s in enumerate(spaces)]
            self._mesh = mesh.topology
            self._ufl_element = ufl.MixedElement(*[fs.ufl_element() for fs in spaces])
            self.name = name or '_'.join(str(s.name) for s in spaces)
            self._initialized = True
            dm = PETSc.DMShell().create()
            with self.make_dat().vec_ro as v:
                dm.setGlobalVector(v.duplicate())
            dm.setAttr('__fs__', weakref.ref(self))
            dm.setCreateFieldDecomposition(self.create_field_decomp)
            dm.setCreateSubDM(self.create_subdm)
            self._dm = dm
            self._ises = self.dof_dset.field_ises
            self._subspaces = []

        if mesh is not mesh.topology:
            self = WithGeometry(self, mesh)
        return self
Beispiel #11
0
def MixedFunctionSpace(spaces, name=None, mesh=None):
    """Create a :class:`.MixedFunctionSpace`.

    :arg spaces: An iterable of constituent spaces, or a
        :class:`~ufl.classes.MixedElement`.
    :arg name: An optional name for the mixed function space.
    :arg mesh: An optional mesh.  Must be provided if spaces is a
        :class:`~ufl.classes.MixedElement`, ignored otherwise.
    """
    if isinstance(spaces, ufl.FiniteElementBase):
        # Build the spaces if we got a mixed element
        assert type(spaces) is ufl.MixedElement and mesh is not None
        sub_elements = []

        def rec(eles):
            for ele in eles:
                # Only want to recurse into MixedElements
                if type(ele) is ufl.MixedElement:
                    rec(ele.sub_elements())
                else:
                    sub_elements.append(ele)

        rec(spaces.sub_elements())
        spaces = [FunctionSpace(mesh, element) for element in sub_elements]

    # Check that function spaces are on the same mesh
    meshes = [space.mesh() for space in spaces]
    for i in range(1, len(meshes)):
        if meshes[i] is not meshes[0]:
            raise ValueError(
                "All function spaces must be defined on the same mesh!")

    # Select mesh
    mesh = meshes[0]
    # Get topological spaces
    spaces = tuple(s.topological for s in flatten(spaces))
    # Error checking
    for space in spaces:
        if type(space) in (impl.FunctionSpace, impl.RealFunctionSpace):
            continue
        elif type(space) is impl.ProxyFunctionSpace:
            if space.component is not None:
                raise ValueError("Can't make mixed space with %s" % space)
            continue
        else:
            raise ValueError("Can't make mixed space with %s" % type(space))

    new = impl.MixedFunctionSpace(spaces, name=name)
    if mesh is not mesh.topology:
        return impl.WithGeometry(new, mesh)
    return new
Beispiel #12
0
 def __init__(self, spaces, name=None):
     """
     :arg spaces: A list of :class:`FunctionSpaceHierarchy`\s
     """
     spaces = [x for x in flatten([s.split() for s in spaces])]
     assert all(isinstance(s, BaseHierarchy) for s in spaces)
     self._hierarchy = tuple([set_level(functionspace.MixedFunctionSpace(s), self, lvl)
                             for lvl, s in enumerate(zip(*spaces))])
     self._spaces = tuple(spaces)
     self._ufl_element = self._hierarchy[0].ufl_element()
     for V in self:
         dm = V._dm
         dm.setCoarsen(coarsen)
         dm.setRefine(refine)
def MixedFunctionSpace(spaces, name=None, mesh=None):
    """Create a :class:`.MixedFunctionSpace`.

    :arg spaces: An iterable of constituent spaces, or a
        :class:`~ufl.classes.MixedElement`.
    :arg name: An optional name for the mixed function space.
    :arg mesh: An optional mesh.  Must be provided if spaces is a
        :class:`~ufl.classes.MixedElement`, ignored otherwise.
    """
    if isinstance(spaces, ufl.FiniteElementBase):
        # Build the spaces if we got a mixed element
        assert type(spaces) is ufl.MixedElement and mesh is not None
        sub_elements = []

        def rec(eles):
            for ele in eles:
                # Only want to recurse into MixedElements
                if type(ele) is ufl.MixedElement:
                    rec(ele.sub_elements())
                else:
                    sub_elements.append(ele)
        rec(spaces.sub_elements())
        spaces = [FunctionSpace(mesh, element) for element in sub_elements]

    # Check that function spaces are on the same mesh
    meshes = [space.mesh() for space in spaces]
    for i in xrange(1, len(meshes)):
        if meshes[i] is not meshes[0]:
            raise ValueError("All function spaces must be defined on the same mesh!")

    # Select mesh
    mesh = meshes[0]
    # Get topological spaces
    spaces = tuple(s.topological for s in flatten(spaces))
    # Error checking
    for space in spaces:
        if type(space) is impl.FunctionSpace:
            continue
        elif type(space) is impl.ProxyFunctionSpace:
            if space.component is not None:
                raise ValueError("Can't make mixed space with %s" % space)
            continue
        else:
            raise ValueError("Can't make mixed space with %s" % type(space))

    new = impl.MixedFunctionSpace(spaces, name=name)
    if mesh is not mesh.topology:
        return impl.WithGeometry(new, mesh)
    return new
Beispiel #14
0
 def __init__(self, spaces, name=None):
     """
     :arg spaces: A list of :class:`FunctionSpaceHierarchy`\s
     """
     spaces = [x for x in flatten([s.split() for s in spaces])]
     assert all(isinstance(s, BaseHierarchy) for s in spaces)
     self._hierarchy = tuple([
         set_level(functionspace.MixedFunctionSpace(s), self, lvl)
         for lvl, s in enumerate(zip(*spaces))
     ])
     self._spaces = tuple(spaces)
     self._ufl_element = self._hierarchy[0].ufl_element()
     for V in self:
         dm = V._dm
         dm.setCoarsen(coarsen)
         dm.setRefine(refine)
Beispiel #15
0
    def kernel_args(self, loops, fundecl):
        """Filter out identical kernel parameters in ``fundecl`` based on the
        :class:`base.Arg`s used in ``loops``."""

        loop_args = list(flatten([l.args for l in loops]))
        unique_loop_args = self.loop_args(loops)
        kernel_args = fundecl.args
        binding = OrderedDict(zip(loop_args, kernel_args))
        new_kernel_args, args_maps = [], []
        for loop_arg, kernel_arg in binding.items():
            unique_loop_arg = unique_loop_args[self._key(loop_arg)]

            # Do nothing if only a single instance of a given Arg is present
            if loop_arg is unique_loop_arg:
                new_kernel_args.append(kernel_arg)
                continue

            # Set up a proper /binding/
            tobind_kernel_arg = binding[unique_loop_arg]
            if tobind_kernel_arg.is_const:
                # Need to remove the /const/ qualifier from the C declaration
                # if the same argument is now written in the fused kernel.
                # Otherwise, /const/ may be appended (if necessary)
                if loop_arg._is_written:
                    tobind_kernel_arg.qual.remove('const')
                elif 'const' not in kernel_arg.qual:
                    kernel_arg.qual.append('const')
            binding[loop_arg] = tobind_kernel_arg

            # An alias may at this point be required
            if kernel_arg.sym.symbol != tobind_kernel_arg.sym.symbol:
                alias = ast_make_alias(tobind_kernel_arg,
                                       kernel_arg.sym.symbol)
                args_maps.append(alias)

        fundecl.args[:] = new_kernel_args
        if args_maps:
            args_maps.insert(0, ast.FlatBlock('// Args aliases\n'))
            args_maps.append(ast.FlatBlock('\n'))
        fundecl.body = args_maps + fundecl.body

        return binding
Beispiel #16
0
    def kernel_args(self, loops, fundecl):
        """Filter out identical kernel parameters in ``fundecl`` based on the
        :class:`base.Arg`s used in ``loops``."""

        loop_args = list(flatten([l.args for l in loops]))
        unique_loop_args = self.loop_args(loops)
        kernel_args = fundecl.args
        binding = OrderedDict(zip(loop_args, kernel_args))
        new_kernel_args, args_maps = [], []
        for loop_arg, kernel_arg in binding.items():
            unique_loop_arg = unique_loop_args[self._key(loop_arg)]

            # Do nothing if only a single instance of a given Arg is present
            if loop_arg is unique_loop_arg:
                new_kernel_args.append(kernel_arg)
                continue

            # Set up a proper /binding/
            tobind_kernel_arg = binding[unique_loop_arg]
            if tobind_kernel_arg.is_const:
                # Need to remove the /const/ qualifier from the C declaration
                # if the same argument is now written in the fused kernel.
                # Otherwise, /const/ may be appended (if necessary)
                if loop_arg._is_written:
                    tobind_kernel_arg.qual.remove('const')
                elif 'const' not in kernel_arg.qual:
                    kernel_arg.qual.append('const')
            binding[loop_arg] = tobind_kernel_arg

            # An alias may at this point be required
            if kernel_arg.sym.symbol != tobind_kernel_arg.sym.symbol:
                alias = ast_make_alias(tobind_kernel_arg, kernel_arg.sym.symbol)
                args_maps.append(alias)

        fundecl.args[:] = new_kernel_args
        if args_maps:
            args_maps.insert(0, ast.FlatBlock('// Args aliases\n'))
            args_maps.append(ast.FlatBlock('\n'))
        fundecl.body = args_maps + fundecl.body

        return binding
Beispiel #17
0
    def __init__(self, spaces, name=None):
        """
        :param spaces: a list (or tuple) of :class:`FunctionSpace`\s

        The function space may be created as ::

            V = MixedFunctionSpace(spaces)

        ``spaces`` may consist of multiple occurances of the same space: ::

            P1  = FunctionSpace(mesh, "CG", 1)
            P2v = VectorFunctionSpace(mesh, "Lagrange", 2)

            ME  = MixedFunctionSpace([P2v, P1, P1, P1])
        """

        if self._initialized:
            return
        self._spaces = [
            IndexedFunctionSpace(s, i, self)
            for i, s in enumerate(flatten(spaces))
        ]
        self._mesh = self._spaces[0].mesh()
        self._ufl_element = ufl.MixedElement(
            *[fs.ufl_element() for fs in self._spaces])
        self.name = name or '_'.join(str(s.name) for s in self._spaces)
        self.rank = 1
        self._index = None
        self._initialized = True
        dm = PETSc.DMShell().create()
        from firedrake.function import Function
        with Function(self).dat.vec_ro as v:
            dm.setGlobalVector(v.duplicate())
        dm.setAttr('__fs__', weakref.ref(self))
        dm.setCreateFieldDecomposition(self.create_field_decomp)
        dm.setCreateSubDM(self.create_subdm)
        self._dm = dm
        self._ises = self.dof_dset.field_ises
        self._subspaces = []
Beispiel #18
0
def loop_chain(name, **kwargs):
    """Analyze the sub-trace of loops lazily evaluated in this contextmanager ::

        [loop_0, loop_1, ..., loop_n-1]

    and produce a new sub-trace (``m <= n``) ::

        [fused_loops_0, fused_loops_1, ..., fused_loops_m-1, peel_loops]

    which is eventually inserted in the global trace of :class:`ParLoop` objects.

    That is, sub-sequences of :class:`ParLoop` objects are potentially replaced by
    new :class:`ParLoop` objects representing the fusion or the tiling of the
    original trace slice.

    :arg name: identifier of the loop chain
    :arg kwargs:
        * mode (default='hard'): the fusion/tiling mode (accepted: soft, hard,
            tile, only_tile, only_omp): ::
            * soft: consecutive loops over the same iteration set that do
                not present RAW or WAR dependencies through indirections
                are fused.
            * hard: fuse consecutive loops presenting inc-after-inc
                dependencies, on top of soft fusion.
            * tile: apply tiling through the SLOPE library, on top of soft
                and hard fusion.
            * only_tile: apply tiling through the SLOPE library, but do not
                apply soft or hard fusion
            * only_omp: ompize individual parloops through the SLOPE library
                (i.e., no fusion takes place)
        * tile_size: (default=1) suggest a starting average tile size.
        * num_unroll (default=1): in a time stepping loop, the length of the loop
            chain is given by ``num_loops * num_unroll``, where ``num_loops`` is the
            number of loops per time loop iteration. Setting this value to something
            greater than 1 may enable fusing longer chains.
        * seed_loop (default=0): the seed loop from which tiles are derived. Ignored
            in case of MPI execution, in which case the seed loop is enforced to 0.
        * force_glb (default=False): force tiling even in presence of global
            reductions. In this case, the user becomes responsible of semantic
            correctness.
        * coloring (default='default'): set a coloring scheme for tiling. The ``default``
            coloring should be used because it ensures correctness by construction,
            based on the execution mode (sequential, openmp, mpi, mixed). So this
            should be changed only if totally confident with what is going on.
            Possible values are default, rand, omp; these are documented in detail
            in the documentation of the SLOPE library.
        * explicit (default=None): an iterator of 3-tuples (f, l, ts), each 3-tuple
            indicating a sub-sequence of loops to be inspected. ``f`` and ``l``
            represent, respectively, the first and last loop index of the sequence;
            ``ts`` is the tile size for the sequence.
        * ignore_war: (default=False) inform SLOPE that inspection doesn't need
            to care about write-after-read dependencies.
        * log (default=False): output inspector and loop chain info to a file.
        * use_glb_maps (default=False): when tiling, use the global maps provided by
            PyOP2, rather than the ones constructed by SLOPE.
        * use_prefetch (default=False): when tiling, try to prefetch the next iteration.
    """
    assert name != lazy_trace_name, "Loop chain name must differ from %s" % lazy_trace_name

    num_unroll = kwargs.setdefault('num_unroll', 1)
    tile_size = kwargs.setdefault('tile_size', 1)
    kwargs.setdefault('seed_loop', 0)
    kwargs.setdefault('use_glb_maps', False)
    kwargs.setdefault('use_prefetch', 0)
    kwargs.setdefault('coloring', 'default')
    kwargs.setdefault('ignore_war', False)
    explicit = kwargs.pop('explicit', None)

    # Get a snapshot of the trace before new par loops are added within this
    # context manager
    from pyop2.base import _trace
    stamp = list(_trace._trace)

    yield

    trace = _trace._trace
    if trace == stamp:
        return

    # What's the first item /B/ that appeared in the trace /before/ entering the
    # context manager and that still has to be executed ?
    # The loop chain will be (B, end_of_current_trace]
    bottom = 0
    for i in reversed(stamp):
        if i in trace:
            bottom = trace.index(i) + 1
            break
    extracted_trace = trace[bottom:]

    # Three possibilities:
    if num_unroll < 1:
        # 1) No tiling requested, but the openmp backend was set, so we still try to
        # omp-ize the loops with SLOPE
        if slope and slope.get_exec_mode() in ['OMP', 'OMP_MPI'] and tile_size > 0:
            block_size = tile_size    # This is rather a 'block' size (no tiling)
            options = {'mode': 'only_omp',
                       'tile_size': block_size}
            new_trace = [Inspector(name, [loop], **options).inspect()([loop])
                         for loop in extracted_trace]
            trace[bottom:] = list(flatten(new_trace))
            _trace.evaluate_all()
    elif explicit:
        # 2) Tile over subsets of loops in the loop chain, as specified
        # by the user through the /explicit/ list
        prev_last = 0
        transformed = []
        for i, (first, last, tile_size) in enumerate(explicit):
            sub_name = "%s_sub%d" % (name, i)
            kwargs['tile_size'] = tile_size
            transformed.extend(extracted_trace[prev_last:first])
            transformed.extend(fuse(sub_name, extracted_trace[first:last+1], **kwargs))
            prev_last = last + 1
        transformed.extend(extracted_trace[prev_last:])
        trace[bottom:] = transformed
        _trace.evaluate_all()
    else:
        # 3) Tile over the entire loop chain, possibly unrolled as by user
        # request of a factor equals to /num_unroll/
        total_loop_chain = loop_chain.unrolled_loop_chain + extracted_trace
        if len(total_loop_chain) / len(extracted_trace) == num_unroll:
            bottom = trace.index(total_loop_chain[0])
            trace[bottom:] = fuse(name, total_loop_chain, **kwargs)
            loop_chain.unrolled_loop_chain = []
            _trace.evaluate_all()
        else:
            loop_chain.unrolled_loop_chain.extend(extracted_trace)
Beispiel #19
0
    def generate_code(self):
        indent = lambda t, i: ('\n' + '  ' * i).join(t.split('\n'))

        # 1) Construct the wrapper arguments
        code_dict = {}
        code_dict['wrapper_name'] = 'wrap_executor'
        code_dict['executor_arg'] = "%s %s" % (slope.Executor.meta['ctype_exec'],
                                               slope.Executor.meta['name_param_exec'])
        _wrapper_args = ', '.join([arg.c_wrapper_arg() for arg in self._args])
        _wrapper_decs = ';\n'.join([arg.c_wrapper_dec() for arg in self._args])
        code_dict['wrapper_args'] = _wrapper_args
        code_dict['wrapper_decs'] = indent(_wrapper_decs, 1)
        code_dict['rank'] = ", %s %s" % (slope.Executor.meta['ctype_rank'],
                                         slope.Executor.meta['rank'])
        code_dict['region_flag'] = ", %s %s" % (slope.Executor.meta['ctype_region_flag'],
                                                slope.Executor.meta['region_flag'])

        # 2) Construct the kernel invocations
        _loop_body, _user_code, _ssinds_arg = [], [], []
        # For each kernel ...
        for i, (kernel, it_space, args) in enumerate(zip(self._all_kernels,
                                                         self._all_itspaces,
                                                         self._all_args)):
            # ... bind the Executor's arguments to this kernel's arguments
            binding = []
            for a1 in args:
                for a2 in self._args:
                    if a1.data is a2.data and a1.map is a2.map:
                        a1.ref_arg = a2
                        break
                binding.append(a1.c_arg_bindto())
            binding = ";\n".join(binding)

            # ... obtain the /code_dict/ as if it were not part of an Executor,
            # since bits of code generation can be reused
            loop_code_dict = sequential.JITModule(kernel, it_space, *args, delay=True)
            loop_code_dict = loop_code_dict.generate_code()

            # ... does the scatter use global or local maps ?
            if self._use_glb_maps:
                loop_code_dict['index_expr'] = '%s[n]' % self._executor.gtl_maps[i]['DIRECT']
                prefetch_var = 'int p = %s[n + %d]' % (self._executor.gtl_maps[i]['DIRECT'],
                                                       self._use_prefetch)
            else:
                prefetch_var = 'int p = n + %d' % self._use_prefetch

            # ... add prefetch intrinsics, if requested
            prefetch_maps, prefetch_vecs = '', ''
            if self._use_prefetch:
                prefetch = lambda addr: '_mm_prefetch ((char*)(%s), _MM_HINT_T0)' % addr
                prefetch_maps = [a.c_map_entry('p') for a in args if a._is_indirect]
                # can save some instructions since prefetching targets chunks of 32 bytes
                prefetch_maps = flatten([j for j in pm if pm.index(j) % 2 == 0]
                                        for pm in prefetch_maps)
                prefetch_maps = list(OrderedDict.fromkeys(prefetch_maps))
                prefetch_maps = ';\n'.join([prefetch_var] +
                                           [prefetch('&(%s)' % pm) for pm in prefetch_maps])
                prefetch_vecs = flatten(a.c_vec_entry('p', True) for a in args
                                        if a._is_indirect)
                prefetch_vecs = ';\n'.join([prefetch(pv) for pv in prefetch_vecs])
            loop_code_dict['prefetch_maps'] = prefetch_maps
            loop_code_dict['prefetch_vecs'] = prefetch_vecs

            # ... build the subset indirection array, if necessary
            _ssind_arg, _ssind_decl = '', ''
            if loop_code_dict['ssinds_arg']:
                _ssind_arg = 'ssinds_%d' % i
                _ssind_decl = 'int* %s' % _ssind_arg
                loop_code_dict['index_expr'] = '%s[n]' % _ssind_arg

            # ... use the proper function name (the function name of the kernel
            # within *this* specific loop chain)
            loop_code_dict['kernel_name'] = kernel._function_names[self._kernel.cache_key]

            # ... finish building up the /code_dict/
            loop_code_dict['args_binding'] = binding
            loop_code_dict['tile_init'] = self._executor.c_loop_init[i]
            loop_code_dict['tile_finish'] = self._executor.c_loop_end[i]
            loop_code_dict['tile_start'] = slope.Executor.meta['tile_start']
            loop_code_dict['tile_end'] = slope.Executor.meta['tile_end']
            loop_code_dict['tile_iter'] = '%s[n]' % self._executor.gtl_maps[i]['DIRECT']
            if _ssind_arg:
                loop_code_dict['tile_iter'] = '%s[%s]' % (_ssind_arg, loop_code_dict['tile_iter'])

            # ... concatenate the rest, i.e., body, user code, ...
            _loop_body.append(strip(TilingJITModule._kernel_wrapper % loop_code_dict))
            _user_code.append(kernel._user_code)
            _ssinds_arg.append(_ssind_decl)

        _loop_chain_body = indent("\n\n".join(_loop_body), 2)
        code_dict['user_code'] = indent("\n".join(_user_code), 1)
        code_dict['ssinds_arg'] = "".join(["%s," % s for s in _ssinds_arg if s])
        code_dict['executor_code'] = indent(self._executor.c_code(_loop_chain_body), 1)

        return code_dict
Beispiel #20
0
def loop_chain(name, **kwargs):
    """Analyze the sub-trace of loops lazily evaluated in this contextmanager ::

        [loop_0, loop_1, ..., loop_n-1]

    and produce a new sub-trace (``m <= n``) ::

        [fused_loops_0, fused_loops_1, ..., fused_loops_m-1, peel_loops]

    which is eventually inserted in the global trace of :class:`ParLoop` objects.

    That is, sub-sequences of :class:`ParLoop` objects are potentially replaced by
    new :class:`ParLoop` objects representing the fusion or the tiling of the
    original trace slice.

    :arg name: identifier of the loop chain
    :arg kwargs:
        * mode (default='hard'): the fusion/tiling mode (accepted: soft, hard,
            tile, only_tile, only_omp): ::
            * soft: consecutive loops over the same iteration set that do
                not present RAW or WAR dependencies through indirections
                are fused.
            * hard: fuse consecutive loops presenting inc-after-inc
                dependencies, on top of soft fusion.
            * tile: apply tiling through the SLOPE library, on top of soft
                and hard fusion.
            * only_tile: apply tiling through the SLOPE library, but do not
                apply soft or hard fusion
            * only_omp: ompize individual parloops through the SLOPE library
                (i.e., no fusion takes place)
        * tile_size: (default=1) suggest a starting average tile size.
        * num_unroll (default=1): in a time stepping loop, the length of the loop
            chain is given by ``num_loops * num_unroll``, where ``num_loops`` is the
            number of loops per time loop iteration. Setting this value to something
            greater than 1 may enable fusing longer chains.
        * seed_loop (default=0): the seed loop from which tiles are derived. Ignored
            in case of MPI execution, in which case the seed loop is enforced to 0.
        * force_glb (default=False): force tiling even in presence of global
            reductions. In this case, the user becomes responsible of semantic
            correctness.
        * coloring (default='default'): set a coloring scheme for tiling. The ``default``
            coloring should be used because it ensures correctness by construction,
            based on the execution mode (sequential, openmp, mpi, mixed). So this
            should be changed only if totally confident with what is going on.
            Possible values are default, rand, omp; these are documented in detail
            in the documentation of the SLOPE library.
        * explicit (default=None): an iterator of 3-tuples (f, l, ts), each 3-tuple
            indicating a sub-sequence of loops to be inspected. ``f`` and ``l``
            represent, respectively, the first and last loop index of the sequence;
            ``ts`` is the tile size for the sequence.
        * ignore_war: (default=False) inform SLOPE that inspection doesn't need
            to care about write-after-read dependencies.
        * log (default=False): output inspector and loop chain info to a file.
        * use_glb_maps (default=False): when tiling, use the global maps provided by
            PyOP2, rather than the ones constructed by SLOPE.
        * use_prefetch (default=False): when tiling, try to prefetch the next iteration.
    """
    assert name != lazy_trace_name, "Loop chain name must differ from %s" % lazy_trace_name

    num_unroll = kwargs.setdefault('num_unroll', 1)
    tile_size = kwargs.setdefault('tile_size', 1)
    kwargs.setdefault('seed_loop', 0)
    kwargs.setdefault('use_glb_maps', False)
    kwargs.setdefault('use_prefetch', 0)
    kwargs.setdefault('coloring', 'default')
    kwargs.setdefault('ignore_war', False)
    explicit = kwargs.pop('explicit', None)

    # Get a snapshot of the trace before new par loops are added within this
    # context manager
    from pyop2.base import _trace
    stamp = list(_trace._trace)

    yield

    trace = _trace._trace
    if trace == stamp:
        return

    # What's the first item /B/ that appeared in the trace /before/ entering the
    # context manager and that still has to be executed ?
    # The loop chain will be (B, end_of_current_trace]
    bottom = 0
    for i in reversed(stamp):
        if i in trace:
            bottom = trace.index(i) + 1
            break
    extracted_trace = trace[bottom:]

    # Three possibilities:
    if num_unroll < 1:
        # 1) No tiling requested, but the openmp backend was set, so we still try to
        # omp-ize the loops with SLOPE
        if slope and slope.get_exec_mode() in ['OMP', 'OMP_MPI'
                                               ] and tile_size > 0:
            block_size = tile_size  # This is rather a 'block' size (no tiling)
            options = {'mode': 'only_omp', 'tile_size': block_size}
            new_trace = [
                Inspector(name, [loop], **options).inspect()([loop])
                for loop in extracted_trace
            ]
            trace[bottom:] = list(flatten(new_trace))
            _trace.evaluate_all()
    elif explicit:
        # 2) Tile over subsets of loops in the loop chain, as specified
        # by the user through the /explicit/ list
        prev_last = 0
        transformed = []
        for i, (first, last, tile_size) in enumerate(explicit):
            sub_name = "%s_sub%d" % (name, i)
            kwargs['tile_size'] = tile_size
            transformed.extend(extracted_trace[prev_last:first])
            transformed.extend(
                fuse(sub_name, extracted_trace[first:last + 1], **kwargs))
            prev_last = last + 1
        transformed.extend(extracted_trace[prev_last:])
        trace[bottom:] = transformed
        _trace.evaluate_all()
    else:
        # 3) Tile over the entire loop chain, possibly unrolled as by user
        # request of a factor equals to /num_unroll/
        total_loop_chain = loop_chain.unrolled_loop_chain + extracted_trace
        if len(total_loop_chain) / len(extracted_trace) == num_unroll:
            bottom = trace.index(total_loop_chain[0])
            trace[bottom:] = fuse(name, total_loop_chain, **kwargs)
            loop_chain.unrolled_loop_chain = []
            _trace.evaluate_all()
        else:
            loop_chain.unrolled_loop_chain.extend(extracted_trace)
Beispiel #21
0
def fuse(name, loop_chain, **kwargs):
    """Apply fusion (and possibly tiling) to an iterator of :class:`ParLoop`
    obecjts, which we refer to as ``loop_chain``. Return an iterator of
    :class:`ParLoop` objects, in which some loops may have been fused or tiled.
    If fusion could not be applied, return the unmodified ``loop_chain``.

    .. note::
       At the moment, the following features are not supported, in which
       case the unmodified ``loop_chain`` is returned.

        * mixed ``Datasets`` and ``Maps``;
        * extruded ``Sets``

    .. note::
       Tiling cannot be applied if any of the following conditions verifies:

        * a global reduction/write occurs in ``loop_chain``
    """
    # If there is nothing to fuse, just return
    if len(loop_chain) in [0, 1]:
        return loop_chain

    # Are there _LazyMatOp objects (i.e., synch points) preventing fusion?
    remainder = []
    synch_points = [l for l in loop_chain if isinstance(l, _LazyMatOp)]
    if synch_points:
        # Fuse only the sub-sequence before the first synch point
        synch_point = loop_chain.index(synch_points[0])
        remainder, loop_chain = loop_chain[
            synch_point:], loop_chain[:synch_point]

    # Return if there is nothing to fuse (e.g. only _LazyMatOp objects were present)
    if len(loop_chain) in [0, 1]:
        return loop_chain + remainder

    # Get an inspector for fusing this /loop_chain/. If there's a cache hit,
    # return the fused par loops straight away. Otherwise, try to run an inspection.
    options = {
        'log': kwargs.get('log', False),
        'mode': kwargs.get('mode', 'hard'),
        'ignore_war': kwargs.get('ignore_war', False),
        'use_glb_maps': kwargs.get('use_glb_maps', False),
        'use_prefetch': kwargs.get('use_prefetch', 0),
        'tile_size': kwargs.get('tile_size', 1),
        'seed_loop': kwargs.get('seed_loop', 0),
        'extra_halo': kwargs.get('extra_halo', False),
        'coloring': kwargs.get('coloring', 'default')
    }
    inspector = Inspector(name, loop_chain, **options)
    if inspector._initialized:
        return inspector.schedule(loop_chain) + remainder

    # Otherwise, is the inspection legal ?
    mode = kwargs.get('mode', 'hard')
    force_glb = kwargs.get('force_glb', False)

    # Skip if loops in /loop_chain/ are already /fusion/ objects: this could happen
    # when loops had already been fused in a /loop_chain/ context
    if any(isinstance(l, extended.ParLoop) for l in loop_chain):
        return loop_chain + remainder

    # Global reductions are dangerous for correctness, so avoid fusion unless the
    # user is forcing it
    if not force_glb and any(l._reduced_globals for l in loop_chain):
        return loop_chain + remainder

    # Loop fusion requires modifying kernels, so ASTs must be available
    if not mode == 'only_tile':
        if any(not l.kernel._ast or l.kernel._attached_info['flatblocks']
               for l in loop_chain):
            return loop_chain + remainder

    # Mixed still not supported
    if any(a._is_mixed for a in flatten([l.args for l in loop_chain])):
        return loop_chain + remainder

    # Extrusion still not supported
    if any(l.is_layered for l in loop_chain):
        return loop_chain + remainder

    # If tiling is requested, SLOPE must be visible
    if mode in ['tile', 'only_tile'] and not slope:
        warning("Couldn't locate SLOPE. Falling back to plain op2.ParLoops.")
        return loop_chain + remainder

    schedule = inspector.inspect()
    return schedule(loop_chain) + remainder
Beispiel #22
0
    def _hard_fuse(self):
        """Fuse consecutive loops over different iteration sets that do not
        present RAW, WAR or WAW dependencies. For examples, two loops like: ::

            par_loop(kernel_1, it_space_1,
                     dat_1_1(INC, ...),
                     dat_1_2(READ, ...),
                     ...)

            par_loop(kernel_2, it_space_2,
                     dat_2_1(INC, ...),
                     dat_2_2(READ, ...),
                     ...)

        where ``dat_1_1 == dat_2_1`` and, possibly (but not necessarily),
        ``it_space_1 != it_space_2``, can be hard fused. Note, in fact, that
        the presence of ``INC`` does not imply a real WAR dependency, because
        increments are associative."""

        loop_chain = self._loop_chain

        if len(loop_chain) == 1:
            # Nothing more to try fusing after soft fusion
            return

        # Search pairs of hard-fusible loops
        fusible = []
        base_loop_index = 0
        while base_loop_index < len(loop_chain):
            base_loop = loop_chain[base_loop_index]

            for i, loop in enumerate(loop_chain[base_loop_index+1:], 1):
                info = loops_analyzer(base_loop, loop)

                if info['homogeneous']:
                    # Hard fusion is meaningless if same iteration space
                    continue

                if not info['pure_iai']:
                    # Can't fuse across loops presenting RAW or WAR dependencies
                    break

                base_inc_dats = set(a.data for a in incs(base_loop))
                loop_inc_dats = set(a.data for a in incs(loop))
                common_inc_dats = base_inc_dats | loop_inc_dats
                common_incs = [a for a in incs(base_loop) | incs(loop)
                               if a.data in common_inc_dats]
                if not common_incs:
                    # Is there an overlap in any of the incremented dats? If
                    # that's not the case, fusion is fruitless
                    break

                # Hard fusion requires a map between the iteration spaces involved
                maps = set(a.map for a in common_incs if a._is_indirect)
                maps |= set(flatten(m.factors for m in maps if hasattr(m, 'factors')))
                set1, set2 = base_loop.it_space.iterset, loop.it_space.iterset
                fusion_map_1 = [m for m in maps if set1 == m.iterset and set2 == m.toset]
                fusion_map_2 = [m for m in maps if set1 == m.toset and set2 == m.iterset]
                if fusion_map_1:
                    fuse_loop = loop
                    fusion_map = fusion_map_1[0]
                elif fusion_map_2:
                    fuse_loop = base_loop
                    base_loop = loop
                    fusion_map = fusion_map_2[0]
                else:
                    continue

                if any(a._is_direct for a in fuse_loop.args):
                    # Cannot perform direct reads in a /fuse/ kernel
                    break

                common_inc = [a for a in common_incs if a in base_loop.args][0]
                fusible.append((base_loop, fuse_loop, fusion_map, common_inc))
                break

            # Set next starting point of the search
            base_loop_index += i

        # For each pair of hard-fusible loops, create a suitable Kernel
        fused = []
        for base_loop, fuse_loop, fusion_map, fused_inc_arg in fusible:
            loop_chain_index = (loop_chain.index(base_loop), loop_chain.index(fuse_loop))
            fused_kernel, fargs = build_hard_fusion_kernel(base_loop, fuse_loop,
                                                           fusion_map, loop_chain_index)
            fused.append((fused_kernel, fusion_map, fargs))

        # Finally, generate a new schedule
        self._schedule = HardFusionSchedule(self._name, self._schedule, fused)
        self._loop_chain = self._schedule(loop_chain, only_hard=True)
Beispiel #23
0
def build_hard_fusion_kernel(base_loop, fuse_loop, fusion_map, loop_chain_index):
    """
    Build AST and :class:`Kernel` for two loops suitable to hard fusion.

    The AST consists of three functions: fusion, base, fuse. base and fuse
    are respectively the ``base_loop`` and the ``fuse_loop`` kernels, whereas
    fusion is the orchestrator that invokes, for each ``base_loop`` iteration,
    base and, if still to be executed, fuse.

    The orchestrator has the following structure: ::

        fusion (buffer, ..., executed):
            base (buffer, ...)
            for i = 0 to arity:
                if not executed[i]:
                    additional pointer staging required by kernel2
                    fuse (sub_buffer, ...)
                    insertion into buffer

    The executed array tracks whether the i-th iteration (out of /arity/)
    adjacent to the main kernel1 iteration has been executed.
    """

    finder = Find((ast.FunDecl, ast.PreprocessNode))

    base = base_loop.kernel
    base_ast = dcopy(base._ast)
    base_info = finder.visit(base_ast)
    base_headers = base_info[ast.PreprocessNode]
    base_fundecl = base_info[ast.FunDecl]
    assert len(base_fundecl) == 1
    base_fundecl = base_fundecl[0]

    fuse = fuse_loop.kernel
    fuse_ast = dcopy(fuse._ast)
    fuse_info = finder.visit(fuse_ast)
    fuse_headers = fuse_info[ast.PreprocessNode]
    fuse_fundecl = fuse_info[ast.FunDecl]
    assert len(fuse_fundecl) == 1
    fuse_fundecl = fuse_fundecl[0]

    # Create /fusion/ arguments and signature
    body = ast.Block([])
    fusion_name = '%s_%s' % (base_fundecl.name, fuse_fundecl.name)
    fusion_args = dcopy(base_fundecl.args + fuse_fundecl.args)
    fusion_fundecl = ast.FunDecl(base_fundecl.ret, fusion_name, fusion_args, body)

    # Make sure kernel and variable names are unique
    base_fundecl.name = "%s_base" % base_fundecl.name
    fuse_fundecl.name = "%s_fuse" % fuse_fundecl.name
    for i, decl in enumerate(fusion_args):
        decl.sym.symbol += '_%d' % i

    # Filter out duplicate arguments, and append extra arguments to the fundecl
    binding = WeakFilter().kernel_args([base_loop, fuse_loop], fusion_fundecl)
    fusion_args += [ast.Decl('int*', 'executed'),
                    ast.Decl('int*', 'fused_iters'),
                    ast.Decl('int', 'i')]

    # Which args are actually used in /fuse/, but not in /base/ ? The gather for
    # such arguments is moved to /fusion/, to avoid usless memory LOADs
    base_dats = set(a.data for a in base_loop.args)
    fuse_dats = set(a.data for a in fuse_loop.args)
    unshared = OrderedDict()
    for arg, decl in binding.items():
        if arg.data in fuse_dats - base_dats:
            unshared.setdefault(decl, arg)

    # Track position of Args that need a postponed gather
    # Can't track Args themselves as they change across different parloops
    fargs = {fusion_args.index(i): ('postponed', False) for i in unshared.keys()}
    fargs.update({len(set(binding.values())): ('onlymap', True)})

    # Add maps for arguments that need a postponed gather
    for decl, arg in unshared.items():
        decl_pos = fusion_args.index(decl)
        fusion_args[decl_pos].sym.symbol = arg.c_arg_name()
        if arg._is_indirect:
            fusion_args[decl_pos].sym.rank = ()
            fusion_args.insert(decl_pos + 1, ast.Decl('int*', arg.c_map_name(0, 0)))

    # Append the invocation of /base/; then, proceed with the invocation
    # of the /fuse/ kernels
    base_funcall_syms = [binding[a].sym.symbol for a in base_loop.args]
    body.children.append(ast.FunCall(base_fundecl.name, *base_funcall_syms))

    for idx in range(fusion_map.arity):

        fused_iter = ast.Assign('i', ast.Symbol('fused_iters', (idx,)))
        fuse_funcall = ast.FunCall(fuse_fundecl.name)
        if_cond = ast.Not(ast.Symbol('executed', ('i',)))
        if_update = ast.Assign(ast.Symbol('executed', ('i',)), 1)
        if_body = ast.Block([fuse_funcall, if_update], open_scope=True)
        if_exec = ast.If(if_cond, [if_body])
        body.children.extend([ast.FlatBlock('\n'), fused_iter, if_exec])

        # Modify the /fuse/ kernel
        # This is to take into account that many arguments are shared with
        # /base/, so they will only staged once for /base/. This requires
        # tweaking the way the arguments are declared and accessed in /fuse/.
        # For example, the shared incremented array (called /buffer/ in
        # the pseudocode in the comment above) now needs to take offsets
        # to be sure the locations that /base/ is supposed to increment are
        # actually accessed. The same concept apply to indirect arguments.
        init = lambda v: '{%s}' % ', '.join([str(j) for j in v])
        for i, fuse_loop_arg in enumerate(fuse_loop.args):
            fuse_kernel_arg = binding[fuse_loop_arg]

            buffer_name = '%s_vec' % fuse_kernel_arg.sym.symbol
            fuse_funcall_sym = ast.Symbol(buffer_name)

            # What kind of temporaries do we need ?
            if fuse_loop_arg.access == INC:
                op, lvalue, rvalue = ast.Incr, fuse_kernel_arg.sym.symbol, buffer_name
                stager = lambda b, l: b.children.extend(l)
                indexer = lambda indices: [(k, j) for j, k in enumerate(indices)]
                pointers = []
            elif fuse_loop_arg.access == READ:
                op, lvalue, rvalue = ast.Assign, buffer_name, fuse_kernel_arg.sym.symbol
                stager = lambda b, l: [b.children.insert(0, j) for j in reversed(l)]
                indexer = lambda indices: [(j, k) for j, k in enumerate(indices)]
                pointers = list(fuse_kernel_arg.pointers)

            # Now gonna handle arguments depending on their type and rank ...

            if fuse_loop_arg._is_global:
                # ... Handle global arguments. These can be dropped in the
                # kernel without any particular fiddling
                fuse_funcall_sym = ast.Symbol(fuse_kernel_arg.sym.symbol)

            elif fuse_kernel_arg in unshared:
                # ... Handle arguments that appear only in /fuse/
                staging = unshared[fuse_kernel_arg].c_vec_init(False).split('\n')
                rvalues = [ast.FlatBlock(j.split('=')[1]) for j in staging]
                lvalues = [ast.Symbol(buffer_name, (j,)) for j in range(len(staging))]
                staging = [ast.Assign(j, k) for j, k in zip(lvalues, rvalues)]

                # Set up the temporary
                buffer_symbol = ast.Symbol(buffer_name, (len(staging),))
                buffer_decl = ast.Decl(fuse_kernel_arg.typ, buffer_symbol,
                                       qualifiers=fuse_kernel_arg.qual,
                                       pointers=list(pointers))

                # Update the if-then AST body
                stager(if_exec.children[0], staging)
                if_exec.children[0].children.insert(0, buffer_decl)

            elif fuse_loop_arg._is_mat:
                # ... Handle Mats
                staging = []
                for b in fused_inc_arg._block_shape:
                    for rc in b:
                        lvalue = ast.Symbol(lvalue, (idx, idx),
                                            ((rc[0], 'j'), (rc[1], 'k')))
                        rvalue = ast.Symbol(rvalue, ('j', 'k'))
                        staging = ItSpace(mode=0).to_for([(0, rc[0]), (0, rc[1])],
                                                         ('j', 'k'),
                                                         [op(lvalue, rvalue)])[:1]

                # Set up the temporary
                buffer_symbol = ast.Symbol(buffer_name, (fuse_kernel_arg.sym.rank,))
                buffer_init = ast.ArrayInit(init([init([0.0])]))
                buffer_decl = ast.Decl(fuse_kernel_arg.typ, buffer_symbol, buffer_init,
                                       qualifiers=fuse_kernel_arg.qual, pointers=pointers)

                # Update the if-then AST body
                stager(if_exec.children[0], staging)
                if_exec.children[0].children.insert(0, buffer_decl)

            elif fuse_loop_arg._is_indirect:
                cdim = fuse_loop_arg.data.cdim

                if cdim == 1 and fuse_kernel_arg.sym.rank:
                    # [Special case]
                    # ... Handle rank 1 indirect arguments that appear in both
                    # /base/ and /fuse/: just point into the right location
                    rank = (idx,) if fusion_map.arity > 1 else ()
                    fuse_funcall_sym = ast.Symbol(fuse_kernel_arg.sym.symbol, rank)

                else:
                    # ... Handle indirect arguments. At the C level, these arguments
                    # are of pointer type, so simple pointer arithmetic is used
                    # to ensure the kernel accesses are to the correct locations
                    fuse_arity = fuse_loop_arg.map.arity
                    base_arity = fuse_arity*fusion_map.arity
                    size = fuse_arity*cdim

                    # Set the proper storage layout before invoking /fuse/
                    ofs_vals = [[base_arity*j + k for k in range(fuse_arity)]
                                for j in range(cdim)]
                    ofs_vals = [[fuse_arity*j + k for k in flatten(ofs_vals)]
                                for j in range(fusion_map.arity)]
                    ofs_vals = list(flatten(ofs_vals))
                    indices = [ofs_vals[idx*size + j] for j in range(size)]

                    staging = [op(ast.Symbol(lvalue, (j,)), ast.Symbol(rvalue, (k,)))
                               for j, k in indexer(indices)]

                    # Set up the temporary
                    buffer_symbol = ast.Symbol(buffer_name, (size,))
                    if fuse_loop_arg.access == INC:
                        buffer_init = ast.ArrayInit(init([0.0]))
                    else:
                        buffer_init = ast.EmptyStatement()
                        pointers.pop()
                    buffer_decl = ast.Decl(fuse_kernel_arg.typ, buffer_symbol, buffer_init,
                                           qualifiers=fuse_kernel_arg.qual,
                                           pointers=pointers)

                    # Update the if-then AST body
                    stager(if_exec.children[0], staging)
                    if_exec.children[0].children.insert(0, buffer_decl)

            else:
                # Nothing special to do for direct arguments
                pass

            # Finally update the /fuse/ funcall
            fuse_funcall.children.append(fuse_funcall_sym)

    fused_headers = set([str(h) for h in base_headers + fuse_headers])
    fused_ast = ast.Root([ast.PreprocessNode(h) for h in fused_headers] +
                         [base_fundecl, fuse_fundecl, fusion_fundecl])

    return Kernel([base, fuse], fused_ast, loop_chain_index), fargs
Beispiel #24
0
def fuse(name, loop_chain, **kwargs):
    """Apply fusion (and possibly tiling) to an iterator of :class:`ParLoop`
    obecjts, which we refer to as ``loop_chain``. Return an iterator of
    :class:`ParLoop` objects, in which some loops may have been fused or tiled.
    If fusion could not be applied, return the unmodified ``loop_chain``.

    .. note::
       At the moment, the following features are not supported, in which
       case the unmodified ``loop_chain`` is returned.

        * mixed ``Datasets`` and ``Maps``;
        * extruded ``Sets``

    .. note::
       Tiling cannot be applied if any of the following conditions verifies:

        * a global reduction/write occurs in ``loop_chain``
    """
    # If there is nothing to fuse, just return
    if len(loop_chain) in [0, 1]:
        return loop_chain

    # Are there _LazyMatOp objects (i.e., synch points) preventing fusion?
    remainder = []
    synch_points = [l for l in loop_chain if isinstance(l, _LazyMatOp)]
    if synch_points:
        # Fuse only the sub-sequence before the first synch point
        synch_point = loop_chain.index(synch_points[0])
        remainder, loop_chain = loop_chain[synch_point:], loop_chain[:synch_point]

    # Return if there is nothing to fuse (e.g. only _LazyMatOp objects were present)
    if len(loop_chain) in [0, 1]:
        return loop_chain + remainder

    # Get an inspector for fusing this /loop_chain/. If there's a cache hit,
    # return the fused par loops straight away. Otherwise, try to run an inspection.
    options = {
        'log': kwargs.get('log', False),
        'mode': kwargs.get('mode', 'hard'),
        'ignore_war': kwargs.get('ignore_war', False),
        'use_glb_maps': kwargs.get('use_glb_maps', False),
        'use_prefetch': kwargs.get('use_prefetch', 0),
        'tile_size': kwargs.get('tile_size', 1),
        'seed_loop': kwargs.get('seed_loop', 0),
        'extra_halo': kwargs.get('extra_halo', False),
        'coloring': kwargs.get('coloring', 'default')
    }
    inspector = Inspector(name, loop_chain, **options)
    if inspector._initialized:
        return inspector.schedule(loop_chain) + remainder

    # Otherwise, is the inspection legal ?
    mode = kwargs.get('mode', 'hard')
    force_glb = kwargs.get('force_glb', False)

    # Skip if loops in /loop_chain/ are already /fusion/ objects: this could happen
    # when loops had already been fused in a /loop_chain/ context
    if any(isinstance(l, extended.ParLoop) for l in loop_chain):
        return loop_chain + remainder

    # Global reductions are dangerous for correctness, so avoid fusion unless the
    # user is forcing it
    if not force_glb and any(l._reduced_globals for l in loop_chain):
        return loop_chain + remainder

    # Loop fusion requires modifying kernels, so ASTs must be available
    if not mode == 'only_tile':
        if any(not l.kernel._ast or l.kernel._attached_info['flatblocks'] for l in loop_chain):
            return loop_chain + remainder

    # Mixed still not supported
    if any(a._is_mixed for a in flatten([l.args for l in loop_chain])):
        return loop_chain + remainder

    # Extrusion still not supported
    if any(l.is_layered for l in loop_chain):
        return loop_chain + remainder

    # If tiling is requested, SLOPE must be visible
    if mode in ['tile', 'only_tile'] and not slope:
        warning("Couldn't locate SLOPE. Falling back to plain op2.ParLoops.")
        return loop_chain + remainder

    schedule = inspector.inspect()
    return schedule(loop_chain) + remainder