Beispiel #1
0
        def _(iet):
            # TODO: we need to pick the rank from `comm_shm`, not `comm`,
            # so that we have nranks == ngpus (as long as the user has launched
            # the right number of MPI processes per node given the available
            # number of GPUs per node)

            objcomm = None
            for i in iet.parameters:
                if isinstance(i, MPICommObject):
                    objcomm = i
                    break

            devicetype = as_list(self.lang[self.platform])

            try:
                lang_init = [self.lang['init'](devicetype)]
            except TypeError:
                # Not all target languages need to be explicitly initialized
                lang_init = []

            deviceid = DeviceID()
            if objcomm is not None:
                rank = Symbol(name='rank')
                rank_decl = LocalExpression(DummyEq(rank, 0))
                rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)])

                ngpus = Symbol(name='ngpus')
                call = self.lang['num-devices'](devicetype)
                ngpus_init = LocalExpression(DummyEq(ngpus, call))

                osdd_then = self.lang['set-device']([deviceid] + devicetype)
                osdd_else = self.lang['set-device']([rank % ngpus] +
                                                    devicetype)

                body = lang_init + [
                    Conditional(
                        CondNe(deviceid, -1),
                        osdd_then,
                        List(
                            body=[rank_decl, rank_init, ngpus_init, osdd_else
                                  ]),
                    )
                ]

                header = c.Comment('Begin of %s+MPI setup' % self.lang['name'])
                footer = c.Comment('End of %s+MPI setup' % self.lang['name'])
            else:
                body = lang_init + [
                    Conditional(
                        CondNe(deviceid, -1),
                        self.lang['set-device']([deviceid] + devicetype))
                ]

                header = c.Comment('Begin of %s setup' % self.lang['name'])
                footer = c.Comment('End of %s setup' % self.lang['name'])

            init = List(header=header, body=body, footer=(footer, c.Line()))
            iet = iet._rebuild(body=(init, ) + iet.body)

            return iet, {'args': deviceid}
Beispiel #2
0
 def __init__(self, prodder):
     condition = CondEq(Function('omp_get_thread_num')(), 0)
     then_body = Call(prodder.name, prodder.arguments)
     Conditional.__init__(self, condition, then_body)
     Prodder.__init__(self,
                      prodder.name,
                      prodder.arguments,
                      periodic=prodder.periodic)
Beispiel #3
0
    def __init__(self, prodder):
        # Atomic-ize any single-thread Prodders in the parallel tree
        condition = CondEq(Ompizer.lang['thread-num'], 0)

        # Prod within a while loop until all communications have completed
        # In other words, the thread delegated to prodding is entrapped for as long
        # as it's required
        prod_until = Not(DefFunction(prodder.name, [i.name for i in prodder.arguments]))
        then_body = List(header=c.Comment('Entrap thread until comms have completed'),
                         body=While(prod_until))

        Conditional.__init__(self, condition, then_body)
        Prodder.__init__(self, prodder.name, prodder.arguments, periodic=prodder.periodic)
Beispiel #4
0
    def _make_guard(self, parregion, *args):
        partrees = FindNodes(ParallelTree).visit(parregion)
        if not any(isinstance(i.root, self.DeviceIteration) for i in partrees):
            return super()._make_guard(parregion, *args)

        cond = []
        # There must be at least one iteration or potential crash
        if not parregion.is_Affine:
            trees = retrieve_iteration_tree(parregion.root)
            tree = trees[0][:parregion.ncollapsed]
            cond.extend([i.symbolic_size > 0 for i in tree])

        # SparseFunctions may occasionally degenerate to zero-size arrays. In such
        # a case, a copy-in produces a `nil` pointer on the device. To fire up a
        # parallel loop we must ensure none of the SparseFunction pointers are `nil`
        symbols = FindSymbols().visit(parregion)
        sfs = [i for i in symbols if i.is_SparseFunction]
        if sfs:
            size = [prod(f._C_get_field(FULL, d).size for d in f.dimensions) for f in sfs]
            cond.extend([i > 0 for i in size])

        # Drop dynamically evaluated conditions (e.g. because the `symbolic_size`
        # is an integer value rather than a symbol). This avoids ugly and
        # unnecessary conditionals such as `if (true) { ...}`
        cond = [i for i in cond if i != true]

        # Combine all cond elements
        if cond:
            parregion = List(body=[Conditional(And(*cond), parregion)])

        return parregion
 def _make_guard(self, partree, collapsed):
     # Do not enter the parallel region if the step increment is 0; this
     # would raise a `Floating point exception (core dumped)` in some OpenMP
     # implementations. Note that using an OpenMP `if` clause won't work
     cond = [CondEq(i.step, 0) for i in collapsed if isinstance(i.step, Symbol)]
     cond = Or(*cond)
     if cond != False:  # noqa: `cond` may be a sympy.False which would be == False
         partree = List(body=[Conditional(cond, Return()), partree])
     return partree
Beispiel #6
0
        def _(iet):
            devicetype = as_list(self.lang[self.platform])
            deviceid = self.deviceid

            init = Conditional(
                CondNe(deviceid, -1),
                self.lang['set-device']([deviceid] + devicetype))
            body = iet.body._rebuild(body=(init, BlankLine) + iet.body.body)
            iet = iet._rebuild(body=body)

            return iet, {}
Beispiel #7
0
        def _(iet):
            body = FindNodes(WhileAlive).visit(iet)
            assert len(body) == 1
            body = body.pop()

            devicetype = as_list(self.lang[self.platform])
            deviceid = self.deviceid

            init = Conditional(
                CondNe(deviceid, -1),
                self.lang['set-device']([deviceid] + devicetype)
            )

            mapper = {body: List(body=[init, BlankLine, body])}
            iet = Transformer(mapper).visit(iet)

            return iet, {}