예제 #1
0
    def _dump_storage(self, iet, storage):
        mapper = {}
        for k, v in storage.items():
            # Expr -> LocalExpr ?
            if k.is_Expression:
                mapper[k] = v
                continue

            # allocs/pallocs
            allocs = flatten(v.allocs)
            for tid, body in as_mapper(v.pallocs, itemgetter(0), itemgetter(1)).items():
                header = self.lang.Region._make_header(tid.symbolic_size)
                init = c.Initializer(c.Value(tid._C_typedata, tid.name),
                                     self.lang['thread-num'])
                allocs.append(c.Module((header, c.Block([init] + body))))
            if allocs:
                allocs.append(c.Line())

            # frees/pfrees
            frees = []
            for tid, body in as_mapper(v.pfrees, itemgetter(0), itemgetter(1)).items():
                header = self.lang.Region._make_header(tid.symbolic_size)
                init = c.Initializer(c.Value(tid._C_typedata, tid.name),
                                     self.lang['thread-num'])
                frees.append(c.Module((header, c.Block([init] + body))))
            frees.extend(flatten(v.frees))
            if frees:
                frees.insert(0, c.Line())

            mapper[k] = k._rebuild(body=List(header=allocs, body=k.body, footer=frees),
                                   **k.args_frozen)

        processed = Transformer(mapper, nested=True).visit(iet)

        return processed
예제 #2
0
    def _inject_definitions(self, iet, storage):
        mapper = {}
        for k, v in storage.items():
            # Expr -> LocalExpr ?
            if k.is_Expression:
                mapper[k] = v
                continue

            # objects
            objs = flatten(v.objs)

            # allocs/pallocs
            allocs = flatten(v.allocs)
            for tid, body in as_mapper(v.pallocs, itemgetter(0),
                                       itemgetter(1)).items():
                header = self.lang.Region._make_header(tid.symbolic_size)
                init = self.lang['thread-num'](retobj=tid)
                allocs.append(Block(header=header, body=[init] + body))

            # frees/pfrees
            frees = []
            for tid, body in as_mapper(v.pfrees, itemgetter(0),
                                       itemgetter(1)).items():
                header = self.lang.Region._make_header(tid.symbolic_size)
                init = self.lang['thread-num'](retobj=tid)
                frees.append(Block(header=header, body=[init] + body))
            frees.extend(flatten(v.frees))

            mapper[k.body] = k.body._rebuild(allocs=allocs,
                                             objs=objs,
                                             frees=frees)

        processed = Transformer(mapper, nested=True).visit(iet)

        return processed
예제 #3
0
def _lower_stepping_dims(iet):
    """
    Lower SteppingDimensions: index functions involving SteppingDimensions are
    turned into ModuloDimensions.

    Examples
    --------
    u[t+1, x] = u[t, x] + 1

    becomes

    u[t1, x] = u[t0, x] + 1
    """
    for i in FindNodes(Iteration).visit(iet):
        if not i.uindices:
            # Be quick: avoid uselessy reconstructing nodes
            continue
        # In an expression, there could be `u[t+1, ...]` and `v[t+1, ...]`, where
        # `u` and `v` are TimeFunction with circular time buffers (save=None) *but*
        # different modulo extent. The `t+1` indices above are therefore conceptually
        # different, so they will be replaced with the proper ModuloDimension through
        # two different calls to `xreplace`
        mindices = [d for d in i.uindices if d.is_Modulo]
        groups = as_mapper(mindices, lambda d: d.modulo)
        for k, v in groups.items():
            mapper = {d.origin: d for d in v}
            rule = lambda i: i.function.is_TimeFunction and i.function._time_size == k
            replacer = lambda i: xreplace_indices(i, mapper, rule)
            iet = XSubs(replacer=replacer).visit(iet)

    return iet
예제 #4
0
def _(expr, terms):
    derivs, others = split(terms, lambda i: i.deriv is not None)
    if not derivs:
        return expr, Term(expr)

    # Map by type of derivative
    mapper = as_mapper(derivs, lambda i: key(i.deriv))
    if len(mapper) == len(derivs):
        return expr, Term(expr)

    processed = []
    for v in mapper.values():
        fact, nonfact = split(v, lambda i: _is_const_coeff(i.other, i.deriv))
        if fact:
            # Finally factorize derivative arguments
            func = fact[0].deriv._new_from_self
            exprs = []
            for i in fact:
                if i.func:
                    exprs.append(i.func(i.other, i.deriv.expr))
                else:
                    assert i.other == 1
                    exprs.append(i.deriv.expr)
            fact = [Term(func(expr=expr.func(*exprs)))]

        for i in fact + nonfact:
            if i.func:
                processed.append(i.func(i.other, i.deriv))
            else:
                processed.append(i.other)

    others = [i.other for i in others]
    expr = expr.func(*(processed + others))

    return expr, Term(expr)
예제 #5
0
파일: aliases.py 프로젝트: ofmla/devito
    def callback(self, clusters, prefix, xtracted=None):
        if not prefix:
            return clusters
        d = prefix[-1].dim

        # Rule out extractions that would break data dependencies
        exclude = set().union(*[c.scope.writes for c in clusters])

        # Rule out extractions that depend on the Dimension currently investigated,
        # as they clearly wouldn't be invariants
        exclude.add(d)

        key = lambda c: self._lookup_key(c, d)
        processed = list(clusters)
        for ak, group in as_mapper(clusters, key=key).items():
            g = [c for c in group if c.is_dense and c not in xtracted]
            if not g:
                continue

            made = self._aliases_from_clusters(g, exclude, ak)

            if made:
                for n, c in enumerate(g, -len(g)):
                    processed[processed.index(c)] = made.pop(n)
                processed = made + processed

                xtracted.extend(made)

        return processed
예제 #6
0
def _drop_halospots(iet):
    """
    Remove HaloSpots that:

        * Embed SEQUENTIAL Iterations
        * Would be used to compute Increments (in which case, a halo exchange
          is actually unnecessary)
    """
    mapper = defaultdict(set)

    # If a HaloSpot Dimension turns out to be SEQUENTIAL, then the HaloSpot is useless
    for hs, iterations in MapNodes(HaloSpot, Iteration).visit(iet).items():
        dmapper = as_mapper(iterations, lambda i: i.dim.root)
        for d, v in dmapper.items():
            if d in hs.dimensions and all(i.is_Sequential for i in v):
                mapper[hs].update(set(hs.functions))
                break

    # If all HaloSpot reads pertain to increments, then the HaloSpot is useless
    for hs, expressions in MapNodes(HaloSpot, Expression).visit(iet).items():
        for f in hs.fmapper:
            scope = Scope([i.expr for i in expressions])
            if all(i.is_increment for i in scope.reads.get(f, [])):
                mapper[hs].add(f)

    # Transform the IET introducing the "reduced" HaloSpots
    subs = {
        hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs]))
        for hs in FindNodes(HaloSpot).visit(iet)
    }
    iet = Transformer(subs, nested=True).visit(iet)

    return iet
예제 #7
0
def iet_lower_dimensions(iet):
    """
    Replace all DerivedDimensions within the ``iet``'s expressions with
    lower-level symbolic objects (other Dimensions or Symbols).

        * Array indices involving SteppingDimensions are turned into ModuloDimensions.
          Example: ``u[t+1, x] = u[t, x] + 1 >>> u[t1, x] = u[t0, x] + 1``
        * Array indices involving ConditionalDimensions used are turned into
          integer-division expressions.
          Example: ``u[t_sub, x] = u[time, x] >>> u[time / 4, x] = u[time, x]``
    """
    # Lower SteppingDimensions
    for i in FindNodes(Iteration).visit(iet):
        if not i.uindices:
            # Be quick: avoid uselessy reconstructing nodes
            continue
        # In an expression, there could be `u[t+1, ...]` and `v[t+1, ...]`, where
        # `u` and `v` are TimeFunction with circular time buffers (save=None) *but*
        # different modulo extent. The `t+1` indices above are therefore conceptually
        # different, so they will be replaced with the proper ModuloDimension through
        # two different calls to `xreplace`
        groups = as_mapper(i.uindices, lambda d: d.modulo)
        for k, v in groups.items():
            mapper = {d.origin: d for d in v}
            rule = lambda i: i.function.is_TimeFunction and i.function._time_size == k
            replacer = lambda i: xreplace_indices(i, mapper, rule)
            iet = XSubs(replacer=replacer).visit(iet)

    # Lower ConditionalDimensions
    cdims = [d for d in FindSymbols('free-symbols').visit(iet)
             if isinstance(d, ConditionalDimension)]
    mapper = {d: IntDiv(d.index, d.factor) for d in cdims}
    iet = XSubs(mapper).visit(iet)

    return iet
예제 #8
0
 def _eval_at(self, func):
     """
     Evaluates the derivative at the location of `func`. It is necessary for staggered
     setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
     has to be computed at x=x + h_x/2.
     """
     # If an x0 already exists do not overwrite it
     x0 = self.x0 or dict(func.indices_ref._getters)
     if self.expr.is_Add:
         # If `expr` has both staggered and non-staggered terms such as
         # `(u(x + h_x/2) + v(x)).dx` then we exploit linearity of FD to split
         # it into `u(x + h_x/2).dx` and `v(x).dx`, since they require
         # different FD indices
         mapper = as_mapper(self.expr._args_diff, lambda i: i.staggered)
         args = [self.expr.func(*v) for v in mapper.values()]
         args.extend(
             [a for a in self.expr.args if a not in self.expr._args_diff])
         args = [self._new_from_self(expr=a, x0=x0) for a in args]
         return self.expr.func(*args)
     elif self.expr.is_Mul:
         # For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear
         # in most equation with div(a * u) for example. The expression is re-centered
         # at the highest priority index (see _gather_for_diff) to compute the
         # derivative at x0.
         return self._new_from_self(x0=x0, expr=self.expr._gather_for_diff)
     else:
         # For every other cases, that has more functions or more complexe arithmetic,
         # there is not actual way to decide what to do so it’s as safe to use
         # the expression as is.
         return self._new_from_self(x0=x0)
예제 #9
0
def hs_comp_locindices(f, dims, ispace, scope):
    """
    Map the Dimensions in ``dims`` to the local indices necessary
    to perform a halo exchange, as described in HaloScheme.__doc__.

    Examples
    --------
    1) u[t+1, x] = f(u[t, x])   => shift == 1
    2) u[t-1, x] = f(u[t, x])   => shift == 1
    3) u[t+1, x] = f(u[t+1, x]) => shift == 0
    In the first and second cases, the x-halo should be inserted at `t`,
    while in the last case it should be inserted at `t+1`.
    """
    loc_indices = {}
    for d in dims:
        func = max if ispace.is_forward(d.root) else min
        loc_index = func([i[d] for i in scope.getreads(f)],
                         key=lambda i: i - d)
        if d.is_Stepping:
            subiters = ispace.sub_iterators.get(d.root, [])
            submap = as_mapper(subiters, lambda md: md.modulo)
            submap = {i.origin: i for i in submap[f._time_size]}
            try:
                loc_indices[d] = submap[loc_index]
            except KeyError:
                raise HaloSchemeException(
                    "Don't know how to build a HaloScheme as the "
                    "stepping index `%s` is undefined" % loc_index)
        else:
            loc_indices[d] = loc_index
    return loc_indices
예제 #10
0
def cross_cluster_cse(clusters):
    """
    Perform common sub-expressions elimination across an iterable of Clusters.
    """
    clusters = clusters.unfreeze()

    # Detect redundancies
    mapper = {}
    for c in clusters:
        candidates = [i for i in c.trace.values() if i.is_unbound_temporary]
        for v in as_mapper(candidates, lambda i: i.rhs).values():
            for i in v[:-1]:
                mapper[i.lhs.base] = v[-1].lhs.base

    if not mapper:
        # Do not waste time reconstructing identical expressions
        return clusters

    # Apply substitutions
    for c in clusters:
        c.exprs = [
            i.xreplace(mapper) for i in c.trace.values()
            if i.lhs.base not in mapper
        ]

    return clusters
예제 #11
0
파일: scheduler.py 프로젝트: opesci/devito
def iet_lower_dimensions(iet):
    """
    Replace all DerivedDimensions within the ``iet``'s expressions with
    lower-level symbolic objects (other Dimensions or Symbols).

        * Array indices involving SteppingDimensions are turned into ModuloDimensions.
          Example: ``u[t+1, x] = u[t, x] + 1 >>> u[t1, x] = u[t0, x] + 1``
        * Array indices involving ConditionalDimensions used are turned into
          integer-division expressions.
          Example: ``u[t_sub, x] = u[time, x] >>> u[time / 4, x] = u[time, x]``
    """
    # Lower SteppingDimensions
    for i in FindNodes(Iteration).visit(iet):
        if not i.uindices:
            # Be quick: avoid uselessy reconstructing nodes
            continue
        # In an expression, there could be `u[t+1, ...]` and `v[t+1, ...]`, where
        # `u` and `v` are TimeFunction with circular time buffers (save=None) *but*
        # different modulo extent. The `t+1` indices above are therefore conceptually
        # different, so they will be replaced with the proper ModuloDimension through
        # two different calls to `xreplace`
        groups = as_mapper(i.uindices, lambda d: d.modulo)
        for k, v in groups.items():
            mapper = {d.origin: d for d in v}
            rule = lambda i: i.function.is_TimeFunction and i.function._time_size == k
            replacer = lambda i: xreplace_indices(i, mapper, rule)
            iet = XSubs(replacer=replacer).visit(iet)

    # Lower ConditionalDimensions
    cdims = [d for d in FindSymbols('free-symbols').visit(iet)
             if isinstance(d, ConditionalDimension)]
    mapper = {d: IntDiv(d.index, d.factor) for d in cdims}
    iet = XSubs(mapper).visit(iet)

    return iet
예제 #12
0
파일: operator.py 프로젝트: yuriyi/devito
    def _apply_substitutions(self, expressions, subs):
        """
        Transform ``expressions`` by: ::

            * Applying any user-provided symbolic substitution;
            * Replacing :class:`Dimension`s with :class:`SubDimension`s based
              on the expression :class:`Region`.
        """
        domain_subs = subs
        interior_subs = subs.copy()

        processed = []
        mapper = as_mapper(expressions, lambda i: i._region)
        for k, v in mapper.items():
            for e in v:
                if k is INTERIOR:
                    # Introduce SubDimensions to iterate over the INTERIOR region only
                    candidates = [
                        i for i in e.free_symbols
                        if isinstance(i, Dimension) and i.is_Space
                    ]
                    interior_subs.update({
                        i: SubDimension.middle("%si" % i, i, 1, 1)
                        for i in candidates if i not in interior_subs
                    })
                    processed.append(e.xreplace(interior_subs))
                elif k is DOMAIN:
                    processed.append(e.xreplace(domain_subs))
                else:
                    raise ValueError("Unsupported Region `%s`" % k)

        return processed
예제 #13
0
    def process(self, iet):
        sync_spots = FindNodes(SyncSpot).visit(iet)
        if not sync_spots:
            return iet, {}

        def key(s):
            # The SyncOps are to be processed in the following order
            return [
                WaitLock, WithLock, Delete, FetchUpdate, FetchPrefetch,
                PrefetchUpdate, WaitPrefetch
            ].index(s)

        callbacks = {
            WaitLock: self._make_waitlock,
            WithLock: self._make_withlock,
            Delete: self._make_delete,
            FetchUpdate: self._make_fetchupdate,
            FetchPrefetch: self._make_fetchprefetch,
            PrefetchUpdate: self._make_prefetchupdate
        }
        postponed_callbacks = {WaitPrefetch: self._make_waitprefetch}
        all_callbacks = [callbacks, postponed_callbacks]

        pieces = namedtuple('Pieces', 'init finalize funcs objs')([], [], [],
                                                                  Objs())

        # The processing is a two-step procedure; first, we apply the `callbacks`;
        # then, the `postponed_callbacks`, as these depend on objects produced by the
        # `callbacks`
        subs = {}
        for cbks in all_callbacks:
            for n in sync_spots:
                mapper = as_mapper(n.sync_ops, lambda i: type(i))
                for _type in sorted(mapper, key=key):
                    try:
                        subs[n] = cbks[_type](subs.get(n, n), mapper[_type],
                                              pieces, iet)
                    except KeyError:
                        pass

        iet = Transformer(subs).visit(iet)

        # Add initialization and finalization code
        init = List(body=pieces.init, footer=c.Line())
        finalize = List(header=c.Line(), body=pieces.finalize)
        body = iet.body._rebuild(body=(init, ) + iet.body.body + (finalize, ))
        iet = iet._rebuild(body=body)

        return iet, {
            'efuncs': pieces.funcs,
            'includes': ['pthread.h'],
            'args':
            [i.size for i in pieces.objs.threads if not is_integer(i.size)]
        }
예제 #14
0
def iet_lower_steppers(iet):
    """
    Replace the :class:`SteppingDimension`s within ``iet``'s expressions with
    suitable :class:`ModuloDimension`s.
    """
    for i in FindNodes(Iteration).visit(iet):
        if not i.uindices:
            # Be quick: avoid uselessy reconstructing nodes
            continue
        # In an expression, there could be `u[t+1, ...]` and `v[t+1, ...]`, where
        # `u` and `v` are TimeFunction with circular time buffers (save=None) *but*
        # different modulo extent. The `t+1` indices above are therefore conceptually
        # different, so they will be replaced with the proper ModuloDimension through
        # two different calls to `xreplace`
        groups = as_mapper(i.uindices, lambda d: d.modulo)
        for k, v in groups.items():
            mapper = {d.origin: d for d in v}
            rule = lambda i: i.function._time_size == k
            iet = ReplaceStepIndices(mapper, rule).visit(iet)
    return iet
예제 #15
0
파일: analysis.py 프로젝트: yuriyi/devito
def mark_halospots(analysis):
    """Update the ``analysis`` detecting the ``REDUNDANT`` HaloSpots within
    ``analysis.iet``."""
    properties = OrderedDict()

    def analyze(fmapper, scope):
        for f, hse in fmapper.items():
            if any(dep.cause & set(hse.loc_indices) for dep in scope.d_anti.project(f)):
                return False
        return True

    for i, scope in analysis.scopes.items():
        mapper = as_mapper(FindNodes(HaloSpot).visit(i), lambda hs: hs.halo_scheme)
        for k, v in mapper.items():
            if len(v) == 1:
                continue
            if analyze(k.fmapper, scope):
                properties.update({i: REDUNDANT for i in v[1:]})

    analysis.update(properties)
예제 #16
0
    def process(self, iet):
        def key(s):
            # The SyncOps are to be processed in the following order
            return [WaitLock, WithLock, Delete, FetchWait,
                    FetchWaitPrefetch].index(s)

        callbacks = {
            WaitLock: self._make_waitlock,
            WithLock: self._make_withlock,
            FetchWait: self._make_fetchwait,
            FetchWaitPrefetch: self._make_fetchwaitprefetch,
            Delete: self._make_delete
        }

        sync_spots = FindNodes(SyncSpot).visit(iet)

        if not sync_spots:
            return iet, {}

        pieces = namedtuple('Pieces', 'init finalize funcs threads')([], [],
                                                                     [], [])

        subs = {}
        for n in sync_spots:
            mapper = as_mapper(n.sync_ops, lambda i: type(i))
            for _type in sorted(mapper, key=key):
                subs[n] = callbacks[_type](subs.get(n, n), mapper[_type],
                                           pieces, iet)

        iet = Transformer(subs).visit(iet)

        # Add initialization and finalization code
        init = List(body=pieces.init, footer=c.Line())
        finalize = List(header=c.Line(), body=pieces.finalize)
        iet = iet._rebuild(body=(init, ) + iet.body + (finalize, ))

        return iet, {
            'efuncs': pieces.funcs,
            'includes': ['pthread.h'],
            'args': [i.size for i in pieces.threads if not is_integer(i.size)]
        }
예제 #17
0
파일: aliases.py 프로젝트: ofmla/devito
    def _generate(self, exprs, exclude):
        # E.g., extract `u.dx*a*b` and `u.dx*a*c` from `[(u.dx*a*b).dy`, `(u.dx*a*c).dy]`
        def cbk_search(expr):
            if isinstance(expr, EvalDerivative) and not expr.base.is_Function:
                return expr.args
            else:
                return flatten(e for e in [cbk_search(a) for a in expr.args]
                               if e)

        cbk_compose = lambda e: split_coeff(e)[1]
        basextr = self._do_generate(exprs, exclude, cbk_search, cbk_compose)
        if not basextr:
            return
        yield basextr

        # E.g., extract `u.dx*a` from `[(u.dx*a*b).dy, (u.dx*a*c).dy]`
        # That is, attempt extracting the largest common derivative-induced subexprs
        mappers = [deindexify(e) for e in basextr.extracted]
        counter = Counter(flatten(m.keys() for m in mappers))
        groups = as_mapper(counter, key=counter.get)
        grank = {
            k: sorted(v, key=lambda e: estimate_cost(e), reverse=True)
            for k, v in groups.items()
        }

        def cbk_search2(expr, rank):
            ret = []
            for e in cbk_search(expr):
                mapper = deindexify(e)
                for i in rank:
                    if i in mapper:
                        ret.extend(mapper[i])
                        break
            return ret

        candidates = sorted(grank, reverse=True)[:2]
        for i in candidates:
            lower_pri_elems = flatten([grank[j] for j in candidates if j != i])
            cbk_search_i = lambda e: cbk_search2(e, grank[i] + lower_pri_elems)
            yield self._do_generate(exprs, exclude, cbk_search_i, cbk_compose)
예제 #18
0
def cross_cluster_cse(clusters):
    """Apply CSE across an iterable of Clusters."""
    clusters = clusters.unfreeze()

    # Detect redundancies
    mapper = {}
    for c in clusters:
        candidates = [i for i in c.trace.values() if i.is_unbound_temporary]
        for v in as_mapper(candidates, lambda i: i.rhs).values():
            for i in v[:-1]:
                mapper[i.lhs.base] = v[-1].lhs.base

    if not mapper:
        # Do not waste time reconstructing identical expressions
        return clusters

    # Apply substitutions
    for c in clusters:
        c.exprs = [i.xreplace(mapper) for i in c.trace.values()
                   if i.lhs.base not in mapper]

    return clusters
예제 #19
0
def _(expr):
    args = [factorize_derivatives(a) for a in expr.args]

    derivs, others = split(args, lambda a: isinstance(a, sympy.Derivative))
    if not derivs:
        return expr

    # Map by type of derivative
    # Note: `D0(a) + D1(b) == D(a + b)` <=> `D0` and `D1`'s metadata match,
    # i.e. they are the same type of derivative
    mapper = as_mapper(derivs, lambda i: i._metadata)
    if len(mapper) == len(derivs):
        return expr

    args = list(others)
    for v in mapper.values():
        c = v[0]
        if len(v) == 1:
            args.append(c)
        else:
            args.append(c._new_from_self(expr=expr.func(*[i.expr for i in v])))
    expr = expr.func(*args)

    return expr
예제 #20
0
    def callback(self, clusters, prefix):
        if not prefix:
            return clusters

        d = prefix[-1].dim

        subiters = flatten(
            [c.ispace.sub_iterators.get(d, []) for c in clusters])
        subiters = {i for i in subiters if i.is_Stepping}
        if not subiters:
            return clusters

        # Collect the index access functions along `d`, e.g., `t + 1` where `t` is
        # a SteppingDimension for `d = time`
        mapper = DefaultOrderedDict(lambda: DefaultOrderedDict(set))
        for c in clusters:
            indexeds = [
                a.indexed for a in c.scope.accesses if a.function.is_Tensor
            ]

            for i in indexeds:
                try:
                    iaf = i.indices[d]
                except KeyError:
                    continue

                # Sanity checks
                sis = iaf.free_symbols & subiters
                if len(sis) == 0:
                    continue
                elif len(sis) == 1:
                    si = sis.pop()
                else:
                    raise InvalidOperator(
                        "Cannot use multiple SteppingDimensions "
                        "to index into a Function")
                size = i.function.shape_allocated[d]
                assert is_integer(size)

                mapper[size][si].add(iaf)

        # Construct the ModuloDimensions
        mds = []
        for size, v in mapper.items():
            for si, iafs in list(v.items()):
                # Offsets are sorted so that the semantic order (t0, t1, t2) follows
                # SymPy's index ordering (t, t-1, t+1) afer modulo replacement so
                # that associativity errors are consistent. This corresponds to
                # sorting offsets {-1, 0, 1} as {0, -1, 1} assigning -inf to 0
                siafs = sorted(iafs,
                               key=lambda i: -np.inf
                               if i - si == 0 else (i - si))

                for iaf in siafs:
                    name = '%s%d' % (si.name, len(mds))
                    offset = uxreplace(iaf, {si: d.root})
                    mds.append(
                        ModuloDimension(name, si, offset, size, origin=iaf))

        # Replacement rule for ModuloDimensions
        def rule(size, e):
            try:
                return e.function.shape_allocated[d] == size
            except (AttributeError, KeyError):
                return False

        # Reconstruct the Clusters
        processed = []
        for c in clusters:
            # Apply substitutions to expressions
            # Note: In an expression, there could be `u[t+1, ...]` and `v[t+1,
            # ...]`, where `u` and `v` are TimeFunction with circular time
            # buffers (save=None) *but* different modulo extent. The `t+1`
            # indices above are therefore conceptually different, so they will
            # be replaced with the proper ModuloDimension through two different
            # calls to `xreplace_indices`
            exprs = c.exprs
            groups = as_mapper(mds, lambda d: d.modulo)
            for size, v in groups.items():
                mapper = {md.origin: md for md in v}

                func = partial(xreplace_indices,
                               mapper=mapper,
                               key=partial(rule, size))
                exprs = [e.apply(func) for e in exprs]

            # Augment IterationSpace
            ispace = IterationSpace(c.ispace.intervals, {
                **c.ispace.sub_iterators,
                **{
                    d: tuple(mds)
                }
            }, c.ispace.directions)

            processed.append(c.rebuild(exprs=exprs, ispace=ispace))

        return processed
예제 #21
0
파일: utils.py 프로젝트: tccw/devito
    def __init__(self, exprs):
        self._mapper = {}
        self._fixed = {}

        # What Functions actually need a halo exchange?
        need_halo = as_mapper(Scope(exprs).d_all, lambda i: i.function)
        need_halo = {k: v for k, v in need_halo.items() if k.is_TensorFunction}

        for i in exprs:
            for f, v in i.dspace.parts.items():
                if f not in need_halo:
                    continue
                if f.grid is None:
                    raise RuntimeError("`%s` needs a `Grid` for a HaloScheme" %
                                       f.name)
                for d in f.dimensions:
                    r = d.root
                    if v[r].is_Null:
                        continue
                    elif d in f.grid.distributor.dimensions:
                        # Found a distributed dimension, calculate what and how
                        # much halo is needed
                        lsize = f._offset_domain[d].left - v[r].lower
                        if lsize > 0:
                            self._mapper.setdefault(f, []).append(
                                (d, LEFT, lsize))
                        rsize = v[r].upper - f._offset_domain[d].right
                        if rsize > 0:
                            self._mapper.setdefault(f, []).append(
                                (d, RIGHT, rsize))
                    else:
                        # Found a serial dimension, we need to determine where,
                        # along this dimension, the halo will have to be placed
                        fixed = self._fixed.setdefault(f, OrderedDict())
                        shift = int(any(d in dep.cause
                                        for dep in need_halo[f]))
                        # Examples:
                        # u[t+1, x] = f(u[t, x])   => shift == 1
                        # u[t-1, x] = f(u[t, x])   => shift == 1
                        # u[t+1, x] = f(u[t+1, x]) => shift == 0
                        # In the first and second cases, the x-halo should be inserted
                        # at `t`, while in the last case it should be inserted at `t+1`
                        if i.ispace.directions[r] is Forward:
                            last = v[r].upper - shift
                        else:
                            last = v[r].lower + shift
                        if d.is_Stepping:
                            # TimeFunctions using modulo-buffered iteration require
                            # special handling
                            subiters = i.ispace.sub_iterators.get(r, [])
                            submap = as_mapper(subiters, lambda md: md.modulo)
                            submap = {
                                i.origin: i
                                for i in submap[f._time_size]
                            }
                            try:
                                handle = submap[d + last]
                            except KeyError:
                                raise HaloSchemeException
                        else:
                            handle = r + last
                        if handle is not None and handle != fixed.get(
                                d, handle):
                            raise HaloSchemeException
                        fixed[d] = handle or fixed.get(d)