Ejemplo n.º 1
        def _(iet):
            # TODO: we need to pick the rank from `comm_shm`, not `comm`,
            # so that we have nranks == ngpus (as long as the user has launched
            # the right number of MPI processes per node given the available
            # number of GPUs per node)

            objcomm = None
            for i in iet.parameters:
                if isinstance(i, MPICommObject):
                    objcomm = i

            devicetype = as_list(self.lang[self.platform])

                lang_init = [self.lang['init'](devicetype)]
            except TypeError:
                # Not all target languages need to be explicitly initialized
                lang_init = []

            deviceid = DeviceID()
            if objcomm is not None:
                rank = Symbol(name='rank')
                rank_decl = LocalExpression(DummyEq(rank, 0))
                rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)])

                ngpus = Symbol(name='ngpus')
                call = self.lang['num-devices'](devicetype)
                ngpus_init = LocalExpression(DummyEq(ngpus, call))

                osdd_then = self.lang['set-device']([deviceid] + devicetype)
                osdd_else = self.lang['set-device']([rank % ngpus] +

                body = lang_init + [
                        CondNe(deviceid, -1),
                            body=[rank_decl, rank_init, ngpus_init, osdd_else

                header = c.Comment('Begin of %s+MPI setup' % self.lang['name'])
                footer = c.Comment('End of %s+MPI setup' % self.lang['name'])
                body = lang_init + [
                        CondNe(deviceid, -1),
                        self.lang['set-device']([deviceid] + devicetype))

                header = c.Comment('Begin of %s setup' % self.lang['name'])
                footer = c.Comment('End of %s setup' % self.lang['name'])

            init = List(header=header, body=body, footer=(footer, c.Line()))
            iet = iet._rebuild(body=(init, ) + iet.body)

            return iet, {'args': deviceid}
Ejemplo n.º 2
def _make_thread_activate(threads, sdata, sync_ops, sregistry):
    if threads.size == 1:
        d = threads.index
        d = Symbol(name=sregistry.make_name(prefix=threads.index.name))

    sync_locks = [s for s in sync_ops if s.is_SyncLock]
    condition = Or(*([CondNe(s.handle, 2) for s in sync_locks] +
                     [CondNe(FieldFromComposite(sdata._field_flag, sdata[d]), 1)]))

    if threads.size == 1:
        activation = [While(condition)]
        activation = [DummyExpr(d, 0),
                      While(condition, DummyExpr(d, (d + 1) % threads.size))]

    activation.extend([DummyExpr(FieldFromComposite(i.name, sdata[d]), i)
                       for i in sdata.dynamic_fields])
    activation.extend([DummyExpr(s.handle, 0) for s in sync_locks])
    activation.append(DummyExpr(FieldFromComposite(sdata._field_flag, sdata[d]), 2))
    activation = List(
        header=[c.Line(), c.Comment("Activate `%s`" % threads.name)],

    return activation
Ejemplo n.º 3
    def _make_sendrecv(self, f, hse, key, **kwargs):
        comm = f.grid.distributor._obj_comm

        buf_dims = [
            Dimension(name='buf_%s' % d.root) for d in f.dimensions
            if d not in hse.loc_indices
        bufg = Array(name='bufg',
        bufs = Array(name='bufs',

        ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]
        ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]

        fromrank = Symbol(name='fromrank')
        torank = Symbol(name='torank')

        gather = Call('gather_%s' % key,
                      [bufg] + list(bufg.shape) + [f] + ofsg)
        scatter = Call('scatter_%s' % key,
                       [bufs] + list(bufs.shape) + [f] + ofss)

        # The `gather` is unnecessary if sending to MPI.PROC_NULL
        gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)
        # The `scatter` must be guarded as we must not alter the halo values along
        # the domain boundary, where the sender is actually MPI.PROC_NULL
        scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')),

        count = reduce(mul, bufs.shape, 1)
        rrecv = MPIRequestObject(name='rrecv')
        rsend = MPIRequestObject(name='rsend')
        recv = Call('MPI_Irecv', [
            bufs, count,
            Macro(dtype_to_mpitype(f.dtype)), fromrank,
            Integer(13), comm, rrecv
        send = Call('MPI_Isend', [
            bufg, count,
            Macro(dtype_to_mpitype(f.dtype)), torank,
            Integer(13), comm, rsend

        waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
        waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])

        iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])
        parameters = ([f] + list(bufs.shape) + ofsg + ofss +
                      [fromrank, torank, comm])
        return Callable('sendrecv_%s' % key, iet, 'void', parameters,
                        ('static', ))
Ejemplo n.º 4
def _make_thread_func(name, iet, root, threads, sregistry):
    # Create the SharedData, that is the data structure that will be used by the
    # main thread to pass information dows to the child thread(s)
    required, parameters, dynamic_parameters = diff_parameters(iet, root)
    parameters = sorted(parameters, key=lambda i: i.is_Function)  # Allow casting
    sdata = SharedData(name=sregistry.make_name(prefix='sdata'), npthreads=threads.size,
                       fields=required, dynamic_fields=dynamic_parameters)

    sbase = sdata.symbolic_base
    sid = sdata.symbolic_id

    # Create a Callable to initialize `sdata` with the known const values
    iname = 'init_%s' % sdata.dtype._type_.__name__
    ibody = [DummyExpr(FieldFromPointer(i._C_name, sbase), i._C_symbol)
             for i in parameters]
        DummyExpr(FieldFromPointer(sdata._field_id, sbase), sid),
        DummyExpr(FieldFromPointer(sdata._field_flag, sbase), 1)
    iparameters = parameters + [sdata, sid]
    isdata = Callable(iname, ibody, 'void', iparameters, 'static')

    # Prepend the SharedData fields available upon thread activation
    preactions = [DummyExpr(i, FieldFromPointer(i.name, sbase))
                  for i in dynamic_parameters]

    # Append the flag reset
    postactions = [List(body=[
        DummyExpr(FieldFromPointer(sdata._field_flag, sbase), 1)

    iet = List(body=preactions + [iet] + postactions)

    # The thread has work to do when it receives the signal that all locks have
    # been set to 0 by the main thread
    iet = Conditional(CondEq(FieldFromPointer(sdata._field_flag, sbase), 2), iet)

    # The thread keeps spinning until the alive flag is set to 0 by the main thread
    iet = While(CondNe(FieldFromPointer(sdata._field_flag, sbase), 0), iet)

    # pthread functions expect exactly one argument, a void*, and must return void*
    tretval = 'void*'
    tparameter = VoidPointer('_%s' % sdata.name)

    # Unpack `sdata`
    unpack = [PointerCast(sdata, tparameter), BlankLine]
    for i in parameters:
        if i.is_AbstractFunction:
            unpack.extend([Dereference(i, sdata), PointerCast(i)])
            unpack.append(DummyExpr(i, FieldFromPointer(i.name, sbase)))
    unpack.append(DummyExpr(sid, FieldFromPointer(sdata._field_id, sbase)))
    iet = List(body=unpack + [iet, BlankLine, Return(Macro('NULL'))])

    tfunc = ThreadFunction(name, iet, tretval, tparameter, 'static')

    return tfunc, isdata, sdata
def guard(clusters):
    Return a new :class:`ClusterGroup` including new :class:`PartialCluster`s
    for each conditional expression encountered in ``clusters``.
    processed = ClusterGroup()
    for c in clusters:
        # Find out what expressions in /c/ should be guarded
        mapper = {}
        for e in c.exprs:
            for k, v in e.ispace.sub_iterators.items():
                for i in v:
                    if i.dim.is_Conditional:
                        mapper.setdefault(i.dim, []).append(e)

        # Build conditional expressions to guard clusters
        conditions = {d: CondEq(d.parent % d.factor, 0) for d in mapper}
        negated = {d: CondNe(d.parent % d.factor, 0) for d in mapper}

        # Expand with guarded clusters
        combs = list(powerset(mapper))
        for dims, ndims in zip(combs, reversed(combs)):
            banned = flatten(v for k, v in mapper.items() if k not in dims)
            exprs = [
                e.xreplace({i: IntDiv(i.parent, i.factor)
                            for i in mapper}) for e in c.exprs
                if e not in banned
            guards = [(i.parent, conditions[i]) for i in dims]
            guards.extend([(i.parent, negated[i]) for i in ndims])
            cluster = PartialCluster(exprs, c.ispace, c.dspace, c.atomics,

    return processed
Ejemplo n.º 6
    def _make_halowait(self, f, hse, key, msg=None):
        cast = cast_mapper[(f.dtype, '*')]

        fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

        dim = Dimension(name='i')

        msgi = IndexedPointer(msg, dim)

        bufs = FieldFromComposite(msg._C_field_bufs, msgi)

        fromrank = FieldFromComposite(msg._C_field_from, msgi)

        sizes = [FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi)
                 for i in range(len(f._dist_dimensions))]
        ofss = [FieldFromComposite('%s[%d]' % (msg._C_field_ofss, i), msgi)
                for i in range(len(f._dist_dimensions))]
        ofss = [fixed.get(d) or ofss.pop(0) for d in f.dimensions]

        # The `scatter` must be guarded as we must not alter the halo values along
        # the domain boundary, where the sender is actually MPI.PROC_NULL
        scatter = Call('scatter%s' % key, [cast(bufs)] + sizes + [f] + ofss)
        scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)

        rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi))
        waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
        rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi))
        waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])

        # The -1 below is because an Iteration, by default, generates <=
        ncomms = Symbol(name='ncomms')
        iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1)
        parameters = ([f] + list(fixed.values()) + [msg, ncomms])
        return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))
Ejemplo n.º 7
    def _make_sendrecv(self, f, hse, key, msg=None):
        comm = f.grid.distributor._obj_comm

        bufg = FieldFromPointer(msg._C_field_bufg, msg)
        bufs = FieldFromPointer(msg._C_field_bufs, msg)

        ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]

        fromrank = Symbol(name='fromrank')
        torank = Symbol(name='torank')

        sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
                 for i in range(len(f._dist_dimensions))]

        gather = Call('gather%s' % key, [bufg] + sizes + [f] + ofsg)
        # The `gather` is unnecessary if sending to MPI.PROC_NULL
        gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)

        count = reduce(mul, sizes, 1)
        rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg))
        rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg))
        recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)),
                         fromrank, Integer(13), comm, rrecv])
        send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)),
                         torank, Integer(13), comm, rsend])

        iet = List(body=[recv, gather, send])
        parameters = ([f] + ofsg + [fromrank, torank, comm, msg])
        return SendRecv(key, iet, parameters, bufg, bufs)
Ejemplo n.º 8
    def _make_wait(self, f, hse, key, msg=None):
        bufs = FieldFromPointer(msg._C_field_bufs, msg)

        ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]

        fromrank = Symbol(name='fromrank')

        sizes = [
            FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
            for i in range(len(f._dist_dimensions))
        scatter = Call('scatter_%s' % key, [bufs] + sizes + [f] + ofss)

        # The `scatter` must be guarded as we must not alter the halo values along
        # the domain boundary, where the sender is actually MPI.PROC_NULL
        scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')),

        rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg))
        waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
        rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg))
        waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])

        iet = List(body=[waitsend, waitrecv, scatter])
        parameters = ([f] + ofss + [fromrank, msg])
        return Callable('wait_%s' % key, iet, 'void', parameters, ('static', ))
Ejemplo n.º 9
    def _(iet):
        # TODO: we need to pick the rank from `comm_shm`, not `comm`,
        # so that we have nranks == ngpus (as long as the user has launched
        # the right number of MPI processes per node given the available
        # number of GPUs per node)

        objcomm = None
        for i in iet.parameters:
            if isinstance(i, MPICommObject):
                objcomm = i

        deviceid = DeviceID()
        device_nvidia = Macro('acc_device_nvidia')
        if objcomm is not None:
            rank = Symbol(name='rank')
            rank_decl = LocalExpression(DummyEq(rank, 0))
            rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)])

            ngpus = Symbol(name='ngpus')
            call = DefFunction('acc_get_num_devices', device_nvidia)
            ngpus_init = LocalExpression(DummyEq(ngpus, call))

            asdn_then = Call('acc_set_device_num', [deviceid, device_nvidia])
            asdn_else = Call('acc_set_device_num',
                             [rank % ngpus, device_nvidia])

            body = [
                Call('acc_init', [device_nvidia]),
                    CondNe(deviceid, -1), asdn_then,
                    List(body=[rank_decl, rank_init, ngpus_init, asdn_else]))
            body = [
                Call('acc_init', [device_nvidia]),
                    CondNe(deviceid, -1),
                    Call('acc_set_device_num', [deviceid, device_nvidia]))

        init = List(header=c.Comment('Begin of OpenACC+MPI setup'),
                    footer=(c.Comment('End of OpenACC+MPI setup'), c.Line()))
        iet = iet._rebuild(body=(init, ) + iet.body)

        return iet, {'args': deviceid}
Ejemplo n.º 10
        def _(iet):
            devicetype = as_list(self.lang[self.platform])
            deviceid = self.deviceid

            init = Conditional(
                CondNe(deviceid, -1),
                self.lang['set-device']([deviceid] + devicetype))
            body = iet.body._rebuild(body=(init, BlankLine) + iet.body.body)
            iet = iet._rebuild(body=body)

            return iet, {}
Ejemplo n.º 11
    def _make_waitprefetch(self, iet, sync_ops, pieces, *args):
        ff = SharedData._field_flag

        waits = []
        objs = filter_ordered(pieces.objs.get(s) for s in sync_ops)
        for sdata, threads in objs:
            wait = BusyWait(
                CondNe(FieldFromComposite(ff, sdata[threads.index]), 1))

        iet = List(header=c.Comment("Wait for the arrival of prefetched data"),
                   body=waits + [BlankLine, iet])

        return iet
Ejemplo n.º 12
    def _make_haloupdate(self, f, hse, key, msg=None):
        comm = f.grid.distributor._obj_comm

        fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices}

        dim = Dimension(name='i')

        msgi = IndexedPointer(msg, dim)

        bufg = FieldFromComposite(msg._C_field_bufg, msgi)
        bufs = FieldFromComposite(msg._C_field_bufs, msgi)

        fromrank = FieldFromComposite(msg._C_field_from, msgi)
        torank = FieldFromComposite(msg._C_field_to, msgi)

        sizes = [
            FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi)
            for i in range(len(f._dist_dimensions))
        ofsg = [
            FieldFromComposite('%s[%d]' % (msg._C_field_ofsg, i), msgi)
            for i in range(len(f._dist_dimensions))
        ofsg = [fixed.get(d) or ofsg.pop(0) for d in f.dimensions]

        # The `gather` is unnecessary if sending to MPI.PROC_NULL
        gather = Call('gather_%s' % key, [bufg] + sizes + [f] + ofsg)
        gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)

        # Make Irecv/Isend
        count = reduce(mul, sizes, 1)
        rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi))
        rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi))
        recv = Call('MPI_Irecv', [
            bufs, count,
            Macro(dtype_to_mpitype(f.dtype)), fromrank,
            Integer(13), comm, rrecv
        send = Call('MPI_Isend', [
            bufg, count,
            Macro(dtype_to_mpitype(f.dtype)), torank,
            Integer(13), comm, rsend

        # The -1 below is because an Iteration, by default, generates <=
        ncomms = Symbol(name='ncomms')
        iet = Iteration([recv, gather, send], dim, ncomms - 1)
        parameters = ([f, comm, msg, ncomms]) + list(fixed.values())
        return Callable('haloupdate%d' % key, iet, 'void', parameters,
                        ('static', ))
Ejemplo n.º 13
def sendrecv(f, fixed):
    """Construct an IET performing a halo exchange along arbitrary
    dimension and side."""
    assert f.is_Function
    assert f.grid is not None

    comm = f.grid.distributor._C_comm

    buf_dims = [Dimension(name='buf_%s' % d.root) for d in f.dimensions if d not in fixed]
    bufg = Array(name='bufg', dimensions=buf_dims, dtype=f.dtype, scope='heap')
    bufs = Array(name='bufs', dimensions=buf_dims, dtype=f.dtype, scope='heap')

    dat_dims = [Dimension(name='dat_%s' % d.root) for d in f.dimensions]
    dat = Array(name='dat', dimensions=dat_dims, dtype=f.dtype, scope='external')

    ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]
    ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]

    fromrank = Symbol(name='fromrank')
    torank = Symbol(name='torank')

    parameters = [bufg] + list(bufg.shape) + [dat] + list(dat.shape) + ofsg
    gather = Call('gather_%s' % f.name, parameters)
    parameters = [bufs] + list(bufs.shape) + [dat] + list(dat.shape) + ofss
    scatter = Call('scatter_%s' % f.name, parameters)

    # The scatter must be guarded as we must not alter the halo values along
    # the domain boundary, where the sender is actually MPI.PROC_NULL
    scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)

    srecv = MPIStatusObject(name='srecv')
    rrecv = MPIRequestObject(name='rrecv')
    rsend = MPIRequestObject(name='rsend')

    count = reduce(mul, bufs.shape, 1)
    recv = Call('MPI_Irecv', [bufs, count, Macro(numpy_to_mpitypes(f.dtype)),
                              fromrank, '13', comm, rrecv])
    send = Call('MPI_Isend', [bufg, count, Macro(numpy_to_mpitypes(f.dtype)),
                              torank, '13', comm, rsend])

    waitrecv = Call('MPI_Wait', [rrecv, srecv])
    waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])

    iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])
    iet = List(body=[ArrayCast(dat), iet_insert_C_decls(iet)])
    parameters = ([dat] + list(dat.shape) + list(bufs.shape) +
                  ofsg + ofss + [fromrank, torank, comm])
    return Callable('sendrecv_%s' % f.name, iet, 'void', parameters, ('static',))
Ejemplo n.º 14
    def __make_tfunc(self, name, iet, root, threads):
        # Create the SharedData
        required = derive_parameters(iet)
        known = (root.parameters +
                 tuple(i for i in required if i.is_Array and i._mem_shared))
        parameters, dynamic_parameters = split(required, lambda i: i in known)

        sdata = SharedData(name=self.sregistry.make_name(prefix='sdata'),

        # Prepend the unwinded SharedData fields, available upon thread activation
        preactions = [
            DummyExpr(i, FieldFromPointer(i.name, sdata.symbolic_base))
            for i in dynamic_parameters
                      FieldFromPointer(sdata._field_id, sdata.symbolic_base)))

        # Append the flag reset
        postactions = [
                    FieldFromPointer(sdata._field_flag, sdata.symbolic_base),

        iet = List(body=preactions + [iet] + postactions)

        # Append the flag reset

        # The thread has work to do when it receives the signal that all locks have
        # been set to 0 by the main thread
        iet = Conditional(
            CondEq(FieldFromPointer(sdata._field_flag, sdata.symbolic_base),
                   2), iet)

        # The thread keeps spinning until the alive flag is set to 0 by the main thread
        iet = While(
            CondNe(FieldFromPointer(sdata._field_flag, sdata.symbolic_base),
                   0), iet)

        return Callable(name, iet, 'void', parameters, 'static'), sdata
Ejemplo n.º 15
        def _(iet):
            body = FindNodes(WhileAlive).visit(iet)
            assert len(body) == 1
            body = body.pop()

            devicetype = as_list(self.lang[self.platform])
            deviceid = self.deviceid

            init = Conditional(
                CondNe(deviceid, -1),
                self.lang['set-device']([deviceid] + devicetype)

            mapper = {body: List(body=[init, BlankLine, body])}
            iet = Transformer(mapper).visit(iet)

            return iet, {}
Ejemplo n.º 16
    def _make_sendrecv(self, f, fixed, extra=None):
        extra = extra or []
        comm = f.grid.distributor._obj_comm

        buf_dims = [
            Dimension(name='buf_%s' % d.root) for d in f.dimensions
            if d not in fixed
        bufg = Array(name='bufg',
        bufs = Array(name='bufs',

        ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions]
        ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]

        fromrank = Symbol(name='fromrank')
        torank = Symbol(name='torank')

        args = [bufg] + list(bufg.shape) + [f] + ofsg + extra
        gather = Call('gather%dd' % f.ndim, args)
        args = [bufs] + list(bufs.shape) + [f] + ofss + extra
        scatter = Call('scatter%dd' % f.ndim, args)

        # The `gather` is unnecessary if sending to MPI.PROC_NULL
        gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)
        # The `scatter` must be guarded as we must not alter the halo values along
        # the domain boundary, where the sender is actually MPI.PROC_NULL
        scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')),

        srecv = MPIStatusObject(name='srecv')
        ssend = MPIStatusObject(name='ssend')
        rrecv = MPIRequestObject(name='rrecv')
        rsend = MPIRequestObject(name='rsend')

        count = reduce(mul, bufs.shape, 1)
        recv = Call('MPI_Irecv', [
            bufs, count,
            Macro(dtype_to_mpitype(f.dtype)), fromrank,
            Integer(13), comm, rrecv
        send = Call('MPI_Isend', [
            bufg, count,
            Macro(dtype_to_mpitype(f.dtype)), torank,
            Integer(13), comm, rsend

        waitrecv = Call('MPI_Wait', [rrecv, srecv])
        waitsend = Call('MPI_Wait', [rsend, ssend])

        iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])
        iet = List(body=iet_insert_C_decls(iet))
        parameters = ([f] + list(bufs.shape) + ofsg + ofss +
                      [fromrank, torank, comm] + extra)
        return Callable('sendrecv%dd' % f.ndim, iet, 'void', parameters,
                        ('static', ))
Ejemplo n.º 17
    def _make_fetchprefetch(self, iet, sync_ops, pieces, root):
        fid = SharedData._field_id

        fetches = []
        prefetches = []
        presents = []
        for s in sync_ops:
            f = s.function
            dimensions = s.dimensions
            fc = s.fetch
            ifc = s.ifetch
            pfc = s.pfetch
            fcond = s.fcond
            pcond = s.pcond

            # Construct init IET
            imask = [(ifc, s.size) if d.root is s.dim.root else FULL
                     for d in dimensions]
            fetch = PragmaTransfer(self.lang._map_to, f, imask=imask)
            fetches.append(Conditional(fcond, fetch))

            # Construct present clauses
            imask = [(fc, s.size) if d.root is s.dim.root else FULL
                     for d in dimensions]
                PragmaTransfer(self.lang._map_present, f, imask=imask))

            # Construct prefetch IET
            imask = [(pfc, s.size) if d.root is s.dim.root else FULL
                     for d in dimensions]
            prefetch = PragmaTransfer(self.lang._map_to_wait,
            prefetches.append(Conditional(pcond, prefetch))

        # Turn init IET into a Callable
        functions = filter_ordered(s.function for s in sync_ops)
        name = self.sregistry.make_name(prefix='init_device')
        body = List(body=fetches)
        parameters = filter_sorted(functions + derive_parameters(body))
        func = Callable(name, body, 'void', parameters, 'static')

        # Perform initial fetch by the main thread
            List(header=c.Comment("Initialize data stream"),
                 body=[Call(name, parameters), BlankLine]))

        # Turn prefetch IET into a ThreadFunction
        name = self.sregistry.make_name(prefix='prefetch_host_to_device')
        body = List(header=c.Line(), body=prefetches)
        tctx = make_thread_ctx(name, body, root, None, sync_ops,

        # Glue together all the IET pieces, including the activation logic
        sdata = tctx.sdata
        threads = tctx.threads
        iet = List(body=[
                    FieldFromComposite(sdata._field_flag, sdata[
                        threads.index]), 1))
        ] + presents + [iet, tctx.activate])

        # Fire up the threads

        # Final wait before jumping back to Python land

        # Keep track of created objects
        pieces.objs.add(sync_ops, sdata, threads)

        return iet
Ejemplo n.º 18
    def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root):
        threads = self.__make_threads()

        fetches = []
        prefetches = []
        presents = []
        for s in sync_ops:
            if s.direction is Forward:
                fc = s.fetch.subs(s.dim, s.dim.symbolic_min)
                fsize = s.function._C_get_field(FULL, s.dim).size
                fc_cond = fc + (s.size - 1) < fsize
                pfc = s.fetch + 1
                pfc_cond = pfc + (s.size - 1) < fsize
                fc = s.fetch.subs(s.dim, s.dim.symbolic_max)
                fc_cond = fc >= 0
                pfc = s.fetch - 1
                pfc_cond = pfc >= 0

            # Construct fetch IET
            imask = [(fc, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            fetch = List(header=self._P._map_to(s.function, imask))
            fetches.append(Conditional(fc_cond, fetch))

            # Construct present clauses
            imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            presents.extend(as_list(self._P._map_present(s.function, imask)))

            # Construct prefetch IET
            imask = [(pfc, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            prefetch = List(header=self._P._map_to_wait(
                s.function, imask, SharedData._field_id))
            prefetches.append(Conditional(pfc_cond, prefetch))

        functions = filter_ordered(s.function for s in sync_ops)
        casts = [PointerCast(f) for f in functions]

        # Turn init IET into a Callable
        name = self.sregistry.make_name(prefix='init_device')
        body = List(body=casts + fetches)
        parameters = filter_sorted(functions + derive_parameters(body))
        func = Callable(name, body, 'void', parameters, 'static')

        # Perform initial fetch by the main thread
            List(header=c.Comment("Initialize data stream for `%s`" %
                 body=[Call(name, func.parameters), BlankLine]))

        # Turn prefetch IET into a threaded Callable
        name = self.sregistry.make_name(prefix='prefetch_host_to_device')
        body = List(header=c.Line(), body=casts + prefetches)
        tfunc, sdata = self.__make_tfunc(name, body, root, threads)

        # Glue together all the IET pieces, including the activation bits
        iet = List(body=[
                    FieldFromComposite(sdata._field_flag, sdata[
                        threads.index]), 1)),
            List(header=presents), iet,
            self.__make_activate_thread(threads, sdata, sync_ops)

        # Fire up the threads
            self.__make_init_threads(threads, sdata, tfunc, pieces))

        # Final wait before jumping back to Python land
        pieces.finalize.append(self.__make_finalize_threads(threads, sdata))

        return iet
Ejemplo n.º 19
    def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root):
        fetches = []
        prefetches = []
        presents = []
        for s in sync_ops:
            if s.direction is Forward:
                fc = s.fetch.subs(s.dim, s.dim.symbolic_min)
                pfc = s.fetch + 1
                fc_cond = s.next_cbk(s.dim.symbolic_min)
                pfc_cond = s.next_cbk(s.dim + 1)
                fc = s.fetch.subs(s.dim, s.dim.symbolic_max)
                pfc = s.fetch - 1
                fc_cond = s.next_cbk(s.dim.symbolic_max)
                pfc_cond = s.next_cbk(s.dim - 1)

            # Construct init IET
            imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions]
            fetch = PragmaList(self.lang._map_to(s.function, imask),
                               {s.function} | fc.free_symbols)
            fetches.append(Conditional(fc_cond, fetch))

            # Construct present clauses
            imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            presents.extend(as_list(self.lang._map_present(s.function, imask)))

            # Construct prefetch IET
            imask = [(pfc, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            prefetch = PragmaList(self.lang._map_to_wait(s.function, imask,
                                  {s.function} | pfc.free_symbols)
            prefetches.append(Conditional(pfc_cond, prefetch))

        # Turn init IET into a Callable
        functions = filter_ordered(s.function for s in sync_ops)
        name = self.sregistry.make_name(prefix='init_device')
        body = List(body=fetches)
        parameters = filter_sorted(functions + derive_parameters(body))
        func = Callable(name, body, 'void', parameters, 'static')

        # Perform initial fetch by the main thread
            header=c.Comment("Initialize data stream"),
            body=[Call(name, parameters), BlankLine]

        # Turn prefetch IET into a ThreadFunction
        name = self.sregistry.make_name(prefix='prefetch_host_to_device')
        body = List(header=c.Line(), body=prefetches)
        tctx = make_thread_ctx(name, body, root, None, sync_ops, self.sregistry)

        # Glue together all the IET pieces, including the activation logic
        sdata = tctx.sdata
        threads = tctx.threads
        iet = List(body=[
                                               sdata[threads.index]), 1)),

        # Fire up the threads

        # Final wait before jumping back to Python land

        return iet