Пример #1
0
        def _(iet):
            # Collect written and read-only symbols
            writes = set()
            reads = set()
            for i, v in MapExprStmts().visit(iet).items():
                if not i.is_Expression:
                    # No-op
                    continue
                if not any(isinstance(j, self._Parallelizer._Iteration) for j in v):
                    # Not an offloaded Iteration tree
                    continue
                if i.write.is_DiscreteFunction:
                    writes.add(i.write)
                reads = (reads | {r for r in i.reads if r.is_DiscreteFunction}) - writes

            # Populate `storage`
            storage = Storage()
            for i in filter_sorted(writes):
                self._map_function_on_high_bw_mem(iet, i, storage)
            for i in filter_sorted(reads):
                self._map_function_on_high_bw_mem(iet, i, storage, read_only=True)

            iet = self._dump_storage(iet, storage)

            return iet
Пример #2
0
    def place_ondevice(self, iet, **kwargs):
        efuncs = kwargs['efuncs']

        storage = Storage()

        if iet.is_ElementalFunction:
            return iet, {}

        # Collect written and read-only symbols
        writes = set()
        reads = set()
        for efunc in efuncs:
            for i, v in MapExprStmts().visit(efunc).items():
                if not i.is_Expression:
                    # No-op
                    continue
                if not any(isinstance(j, DeviceParallelIteration) for j in v):
                    # Not an offloaded Iteration tree
                    continue
                if i.write.is_DiscreteFunction:
                    writes.add(i.write)
                reads = (reads | {r
                                  for r in i.reads
                                  if r.is_DiscreteFunction}) - writes

        # Update `storage`
        for i in filter_sorted(writes):
            self._map_function_on_high_bw_mem(i, storage)
        for i in filter_sorted(reads):
            self._map_function_on_high_bw_mem(i, storage, read_only=True)

        iet = self._dump_storage(iet, storage)

        return iet, {}
Пример #3
0
def iet_insert_decls(iet, external):
    """
    Transform the input IET inserting the necessary symbol declarations.
    Declarations are placed as close as possible to the first symbol occurrence.

    Parameters
    ----------
    iet : Node
        The input Iteration/Expression tree.
    external : tuple, optional
        The symbols defined in some outer Callable, which therefore must not
        be re-defined.
    """
    iet = as_tuple(iet)

    # Classify and then schedule declarations to stack/heap
    allocator = Allocator()
    for k, v in MapExprStmts().visit(iet).items():
        if k.is_Expression:
            if k.is_definition:
                # On the stack
                site = v if v else iet
                allocator.push_scalar_on_stack(site[-1], k)
                continue
            objs = [k.write]
        elif k.is_Call:
            objs = k.arguments

        for i in objs:
            try:
                if i.is_LocalObject:
                    # On the stack
                    site = v if v else iet
                    allocator.push_object_on_stack(site[-1], i)
                elif i.is_Array:
                    if i in as_tuple(external):
                        # The Array is defined in some other IET
                        continue
                    elif i._mem_stack:
                        # On the stack
                        allocator.push_object_on_stack(iet[0], i)
                    else:
                        # On the heap
                        allocator.push_array_on_heap(i)
            except AttributeError:
                # E.g., a generic SymPy expression
                pass

    # Introduce declarations on the stack
    mapper = dict(allocator.onstack)
    iet = Transformer(mapper, nested=True).visit(iet)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet