示例#1
0
        def _(iet):
            # TODO: we need to pick the rank from `comm_shm`, not `comm`,
            # so that we have nranks == ngpus (as long as the user has launched
            # the right number of MPI processes per node given the available
            # number of GPUs per node)

            objcomm = None
            for i in iet.parameters:
                if isinstance(i, MPICommObject):
                    objcomm = i
                    break

            devicetype = as_list(self.lang[self.platform])

            try:
                lang_init = [self.lang['init'](devicetype)]
            except TypeError:
                # Not all target languages need to be explicitly initialized
                lang_init = []

            deviceid = DeviceID()
            if objcomm is not None:
                rank = Symbol(name='rank')
                rank_decl = LocalExpression(DummyEq(rank, 0))
                rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)])

                ngpus = Symbol(name='ngpus')
                call = self.lang['num-devices'](devicetype)
                ngpus_init = LocalExpression(DummyEq(ngpus, call))

                osdd_then = self.lang['set-device']([deviceid] + devicetype)
                osdd_else = self.lang['set-device']([rank % ngpus] +
                                                    devicetype)

                body = lang_init + [
                    Conditional(
                        CondNe(deviceid, -1),
                        osdd_then,
                        List(
                            body=[rank_decl, rank_init, ngpus_init, osdd_else
                                  ]),
                    )
                ]

                header = c.Comment('Begin of %s+MPI setup' % self.lang['name'])
                footer = c.Comment('End of %s+MPI setup' % self.lang['name'])
            else:
                body = lang_init + [
                    Conditional(
                        CondNe(deviceid, -1),
                        self.lang['set-device']([deviceid] + devicetype))
                ]

                header = c.Comment('Begin of %s setup' % self.lang['name'])
                footer = c.Comment('End of %s setup' % self.lang['name'])

            init = List(header=header, body=body, footer=(footer, c.Line()))
            iet = iet._rebuild(body=(init, ) + iet.body)

            return iet, {'args': deviceid}
示例#2
0
文件: efunc.py 项目: speglich/devito
def _make_thread_activate(threads, sdata, sync_ops, sregistry):
    if threads.size == 1:
        d = threads.index
    else:
        d = Symbol(name=sregistry.make_name(prefix=threads.index.name))

    sync_locks = [s for s in sync_ops if s.is_SyncLock]
    condition = Or(*([CondNe(s.handle, 2) for s in sync_locks] +
                     [CondNe(FieldFromComposite(sdata._field_flag, sdata[d]), 1)]))

    if threads.size == 1:
        activation = [While(condition)]
    else:
        activation = [DummyExpr(d, 0),
                      While(condition, DummyExpr(d, (d + 1) % threads.size))]

    activation.extend([DummyExpr(FieldFromComposite(i.name, sdata[d]), i)
                       for i in sdata.dynamic_fields])
    activation.extend([DummyExpr(s.handle, 0) for s in sync_locks])
    activation.append(DummyExpr(FieldFromComposite(sdata._field_flag, sdata[d]), 2))
    activation = List(
        header=[c.Line(), c.Comment("Activate `%s`" % threads.name)],
        body=activation,
        footer=c.Line()
    )

    return activation
示例#3
0
    def _make_sendrecv(self, f, hse, key, **kwargs):
        comm = f.grid.distributor._obj_comm

        buf_dims = [
            Dimension(name='buf_%s' % d.root) for d in f.dimensions
            if d not in hse.loc_indices
        ]
        bufg = Array(name='bufg',
                     dimensions=buf_dims,
                     dtype=f.dtype,
                     padding=0,
                     scope='heap')
        bufs = Array(name='bufs',
                     dimensions=buf_dims,
                     dtype=f.dtype,
                     padding=0,
                     scope='heap')

        ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]
        ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]

        fromrank = Symbol(name='fromrank')
        torank = Symbol(name='torank')

        gather = Call('gather_%s' % key,
                      [bufg] + list(bufg.shape) + [f] + ofsg)
        scatter = Call('scatter_%s' % key,
                       [bufs] + list(bufs.shape) + [f] + ofss)

        # The `gather` is unnecessary if sending to MPI.PROC_NULL
        gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)
        # The `scatter` must be guarded as we must not alter the halo values along
        # the domain boundary, where the sender is actually MPI.PROC_NULL
        scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')),
                              scatter)

        count = reduce(mul, bufs.shape, 1)
        rrecv = MPIRequestObject(name='rrecv')
        rsend = MPIRequestObject(name='rsend')
        recv = Call('MPI_Irecv', [
            bufs, count,
            Macro(dtype_to_mpitype(f.dtype)), fromrank,
            Integer(13), comm, rrecv
        ])
        send = Call('MPI_Isend', [
            bufg, count,
            Macro(dtype_to_mpitype(f.dtype)), torank,
            Integer(13), comm, rsend
        ])

        waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
        waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])

        iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])
        parameters = ([f] + list(bufs.shape) + ofsg + ofss +
                      [fromrank, torank, comm])
        return Callable('sendrecv_%s' % key, iet, 'void', parameters,
                        ('static', ))
示例#4
0
def _make_thread_func(name, iet, root, threads, sregistry):
    # Create the SharedData, that is the data structure that will be used by the
    # main thread to pass information dows to the child thread(s)
    required, parameters, dynamic_parameters = diff_parameters(iet, root)
    parameters = sorted(parameters, key=lambda i: i.is_Function)  # Allow casting
    sdata = SharedData(name=sregistry.make_name(prefix='sdata'), npthreads=threads.size,
                       fields=required, dynamic_fields=dynamic_parameters)

    sbase = sdata.symbolic_base
    sid = sdata.symbolic_id

    # Create a Callable to initialize `sdata` with the known const values
    iname = 'init_%s' % sdata.dtype._type_.__name__
    ibody = [DummyExpr(FieldFromPointer(i._C_name, sbase), i._C_symbol)
             for i in parameters]
    ibody.extend([
        BlankLine,
        DummyExpr(FieldFromPointer(sdata._field_id, sbase), sid),
        DummyExpr(FieldFromPointer(sdata._field_flag, sbase), 1)
    ])
    iparameters = parameters + [sdata, sid]
    isdata = Callable(iname, ibody, 'void', iparameters, 'static')

    # Prepend the SharedData fields available upon thread activation
    preactions = [DummyExpr(i, FieldFromPointer(i.name, sbase))
                  for i in dynamic_parameters]

    # Append the flag reset
    postactions = [List(body=[
        BlankLine,
        DummyExpr(FieldFromPointer(sdata._field_flag, sbase), 1)
    ])]

    iet = List(body=preactions + [iet] + postactions)

    # The thread has work to do when it receives the signal that all locks have
    # been set to 0 by the main thread
    iet = Conditional(CondEq(FieldFromPointer(sdata._field_flag, sbase), 2), iet)

    # The thread keeps spinning until the alive flag is set to 0 by the main thread
    iet = While(CondNe(FieldFromPointer(sdata._field_flag, sbase), 0), iet)

    # pthread functions expect exactly one argument, a void*, and must return void*
    tretval = 'void*'
    tparameter = VoidPointer('_%s' % sdata.name)

    # Unpack `sdata`
    unpack = [PointerCast(sdata, tparameter), BlankLine]
    for i in parameters:
        if i.is_AbstractFunction:
            unpack.extend([Dereference(i, sdata), PointerCast(i)])
        else:
            unpack.append(DummyExpr(i, FieldFromPointer(i.name, sbase)))
    unpack.append(DummyExpr(sid, FieldFromPointer(sdata._field_id, sbase)))
    unpack.append(BlankLine)
    iet = List(body=unpack + [iet, BlankLine, Return(Macro('NULL'))])

    tfunc = ThreadFunction(name, iet, tretval, tparameter, 'static')

    return tfunc, isdata, sdata
def guard(clusters):
    """
    Return a new :class:`ClusterGroup` including new :class:`PartialCluster`s
    for each conditional expression encountered in ``clusters``.
    """
    processed = ClusterGroup()
    for c in clusters:
        # Find out what expressions in /c/ should be guarded
        mapper = {}
        for e in c.exprs:
            for k, v in e.ispace.sub_iterators.items():
                for i in v:
                    if i.dim.is_Conditional:
                        mapper.setdefault(i.dim, []).append(e)

        # Build conditional expressions to guard clusters
        conditions = {d: CondEq(d.parent % d.factor, 0) for d in mapper}
        negated = {d: CondNe(d.parent % d.factor, 0) for d in mapper}

        # Expand with guarded clusters
        combs = list(powerset(mapper))
        for dims, ndims in zip(combs, reversed(combs)):
            banned = flatten(v for k, v in mapper.items() if k not in dims)
            exprs = [
                e.xreplace({i: IntDiv(i.parent, i.factor)
                            for i in mapper}) for e in c.exprs
                if e not in banned
            ]
            guards = [(i.parent, conditions[i]) for i in dims]
            guards.extend([(i.parent, negated[i]) for i in ndims])
            cluster = PartialCluster(exprs, c.ispace, c.dspace, c.atomics,
                                     dict(guards))
            processed.append(cluster)

    return processed
示例#6
0
    def _make_halowait(self, f, hse, key, msg=None):
        cast = cast_mapper[(f.dtype, '*')]

        fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

        dim = Dimension(name='i')

        msgi = IndexedPointer(msg, dim)

        bufs = FieldFromComposite(msg._C_field_bufs, msgi)

        fromrank = FieldFromComposite(msg._C_field_from, msgi)

        sizes = [FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi)
                 for i in range(len(f._dist_dimensions))]
        ofss = [FieldFromComposite('%s[%d]' % (msg._C_field_ofss, i), msgi)
                for i in range(len(f._dist_dimensions))]
        ofss = [fixed.get(d) or ofss.pop(0) for d in f.dimensions]

        # The `scatter` must be guarded as we must not alter the halo values along
        # the domain boundary, where the sender is actually MPI.PROC_NULL
        scatter = Call('scatter%s' % key, [cast(bufs)] + sizes + [f] + ofss)
        scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)

        rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi))
        waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
        rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi))
        waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])

        # The -1 below is because an Iteration, by default, generates <=
        ncomms = Symbol(name='ncomms')
        iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1)
        parameters = ([f] + list(fixed.values()) + [msg, ncomms])
        return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))
示例#7
0
    def _make_sendrecv(self, f, hse, key, msg=None):
        comm = f.grid.distributor._obj_comm

        bufg = FieldFromPointer(msg._C_field_bufg, msg)
        bufs = FieldFromPointer(msg._C_field_bufs, msg)

        ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]

        fromrank = Symbol(name='fromrank')
        torank = Symbol(name='torank')

        sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
                 for i in range(len(f._dist_dimensions))]

        gather = Call('gather%s' % key, [bufg] + sizes + [f] + ofsg)
        # The `gather` is unnecessary if sending to MPI.PROC_NULL
        gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)

        count = reduce(mul, sizes, 1)
        rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg))
        rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg))
        recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)),
                         fromrank, Integer(13), comm, rrecv])
        send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)),
                         torank, Integer(13), comm, rsend])

        iet = List(body=[recv, gather, send])
        parameters = ([f] + ofsg + [fromrank, torank, comm, msg])
        return SendRecv(key, iet, parameters, bufg, bufs)
示例#8
0
    def _make_wait(self, f, hse, key, msg=None):
        bufs = FieldFromPointer(msg._C_field_bufs, msg)

        ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]

        fromrank = Symbol(name='fromrank')

        sizes = [
            FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
            for i in range(len(f._dist_dimensions))
        ]
        scatter = Call('scatter_%s' % key, [bufs] + sizes + [f] + ofss)

        # The `scatter` must be guarded as we must not alter the halo values along
        # the domain boundary, where the sender is actually MPI.PROC_NULL
        scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')),
                              scatter)

        rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg))
        waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
        rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg))
        waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])

        iet = List(body=[waitsend, waitrecv, scatter])
        parameters = ([f] + ofss + [fromrank, msg])
        return Callable('wait_%s' % key, iet, 'void', parameters, ('static', ))
示例#9
0
    def _(iet):
        # TODO: we need to pick the rank from `comm_shm`, not `comm`,
        # so that we have nranks == ngpus (as long as the user has launched
        # the right number of MPI processes per node given the available
        # number of GPUs per node)

        objcomm = None
        for i in iet.parameters:
            if isinstance(i, MPICommObject):
                objcomm = i
                break

        deviceid = DeviceID()
        device_nvidia = Macro('acc_device_nvidia')
        if objcomm is not None:
            rank = Symbol(name='rank')
            rank_decl = LocalExpression(DummyEq(rank, 0))
            rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)])

            ngpus = Symbol(name='ngpus')
            call = DefFunction('acc_get_num_devices', device_nvidia)
            ngpus_init = LocalExpression(DummyEq(ngpus, call))

            asdn_then = Call('acc_set_device_num', [deviceid, device_nvidia])
            asdn_else = Call('acc_set_device_num',
                             [rank % ngpus, device_nvidia])

            body = [
                Call('acc_init', [device_nvidia]),
                Conditional(
                    CondNe(deviceid, -1), asdn_then,
                    List(body=[rank_decl, rank_init, ngpus_init, asdn_else]))
            ]
        else:
            body = [
                Call('acc_init', [device_nvidia]),
                Conditional(
                    CondNe(deviceid, -1),
                    Call('acc_set_device_num', [deviceid, device_nvidia]))
            ]

        init = List(header=c.Comment('Begin of OpenACC+MPI setup'),
                    body=body,
                    footer=(c.Comment('End of OpenACC+MPI setup'), c.Line()))
        iet = iet._rebuild(body=(init, ) + iet.body)

        return iet, {'args': deviceid}
示例#10
0
文件: langbase.py 项目: ofmla/devito
        def _(iet):
            devicetype = as_list(self.lang[self.platform])
            deviceid = self.deviceid

            init = Conditional(
                CondNe(deviceid, -1),
                self.lang['set-device']([deviceid] + devicetype))
            body = iet.body._rebuild(body=(init, BlankLine) + iet.body.body)
            iet = iet._rebuild(body=body)

            return iet, {}
示例#11
0
    def _make_waitprefetch(self, iet, sync_ops, pieces, *args):
        ff = SharedData._field_flag

        waits = []
        objs = filter_ordered(pieces.objs.get(s) for s in sync_ops)
        for sdata, threads in objs:
            wait = BusyWait(
                CondNe(FieldFromComposite(ff, sdata[threads.index]), 1))
            waits.append(wait)

        iet = List(header=c.Comment("Wait for the arrival of prefetched data"),
                   body=waits + [BlankLine, iet])

        return iet
示例#12
0
    def _make_haloupdate(self, f, hse, key, msg=None):
        comm = f.grid.distributor._obj_comm

        fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

        dim = Dimension(name='i')

        msgi = IndexedPointer(msg, dim)

        bufg = FieldFromComposite(msg._C_field_bufg, msgi)
        bufs = FieldFromComposite(msg._C_field_bufs, msgi)

        fromrank = FieldFromComposite(msg._C_field_from, msgi)
        torank = FieldFromComposite(msg._C_field_to, msgi)

        sizes = [
            FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi)
            for i in range(len(f._dist_dimensions))
        ]
        ofsg = [
            FieldFromComposite('%s[%d]' % (msg._C_field_ofsg, i), msgi)
            for i in range(len(f._dist_dimensions))
        ]
        ofsg = [fixed.get(d) or ofsg.pop(0) for d in f.dimensions]

        # The `gather` is unnecessary if sending to MPI.PROC_NULL
        gather = Call('gather_%s' % key, [bufg] + sizes + [f] + ofsg)
        gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)

        # Make Irecv/Isend
        count = reduce(mul, sizes, 1)
        rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi))
        rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi))
        recv = Call('MPI_Irecv', [
            bufs, count,
            Macro(dtype_to_mpitype(f.dtype)), fromrank,
            Integer(13), comm, rrecv
        ])
        send = Call('MPI_Isend', [
            bufg, count,
            Macro(dtype_to_mpitype(f.dtype)), torank,
            Integer(13), comm, rsend
        ])

        # The -1 below is because an Iteration, by default, generates <=
        ncomms = Symbol(name='ncomms')
        iet = Iteration([recv, gather, send], dim, ncomms - 1)
        parameters = ([f, comm, msg, ncomms]) + list(fixed.values())
        return Callable('haloupdate%d' % key, iet, 'void', parameters,
                        ('static', ))
示例#13
0
文件: routines.py 项目: ponykid/SNIST
def sendrecv(f, fixed):
    """Construct an IET performing a halo exchange along arbitrary
    dimension and side."""
    assert f.is_Function
    assert f.grid is not None

    comm = f.grid.distributor._C_comm

    buf_dims = [Dimension(name='buf_%s' % d.root) for d in f.dimensions if d not in fixed]
    bufg = Array(name='bufg', dimensions=buf_dims, dtype=f.dtype, scope='heap')
    bufs = Array(name='bufs', dimensions=buf_dims, dtype=f.dtype, scope='heap')

    dat_dims = [Dimension(name='dat_%s' % d.root) for d in f.dimensions]
    dat = Array(name='dat', dimensions=dat_dims, dtype=f.dtype, scope='external')

    ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]
    ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]

    fromrank = Symbol(name='fromrank')
    torank = Symbol(name='torank')

    parameters = [bufg] + list(bufg.shape) + [dat] + list(dat.shape) + ofsg
    gather = Call('gather_%s' % f.name, parameters)
    parameters = [bufs] + list(bufs.shape) + [dat] + list(dat.shape) + ofss
    scatter = Call('scatter_%s' % f.name, parameters)

    # The scatter must be guarded as we must not alter the halo values along
    # the domain boundary, where the sender is actually MPI.PROC_NULL
    scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)

    srecv = MPIStatusObject(name='srecv')
    rrecv = MPIRequestObject(name='rrecv')
    rsend = MPIRequestObject(name='rsend')

    count = reduce(mul, bufs.shape, 1)
    recv = Call('MPI_Irecv', [bufs, count, Macro(numpy_to_mpitypes(f.dtype)),
                              fromrank, '13', comm, rrecv])
    send = Call('MPI_Isend', [bufg, count, Macro(numpy_to_mpitypes(f.dtype)),
                              torank, '13', comm, rsend])

    waitrecv = Call('MPI_Wait', [rrecv, srecv])
    waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])

    iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])
    iet = List(body=[ArrayCast(dat), iet_insert_C_decls(iet)])
    parameters = ([dat] + list(dat.shape) + list(bufs.shape) +
                  ofsg + ofss + [fromrank, torank, comm])
    return Callable('sendrecv_%s' % f.name, iet, 'void', parameters, ('static',))
示例#14
0
    def __make_tfunc(self, name, iet, root, threads):
        # Create the SharedData
        required = derive_parameters(iet)
        known = (root.parameters +
                 tuple(i for i in required if i.is_Array and i._mem_shared))
        parameters, dynamic_parameters = split(required, lambda i: i in known)

        sdata = SharedData(name=self.sregistry.make_name(prefix='sdata'),
                           nthreads_std=threads.size,
                           fields=dynamic_parameters)
        parameters.append(sdata)

        # Prepend the unwinded SharedData fields, available upon thread activation
        preactions = [
            DummyExpr(i, FieldFromPointer(i.name, sdata.symbolic_base))
            for i in dynamic_parameters
        ]
        preactions.append(
            DummyExpr(sdata.symbolic_id,
                      FieldFromPointer(sdata._field_id, sdata.symbolic_base)))

        # Append the flag reset
        postactions = [
            List(body=[
                BlankLine,
                DummyExpr(
                    FieldFromPointer(sdata._field_flag, sdata.symbolic_base),
                    1)
            ])
        ]

        iet = List(body=preactions + [iet] + postactions)

        # Append the flag reset

        # The thread has work to do when it receives the signal that all locks have
        # been set to 0 by the main thread
        iet = Conditional(
            CondEq(FieldFromPointer(sdata._field_flag, sdata.symbolic_base),
                   2), iet)

        # The thread keeps spinning until the alive flag is set to 0 by the main thread
        iet = While(
            CondNe(FieldFromPointer(sdata._field_flag, sdata.symbolic_base),
                   0), iet)

        return Callable(name, iet, 'void', parameters, 'static'), sdata
示例#15
0
        def _(iet):
            body = FindNodes(WhileAlive).visit(iet)
            assert len(body) == 1
            body = body.pop()

            devicetype = as_list(self.lang[self.platform])
            deviceid = self.deviceid

            init = Conditional(
                CondNe(deviceid, -1),
                self.lang['set-device']([deviceid] + devicetype)
            )

            mapper = {body: List(body=[init, BlankLine, body])}
            iet = Transformer(mapper).visit(iet)

            return iet, {}
示例#16
0
    def _make_sendrecv(self, f, fixed, extra=None):
        extra = extra or []
        comm = f.grid.distributor._obj_comm

        buf_dims = [
            Dimension(name='buf_%s' % d.root) for d in f.dimensions
            if d not in fixed
        ]
        bufg = Array(name='bufg',
                     dimensions=buf_dims,
                     dtype=f.dtype,
                     scope='heap')
        bufs = Array(name='bufs',
                     dimensions=buf_dims,
                     dtype=f.dtype,
                     scope='heap')

        ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]
        ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]

        fromrank = Symbol(name='fromrank')
        torank = Symbol(name='torank')

        args = [bufg] + list(bufg.shape) + [f] + ofsg + extra
        gather = Call('gather%dd' % f.ndim, args)
        args = [bufs] + list(bufs.shape) + [f] + ofss + extra
        scatter = Call('scatter%dd' % f.ndim, args)

        # The `gather` is unnecessary if sending to MPI.PROC_NULL
        gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)
        # The `scatter` must be guarded as we must not alter the halo values along
        # the domain boundary, where the sender is actually MPI.PROC_NULL
        scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')),
                              scatter)

        srecv = MPIStatusObject(name='srecv')
        ssend = MPIStatusObject(name='ssend')
        rrecv = MPIRequestObject(name='rrecv')
        rsend = MPIRequestObject(name='rsend')

        count = reduce(mul, bufs.shape, 1)
        recv = Call('MPI_Irecv', [
            bufs, count,
            Macro(dtype_to_mpitype(f.dtype)), fromrank,
            Integer(13), comm, rrecv
        ])
        send = Call('MPI_Isend', [
            bufg, count,
            Macro(dtype_to_mpitype(f.dtype)), torank,
            Integer(13), comm, rsend
        ])

        waitrecv = Call('MPI_Wait', [rrecv, srecv])
        waitsend = Call('MPI_Wait', [rsend, ssend])

        iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])
        iet = List(body=iet_insert_C_decls(iet))
        parameters = ([f] + list(bufs.shape) + ofsg + ofss +
                      [fromrank, torank, comm] + extra)
        return Callable('sendrecv%dd' % f.ndim, iet, 'void', parameters,
                        ('static', ))
示例#17
0
    def _make_fetchprefetch(self, iet, sync_ops, pieces, root):
        fid = SharedData._field_id

        fetches = []
        prefetches = []
        presents = []
        for s in sync_ops:
            f = s.function
            dimensions = s.dimensions
            fc = s.fetch
            ifc = s.ifetch
            pfc = s.pfetch
            fcond = s.fcond
            pcond = s.pcond

            # Construct init IET
            imask = [(ifc, s.size) if d.root is s.dim.root else FULL
                     for d in dimensions]
            fetch = PragmaTransfer(self.lang._map_to, f, imask=imask)
            fetches.append(Conditional(fcond, fetch))

            # Construct present clauses
            imask = [(fc, s.size) if d.root is s.dim.root else FULL
                     for d in dimensions]
            presents.append(
                PragmaTransfer(self.lang._map_present, f, imask=imask))

            # Construct prefetch IET
            imask = [(pfc, s.size) if d.root is s.dim.root else FULL
                     for d in dimensions]
            prefetch = PragmaTransfer(self.lang._map_to_wait,
                                      f,
                                      imask=imask,
                                      queueid=fid)
            prefetches.append(Conditional(pcond, prefetch))

        # Turn init IET into a Callable
        functions = filter_ordered(s.function for s in sync_ops)
        name = self.sregistry.make_name(prefix='init_device')
        body = List(body=fetches)
        parameters = filter_sorted(functions + derive_parameters(body))
        func = Callable(name, body, 'void', parameters, 'static')
        pieces.funcs.append(func)

        # Perform initial fetch by the main thread
        pieces.init.append(
            List(header=c.Comment("Initialize data stream"),
                 body=[Call(name, parameters), BlankLine]))

        # Turn prefetch IET into a ThreadFunction
        name = self.sregistry.make_name(prefix='prefetch_host_to_device')
        body = List(header=c.Line(), body=prefetches)
        tctx = make_thread_ctx(name, body, root, None, sync_ops,
                               self.sregistry)
        pieces.funcs.extend(tctx.funcs)

        # Glue together all the IET pieces, including the activation logic
        sdata = tctx.sdata
        threads = tctx.threads
        iet = List(body=[
            BlankLine,
            BusyWait(
                CondNe(
                    FieldFromComposite(sdata._field_flag, sdata[
                        threads.index]), 1))
        ] + presents + [iet, tctx.activate])

        # Fire up the threads
        pieces.init.append(tctx.init)

        # Final wait before jumping back to Python land
        pieces.finalize.append(tctx.finalize)

        # Keep track of created objects
        pieces.objs.add(sync_ops, sdata, threads)

        return iet
示例#18
0
    def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root):
        threads = self.__make_threads()

        fetches = []
        prefetches = []
        presents = []
        for s in sync_ops:
            if s.direction is Forward:
                fc = s.fetch.subs(s.dim, s.dim.symbolic_min)
                fsize = s.function._C_get_field(FULL, s.dim).size
                fc_cond = fc + (s.size - 1) < fsize
                pfc = s.fetch + 1
                pfc_cond = pfc + (s.size - 1) < fsize
            else:
                fc = s.fetch.subs(s.dim, s.dim.symbolic_max)
                fc_cond = fc >= 0
                pfc = s.fetch - 1
                pfc_cond = pfc >= 0

            # Construct fetch IET
            imask = [(fc, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            fetch = List(header=self._P._map_to(s.function, imask))
            fetches.append(Conditional(fc_cond, fetch))

            # Construct present clauses
            imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            presents.extend(as_list(self._P._map_present(s.function, imask)))

            # Construct prefetch IET
            imask = [(pfc, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            prefetch = List(header=self._P._map_to_wait(
                s.function, imask, SharedData._field_id))
            prefetches.append(Conditional(pfc_cond, prefetch))

        functions = filter_ordered(s.function for s in sync_ops)
        casts = [PointerCast(f) for f in functions]

        # Turn init IET into a Callable
        name = self.sregistry.make_name(prefix='init_device')
        body = List(body=casts + fetches)
        parameters = filter_sorted(functions + derive_parameters(body))
        func = Callable(name, body, 'void', parameters, 'static')
        pieces.funcs.append(func)

        # Perform initial fetch by the main thread
        pieces.init.append(
            List(header=c.Comment("Initialize data stream for `%s`" %
                                  threads.name),
                 body=[Call(name, func.parameters), BlankLine]))

        # Turn prefetch IET into a threaded Callable
        name = self.sregistry.make_name(prefix='prefetch_host_to_device')
        body = List(header=c.Line(), body=casts + prefetches)
        tfunc, sdata = self.__make_tfunc(name, body, root, threads)
        pieces.funcs.append(tfunc)

        # Glue together all the IET pieces, including the activation bits
        iet = List(body=[
            BlankLine,
            BusyWait(
                CondNe(
                    FieldFromComposite(sdata._field_flag, sdata[
                        threads.index]), 1)),
            List(header=presents), iet,
            self.__make_activate_thread(threads, sdata, sync_ops)
        ])

        # Fire up the threads
        pieces.init.append(
            self.__make_init_threads(threads, sdata, tfunc, pieces))
        pieces.threads.append(threads)

        # Final wait before jumping back to Python land
        pieces.finalize.append(self.__make_finalize_threads(threads, sdata))

        return iet
示例#19
0
    def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root):
        fetches = []
        prefetches = []
        presents = []
        for s in sync_ops:
            if s.direction is Forward:
                fc = s.fetch.subs(s.dim, s.dim.symbolic_min)
                pfc = s.fetch + 1
                fc_cond = s.next_cbk(s.dim.symbolic_min)
                pfc_cond = s.next_cbk(s.dim + 1)
            else:
                fc = s.fetch.subs(s.dim, s.dim.symbolic_max)
                pfc = s.fetch - 1
                fc_cond = s.next_cbk(s.dim.symbolic_max)
                pfc_cond = s.next_cbk(s.dim - 1)

            # Construct init IET
            imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions]
            fetch = PragmaList(self.lang._map_to(s.function, imask),
                               {s.function} | fc.free_symbols)
            fetches.append(Conditional(fc_cond, fetch))

            # Construct present clauses
            imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            presents.extend(as_list(self.lang._map_present(s.function, imask)))

            # Construct prefetch IET
            imask = [(pfc, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            prefetch = PragmaList(self.lang._map_to_wait(s.function, imask,
                                                         SharedData._field_id),
                                  {s.function} | pfc.free_symbols)
            prefetches.append(Conditional(pfc_cond, prefetch))

        # Turn init IET into a Callable
        functions = filter_ordered(s.function for s in sync_ops)
        name = self.sregistry.make_name(prefix='init_device')
        body = List(body=fetches)
        parameters = filter_sorted(functions + derive_parameters(body))
        func = Callable(name, body, 'void', parameters, 'static')
        pieces.funcs.append(func)

        # Perform initial fetch by the main thread
        pieces.init.append(List(
            header=c.Comment("Initialize data stream"),
            body=[Call(name, parameters), BlankLine]
        ))

        # Turn prefetch IET into a ThreadFunction
        name = self.sregistry.make_name(prefix='prefetch_host_to_device')
        body = List(header=c.Line(), body=prefetches)
        tctx = make_thread_ctx(name, body, root, None, sync_ops, self.sregistry)
        pieces.funcs.extend(tctx.funcs)

        # Glue together all the IET pieces, including the activation logic
        sdata = tctx.sdata
        threads = tctx.threads
        iet = List(body=[
            BlankLine,
            BusyWait(CondNe(FieldFromComposite(sdata._field_flag,
                                               sdata[threads.index]), 1)),
            List(header=presents),
            iet,
            tctx.activate
        ])

        # Fire up the threads
        pieces.init.append(tctx.init)
        pieces.threads.append(threads)

        # Final wait before jumping back to Python land
        pieces.finalize.append(tctx.finalize)

        return iet