Example #1
0
    def push_object_on_stack(self, scope, obj):
        """Define an Array or a composite type (e.g., a struct) on the stack."""
        handle = self.stack.setdefault(scope, OrderedDict())

        if obj.is_LocalObject:
            handle[obj] = Element(c.Value(obj._C_typename, obj.name))
        else:
            shape = "".join("[%s]" % ccode(i) for i in obj.symbolic_shape)
            alignment = "__attribute__((aligned(%d)))" % obj._data_alignment
            value = "%s%s %s" % (obj.name, shape, alignment)
            handle[obj] = Element(c.POD(obj.dtype, value))
    def _specialize_iet(self, nodes):
        """Transform the Iteration/Expression tree to offload the computation of
        one or more loop nests onto YASK. This involves calling the YASK compiler
        to generate YASK code. Such YASK code is then called from within the
        transformed Iteration/Expression tree."""
        log("Specializing a Devito Operator for YASK...")

        self.context = YaskNullContext()
        self.yk_soln = YaskNullKernel()

        offloadable = find_offloadable_trees(nodes)
        if len(offloadable) == 0:
            log("No offloadable trees found")
        elif len(offloadable) == 1:
            tree, grid, dtype = offloadable[0]
            self.context = contexts.fetch(grid, dtype)

            # Create a YASK compiler solution for this Operator
            yc_soln = self.context.make_yc_solution(namespace['jit-yc-soln'])

            transform = sympy2yask(self.context, yc_soln)
            try:
                for i in tree[-1].nodes:
                    transform(i.expr)

                funcall = make_sharedptr_funcall(namespace['code-soln-run'],
                                                 ['time'],
                                                 namespace['code-soln-name'])
                funcall = Element(c.Statement(ccode(funcall)))
                nodes = Transformer({tree[1]: funcall}).visit(nodes)

                # Track /funcall/ as an external function call
                self.func_table[namespace['code-soln-run']] = MetaCall(
                    None, False)

                # JIT-compile the newly-created YASK kernel
                local_grids = [i for i in transform.mapper if i.is_Array]
                self.yk_soln = self.context.make_yk_solution(
                    namespace['jit-yk-soln'], yc_soln, local_grids)

                # Print some useful information about the newly constructed solution
                log("Solution '%s' contains %d grid(s) and %d equation(s)." %
                    (yc_soln.get_name(), yc_soln.get_num_grids(),
                     yc_soln.get_num_equations()))

            except:
                log("Unable to offload a candidate tree.")
        else:
            exit("Found more than one offloadable trees in a single Operator")

        # Some Iteration/Expression trees are not offloaded to YASK and may
        # require further processing to be executed in YASK, due to the differences
        # in storage layout employed by Devito and YASK
        nodes = make_grid_accesses(nodes)

        log("Specialization successfully performed!")

        return nodes
Example #3
0
def iet_insert_C_decls(iet, func_table):
    """
    Given an Iteration/Expression tree ``iet``, build a new tree with the
    necessary symbol declarations. Declarations are placed as close as
    possible to the first symbol use.

    :param iet: The input Iteration/Expression tree.
    :param func_table: A mapper from callable names to :class:`Callable`s
                       called from within ``iet``.
    """
    # Resolve function calls first
    scopes = []
    me = MapExpressions()
    for k, v in me.visit(iet).items():
        if k.is_Call:
            func = func_table[k.name]
            if func.local:
                scopes.extend(me.visit(func.root, queue=list(v)).items())
        else:
            scopes.append((k, v))

    # Determine all required declarations
    allocator = Allocator()
    mapper = OrderedDict()
    for k, v in scopes:
        if k.is_scalar:
            # Inline declaration
            mapper[k] = LocalExpression(**k.args)
        elif k.write is None or k.write._mem_external:
            # Nothing to do, e.g., variable passed as kernel argument
            continue
        elif k.write._mem_stack:
            # On the stack
            key = lambda i: not i.is_Parallel
            site = filter_iterations(v, key=key, stop='asap') or [iet]
            allocator.push_stack(site[-1], k.write)
        else:
            # On the heap, as a tensor that must be globally accessible
            allocator.push_heap(k.write)

    # Introduce declarations on the stack
    for k, v in allocator.onstack:
        mapper[k] = tuple(Element(i) for i in v)
    iet = NestedTransformer(mapper).visit(iet)
    for k, v in list(func_table.items()):
        if v.local:
            func_table[k] = MetaCall(
                Transformer(mapper).visit(v.root), v.local)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet
Example #4
0
    def push_array_on_stack(self, scope, obj):
        """Define an Array on the stack."""
        handle = self.stack.setdefault(scope, OrderedDict())

        if obj in flatten(self.stack.values()):
            return

        shape = "".join("[%s]" % ccode(i) for i in obj.symbolic_shape)
        alignment = "__attribute__((aligned(%d)))" % obj._data_alignment
        value = "%s%s %s" % (obj.name, shape, alignment)
        handle[obj] = Element(c.POD(obj.dtype, value))
Example #5
0
    def _insert_declarations(self, nodes):
        """Populate the Operator's body with the necessary variable declarations."""

        # Resolve function calls first
        scopes = []
        me = MapExpressions()
        for k, v in me.visit(nodes).items():
            if k.is_Call:
                func = self.func_table[k.name]
                if func.local:
                    scopes.extend(me.visit(func.root, queue=list(v)).items())
            else:
                scopes.append((k, v))

        # Determine all required declarations
        allocator = Allocator()
        mapper = OrderedDict()
        for k, v in scopes:
            if k.is_scalar:
                # Inline declaration
                mapper[k] = LocalExpression(**k.args)
            elif k.write._mem_external:
                # Nothing to do, variable passed as kernel argument
                continue
            elif k.write._mem_stack:
                # On the stack, as established by the DLE
                key = lambda i: not i.is_Parallel
                site = filter_iterations(v, key=key, stop='asap') or [nodes]
                allocator.push_stack(site[-1], k.write)
            else:
                # On the heap, as a tensor that must be globally accessible
                allocator.push_heap(k.write)

        # Introduce declarations on the stack
        for k, v in allocator.onstack:
            mapper[k] = tuple(Element(i) for i in v)
        nodes = NestedTransformer(mapper).visit(nodes)
        for k, v in list(self.func_table.items()):
            if v.local:
                self.func_table[k] = FunMeta(
                    Transformer(mapper).visit(v.root), v.local)

        # Introduce declarations on the heap (if any)
        if allocator.onheap:
            decls, allocs, frees = zip(*allocator.onheap)
            nodes = List(header=decls + allocs, body=nodes, footer=frees)

        return nodes
Example #6
0
def make_grid_accesses(node):
    """
    Construct a new Iteration/Expression based on ``node``, in which all
    :class:`types.Indexed` accesses have been converted into YASK grid
    accesses.
    """

    def make_grid_gets(expr):
        mapper = {}
        indexeds = retrieve_indexed(expr)
        data_carriers = [i for i in indexeds if i.base.function.from_YASK]
        for i in data_carriers:
            name = namespace['code-grid-name'](i.base.function.name)
            args = [ListInitializer([INT(make_grid_gets(j)) for j in i.indices])]
            mapper[i] = make_sharedptr_funcall(namespace['code-grid-get'], args, name)
        return expr.xreplace(mapper)

    mapper = {}
    for i, e in enumerate(FindNodes(Expression).visit(node)):
        lhs, rhs = e.expr.args

        # RHS translation
        rhs = make_grid_gets(rhs)

        # LHS translation
        if e.write.from_YASK:
            name = namespace['code-grid-name'](e.write.name)
            args = [rhs]
            args += [ListInitializer([INT(make_grid_gets(i)) for i in lhs.indices])]
            handle = make_sharedptr_funcall(namespace['code-grid-put'], args, name)
            processed = Element(c.Statement(ccode(handle)))
        else:
            # Writing to a scalar temporary
            processed = Expression(e.expr.func(lhs, rhs))

        mapper.update({e: processed})

    return Transformer(mapper).visit(node)
Example #7
0
    def make_parallel(self, iet):
        """Transform ``iet`` by introducing shared-memory parallelism."""
        mapper = OrderedDict()
        for tree in retrieve_iteration_tree(iet):
            # Get the first omp-parallelizable Iteration in `tree`
            candidates = filter_iterations(tree, key=self.key, stop='asap')
            if not candidates:
                continue
            root = candidates[0]

            # Build the `omp-for` tree
            partree = self._make_parallel_tree(root, candidates)

            # Find out the thread-private and thread-shared variables
            private = [
                i for i in FindSymbols().visit(partree)
                if i.is_Array and i._mem_stack
            ]

            # Build the `omp-parallel` region
            private = sorted(set([i.name for i in private]))
            private = ('private(%s)' % ','.join(private)) if private else ''
            partree = Block(header=self.lang['par-region'](self.nthreads.name,
                                                           private),
                            body=partree)

            # Do not enter the parallel region if the step increment might be 0; this
            # would raise a `Floating point exception (core dumped)` in some OpenMP
            # implementation. Note that using an OpenMP `if` clause won't work
            if isinstance(root.step, Symbol):
                cond = Conditional(CondEq(root.step, 0),
                                   Element(c.Statement('return')))
                partree = List(body=[cond, partree])

            mapper[root] = partree
        iet = Transformer(mapper).visit(iet)

        return iet, {'input': [self.nthreads] if mapper else []}
Example #8
0
    def _specialize(self, nodes, parameters):
        """
        Create a YASK representation of this Iteration/Expression tree.

        ``parameters`` is modified in-place adding YASK-related arguments.
        """
        log("Specializing a Devito Operator for YASK...")

        self.context = YaskNullContext()
        self.yk_soln = YaskNullKernel()
        local_grids = []

        offloadable = find_offloadable_trees(nodes)
        if len(offloadable) == 0:
            log("No offloadable trees found")
        elif len(offloadable) == 1:
            tree, grid, dtype = offloadable[0]
            self.context = contexts.fetch(grid, dtype)

            # Create a YASK compiler solution for this Operator
            yc_soln = self.context.make_yc_solution(namespace['jit-yc-soln'])

            transform = sympy2yask(self.context, yc_soln)
            try:
                for i in tree[-1].nodes:
                    transform(i.expr)

                funcall = make_sharedptr_funcall(namespace['code-soln-run'],
                                                 ['time'],
                                                 namespace['code-soln-name'])
                funcall = Element(c.Statement(ccode(funcall)))
                nodes = Transformer({tree[1]: funcall}).visit(nodes)

                # Track /funcall/ as an external function call
                self.func_table[namespace['code-soln-run']] = MetaCall(
                    None, False)

                # JIT-compile the newly-created YASK kernel
                local_grids += [i for i in transform.mapper if i.is_Array]
                self.yk_soln = self.context.make_yk_solution(
                    namespace['jit-yk-soln'], yc_soln, local_grids)

                # Now we must drop a pointer to the YASK solution down to C-land
                parameters.append(
                    Object(namespace['code-soln-name'],
                           namespace['type-solution'],
                           self.yk_soln.rawpointer))

                # Print some useful information about the newly constructed solution
                log("Solution '%s' contains %d grid(s) and %d equation(s)." %
                    (yc_soln.get_name(), yc_soln.get_num_grids(),
                     yc_soln.get_num_equations()))
            except:
                log("Unable to offload a candidate tree.")
        else:
            exit("Found more than one offloadable trees in a single Operator")

        # Some Iteration/Expression trees are not offloaded to YASK and may
        # require further processing to be executed in YASK, due to the differences
        # in storage layout employed by Devito and YASK
        nodes = make_grid_accesses(nodes)

        # Update the parameters list adding all necessary YASK grids
        for i in list(parameters) + local_grids:
            try:
                if i.from_YASK:
                    parameters.append(
                        Object(namespace['code-grid-name'](i.name),
                               namespace['type-grid']))
            except AttributeError:
                # Ignore e.g. Dimensions
                pass

        log("Specialization successfully performed!")

        return nodes
Example #9
0
def iet_insert_C_decls(iet, external=None):
    """
    Given an IET, build a new tree with the necessary symbol declarations.
    Declarations are placed as close as possible to the first symbol occurrence.

    Parameters
    ----------
    iet : Node
        The input Iteration/Expression tree.
    external : tuple, optional
        The symbols defined in some outer Callable, which therefore must not
        be re-defined.
    """
    external = external or []

    # Classify and then schedule declarations to stack/heap
    allocator = Allocator()
    mapper = OrderedDict()
    for k, v in MapExpressions().visit(iet).items():
        if k.is_Expression:
            if k.is_scalar_assign:
                # Inline declaration
                mapper[k] = LocalExpression(**k.args)
                continue
            objs = [k.write]
        elif k.is_Call:
            objs = k.params

        for i in objs:
            try:
                if i.is_LocalObject:
                    # On the stack
                    site = v[-1] if v else iet
                    allocator.push_stack(site, i)
                elif i.is_Array:
                    if i in external:
                        # The Array is to be defined in some foreign IET
                        continue
                    elif i._mem_stack:
                        # On the stack
                        key = lambda i: not i.is_Parallel
                        site = filter_iterations(v, key=key,
                                                 stop='asap') or [iet]
                        allocator.push_stack(site[-1], i)
                    else:
                        # On the heap, as a tensor that must be globally accessible
                        allocator.push_heap(i)
            except AttributeError:
                # E.g., a generic SymPy expression
                pass

    # Introduce declarations on the stack
    for k, v in allocator.onstack:
        mapper[k] = tuple(Element(i) for i in v)
    iet = Transformer(mapper, nested=True).visit(iet)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet
Example #10
0
def opsit(trees, count):
    node_factory = OPSNodeFactory()
    expressions = []
    for tree in trees:
        expressions.extend(FindNodes(Expression).visit(tree.inner))

    it_range = []
    it_dims = 0
    for tree in trees:
        if isinstance(tree, IterationTree):
            it_range = [it.bounds() for it in tree]
            it_dims = len(tree)

    block = OPSBlock(namespace['ops_block'](count))
    block_init = Element(
        cgen.Initializer(
            block, Call("ops_decl_block",
                        [it_dims, String(block.name)], False)))

    ops_expressions = []
    accesses = defaultdict(set)

    for i in reversed(expressions):
        extend_accesses(accesses, get_accesses(i.expr))
        ops_expressions.insert(0,
                               Expression(make_ops_ast(i.expr, node_factory)))

    ops_stencils_initializers, ops_stencils = generate_ops_stencils(accesses)

    to_remove = [
        f.name for f in FindSymbols('defines').visit(List(body=expressions))
    ]

    parameters = FindSymbols('symbolics').visit(List(body=ops_expressions))
    parameters = [
        p for p in parameters
        if p.name != 'OPS_ACC_size' and p.name not in to_remove
    ]
    parameters = sorted(parameters, key=lambda i: (i.is_Constant, i.name))

    arguments = FindSymbols('symbolics').visit(List(body=expressions))
    arguments = [a for a in arguments if a.name not in to_remove]
    arguments = sorted(arguments, key=lambda i: (i.is_Constant, i.name))

    ops_expressions = [
        Expression(fix_ops_acc(e.expr, [p.name for p in parameters]))
        for e in ops_expressions
    ]

    callable_kernel = Callable(namespace['ops_kernel'](count), ops_expressions,
                               "void", parameters)

    dat_declarations = []
    argname_to_dat = {}

    for a in arguments:
        if a.is_Constant:
            continue

        dat_dec, dat_sym = to_ops_dat(a, block)
        dat_declarations.extend(dat_dec)

        argname_to_dat.update(dat_sym)

    par_loop_range_arr = SymbolicArray(name=namespace['ops_range'](count),
                                       dimensions=(len(it_range) * 2, ),
                                       dtype=np.int32)
    range_vals = []
    for mn, mx in it_range:
        range_vals.append(mn)
        range_vals.append(mx)
    par_loop_range_init = Expression(
        ClusterizedEq(Eq(par_loop_range_arr, ListInitializer(range_vals))))

    ops_args = get_ops_args([p for p in parameters], ops_stencils,
                            argname_to_dat)

    par_loop = Call("ops_par_loop", [
        FunctionPointer(callable_kernel.name),
        String(callable_kernel.name), block, it_dims, par_loop_range_arr,
        *ops_args
    ])

    return (callable_kernel,
            [par_loop_range_init, block_init] + ops_stencils_initializers +
            dat_declarations + [Call("ops_partition", [String("")])],
            List(body=[par_loop]), it_dims)
Example #11
0
def to_ops_dat(function, block):
    ndim = function.ndim - (1 if function.is_TimeFunction else 0)
    dim = SymbolicArray(name="%s_dim" % function.name,
                        dimensions=(ndim, ),
                        dtype=np.int32)

    base = SymbolicArray(name="%s_base" % function.name,
                         dimensions=(ndim, ),
                         dtype=np.int32)

    d_p = SymbolicArray(name="%s_d_p" % function.name,
                        dimensions=(ndim, ),
                        dtype=np.int32)

    d_m = SymbolicArray(name="%s_d_m" % function.name,
                        dimensions=(ndim, ),
                        dtype=np.int32)

    res = []
    dats = {}
    ops_decl_dat_call = []

    if function.is_TimeFunction:
        time_pos = function._time_position
        time_index = function.indices[time_pos]
        time_dims = function.shape[time_pos]

        dim_shape = function.shape[:time_pos] + function.shape[time_pos + 1:]
        padding = function.padding[:time_pos] + function.padding[time_pos + 1:]
        halo = function.halo[:time_pos] + function.halo[time_pos + 1:]
        base_val = [0 for i in range(ndim)]
        d_p_val = tuple([p[0] + h[0] for p, h in zip(padding, halo)])
        d_m_val = tuple([-(p[1] + h[1]) for p, h in zip(padding, halo)])

        ops_dat_array = SymbolicArray(
            name="%s_dat" % function.name,
            dimensions=[time_dims],
            dtype="ops_dat",
        )

        ops_decl_dat_call.append(
            Element(
                cgen.Statement(
                    "%s %s[%s]" %
                    (ops_dat_array.dtype, ops_dat_array.name, time_dims))))

        for i in range(time_dims):
            access = FunctionTimeAccess(function, i)
            ops_dat_access = ArrayAccess(ops_dat_array, i)
            call = Call("ops_decl_dat", [
                block, 1, dim, base, d_m, d_p, access,
                String(function._C_typedata),
                String("%s%s%s" % (function.name, time_index, i))
            ], False)
            dats["%s%s%s" % (function.name, time_index, i)] = ArrayAccess(
                ops_dat_array, Symbol("%s%s" % (time_index, i)))
            ops_decl_dat_call.append(Element(cgen.Assign(ops_dat_access,
                                                         call)))
    else:
        ops_dat = OPSDat("%s_dat" % function.name)
        dats[function.name] = ops_dat

        d_p_val = tuple(
            [p[0] + h[0] for p, h in zip(function.padding, function.halo)])
        d_m_val = tuple(
            [-(p[1] + h[1]) for p, h in zip(function.padding, function.halo)])
        dim_shape = function.shape
        base_val = [0 for i in function.shape]

        ops_decl_dat_call.append(
            Element(
                cgen.Initializer(
                    ops_dat,
                    Call("ops_decl_dat", [
                        block, 1, dim, base, d_m, d_p,
                        FunctionTimeAccess(function, 0),
                        String(function._C_typedata),
                        String(function.name)
                    ], False))))

    res.append(Expression(ClusterizedEq(Eq(dim, ListInitializer(dim_shape)))))
    res.append(Expression(ClusterizedEq(Eq(base, ListInitializer(base_val)))))
    res.append(Expression(ClusterizedEq(Eq(d_p, ListInitializer(d_p_val)))))
    res.append(Expression(ClusterizedEq(Eq(d_m, ListInitializer(d_m_val)))))
    res.extend(ops_decl_dat_call)

    return res, dats
Example #12
0
    def _specialize_iet(self, iet, **kwargs):
        """
        Transform the Iteration/Expression tree to offload the computation of
        one or more loop nests onto YASK. This involves calling the YASK compiler
        to generate YASK code. Such YASK code is then called from within the
        transformed Iteration/Expression tree.
        """
        offloadable = find_offloadable_trees(iet)

        if len(offloadable.trees) == 0:
            self.yk_soln = YaskNullKernel()

            log("No offloadable trees found")
        else:
            context = contexts.fetch(offloadable.grid, offloadable.dtype)

            # A unique name for the 'real' compiler and kernel solutions
            name = namespace['jit-soln'](Signer._digest(iet, configuration))

            # Create a YASK compiler solution for this Operator
            yc_soln = context.make_yc_solution(name)

            try:
                trees = offloadable.trees

                # Generate YASK grids and populate `yc_soln` with equations
                mapper = yaskizer(trees, yc_soln)
                local_grids = [i for i in mapper if i.is_Array]

                # Transform the IET
                funcall = make_sharedptr_funcall(namespace['code-soln-run'],
                                                 ['time'],
                                                 namespace['code-soln-name'])
                funcall = Element(c.Statement(ccode(funcall)))
                mapper = {trees[0].root: funcall}
                mapper.update({i.root: mapper.get(i.root)
                               for i in trees})  # Drop trees
                iet = Transformer(mapper).visit(iet)

                # Mark `funcall` as an external function call
                self.func_table[namespace['code-soln-run']] = MetaCall(
                    None, False)

                # JIT-compile the newly-created YASK kernel
                self.yk_soln = context.make_yk_solution(
                    name, yc_soln, local_grids)

                # Print some useful information about the newly constructed solution
                log("Solution '%s' contains %d grid(s) and %d equation(s)." %
                    (yc_soln.get_name(), yc_soln.get_num_grids(),
                     yc_soln.get_num_equations()))
            except NotImplementedError as e:
                self.yk_soln = YaskNullKernel()

                log("Unable to offload a candidate tree. Reason: [%s]" %
                    str(e))

        # Some Iteration/Expression trees are not offloaded to YASK and may
        # require further processing to be executed in YASK, due to the differences
        # in storage layout employed by Devito and YASK
        iet = make_grid_accesses(iet)

        # Finally optimize all non-yaskized loops
        iet = super(Operator, self)._specialize_iet(iet, **kwargs)

        return iet
Example #13
0
def iet_insert_C_decls(iet, func_table=None):
    """
    Given an Iteration/Expression tree ``iet``, build a new tree with the
    necessary symbol declarations. Declarations are placed as close as
    possible to the first symbol use.

    :param iet: The input Iteration/Expression tree.
    :param func_table: (Optional) a mapper from callable names within ``iet``
                       to :class:`Callable`s.
    """
    func_table = func_table or {}
    allocator = Allocator()
    mapper = OrderedDict()

    # Detect all IET nodes accessing symbols that need to be declared
    scopes = []
    me = MapExpressions()
    for k, v in me.visit(iet).items():
        if k.is_Call:
            func = func_table.get(k.name)
            if func is not None and func.local:
                scopes.extend(me.visit(func.root, queue=list(v)).items())
        scopes.append((k, v))

    # Classify, and then schedule declarations to stack/heap
    for k, v in scopes:
        if k.is_Expression:
            if k.is_scalar:
                # Inline declaration
                mapper[k] = LocalExpression(**k.args)
                continue
            objs = [k.write]
        elif k.is_Call:
            objs = k.params
        else:
            raise NotImplementedError("Cannot schedule declarations for IET "
                                      "node of type `%s`" % type(k))
        for i in objs:
            try:
                if i.is_LocalObject:
                    # On the stack
                    site = v[-1] if v else iet
                    allocator.push_stack(site, i)
                elif i.is_Array:
                    if i._mem_external:
                        # Nothing to do; e.g., a user-provided Function
                        continue
                    elif i._mem_stack:
                        # On the stack
                        key = lambda i: not i.is_Parallel
                        site = filter_iterations(v, key=key,
                                                 stop='asap') or [iet]
                        allocator.push_stack(site[-1], i)
                    else:
                        # On the heap, as a tensor that must be globally accessible
                        allocator.push_heap(i)
            except AttributeError:
                # E.g., a generic SymPy expression
                pass

    # Introduce declarations on the stack
    for k, v in allocator.onstack:
        mapper[k] = tuple(Element(i) for i in v)
    iet = Transformer(mapper, nested=True).visit(iet)
    for k, v in list(func_table.items()):
        if v.local:
            func_table[k] = MetaCall(
                Transformer(mapper).visit(v.root), v.local)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet
Example #14
0
 def push_object_on_stack(self, scope, obj):
     """Define a LocalObject on the stack."""
     handle = self.stack.setdefault(scope, OrderedDict())
     handle[obj] = Element(c.Value(obj._C_typename, obj.name))
Example #15
0
def iet_insert_C_decls(iet, func_table=None):
    """
    Given an Iteration/Expression tree ``iet``, build a new tree with the
    necessary symbol declarations. Declarations are placed as close as
    possible to the first symbol use.

    :param iet: The input Iteration/Expression tree.
    :param func_table: (Optional) a mapper from callable names within ``iet``
                       to :class:`Callable`s.
    """
    func_table = func_table or {}
    allocator = Allocator()
    mapper = OrderedDict()

    # First, schedule declarations for Expressions
    scopes = []
    me = MapExpressions()
    for k, v in me.visit(iet).items():
        if k.is_Call:
            func = func_table.get(k.name)
            if func is not None and func.local:
                scopes.extend(me.visit(func.root, queue=list(v)).items())
        else:
            scopes.append((k, v))
    for k, v in scopes:
        if k.is_scalar:
            # Inline declaration
            mapper[k] = LocalExpression(**k.args)
        elif k.write is None or k.write._mem_external:
            # Nothing to do, e.g., variable passed as kernel argument
            continue
        elif k.write._mem_stack:
            # On the stack
            key = lambda i: not i.is_Parallel
            site = filter_iterations(v, key=key, stop='asap') or [iet]
            allocator.push_stack(site[-1], k.write)
        else:
            # On the heap, as a tensor that must be globally accessible
            allocator.push_heap(k.write)

    # Then, schedule declarations callables arguments passed by reference/pointer
    # (as modified internally by the callable)
    scopes = [(k, v) for k, v in me.visit(iet).items() if k.is_Call]
    for k, v in scopes:
        site = v[-1] if v else iet
        for i in k.params:
            try:
                if i.is_LocalObject:
                    # On the stack
                    allocator.push_stack(site, i)
                elif i.is_Array:
                    if i._mem_stack:
                        # On the stack
                        allocator.push_stack(site, i)
                    elif i._mem_heap:
                        # On the heap
                        allocator.push_heap(i)
            except AttributeError:
                # E.g., a generic SymPy expression
                pass

    # Introduce declarations on the stack
    for k, v in allocator.onstack:
        mapper[k] = tuple(Element(i) for i in v)
    iet = NestedTransformer(mapper).visit(iet)
    for k, v in list(func_table.items()):
        if v.local:
            func_table[k] = MetaCall(Transformer(mapper).visit(v.root), v.local)

    # Introduce declarations on the heap (if any)
    if allocator.onheap:
        decls, allocs, frees = zip(*allocator.onheap)
        iet = List(header=decls + allocs, body=iet, footer=frees)

    return iet