Exemplo n.º 1
0
    def make_parallel(self, iet):
        """
        Transform ``iet`` by decorating its parallel :class:`Iteration`s with
        suitable ``#pragma omp ...`` for thread-level parallelism.
        """
        # Group sequences of loops that should go within the same parallel region
        was_tagged = False
        groups = OrderedDict()
        for tree in retrieve_iteration_tree(iet):
            # Determine the number of consecutive parallelizable Iterations
            candidates = filter_iterations(tree, key=self.key, stop='asap')
            if not candidates:
                was_tagged = False
                continue
            # Consecutive tagged Iteration go in the same group
            is_tagged = any(i.tag is not None for i in tree)
            key = len(groups) - (is_tagged & was_tagged)
            handle = groups.setdefault(key, OrderedDict())
            handle[candidates[0]] = candidates
            was_tagged = is_tagged

        mapper = OrderedDict()
        for group in groups.values():
            private = []
            for root, candidates in group.items():
                mapper.update(self._make_parallel_tree(root, candidates))

                # Track the thread-private and thread-shared variables
                private.extend([
                    i for i in FindSymbols('symbolics').visit(root)
                    if i.is_Array and i._mem_stack
                ])

            # Build the parallel region
            private = sorted(set([i.name for i in private]))
            private = ('private(%s)' % ','.join(private)) if private else ''
            rebuilt = [v for k, v in mapper.items() if k in group]
            par_region = Block(header=self.lang['par-region'](private),
                               body=rebuilt)
            for k, v in list(mapper.items()):
                if isinstance(v, Iteration):
                    mapper[k] = None if v.is_Remainder else par_region
        processed = Transformer(mapper).visit(iet)

        # Hack/workaround to the fact that the OpenMP pragmas are not true
        # IET nodes, so the `nthreads` variables won't be detected as a
        # Callable parameter unless inserted in a mock Expression
        if mapper:
            nt = NThreads()
            eq = LocalExpression(DummyEq(Symbol(name='nt', dtype=np.int32),
                                         nt))
            return List(body=[eq, processed]), {'input': [nt]}
        else:
            return List(body=processed), {}
Exemplo n.º 2
0
def mpi_gpu_direct(iet, **kwargs):
    """
    Modify MPI Callables to enable multiple GPUs performing GPU-Direct communication.
    """
    mapper = {}
    for node in FindNodes((IsendCall, IrecvCall)).visit(iet):
        header = c.Pragma('omp target data use_device_ptr(%s)' %
                          node.arguments[0].name)
        mapper[node] = Block(header=header, body=node)

    iet = Transformer(mapper).visit(iet)

    return iet, {}
Exemplo n.º 3
0
def test_transformer_replace(exprs, block1, block2, block3):
    """Basic transformer test that replaces an expression"""
    line1 = '// Replaced expression'
    replacer = Block(c.Line(line1))
    transformer = Transformer({exprs[0]: replacer})

    for block in [block1, block2, block3]:
        newblock = transformer.visit(block)
        newcode = str(newblock.ccode)
        oldnumlines = len(str(block.ccode).split('\n'))
        newnumlines = len(newcode.split('\n'))
        assert newnumlines >= oldnumlines
        assert line1 in newcode
        assert "a[i0] = a[i0] + b[i0] + 5.0F;" not in newcode
Exemplo n.º 4
0
def test_transformer_wrap(exprs, block1, block2, block3):
    """Basic transformer test that wraps an expression in comments"""
    line1 = '// This is the opening comment'
    line2 = '// This is the closing comment'
    wrapper = lambda n: Block(c.Line(line1), n, c.Line(line2))
    transformer = Transformer({exprs[0]: wrapper(exprs[0])})

    for block in [block1, block2, block3]:
        newblock = transformer.visit(block)
        newcode = str(newblock.ccode)
        oldnumlines = len(str(block.ccode).split('\n'))
        newnumlines = len(newcode.split('\n'))
        assert newnumlines >= oldnumlines + 2
        assert line1 in newcode
        assert line2 in newcode
        assert "a[i] = a[i] + b[i] + 5.0F;" in newcode
Exemplo n.º 5
0
    def make_parallel(self, iet):
        """
        Transform ``iet`` by decorating its parallel :class:`Iteration`s with
        suitable ``#pragma omp ...`` triggering thread-level parallelism.
        """
        # Group sequences of loops that should go within the same parallel region
        was_tagged = False
        groups = OrderedDict()
        for tree in retrieve_iteration_tree(iet):
            # Determine the number of consecutive parallelizable Iterations
            candidates = filter_iterations(tree, key=self.key, stop='asap')
            if not candidates:
                was_tagged = False
                continue
            # Consecutive tagged Iteration go in the same group
            is_tagged = any(i.tag is not None for i in tree)
            key = len(groups) - (is_tagged & was_tagged)
            handle = groups.setdefault(key, OrderedDict())
            handle[candidates[0]] = candidates
            was_tagged = is_tagged

        mapper = OrderedDict()
        for group in groups.values():
            private = []
            for root, candidates in group.items():
                mapper.update(self._make_parallel_tree(root, candidates))

                # Track the thread-private and thread-shared variables
                private.extend([i for i in FindSymbols('symbolics').visit(root)
                                if i.is_Array and i._mem_stack])

            # Build the parallel region
            private = sorted(set([i.name for i in private]))
            private = ('private(%s)' % ','.join(private)) if private else ''
            rebuilt = [v for k, v in mapper.items() if k in group]
            par_region = Block(header=self.lang['par-region'](private), body=rebuilt)
            for k, v in list(mapper.items()):
                if isinstance(v, Iteration):
                    mapper[k] = None if v.is_Remainder else par_region

        return Transformer(mapper).visit(iet)
Exemplo n.º 6
0
    def make_parallel(self, iet):
        """Transform ``iet`` by introducing shared-memory parallelism."""
        mapper = OrderedDict()
        for tree in retrieve_iteration_tree(iet):
            # Get the first omp-parallelizable Iteration in `tree`
            candidates = filter_iterations(tree, key=self.key, stop='asap')
            if not candidates:
                continue
            root = candidates[0]

            # Build the `omp-for` tree
            partree = self._make_parallel_tree(root, candidates)

            # Find out the thread-private and thread-shared variables
            private = [
                i for i in FindSymbols().visit(partree)
                if i.is_Array and i._mem_stack
            ]

            # Build the `omp-parallel` region
            private = sorted(set([i.name for i in private]))
            private = ('private(%s)' % ','.join(private)) if private else ''
            partree = Block(header=self.lang['par-region'](self.nthreads.name,
                                                           private),
                            body=partree)

            # Do not enter the parallel region if the step increment might be 0; this
            # would raise a `Floating point exception (core dumped)` in some OpenMP
            # implementation. Note that using an OpenMP `if` clause won't work
            if isinstance(root.step, Symbol):
                cond = Conditional(CondEq(root.step, 0),
                                   Element(c.Statement('return')))
                partree = List(body=[cond, partree])

            mapper[root] = partree
        iet = Transformer(mapper).visit(iet)

        return iet, {'input': [self.nthreads] if mapper else []}