Ejemplo n.º 1
0
def st_make_halo(stree):
    """
    Add :class:`NodeHalo`s to a :class:`ScheduleTree`. A HaloNode captures
    the halo exchanges that should take place before executing the sub-tree;
    these are described by means of a :class:`HaloScheme`.
    """
    # Build a HaloScheme for each expression bundle
    halo_schemes = {}
    for n in findall(stree, lambda i: i.is_Exprs):
        try:
            halo_schemes[n] = HaloScheme(n.exprs, n.ispace, n.dspace)
        except HaloSchemeException as e:
            if configuration['mpi']:
                raise RuntimeError(str(e))

    # Insert the HaloScheme at a suitable level in the ScheduleTree
    mapper = {}
    for k, hs in halo_schemes.items():
        for f, v in hs.fmapper.items():
            spot = k
            ancestors = [n for n in k.ancestors if n.is_Iteration]
            for n in ancestors:
                test0 = any(n.dim is i.dim for i in v.halos)
                test1 = n.dim not in [i.root for i in v.loc_indices]
                if test0 or test1:
                    spot = n
                    break
            mapper.setdefault(spot, []).append((f, v))
    for spot, entries in mapper.items():
        insert(NodeHalo(HaloScheme(fmapper=dict(entries))), spot.parent,
               [spot])

    return stree
Ejemplo n.º 2
0
def stree_make_halo(stree):
    """
    Add NodeHalos to a ScheduleTree. A NodeHalo captures the halo exchanges
    that should take place before executing the sub-tree; these are described
    by means of a HaloScheme.
    """
    # Build a HaloScheme for each expression bundle
    halo_schemes = {}
    for n in findall(stree, lambda i: i.is_Exprs):
        try:
            halo_schemes[n] = HaloScheme(n.exprs, n.ispace)
        except HaloSchemeException as e:
            if configuration['mpi']:
                raise RuntimeError(str(e))

    # Split a HaloScheme based on where it should be inserted
    # For example, it's possible that, for a given HaloScheme, a Function's
    # halo needs to be exchanged at a certain `stree` depth, while another
    # Function's halo needs to be exchanged before some other nodes
    mapper = {}
    for k, hs in halo_schemes.items():
        for f, v in hs.fmapper.items():
            spot = k
            ancestors = [n for n in k.ancestors if n.is_Iteration]
            for n in ancestors:
                # Place the halo exchange right before the first
                # distributed Dimension which requires it
                if any(i.dim in n.dim._defines for i in v.halos):
                    spot = n
                    break
            mapper.setdefault(spot, []).append(hs.project(f))

    # Now fuse the HaloSchemes at the same `stree` depth and perform the insertion
    for spot, halo_schemes in mapper.items():
        insert(NodeHalo(HaloScheme.union(halo_schemes)), spot.parent, [spot])

    return stree
Ejemplo n.º 3
0
def st_make_halo(stree):
    """
    Add :class:`NodeHalo` to a :class:`ScheduleTree`. A halo node describes
    what halo exchanges should take place before executing the sub-tree.
    """
    if not configuration['mpi']:
        # TODO: This will be dropped as soon as stronger analysis will have
        # been implemented
        return stree

    processed = {}
    for n in LevelOrderIter(stree, stop=lambda i: i.parent in processed):
        if not n.is_Iteration:
            continue
        exprs = flatten(i.exprs for i in findall(n, lambda i: i.is_Exprs))
        try:
            halo_scheme = HaloScheme(exprs)
            if n.dim in halo_scheme.dmapper:
                processed[n] = NodeHalo(halo_scheme)
        except HaloSchemeException:
            # We should get here only when trying to compute a halo
            # scheme for a group of expressions that belong to different
            # iteration spaces. We expect proper halo schemes to be built
            # as the `stree` visit proceeds.
            # TODO: However, at the end, we should check that a halo scheme,
            # possibly even a "void" one, has been built for *all* of the
            # expressions, and error out otherwise.
            continue
        except RuntimeError as e:
            if configuration['mpi'] is True:
                raise RuntimeError(str(e))

    for k, v in processed.items():
        insert(v, k.parent, [k])

    return stree