Ejemplo n.º 1
0
def assign(f, rhs=0, options=None, name='assign', **kwargs):
    """
    Assign a list of RHSs to a list of Functions.

    Parameters
    ----------
    f : Function or list of Functions
        The left-hand side of the assignment.
    rhs : expr-like or list of expr-like, optional
        The right-hand side of the assignment.
    options : dict or list of dict, optional
        Dictionary or list (of len(f)) of dictionaries containing optional arguments to
        be passed to Eq.
    name : str, optional
        Name of the operator.

    Examples
    --------
    >>> from devito import Grid, Function, assign
    >>> grid = Grid(shape=(4, 4))
    >>> f = Function(name='f', grid=grid, dtype=np.int32)
    >>> g = Function(name='g', grid=grid, dtype=np.int32)
    >>> h = Function(name='h', grid=grid, dtype=np.int32)
    >>> functions = [f, g, h]
    >>> scalars = [1, 2, 3]
    >>> assign(functions, scalars)
    >>> f.data
    Data([[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]], dtype=int32)
    >>> g.data
    Data([[2, 2, 2, 2],
          [2, 2, 2, 2],
          [2, 2, 2, 2],
          [2, 2, 2, 2]], dtype=int32)
    >>> h.data
    Data([[3, 3, 3, 3],
          [3, 3, 3, 3],
          [3, 3, 3, 3],
          [3, 3, 3, 3]], dtype=int32)
    """
    if not isinstance(rhs, list):
        rhs = len(as_list(f)) * [
            rhs,
        ]
    eqs = []
    if options:
        for i, j, k in zip(as_list(f), rhs, options):
            if k is not None:
                eqs.append(dv.Eq(i, j, **k))
            else:
                eqs.append(dv.Eq(i, j))
    else:
        for i, j in zip(as_list(f), rhs):
            eqs.append(dv.Eq(i, j))
    dv.Operator(eqs, name=name, **kwargs)()
Ejemplo n.º 2
0
    def augment(self, sub_iterators):
        """
        Create a new IterationSpace with additional sub-iterators.
        """
        items = dict(self.sub_iterators)
        for k, v in sub_iterators.items():
            if k not in self.intervals:
                continue
            items[k] = as_list(items.get(k))
            for i in as_list(v):
                if i not in items[k]:
                    items[k].append(i)

        return IterationSpace(self.intervals, items, self.directions)
Ejemplo n.º 3
0
 def augment(self, sub_iterators):
     """
     Create a new IterationSpace with additional sub iterators.
     """
     v = {k: as_list(v) for k, v in sub_iterators.items() if k in self.intervals}
     sub_iterators = {**self.sub_iterators, **v}
     return IterationSpace(self.intervals, sub_iterators, self.directions)
Ejemplo n.º 4
0
    def add(self, expr, make, terms=None):
        """
        Without ``terms``: add ``expr`` to the mapper binding it to the symbol
        generated with the callback ``make``.
        With ``terms``: add the compound sub-expression made of ``terms`` to the
        mapper. ``terms`` is a list of one or more items in ``expr.args``.
        """
        if expr in self:
            return

        if not terms:
            self[expr] = self.extracted[expr] = make()
            return

        terms = as_list(terms)

        base = terms.pop(0)
        if terms:
            k = expr.func(base, *terms)
            try:
                symbol = self.extracted[k]
            except KeyError:
                symbol = self.extracted.setdefault(k, make())
            self[expr] = self.Uxsubmap.fromkeys(terms)
            self[expr][base] = symbol
        else:
            self[base] = self.extracted[base] = make()
Ejemplo n.º 5
0
        def _(iet):
            # TODO: we need to pick the rank from `comm_shm`, not `comm`,
            # so that we have nranks == ngpus (as long as the user has launched
            # the right number of MPI processes per node given the available
            # number of GPUs per node)

            objcomm = None
            for i in iet.parameters:
                if isinstance(i, MPICommObject):
                    objcomm = i
                    break

            devicetype = as_list(self.lang[self.platform])

            try:
                lang_init = [self.lang['init'](devicetype)]
            except TypeError:
                # Not all target languages need to be explicitly initialized
                lang_init = []

            deviceid = DeviceID()
            if objcomm is not None:
                rank = Symbol(name='rank')
                rank_decl = LocalExpression(DummyEq(rank, 0))
                rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)])

                ngpus = Symbol(name='ngpus')
                call = self.lang['num-devices'](devicetype)
                ngpus_init = LocalExpression(DummyEq(ngpus, call))

                osdd_then = self.lang['set-device']([deviceid] + devicetype)
                osdd_else = self.lang['set-device']([rank % ngpus] +
                                                    devicetype)

                body = lang_init + [
                    Conditional(
                        CondNe(deviceid, -1),
                        osdd_then,
                        List(
                            body=[rank_decl, rank_init, ngpus_init, osdd_else
                                  ]),
                    )
                ]

                header = c.Comment('Begin of %s+MPI setup' % self.lang['name'])
                footer = c.Comment('End of %s+MPI setup' % self.lang['name'])
            else:
                body = lang_init + [
                    Conditional(
                        CondNe(deviceid, -1),
                        self.lang['set-device']([deviceid] + devicetype))
                ]

                header = c.Comment('Begin of %s setup' % self.lang['name'])
                footer = c.Comment('End of %s setup' % self.lang['name'])

            init = List(header=header, body=body, footer=(footer, c.Line()))
            iet = iet._rebuild(body=(init, ) + iet.body)

            return iet, {'args': deviceid}
Ejemplo n.º 6
0
        def _(iet):
            devicetype = as_list(self.lang[self.platform])
            deviceid = self.deviceid

            init = Conditional(
                CondNe(deviceid, -1),
                self.lang['set-device']([deviceid] + devicetype))
            body = iet.body._rebuild(body=(init, BlankLine) + iet.body.body)
            iet = iet._rebuild(body=body)

            return iet, {}
Ejemplo n.º 7
0
 def _set_global_idx(self, val, idx, val_idx):
     """
     Compute the global indices to which val (the locally stored data) correspond.
     """
     data_loc_idx = as_tuple(val._index_glb_to_loc(val_idx))
     data_glb_idx = []
     # Convert integers to slices so that shape dims are preserved
     if is_integer(as_tuple(idx)[0]):
         data_glb_idx.append(slice(0, 1, 1))
     for i, j in zip(data_loc_idx, val._decomposition):
         if not j.loc_empty:
             data_glb_idx.append(j.index_loc_to_glb(i))
         else:
             data_glb_idx.append(None)
     mapped_idx = []
     # Add any integer indices that were not present in `val_idx`.
     if len(as_list(idx)) > len(data_glb_idx):
         for index, value in enumerate(idx):
             if is_integer(value) and index > 0:
                 data_glb_idx.insert(index, value)
     # Based on `data_glb_idx` the indices to which the locally stored data
     # block correspond can now be computed:
     for i, j, k in zip(data_glb_idx, as_tuple(idx), self._decomposition):
         if is_integer(j):
             mapped_idx.append(j)
             continue
         elif isinstance(j, slice) and j.start is None:
             norm = 0
         elif isinstance(j, slice) and j.start is not None:
             if j.start >= 0:
                 norm = j.start
             else:
                 norm = j.start + k.glb_max + 1
         else:
             norm = j
         if i is not None:
             if isinstance(j, slice) and j.step is not None:
                 stop = j.step * i.stop + norm
             else:
                 stop = i.stop + norm
         if i is not None:
             if isinstance(j, slice) and j.step is not None:
                 mapped_idx.append(
                     slice(j.step * i.start + norm, stop, j.step))
             else:
                 mapped_idx.append(slice(i.start + norm, stop, i.step))
         else:
             mapped_idx.append(None)
     return as_tuple(mapped_idx)
Ejemplo n.º 8
0
        def _(iet):
            body = FindNodes(WhileAlive).visit(iet)
            assert len(body) == 1
            body = body.pop()

            devicetype = as_list(self.lang[self.platform])
            deviceid = self.deviceid

            init = Conditional(
                CondNe(deviceid, -1),
                self.lang['set-device']([deviceid] + devicetype)
            )

            mapper = {body: List(body=[init, BlankLine, body])}
            iet = Transformer(mapper).visit(iet)

            return iet, {}
Ejemplo n.º 9
0
def initialize_function(function,
                        data,
                        nbl,
                        mapper=None,
                        mode='constant',
                        name='padfunc',
                        **kwargs):
    """
    Initialize a Function with the given ``data``. ``data``
    does *not* include the ``nbl`` outer/boundary layers; these are added via padding
    by this function.

    Parameters
    ----------
    function : Function
        The initialised object.
    data : ndarray or Function
        The data used for initialisation.
    nbl : int or tuple of int
        Number of outer layers (such as absorbing layers for boundary damping).
    mapper : dict, optional
        Dictionary containing, for each dimension of `function`, a sub-dictionary
        containing the following keys:
        1) 'lhs': List of additional expressions to be added to the LHS expressions list.
        2) 'rhs': List of additional expressions to be added to the RHS expressions list.
        3) 'options': Options pertaining to the additional equations that will be
        constructed.
    mode : str, optional
        The function initialisation mode. 'constant' and 'reflect' are
        accepted.
    name : str, optional
        The name assigned to the operator.

    Examples
    --------
    In the following example the `'interior'` of a function is set to one plus
    the value on the boundary.

    >>> import numpy as np
    >>> from devito import Grid, SubDomain, Function, initialize_function

    Create the computational domain:

    >>> grid = Grid(shape=(6, 6))
    >>> x, y = grid.dimensions

    Create the Function we wish to set along with the data to set it:

    >>> f = Function(name='f', grid=grid, dtype=np.int32)
    >>> data = np.full((4, 4), 2, dtype=np.int32)

    Now create the additional expressions and options required to set the value of
    the interior region to one greater than the boundary value. Note that the equation
    is specified on the second (final) grid dimension so that additional equation is
    executed after padding is complete.

    >>> lhs = f
    >>> rhs = f+1
    >>> options = {'subdomain': grid.subdomains['interior']}
    >>> mapper = {}
    >>> mapper[y] = {'lhs': lhs, 'rhs': rhs, 'options': options}

    Call the initialize_function routine:

    >>> initialize_function(f, data, 1, mapper=mapper)
    >>> f.data
    Data([[2, 2, 2, 2, 2, 2],
          [2, 3, 3, 3, 3, 2],
          [2, 3, 3, 3, 3, 2],
          [2, 3, 3, 3, 3, 2],
          [2, 3, 3, 3, 3, 2],
          [2, 2, 2, 2, 2, 2]], dtype=int32)
    """
    if isinstance(function, dv.TimeFunction):
        raise NotImplementedError("TimeFunctions are not currently supported.")

    if nbl == 0:
        if isinstance(data, dv.Function):
            function.data[:] = data.data[:]
        else:
            function.data[:] = data[:]
        return

    if len(as_tuple(nbl)) == 1 and len(as_tuple(nbl)) < function.ndim:
        nbl = function.ndim * (as_tuple(nbl)[0], )
    elif len(as_tuple(nbl)) == function.ndim:
        pass
    else:
        raise ValueError(
            "nbl must be an integer or tuple of integers of length" +
            " function.shape.")

    slices = tuple([
        slice(n, -n) for _, n in zip(range(function.grid.dim), as_tuple(nbl))
    ])
    if isinstance(data, dv.Function):
        function.data[slices] = data.data[:]
    else:
        function.data[slices] = data
    lhs = []
    rhs = []
    options = []

    if mode == 'reflect' and function.grid.distributor.is_parallel:
        # Check that HALO size is appropriate
        halo = function.halo
        local_size = function.shape

        def buff(i, j):
            return [(i + k - 2 * max(nbl)) for k in j]

        b = [
            min(l) for l in (w for w in (buff(i, j)
                                         for i, j in zip(local_size, halo)))
        ]
        if any(np.array(b) < 0):
            raise ValueError("Function `%s` halo is not sufficiently thick." %
                             function)

    for d, n in zip(function.space_dimensions, as_tuple(nbl)):
        dim_l = dv.SubDimension.left(name='abc_%s_l' % d.name,
                                     parent=d,
                                     thickness=n)
        dim_r = dv.SubDimension.right(name='abc_%s_r' % d.name,
                                      parent=d,
                                      thickness=n)
        if mode == 'constant':
            subsl = n
            subsr = d.symbolic_max - n
        elif mode == 'reflect':
            subsl = 2 * n - 1 - dim_l
            subsr = 2 * (d.symbolic_max - n) + 1 - dim_r
        else:
            raise ValueError("Mode not available")
        lhs.append(function.subs({d: dim_l}))
        lhs.append(function.subs({d: dim_r}))
        rhs.append(function.subs({d: subsl}))
        rhs.append(function.subs({d: subsr}))
        options.extend([None, None])

        if mapper and d in mapper.keys():
            exprs = mapper[d]
            lhs_extra = exprs['lhs']
            rhs_extra = exprs['rhs']
            lhs.extend(as_list(lhs_extra))
            rhs.extend(as_list(rhs_extra))
            options_extra = exprs.get('options',
                                      len(as_list(lhs_extra)) * [
                                          None,
                                      ])
            if isinstance(options_extra, list):
                options.extend(options_extra)
            else:
                options.extend([options_extra])

    if all(options is None for i in options):
        options = None

    assign(lhs, rhs, options=options, name=name, **kwargs)
Ejemplo n.º 10
0
 def add_include_dirs(self, dirs):
     self.include_dirs = filter_ordered(self.include_dirs + as_list(dirs))
Ejemplo n.º 11
0
 def __pfields_setup__(cls, **kwargs):
     fields = as_list(kwargs.get('fields'))
     fields.extend(
         [cls._symbolic_id, cls._symbolic_deviceid, cls._symbolic_flag])
     return [(i._C_name, i._C_ctype) for i in fields]
Ejemplo n.º 12
0
    def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root):
        fetches = []
        prefetches = []
        presents = []
        for s in sync_ops:
            if s.direction is Forward:
                fc = s.fetch.subs(s.dim, s.dim.symbolic_min)
                pfc = s.fetch + 1
                fc_cond = s.next_cbk(s.dim.symbolic_min)
                pfc_cond = s.next_cbk(s.dim + 1)
            else:
                fc = s.fetch.subs(s.dim, s.dim.symbolic_max)
                pfc = s.fetch - 1
                fc_cond = s.next_cbk(s.dim.symbolic_max)
                pfc_cond = s.next_cbk(s.dim - 1)

            # Construct init IET
            imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions]
            fetch = PragmaList(self.lang._map_to(s.function, imask),
                               {s.function} | fc.free_symbols)
            fetches.append(Conditional(fc_cond, fetch))

            # Construct present clauses
            imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            presents.extend(as_list(self.lang._map_present(s.function, imask)))

            # Construct prefetch IET
            imask = [(pfc, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            prefetch = PragmaList(self.lang._map_to_wait(s.function, imask,
                                                         SharedData._field_id),
                                  {s.function} | pfc.free_symbols)
            prefetches.append(Conditional(pfc_cond, prefetch))

        # Turn init IET into a Callable
        functions = filter_ordered(s.function for s in sync_ops)
        name = self.sregistry.make_name(prefix='init_device')
        body = List(body=fetches)
        parameters = filter_sorted(functions + derive_parameters(body))
        func = Callable(name, body, 'void', parameters, 'static')
        pieces.funcs.append(func)

        # Perform initial fetch by the main thread
        pieces.init.append(List(
            header=c.Comment("Initialize data stream"),
            body=[Call(name, parameters), BlankLine]
        ))

        # Turn prefetch IET into a ThreadFunction
        name = self.sregistry.make_name(prefix='prefetch_host_to_device')
        body = List(header=c.Line(), body=prefetches)
        tctx = make_thread_ctx(name, body, root, None, sync_ops, self.sregistry)
        pieces.funcs.extend(tctx.funcs)

        # Glue together all the IET pieces, including the activation logic
        sdata = tctx.sdata
        threads = tctx.threads
        iet = List(body=[
            BlankLine,
            BusyWait(CondNe(FieldFromComposite(sdata._field_flag,
                                               sdata[threads.index]), 1)),
            List(header=presents),
            iet,
            tctx.activate
        ])

        # Fire up the threads
        pieces.init.append(tctx.init)
        pieces.threads.append(threads)

        # Final wait before jumping back to Python land
        pieces.finalize.append(tctx.finalize)

        return iet
Ejemplo n.º 13
0
 def add_library_dirs(self, dirs):
     self.library_dirs = filter_ordered(self.library_dirs + as_list(dirs))
Ejemplo n.º 14
0
    def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root):
        threads = self.__make_threads()

        fetches = []
        prefetches = []
        presents = []
        for s in sync_ops:
            if s.direction is Forward:
                fc = s.fetch.subs(s.dim, s.dim.symbolic_min)
                fsize = s.function._C_get_field(FULL, s.dim).size
                fc_cond = fc + (s.size - 1) < fsize
                pfc = s.fetch + 1
                pfc_cond = pfc + (s.size - 1) < fsize
            else:
                fc = s.fetch.subs(s.dim, s.dim.symbolic_max)
                fc_cond = fc >= 0
                pfc = s.fetch - 1
                pfc_cond = pfc >= 0

            # Construct fetch IET
            imask = [(fc, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            fetch = List(header=self._P._map_to(s.function, imask))
            fetches.append(Conditional(fc_cond, fetch))

            # Construct present clauses
            imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            presents.extend(as_list(self._P._map_present(s.function, imask)))

            # Construct prefetch IET
            imask = [(pfc, s.size) if d.root is s.dim.root else FULL
                     for d in s.dimensions]
            prefetch = List(header=self._P._map_to_wait(
                s.function, imask, SharedData._field_id))
            prefetches.append(Conditional(pfc_cond, prefetch))

        functions = filter_ordered(s.function for s in sync_ops)
        casts = [PointerCast(f) for f in functions]

        # Turn init IET into a Callable
        name = self.sregistry.make_name(prefix='init_device')
        body = List(body=casts + fetches)
        parameters = filter_sorted(functions + derive_parameters(body))
        func = Callable(name, body, 'void', parameters, 'static')
        pieces.funcs.append(func)

        # Perform initial fetch by the main thread
        pieces.init.append(
            List(header=c.Comment("Initialize data stream for `%s`" %
                                  threads.name),
                 body=[Call(name, func.parameters), BlankLine]))

        # Turn prefetch IET into a threaded Callable
        name = self.sregistry.make_name(prefix='prefetch_host_to_device')
        body = List(header=c.Line(), body=casts + prefetches)
        tfunc, sdata = self.__make_tfunc(name, body, root, threads)
        pieces.funcs.append(tfunc)

        # Glue together all the IET pieces, including the activation bits
        iet = List(body=[
            BlankLine,
            BusyWait(
                CondNe(
                    FieldFromComposite(sdata._field_flag, sdata[
                        threads.index]), 1)),
            List(header=presents), iet,
            self.__make_activate_thread(threads, sdata, sync_ops)
        ])

        # Fire up the threads
        pieces.init.append(
            self.__make_init_threads(threads, sdata, tfunc, pieces))
        pieces.threads.append(threads)

        # Final wait before jumping back to Python land
        pieces.finalize.append(self.__make_finalize_threads(threads, sdata))

        return iet
Ejemplo n.º 15
0
def actions_from_update_memcpy(cluster, clusters, prefix, actions):
    it = prefix[-1]
    d = it.dim
    direction = it.direction

    # Prepare the data to instantiate a PrefetchUpdate SyncOp
    e = cluster.exprs[0]

    size = 1

    function = e.rhs.function
    fetch = e.rhs.indices[d]
    ifetch = fetch.subs(d, d.symbolic_min)
    if direction is Forward:
        fcond = make_cond(cluster.guards.get(d), d, direction, d.symbolic_min)
    else:
        fcond = make_cond(cluster.guards.get(d), d, direction, d.symbolic_max)

    if direction is Forward:
        pfetch = fetch + 1
        pcond = make_cond(cluster.guards.get(d), d, direction, d + 1)
    else:
        pfetch = fetch - 1
        pcond = make_cond(cluster.guards.get(d), d, direction, d - 1)

    target = e.lhs.function
    tstore0 = e.lhs.indices[d]

    # If fetching into e.g., `ub[sb1]`, we'll need to prefetch into e.g. `ub[sb0]`
    if is_integer(tstore0):
        tstore = tstore0
    else:
        assert tstore0.is_Modulo
        subiters = [md for md in cluster.sub_iterators[d] if md.parent is tstore0.parent]
        osubiters = sorted(subiters, key=lambda i: Vector(i.offset))
        n = osubiters.index(tstore0)
        if direction is Forward:
            tstore = osubiters[(n + 1) % len(osubiters)]
        else:
            tstore = osubiters[(n - 1) % len(osubiters)]

    # Turn `cluster` into a prefetch Cluster
    expr = uxreplace(e, {tstore0: tstore, fetch: pfetch})
    guards = {d: And(*([pcond] + as_list(cluster.guards.get(d))))}
    syncs = {d: [PrefetchUpdate(
        d, size,
        function, fetch, ifetch, fcond,
        pfetch, pcond,
        target, tstore
    )]}
    pcluster = cluster.rebuild(exprs=expr, guards=guards, syncs=syncs)

    # Since we're turning `e` into a prefetch, we need to:
    # 1) attach a WaitPrefetch SyncOp to the first Cluster accessing `target`
    # 2) insert the prefetch Cluster right after the last Cluster accessing `target`
    # 3) drop the original Cluster performing a memcpy-based fetch
    n = clusters.index(cluster)
    first = None
    last = None
    for c in clusters[n+1:]:
        if target in c.scope.reads:
            if first is None:
                first = c
            last = c
    assert first is not None
    assert last is not None
    actions[first].syncs[d].append(WaitPrefetch(
        d, size,
        function, fetch, ifetch, fcond,
        pfetch, pcond,
        target, tstore
    ))
    actions[last].insert.append(pcluster)
    actions[cluster].drop = True
Ejemplo n.º 16
0
 def __pfields_setup__(cls, **kwargs):
     fields = as_list(kwargs.get('fields')) + [cls._symbolic_id, cls._symbolic_flag]
     return [(i._C_name, i._C_ctype) for i in fields]
Ejemplo n.º 17
0
 def add_ldflags(self, flags):
     self.ldflags = filter_ordered(self.ldflags + as_list(flags))
Ejemplo n.º 18
0
 def add_libraries(self, libs):
     self.libraries = filter_ordered(self.libraries + as_list(libs))
Ejemplo n.º 19
0
def evalrel(func=min, input=None, assumptions=None):
    """
    The purpose of this function is two-fold: (i) to reduce the `input` candidates of a
    for a MIN/MAX expression based on the given `assumptions` and (ii) return the nested
    MIN/MAX expression of the reduced-size input.

    Parameters
    ----------
    func : builtin function or method
        min or max. Defines the operation to simplify. Defaults to `min`.
    input : list
        A list of the candidate symbols to be simplified. Defaults to None.
    assumptions : list
        A list of assumptions formed as relationals between candidates, assumed to be
        True. Defaults to None.

    Examples
    --------
    Assuming no values are known for `a`, `b`, `c`, `d` but we know that `d<=a` and
    `c>=b` we can safely drop `a` and `c` from the candidate list.

    >>> from devito import Symbol
    >>> a = Symbol('a')
    >>> b = Symbol('b')
    >>> c = Symbol('c')
    >>> d = Symbol('d')
    >>> evalrel(max, [a, b, c, d], [Le(d, a), Ge(c, b)])
    MAX(a, c)
    """
    sfunc = (Min if func is min else Max)  # Choose SymPy's Min/Max

    # Form relationals so that RHS has more arguments than LHS:
    # i.e. (a + d >= b) ---> (b <= a + d)
    assumptions = [
        a.reversed if len(a.lhs.args) > len(a.rhs.args) else a
        for a in as_list(assumptions)
    ]

    # Break-down relations if possible
    processed = []
    for a in as_list(assumptions):
        if isinstance(a, (Ge, Gt)) and a.rhs.is_Add and a.lhs.is_positive:
            if all(i.is_positive for i in a.rhs.args):
                # If `c >= a + b, {a, b, c} >= 0` then add 'c>=a, c>=b'
                processed.extend(Ge(a.lhs, i) for i in a.rhs.args)
            elif len(a.rhs.args) == 2:
                # If `c >= a + b, a>=0, b<=0` then add 'c>=b, c<=a'
                processed.extend(
                    Ge(a.lhs, i) if not i.is_positive else Le(a.lhs, i)
                    for i in a.rhs.args)
            else:
                processed.append(a)
        else:
            processed.append(a)

    # Apply assumptions to fill a subs mapper
    # e.g. When looking for 'max' and Gt(a, b), mapper is filled with {b: a} so that `b`
    # is subsituted by `a`
    mapper = {}
    for a in processed:
        if set(a.args).issubset(input):
            # If a.args=(a, b) and input=(a, b, c), condition is True,
            # if a.args=(a, d) and input=(a, b, c), condition is False
            assert len(a.args) == 2
            a0, a1 = a.args
            if ((isinstance(a, (Ge, Gt)) and func is max)
                    or (isinstance(a, (Le, Lt)) and func is min)):
                mapper[a1] = a0
            elif ((isinstance(a, (Le, Lt)) and func is max)
                  or (isinstance(a, (Ge, Gt)) and func is min)):
                mapper[a0] = a1

    # Collapse graph paths
    mapper = transitive_closure(mapper)
    input = [i.subs(mapper) for i in input]

    # Explore simplification opportunities that may have emerged and generate MIN/MAX
    # expression
    try:
        exp = sfunc(*input)  # Can it be evaluated or simplified?
        if exp.is_Function:
            # Use the new args only if evaluation managed to reduce the number
            # of candidates.
            input = min(input, exp.args, key=len)
        else:
            # Since we are here, exp is a simplified expression
            return exp
    except TypeError:
        pass
    return rfunc(func, *input)
Ejemplo n.º 20
0
    def _make_fetchprefetch(self, iet, sync_ops, pieces, root):
        fid = SharedData._field_id

        fetches = []
        prefetches = []
        presents = []
        for s in sync_ops:
            f = s.function
            dimensions = s.dimensions
            fc = s.fetch
            ifc = s.ifetch
            pfc = s.pfetch
            fcond = s.fcond
            pcond = s.pcond

            # Construct init IET
            imask = [(ifc, s.size) if d.root is s.dim.root else FULL
                     for d in dimensions]
            fetch = PragmaList(self.lang._map_to(f, imask), f,
                               ifc.free_symbols | {f.indexed})
            fetches.append(Conditional(fcond, fetch))

            # Construct present clauses
            imask = [(fc, s.size) if d.root is s.dim.root else FULL
                     for d in dimensions]
            presents.extend(as_list(self.lang._map_present(f, imask)))

            # Construct prefetch IET
            imask = [(pfc, s.size) if d.root is s.dim.root else FULL
                     for d in dimensions]
            prefetch = PragmaList(self.lang._map_to_wait(f, imask, fid), f,
                                  pfc.free_symbols | {f.indexed})
            prefetches.append(Conditional(pcond, prefetch))

        # Turn init IET into a Callable
        functions = filter_ordered(s.function for s in sync_ops)
        name = self.sregistry.make_name(prefix='init_device')
        body = List(body=fetches)
        parameters = filter_sorted(functions + derive_parameters(body))
        func = Callable(name, body, 'void', parameters, 'static')
        pieces.funcs.append(func)

        # Perform initial fetch by the main thread
        pieces.init.append(
            List(header=c.Comment("Initialize data stream"),
                 body=[Call(name, parameters), BlankLine]))

        # Turn prefetch IET into a ThreadFunction
        name = self.sregistry.make_name(prefix='prefetch_host_to_device')
        body = List(header=c.Line(), body=prefetches)
        tctx = make_thread_ctx(name, body, root, None, sync_ops,
                               self.sregistry)
        pieces.funcs.extend(tctx.funcs)

        # Glue together all the IET pieces, including the activation logic
        sdata = tctx.sdata
        threads = tctx.threads
        iet = List(body=[
            BlankLine,
            BusyWait(
                CondNe(
                    FieldFromComposite(sdata._field_flag, sdata[
                        threads.index]), 1)),
            List(header=presents), iet, tctx.activate
        ])

        # Fire up the threads
        pieces.init.append(tctx.init)

        # Final wait before jumping back to Python land
        pieces.finalize.append(tctx.finalize)

        # Keep track of created objects
        pieces.objs.add(sync_ops, sdata, threads)

        return iet