Esempio n. 1
0
def test_precompute_nested_subst(ctx_factory):
    ctx = ctx_factory()

    knl = lp.make_kernel(
        "{[i,j]: 0<=i<n and 0<=j<5}",
        """
        E:=a[i]
        D:=E*E
        b[i] = D
        """)

    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32))

    ref_knl = knl

    knl = lp.tag_inames(knl, dict(j="g.1"))
    knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")

    from loopy.symbolic import get_dependencies
    assert "i_inner" not in get_dependencies(knl.substitutions["D"].expression)
    knl = lp.precompute(knl, "D", "i_inner")

    # There's only one surviving 'E' rule.
    assert len([
        rule_name
        for rule_name in knl.substitutions
        if rule_name.startswith("E")]) == 1

    # That rule should use the newly created prefetch inames,
    # not the prior 'i_inner'
    assert "i_inner" not in get_dependencies(knl.substitutions["E"].expression)

    lp.auto_test_vs_ref(
            ref_knl, ctx, knl,
            parameters=dict(n=12345))
Esempio n. 2
0
def _add_kernel_axis(kernel, axis_name, start, stop, base_inames):
    from loopy.kernel.tools import DomainChanger
    domch = DomainChanger(kernel, base_inames)

    domain = domch.domain
    new_dim_idx = domain.dim(dim_type.set)
    domain = (domain
            .insert_dims(dim_type.set, new_dim_idx, 1)
            .set_dim_name(dim_type.set, new_dim_idx, axis_name))

    from loopy.symbolic import get_dependencies
    deps = get_dependencies(start) | get_dependencies(stop)
    assert deps <= kernel.all_params()

    param_names = domain.get_var_names(dim_type.param)
    for dep in deps:
        if dep not in param_names:
            new_dim_idx = domain.dim(dim_type.param)
            domain = (domain
                    .insert_dims(dim_type.param, new_dim_idx, 1)
                    .set_dim_name(dim_type.param, new_dim_idx, dep))

    from loopy.isl_helpers import make_slab
    slab = make_slab(domain.get_space(), axis_name, start, stop)

    domain = domain & slab

    return kernel.copy(domains=domch.get_domains_with(domain))
Esempio n. 3
0
    def __init__(self, domains, instructions, temporary_variables,
            subst_rules, default_offset):
        self.domains = domains
        self.instructions = instructions
        self.temporary_variables = temporary_variables
        self.subst_rules = subst_rules
        self.default_offset = default_offset

        from loopy.symbolic import SubstitutionRuleExpander
        self.submap = SubstitutionRuleExpander(subst_rules)

        self.all_inames = set()
        for dom in domains:
            self.all_inames.update(dom.get_var_names(dim_type.set))

        all_params = set()
        for dom in domains:
            all_params.update(dom.get_var_names(dim_type.param))
        self.all_params = all_params - self.all_inames

        self.all_names = set()
        self.all_written_names = set()
        from loopy.symbolic import get_dependencies
        for insn in instructions:
            if isinstance(insn, ExpressionInstruction):
                (assignee_var_name, _), = insn.assignees_and_indices()
                self.all_written_names.add(assignee_var_name)
                self.all_names.update(get_dependencies(
                    self.submap(insn.assignee)))
                self.all_names.update(get_dependencies(
                    self.submap(insn.expression)))
Esempio n. 4
0
File: data.py Progetto: shwina/loopy
def _add_kernel_axis(kernel, axis_name, start, stop, base_inames):
    from loopy.kernel.tools import DomainChanger
    domch = DomainChanger(kernel, base_inames)

    domain = domch.domain
    new_dim_idx = domain.dim(dim_type.set)
    domain = (domain
            .insert_dims(dim_type.set, new_dim_idx, 1)
            .set_dim_name(dim_type.set, new_dim_idx, axis_name))

    from loopy.symbolic import get_dependencies
    deps = get_dependencies(start) | get_dependencies(stop)
    assert deps <= kernel.all_params()

    param_names = domain.get_var_names(dim_type.param)
    for dep in deps:
        if dep not in param_names:
            new_dim_idx = domain.dim(dim_type.param)
            domain = (domain
                    .insert_dims(dim_type.param, new_dim_idx, 1)
                    .set_dim_name(dim_type.param, new_dim_idx, dep))

    from loopy.isl_helpers import make_slab
    slab = make_slab(domain.get_space(), axis_name, start, stop)

    domain = domain & slab

    return kernel.copy(domains=domch.get_domains_with(domain))
Esempio n. 5
0
    def map_subscript(self, expr, domain, insn_id):
        WalkMapper.map_subscript(self, expr, domain, insn_id)

        from pymbolic.primitives import Variable
        assert isinstance(expr.aggregate, Variable)

        shape = None
        var_name = expr.aggregate.name
        if var_name in self.kernel.arg_dict:
            arg = self.kernel.arg_dict[var_name]
            shape = arg.shape
        elif var_name in self.kernel.temporary_variables:
            tv = self.kernel.temporary_variables[var_name]
            shape = tv.shape

        if shape is not None:
            subscript = expr.index

            if not isinstance(subscript, tuple):
                subscript = (subscript, )

            from loopy.symbolic import get_dependencies

            available_vars = set(domain.get_var_dict())
            shape_deps = set()
            for shape_axis in shape:
                if shape_axis is not None:
                    shape_deps.update(get_dependencies(shape_axis))

            if not (get_dependencies(subscript) <= available_vars
                    and shape_deps <= available_vars):
                return

            if len(subscript) != len(shape):
                raise LoopyError(
                    "subscript to '%s' in '%s' has the wrong "
                    "number of indices (got: %d, expected: %d)" %
                    (expr.aggregate.name, expr, len(subscript), len(shape)))

            access_range = self._get_access_range(domain, subscript)
            if access_range is None:
                # Likely: index was non-affine, nothing we can do.
                return

            shape_domain = isl.BasicSet.universe(access_range.get_space())
            for idim in range(len(subscript)):
                shape_axis = shape[idim]

                if shape_axis is not None:
                    slab = self._make_slab(shape_domain.get_space(),
                                           (dim_type.in_, idim), 0, shape_axis)

                    shape_domain = shape_domain.intersect(slab)

            if not access_range.is_subset(shape_domain):
                raise LoopyError(
                    "'%s' in instruction '%s' "
                    "accesses out-of-bounds array element (could not"
                    " establish '%s' is a subset of '%s')." %
                    (expr, insn_id, access_range, shape_domain))
Esempio n. 6
0
def test_precompute_nested_subst(ctx_factory):
    ctx = ctx_factory()

    knl = lp.make_kernel(
        "{[i,j]: 0<=i<n and 0<=j<5}", """
        E:=a[i]
        D:=E*E
        b[i] = D
        """)

    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32))

    ref_knl = knl

    knl = lp.tag_inames(knl, dict(j="g.1"))
    knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")

    from loopy.symbolic import get_dependencies
    assert "i_inner" not in get_dependencies(knl.substitutions["D"].expression)
    knl = lp.precompute(knl, "D", "i_inner")

    # There's only one surviving 'E' rule.
    assert len([
        rule_name for rule_name in knl.substitutions
        if rule_name.startswith("E")
    ]) == 1

    # That rule should use the newly created prefetch inames,
    # not the prior 'i_inner'
    assert "i_inner" not in get_dependencies(knl.substitutions["E"].expression)

    lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=12345))
Esempio n. 7
0
    def read_dependency_names(self):
        result = set(self.read_variables)

        from loopy.symbolic import get_dependencies
        for name, iname_expr in self.iname_exprs:
            result.update(get_dependencies(iname_expr))

        for _, subscript in self.assignees_and_indices():
            result.update(get_dependencies(subscript))

        return frozenset(result) | self.predicates
Esempio n. 8
0
def _get_assignee_subscript_deps(expr):
    from pymbolic.primitives import Variable, Subscript
    from loopy.symbolic import LinearSubscript, get_dependencies

    if isinstance(expr, Variable):
        return frozenset()
    elif isinstance(expr, Subscript):
        return get_dependencies(expr.index)
    elif isinstance(expr, LinearSubscript):
        return get_dependencies(expr.index)
    else:
        raise RuntimeError("invalid lvalue '%s'" % expr)
Esempio n. 9
0
    def read_dependency_names(self):
        from loopy.symbolic import get_dependencies

        result = get_dependencies(self.expression)
        for _, subscript in self.assignees_and_indices():
            result = result | get_dependencies(subscript)

        processed_predicates = frozenset(pred.lstrip("!") for pred in self.predicates)

        result = result | processed_predicates

        return result
Esempio n. 10
0
def _get_assignee_subscript_deps(expr):
    from pymbolic.primitives import Variable, Subscript, Lookup
    from loopy.symbolic import LinearSubscript, get_dependencies

    if isinstance(expr, Lookup):
        expr = expr.aggregate

    if isinstance(expr, Variable):
        return frozenset()
    elif isinstance(expr, Subscript):
        return get_dependencies(expr.index)
    elif isinstance(expr, LinearSubscript):
        return get_dependencies(expr.index)
    else:
        raise RuntimeError("invalid lvalue '%s'" % expr)
Esempio n. 11
0
def simplify_via_aff(expr):
    from loopy.symbolic import aff_from_expr, aff_to_expr, get_dependencies
    deps = get_dependencies(expr)
    return aff_to_expr(
        aff_from_expr(
            isl.Space.create_from_names(isl.DEFAULT_CONTEXT, list(deps)),
            expr))
Esempio n. 12
0
    def get_function_declaration(self, codegen_state, codegen_result,
            schedule_index):
        fdecl = super().get_function_declaration(
                codegen_state, codegen_result, schedule_index)

        from loopy.target.c import FunctionDeclarationWrapper
        assert isinstance(fdecl, FunctionDeclarationWrapper)
        if not codegen_state.is_entrypoint:
            # auxiliary kernels need not mention opencl speicific qualifiers
            # for a functions signature
            return fdecl

        fdecl = fdecl.subdecl

        from cgen.opencl import CLKernel, CLRequiredWorkGroupSize
        fdecl = CLKernel(fdecl)

        from loopy.schedule import get_insn_ids_for_block_at
        _, local_sizes = codegen_state.kernel.get_grid_sizes_for_insn_ids_as_exprs(
                get_insn_ids_for_block_at(
                    codegen_state.kernel.linearization, schedule_index),
                codegen_state.callables_table)

        from loopy.symbolic import get_dependencies
        if not get_dependencies(local_sizes):
            # sizes can't have parameter dependencies if they are
            # to be used in static WG size.

            fdecl = CLRequiredWorkGroupSize(local_sizes, fdecl)

        return FunctionDeclarationWrapper(fdecl)
Esempio n. 13
0
    def get_function_declaration(self, codegen_state, codegen_result,
                                 schedule_index):
        fdecl = super().get_function_declaration(codegen_state, codegen_result,
                                                 schedule_index)

        from loopy.target.c import FunctionDeclarationWrapper
        assert isinstance(fdecl, FunctionDeclarationWrapper)
        fdecl = fdecl.subdecl

        from cgen.cuda import CudaGlobal, CudaLaunchBounds
        fdecl = CudaGlobal(fdecl)

        if self.target.extern_c:
            from cgen import Extern
            fdecl = Extern("C", fdecl)

        from loopy.schedule import get_insn_ids_for_block_at
        _, local_grid_size = \
                codegen_state.kernel.get_grid_sizes_for_insn_ids_as_exprs(
                        get_insn_ids_for_block_at(
                            codegen_state.kernel.linearization, schedule_index),
                        codegen_state.callables_table)

        from loopy.symbolic import get_dependencies
        if not get_dependencies(local_grid_size):
            # Sizes can't have parameter dependencies if they are
            # to be used in static thread block size.
            from pytools import product
            nthreads = product(local_grid_size)

            fdecl = CudaLaunchBounds(nthreads, fdecl)

        return FunctionDeclarationWrapper(fdecl)
Esempio n. 14
0
def test_precompute_confusing_subst_arguments(ctx_factory):
    ctx = ctx_factory()

    knl = lp.make_kernel(
        "{[i,j]: 0<=i<n and 0<=j<5}",
        """
        D(i):=a[i+1]-a[i]
        b[i,j] = D(j)
        """)

    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32))

    ref_knl = knl

    knl = lp.tag_inames(knl, dict(j="g.1"))
    knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")

    from loopy.symbolic import get_dependencies
    assert "i_inner" not in get_dependencies(knl.substitutions["D"].expression)
    knl = lp.precompute(knl, "D", sweep_inames="j",
            precompute_outer_inames="j, i_inner, i_outer")

    lp.auto_test_vs_ref(
            ref_knl, ctx, knl,
            parameters=dict(n=12345))
Esempio n. 15
0
    def get_function_declaration(self, codegen_state, codegen_result,
                                 schedule_index):
        fdecl = super(OpenCLCASTBuilder,
                      self).get_function_declaration(codegen_state,
                                                     codegen_result,
                                                     schedule_index)

        from loopy.target.c import FunctionDeclarationWrapper
        assert isinstance(fdecl, FunctionDeclarationWrapper)
        fdecl = fdecl.subdecl

        from cgen.opencl import CLKernel, CLRequiredWorkGroupSize
        fdecl = CLKernel(fdecl)

        from loopy.schedule import get_insn_ids_for_block_at
        _, local_sizes = codegen_state.kernel.get_grid_sizes_for_insn_ids_as_exprs(
            get_insn_ids_for_block_at(codegen_state.kernel.schedule,
                                      schedule_index))

        from loopy.symbolic import get_dependencies
        if not get_dependencies(local_sizes):
            # sizes can't have parameter dependencies if they are
            # to be used in static WG size.

            fdecl = CLRequiredWorkGroupSize(local_sizes, fdecl)

        return FunctionDeclarationWrapper(fdecl)
Esempio n. 16
0
    def map_floor_div(self, expr, type_context):
        from loopy.symbolic import get_dependencies
        iname_deps = get_dependencies(expr) & self.kernel.all_inames()
        domain = self.kernel.get_inames_domain(iname_deps)

        assumption_non_param = isl.BasicSet.from_params(
            self.kernel.assumptions)
        assumptions, domain = isl.align_two(assumption_non_param, domain)
        domain = domain & assumptions

        from loopy.isl_helpers import is_nonnegative
        num_nonneg = is_nonnegative(expr.numerator, domain)
        den_nonneg = is_nonnegative(expr.denominator, domain)

        def seen_func(name):
            idt = self.kernel.index_dtype
            from loopy.codegen import SeenFunction
            self.codegen_state.seen_functions.add(
                SeenFunction(name, name, (idt, idt)))

        if den_nonneg:
            if num_nonneg:
                # parenthesize to avoid negative signs being dragged in from the
                # outside by associativity
                return (self.rec(expr.numerator, type_context) //
                        self.rec(expr.denominator, type_context))
            else:
                seen_func("int_floor_div_pos_b")
                return var("int_floor_div_pos_b")(self.rec(
                    expr.numerator, 'i'), self.rec(expr.denominator, 'i'))
        else:
            seen_func("int_floor_div")
            return var("int_floor_div")(self.rec(expr.numerator, 'i'),
                                        self.rec(expr.denominator, 'i'))
Esempio n. 17
0
    def get_function_declaration(self, codegen_state, codegen_result,
            schedule_index):
        fdecl = super(CUDACASTBuilder, self).get_function_declaration(
                codegen_state, codegen_result, schedule_index)

        from cgen.cuda import CudaGlobal, CudaLaunchBounds
        fdecl = CudaGlobal(fdecl)

        if self.target.extern_c:
            from cgen import Extern
            fdecl = Extern("C", fdecl)

        from loopy.schedule import get_insn_ids_for_block_at
        _, local_grid_size = \
                codegen_state.kernel.get_grid_sizes_for_insn_ids_as_exprs(
                        get_insn_ids_for_block_at(
                            codegen_state.kernel.schedule, schedule_index))

        from loopy.symbolic import get_dependencies
        if not get_dependencies(local_grid_size):
            # Sizes can't have parameter dependencies if they are
            # to be used in static thread block size.
            from pytools import product
            nthreads = product(local_grid_size)

            fdecl = CudaLaunchBounds(nthreads, fdecl)

        return fdecl
Esempio n. 18
0
    def get_function_declaration(self, codegen_state, codegen_result,
                                 schedule_index):
        fdecl = super(CUDACASTBuilder,
                      self).get_function_declaration(codegen_state,
                                                     codegen_result,
                                                     schedule_index)

        from cgen.cuda import CudaGlobal, CudaLaunchBounds
        fdecl = CudaGlobal(fdecl)

        if self.target.extern_c:
            from cgen import Extern
            fdecl = Extern("C", fdecl)

        from loopy.schedule import get_insn_ids_for_block_at
        _, local_grid_size = \
                codegen_state.kernel.get_grid_sizes_for_insn_ids_as_exprs(
                        get_insn_ids_for_block_at(
                            codegen_state.kernel.schedule, schedule_index))

        from loopy.symbolic import get_dependencies
        if not get_dependencies(local_grid_size):
            # Sizes can't have parameter dependencies if they are
            # to be used in static thread block size.
            from pytools import product
            nthreads = product(local_grid_size)

            fdecl = CudaLaunchBounds(nthreads, fdecl)

        return fdecl
Esempio n. 19
0
    def read_dependency_names(self):
        from loopy.symbolic import get_dependencies
        result = frozenset()

        for pred in self.predicates:
            result = result | get_dependencies(pred)

        return result
Esempio n. 20
0
    def read_dependency_names(self):
        from loopy.symbolic import get_dependencies
        result = frozenset()

        for pred in self.predicates:
            result = result | get_dependencies(pred)

        return result
Esempio n. 21
0
    def _is_access_descriptor_in_footprint_inner(self, storage_axis_exprs):
        # Make all inames except the sweep parameters. (The footprint may depend on
        # those.) (I.e. only leave sweep inames as out parameters.)

        global_s2s_par_dom = move_to_par_from_out(
                self.stor2sweep,
                except_inames=frozenset(self.primed_sweep_inames)).domain()

        arg_inames = (
                set(global_s2s_par_dom.get_var_names(dim_type.param))
                & self.kernel.all_inames())

        for arg in storage_axis_exprs:
            arg_inames.update(get_dependencies(arg))
        arg_inames = frozenset(arg_inames)

        from loopy.kernel import CannotBranchDomainTree
        try:
            usage_domain = self.kernel.get_inames_domain(arg_inames)
        except CannotBranchDomainTree:
            return False

        for i in range(usage_domain.dim(dim_type.set)):
            iname = usage_domain.get_dim_name(dim_type.set, i)
            if iname in self.sweep_inames:
                usage_domain = usage_domain.set_dim_name(
                        dim_type.set, i, iname+"'")

        stor2sweep = build_per_access_storage_to_domain_map(
                storage_axis_exprs,
                usage_domain, self.storage_axis_names,
                self.prime_sweep_inames)

        if stor2sweep is None:
            # happens if there are no indices
            # -> yes, in footprint
            return True

        if isinstance(stor2sweep, isl.BasicMap):
            stor2sweep = isl.Map.from_basic_map(stor2sweep)

        stor2sweep = stor2sweep.intersect_range(usage_domain)

        stor2sweep = move_to_par_from_out(stor2sweep,
                except_inames=frozenset(self.primed_sweep_inames))

        s2s_domain = stor2sweep.domain()
        s2s_domain, aligned_g_s2s_parm_dom = isl.align_two(
                s2s_domain, global_s2s_par_dom)

        arg_restrictions = (
                aligned_g_s2s_parm_dom
                .eliminate(dim_type.set, 0,
                    aligned_g_s2s_parm_dom.dim(dim_type.set))
                .remove_divs())

        return (arg_restrictions & s2s_domain).is_subset(
                aligned_g_s2s_parm_dom)
Esempio n. 22
0
    def _is_access_descriptor_in_footprint_inner(self, storage_axis_exprs):
        # Make all inames except the sweep parameters. (The footprint may depend on
        # those.) (I.e. only leave sweep inames as out parameters.)

        global_s2s_par_dom = move_to_par_from_out(
                self.stor2sweep,
                except_inames=frozenset(self.primed_sweep_inames)).domain()

        arg_inames = (
                set(global_s2s_par_dom.get_var_names(dim_type.param))
                & self.kernel.all_inames())

        for arg in storage_axis_exprs:
            arg_inames.update(get_dependencies(arg))
        arg_inames = frozenset(arg_inames)

        from loopy.kernel import CannotBranchDomainTree
        try:
            usage_domain = self.kernel.get_inames_domain(arg_inames)
        except CannotBranchDomainTree:
            return False

        for i in range(usage_domain.dim(dim_type.set)):
            iname = usage_domain.get_dim_name(dim_type.set, i)
            if iname in self.sweep_inames:
                usage_domain = usage_domain.set_dim_name(
                        dim_type.set, i, iname+"'")

        stor2sweep = build_per_access_storage_to_domain_map(
                storage_axis_exprs,
                usage_domain, self.storage_axis_names,
                self.prime_sweep_inames)

        if stor2sweep is None:
            # happens if there are no indices
            # -> yes, in footprint
            return True

        if isinstance(stor2sweep, isl.BasicMap):
            stor2sweep = isl.Map.from_basic_map(stor2sweep)

        stor2sweep = stor2sweep.intersect_range(usage_domain)

        stor2sweep = move_to_par_from_out(stor2sweep,
                except_inames=frozenset(self.primed_sweep_inames))

        s2s_domain = stor2sweep.domain()
        s2s_domain, aligned_g_s2s_parm_dom = isl.align_two(
                s2s_domain, global_s2s_par_dom)

        arg_restrictions = (
                aligned_g_s2s_parm_dom
                .eliminate(dim_type.set, 0,
                    aligned_g_s2s_parm_dom.dim(dim_type.set))
                .remove_divs())

        return (arg_restrictions & s2s_domain).is_subset(
                aligned_g_s2s_parm_dom)
Esempio n. 23
0
    def read_dependency_names(self):
        from loopy.symbolic import get_dependencies
        result = (super(MultiAssignmentBase, self).read_dependency_names()
                  | get_dependencies(self.expression))

        for subscript_deps in self.assignee_subscript_deps():
            result = result | subscript_deps

        return result
Esempio n. 24
0
    def read_dependency_names(self):
        from loopy.symbolic import get_dependencies
        result = (
                super(MultiAssignmentBase, self).read_dependency_names()
                | get_dependencies(self.expression))

        for subscript_deps in self.assignee_subscript_deps():
            result = result | subscript_deps

        return result
Esempio n. 25
0
def simplify_via_aff(expr):
    from loopy.symbolic import aff_to_expr, guarded_aff_from_expr, get_dependencies
    from loopy.diagnostic import ExpressionToAffineConversionError

    deps = sorted(get_dependencies(expr))
    try:
        return aff_to_expr(
            guarded_aff_from_expr(
                isl.Space.create_from_names(isl.DEFAULT_CONTEXT, list(deps)),
                expr))
    except ExpressionToAffineConversionError:
        return expr
Esempio n. 26
0
    def rec_undiff(self, expr, *args):
        dc = self.diff_context

        from loopy.symbolic import get_dependencies
        deps = get_dependencies(expr)
        for dep in deps:
            assert isinstance(dep, str)
            if (dep in dc.kernel.arg_dict
                    or dep in dc.kernel.temporary_variables):
                dc.import_output_var(dep)

        return expr
Esempio n. 27
0
    def read_dependency_names(self):
        result = (super(CInstruction, self).read_dependency_names()
                  | frozenset(self.read_variables))

        from loopy.symbolic import get_dependencies
        for name, iname_expr in self.iname_exprs:
            result = result | get_dependencies(iname_expr)

        for subscript_deps in self.assignee_subscript_deps():
            result = result | subscript_deps

        return frozenset(result) | self.predicates
Esempio n. 28
0
    def rec_undiff(self, expr, *args):
        dc = self.diff_context

        from loopy.symbolic import get_dependencies
        deps = get_dependencies(expr)
        for dep in deps:
            assert isinstance(dep, str)
            if (dep in dc.kernel.arg_dict
                    or dep in dc.kernel.temporary_variables):
                dc.import_output_var(dep)

        return expr
Esempio n. 29
0
    def find_unifiable_cf_index(index_key):
        for i, (key, _val) in enumerate(common_factors):
            unif = UnidirectionalUnifier(
                lhs_mapping_candidates=get_dependencies(key))

            unif_result = unif(key, index_key)

            if unif_result:
                assert len(unif_result) == 1
                return i, unif_result[0]

        return None, None
Esempio n. 30
0
def check_for_writes_to_predicates(kernel):
    from loopy.symbolic import get_dependencies
    for insn in kernel.instructions:
        pred_vars = (
                frozenset.union(
                    *(get_dependencies(pred) for pred in insn.predicates))
                if insn.predicates else frozenset())
        written_pred_vars = frozenset(insn.assignee_var_names()) & pred_vars
        if written_pred_vars:
            raise LoopyError("In instruction '%s': may not write to "
                    "variable(s) '%s' involved in the instruction's predicates"
                    % (insn.id, ", ".join(written_pred_vars)))
Esempio n. 31
0
    def find_unifiable_cf_index(index_key):
        for i, (key, val) in enumerate(common_factors):
            unif = UnidirectionalUnifier(
                    lhs_mapping_candidates=get_dependencies(key))

            unif_result = unif(key, index_key)

            if unif_result:
                assert len(unif_result) == 1
                return i, unif_result[0]

        return None, None
Esempio n. 32
0
def check_that_shapes_and_strides_are_arguments(kernel):
    from loopy.kernel.data import ValueArg
    from loopy.kernel.array import ArrayBase, FixedStrideArrayDimTag
    from loopy.symbolic import get_dependencies
    import loopy as lp

    integer_arg_names = {
        arg.name
        for arg in kernel.args
        if isinstance(arg, ValueArg) and arg.dtype.is_integral()
    }

    for arg in kernel.args:
        if isinstance(arg, ArrayBase):
            if isinstance(arg.shape, tuple):
                shape_deps = set()
                for shape_axis in arg.shape:
                    if shape_axis is not None:
                        shape_deps.update(get_dependencies(shape_axis))

                if not shape_deps <= integer_arg_names:
                    raise LoopyError(
                        "'%s' has a shape that depends on "
                        "non-argument(s): %s" %
                        (arg.name, ", ".join(shape_deps - integer_arg_names)))

            if arg.dim_tags is None:
                continue

            for dim_tag in arg.dim_tags:
                if isinstance(dim_tag, FixedStrideArrayDimTag):
                    if dim_tag.stride is lp.auto:
                        continue

                    deps = get_dependencies(dim_tag.stride)
                    if not deps <= integer_arg_names:
                        raise LoopyError(
                            "'%s' has a stride that depends on "
                            "non-argument(s): %s" %
                            (arg.name, ", ".join(deps - integer_arg_names)))
Esempio n. 33
0
    def write_dependency_names(self):
        """Return a set of dependencies of the left hand side of the
        assignments performed by this instruction, including written variables
        and indices.
        """

        result = set()
        for assignee, indices in self.assignees_and_indices():
            result.add(assignee)
            from loopy.symbolic import get_dependencies
            result.update(get_dependencies(indices))

        return frozenset(result)
Esempio n. 34
0
    def read_dependency_names(self):
        result = (
                super(CInstruction, self).read_dependency_names()
                | frozenset(self.read_variables))

        from loopy.symbolic import get_dependencies
        for name, iname_expr in self.iname_exprs:
            result = result | get_dependencies(iname_expr)

        for subscript_deps in self.assignee_subscript_deps():
            result = result | subscript_deps

        return frozenset(result) | self.predicates
Esempio n. 35
0
def check_that_shapes_and_strides_are_arguments(kernel):
    from loopy.kernel.data import ValueArg
    from loopy.kernel.array import ArrayBase, FixedStrideArrayDimTag
    from loopy.symbolic import get_dependencies
    import loopy as lp

    integer_arg_names = set(
            arg.name
            for arg in kernel.args
            if isinstance(arg, ValueArg)
            and arg.dtype.kind == "i")

    for arg in kernel.args:
        if isinstance(arg, ArrayBase):
            if isinstance(arg.shape, tuple):
                shape_deps = set()
                for shape_axis in arg.shape:
                    if shape_axis is not None:
                        shape_deps.update(get_dependencies(shape_axis))

                if not shape_deps <= integer_arg_names:
                    raise LoopyError("'%s' has a shape that depends on "
                            "non-argument(s): %s" % (
                                arg.name, ", ".join(shape_deps-integer_arg_names)))

            if arg.dim_tags is None:
                continue

            for dim_tag in arg.dim_tags:
                if isinstance(dim_tag, FixedStrideArrayDimTag):
                    if dim_tag.stride is lp.auto:
                        continue

                    deps = get_dependencies(dim_tag.stride)
                    if not deps <= integer_arg_names:
                        raise LoopyError("'%s' has a stride that depends on "
                                "non-argument(s): %s" % (
                                    arg.name, ", ".join(deps-integer_arg_names)))
Esempio n. 36
0
    def map_substitution(self, name, tag, arguments, expn_state):
        process_me = name == self.subst_name

        if self.subst_tag is not None and self.subst_tag != tag:
            process_me = False

        process_me = process_me and self.within(
            expn_state.kernel, expn_state.instruction, expn_state.stack)

        if not process_me:
            return super(RuleInvocationGatherer,
                         self).map_substitution(name, tag, arguments,
                                                expn_state)

        rule = self.rule_mapping_context.old_subst_rules[name]
        arg_context = self.make_new_arg_context(name, rule.arguments,
                                                arguments,
                                                expn_state.arg_context)

        arg_deps = set()
        for arg_val in six.itervalues(arg_context):
            arg_deps = (arg_deps
                        | get_dependencies(self.subst_expander(arg_val)))

        # FIXME: This is too strict--and the footprint machinery
        # needs to be taught how to deal with locally constant
        # variables.
        if not arg_deps <= self.kernel.all_inames():
            from warnings import warn
            warn("Precompute arguments in '%s(%s)' do not consist exclusively "
                 "of inames and constants--specifically, these are "
                 "not inames: %s. Ignoring." % (
                     name,
                     ", ".join(str(arg) for arg in arguments),
                     ", ".join(arg_deps - self.kernel.all_inames()),
                 ))

            return super(RuleInvocationGatherer,
                         self).map_substitution(name, tag, arguments,
                                                expn_state)

        args = [arg_context[arg_name] for arg_name in rule.arguments]

        self.access_descriptors.append(
            RuleAccessDescriptor(
                identifier=access_descriptor_id(args, expn_state.stack),
                args=args,
            ))

        return 0  # exact value irrelevant
Esempio n. 37
0
    def map_substitution(self, name, tag, arguments, expn_state):
        process_me = name == self.subst_name

        if self.subst_tag is not None and self.subst_tag != tag:
            process_me = False

        process_me = process_me and self.within(
                expn_state.kernel,
                expn_state.instruction,
                expn_state.stack)

        if not process_me:
            return super(RuleInvocationGatherer, self).map_substitution(
                    name, tag, arguments, expn_state)

        rule = self.rule_mapping_context.old_subst_rules[name]
        arg_context = self.make_new_arg_context(
                    name, rule.arguments, arguments, expn_state.arg_context)

        arg_deps = set()
        for arg_val in six.itervalues(arg_context):
            arg_deps = (arg_deps
                    | get_dependencies(self.subst_expander(arg_val)))

        # FIXME: This is too strict--and the footprint machinery
        # needs to be taught how to deal with locally constant
        # variables.
        if not arg_deps <= self.kernel.all_inames():
            from warnings import warn
            warn("Precompute arguments in '%s(%s)' do not consist exclusively "
                    "of inames and constants--specifically, these are "
                    "not inames: %s. Ignoring." % (
                        name,
                        ", ".join(str(arg) for arg in arguments),
                        ", ".join(arg_deps - self.kernel.all_inames()),
                        ))

            return super(RuleInvocationGatherer, self).map_substitution(
                    name, tag, arguments, expn_state)

        args = [arg_context[arg_name] for arg_name in rule.arguments]

        self.access_descriptors.append(
                RuleAccessDescriptor(
                    identifier=access_descriptor_id(args, expn_state.stack),
                    args=args,
                    ))

        return 0  # exact value irrelevant
Esempio n. 38
0
def check_identifiers_in_subst_rules(knl):
    """Substitution rules may only refer to kernel-global quantities or their
    own arguments.
    """

    from loopy.symbolic import get_dependencies

    allowed_identifiers = knl.all_variable_names()

    for rule in six.itervalues(knl.substitutions):
        deps = get_dependencies(rule.expression)
        rule_allowed_identifiers = allowed_identifiers | frozenset(rule.arguments)

        if not deps <= rule_allowed_identifiers:
            raise LoopyError("kernel '%s': substitution rule '%s' refers to "
                    "identifier(s) '%s' which are neither rule arguments nor "
                    "kernel-global identifiers"
                    % (knl.name, ", ".join(deps-rule_allowed_identifiers)))
Esempio n. 39
0
    def map_subscript(self, expr):
        name = expr.aggregate.name

        var = self.kernel.arg_dict.get(name)
        if var is None:
            var = self.kernel.temporary_variables.get(name)

        if var is None:
            raise LoopyError("unknown array variable in subscript: %s"
                    % name)

        from loopy.kernel.array import ArrayBase
        if not isinstance(var, ArrayBase):
            raise LoopyError("non-array subscript '%s'" % expr)

        index = expr.index
        if not isinstance(index, tuple):
            index = (index,)

        from loopy.symbolic import get_dependencies
        from loopy.kernel.array import VectorArrayDimTag
        from pymbolic.primitives import Variable

        possible = None
        for i in range(len(var.shape)):
            if (
                    isinstance(var.dim_tags[i], VectorArrayDimTag)
                    and isinstance(index[i], Variable)
                    and index[i].name == self.vec_iname):
                if var.shape[i] != self.vec_iname_length:
                    raise Unvectorizable("vector length was mismatched")

                if possible is None:
                    possible = True

            else:
                if self.vec_iname in get_dependencies(index[i]):
                    raise Unvectorizable("vectorizing iname '%s' occurs in "
                            "unvectorized subscript axis %d (1-based) of "
                            "expression '%s'"
                            % (self.vec_iname, i+1, expr))
                    break

        return bool(possible)
Esempio n. 40
0
    def guess_kernel_args_if_requested(self, kernel_args):
        # Ellipsis is syntactically allowed in Py3.
        if "..." not in kernel_args and Ellipsis not in kernel_args:
            return kernel_args

        kernel_args = [arg for arg in kernel_args
                if arg is not Ellipsis and arg != "..."]

        # {{{ find names that are *not* arguments

        temp_var_names = set(six.iterkeys(self.temporary_variables))

        for insn in self.instructions:
            if isinstance(insn, ExpressionInstruction):
                if insn.temp_var_type is not None:
                    (assignee_var_name, _), = insn.assignees_and_indices()
                    temp_var_names.add(assignee_var_name)

        # }}}

        # {{{ find existing and new arg names

        existing_arg_names = set()
        for arg in kernel_args:
            existing_arg_names.add(arg.name)

        not_new_arg_names = existing_arg_names | temp_var_names | self.all_inames

        from loopy.kernel.data import ArrayBase
        from loopy.symbolic import get_dependencies
        for arg in kernel_args:
            if isinstance(arg, ArrayBase):
                if isinstance(arg.shape, tuple):
                    self.all_names.update(
                            get_dependencies(arg.shape))

        new_arg_names = (self.all_names | self.all_params) - not_new_arg_names

        # }}}

        for arg_name in sorted(new_arg_names):
            kernel_args.append(self.make_new_arg(arg_name))

        return kernel_args
Esempio n. 41
0
def check_identifiers_in_subst_rules(knl):
    """Substitution rules may only refer to kernel-global quantities or their
    own arguments.
    """

    from loopy.symbolic import get_dependencies

    allowed_identifiers = knl.all_variable_names()

    for rule in six.itervalues(knl.substitutions):
        deps = get_dependencies(rule.expression)
        rule_allowed_identifiers = allowed_identifiers | frozenset(rule.arguments)

        if not deps <= rule_allowed_identifiers:
            raise LoopyError("kernel '%s': substitution rule '%s' refers to "
                    "identifier(s) '%s' which are neither rule arguments nor "
                    "kernel-global identifiers"
                    % (knl.name, rule.name,
                       ", ".join(deps-rule_allowed_identifiers)))
Esempio n. 42
0
    def map_subscript(self, expr):
        name = expr.aggregate.name

        var = self.kernel.arg_dict.get(name)
        if var is None:
            var = self.kernel.temporary_variables.get(name)

        if var is None:
            raise LoopyError("unknown array variable in subscript: %s"
                    % name)

        from loopy.kernel.array import ArrayBase
        if not isinstance(var, ArrayBase):
            raise LoopyError("non-array subscript '%s'" % expr)

        index = expr.index_tuple

        from loopy.symbolic import get_dependencies
        from loopy.kernel.array import VectorArrayDimTag
        from pymbolic.primitives import Variable

        possible = None
        for i in range(len(var.shape)):
            if (
                    isinstance(var.dim_tags[i], VectorArrayDimTag)
                    and isinstance(index[i], Variable)
                    and index[i].name == self.vec_iname):
                if var.shape[i] != self.vec_iname_length:
                    raise Unvectorizable("vector length was mismatched")

                if possible is None:
                    possible = True

            else:
                if self.vec_iname in get_dependencies(index[i]):
                    raise Unvectorizable("vectorizing iname '%s' occurs in "
                            "unvectorized subscript axis %d (1-based) of "
                            "expression '%s'"
                            % (self.vec_iname, i+1, expr))
                    break

        return bool(possible)
Esempio n. 43
0
    def map_floor_div(self, expr, enclosing_prec, type_context):
        from loopy.symbolic import get_dependencies

        iname_deps = get_dependencies(expr) & self.kernel.all_inames()
        domain = self.kernel.get_inames_domain(iname_deps)

        assumption_non_param = isl.BasicSet.from_params(self.kernel.assumptions)
        assumptions, domain = isl.align_two(assumption_non_param, domain)
        domain = domain & assumptions

        from loopy.isl_helpers import is_nonnegative

        num_nonneg = is_nonnegative(expr.numerator, domain)
        den_nonneg = is_nonnegative(expr.denominator, domain)

        def seen_func(name):
            idt = self.kernel.index_dtype
            from loopy.codegen import SeenFunction

            self.codegen_state.seen_functions.add(SeenFunction(name, name, (idt, idt)))

        if den_nonneg:
            if num_nonneg:
                # parenthesize to avoid negative signs being dragged in from the
                # outside by associativity
                return "(%s / %s)" % (
                    self.rec(expr.numerator, PREC_PRODUCT, type_context),
                    # analogous to ^{-1}
                    self.rec(expr.denominator, PREC_POWER, type_context),
                )
            else:
                seen_func("int_floor_div_pos_b")
                return "int_floor_div_pos_b(%s, %s)" % (
                    self.rec(expr.numerator, PREC_NONE, "i"),
                    self.rec(expr.denominator, PREC_NONE, "i"),
                )
        else:
            seen_func("int_floor_div")
            return "int_floor_div(%s, %s)" % (
                self.rec(expr.numerator, PREC_NONE, "i"),
                self.rec(expr.denominator, PREC_NONE, "i"),
            )
Esempio n. 44
0
    def get_function_declaration(self, codegen_state, codegen_result,
            schedule_index):
        fdecl = super(OpenCLCASTBuilder, self).get_function_declaration(
                codegen_state, codegen_result, schedule_index)

        from cgen.opencl import CLKernel, CLRequiredWorkGroupSize
        fdecl = CLKernel(fdecl)

        from loopy.schedule import get_insn_ids_for_block_at
        _, local_sizes = codegen_state.kernel.get_grid_sizes_for_insn_ids_as_exprs(
                get_insn_ids_for_block_at(
                    codegen_state.kernel.schedule, schedule_index))

        from loopy.symbolic import get_dependencies
        if not get_dependencies(local_sizes):
            # sizes can't have parameter dependencies if they are
            # to be used in static WG size.

            fdecl = CLRequiredWorkGroupSize(local_sizes, fdecl)

        return fdecl
Esempio n. 45
0
    def add_diff_inames(self):
        diff_inames = tuple(
            self.rule_mapping_context.make_unique_var_name(
                self.diff_iname_prefix+str(i))
            for i in range(len(self.additional_shape)))

        diff_parameters = set()
        from loopy.symbolic import get_dependencies
        for s in self.additional_shape:
            diff_parameters.update(get_dependencies(s))

        diff_domain = isl.BasicSet(
                "[%s] -> {[%s]}"
                % (", ".join(diff_parameters), ", ".join(diff_inames)))

        for i, diff_iname in enumerate(diff_inames):
            diff_domain = diff_domain & make_slab(
                diff_domain.space, diff_iname, 0, self.additional_shape[i])

        self.new_domains.append(diff_domain)

        return diff_inames
Esempio n. 46
0
    def add_diff_inames(self):
        diff_inames = tuple(
            self.rule_mapping_context.make_unique_var_name(
                self.diff_iname_prefix + str(i))
            for i in range(len(self.additional_shape)))

        diff_parameters = set()
        from loopy.symbolic import get_dependencies
        for s in self.additional_shape:
            diff_parameters.update(get_dependencies(s))

        diff_domain = isl.BasicSet(
            "[%s] -> {[%s]}" %
            (", ".join(diff_parameters), ", ".join(diff_inames)))

        for i, diff_iname in enumerate(diff_inames):
            diff_domain = diff_domain & make_slab(
                diff_domain.space, diff_iname, 0, self.additional_shape[i])

        self.new_domains.append(diff_domain)

        return diff_inames
Esempio n. 47
0
    def map_integer_div_operator(self, base_func_name, op_func, expr, type_context):
        from loopy.symbolic import get_dependencies
        iname_deps = get_dependencies(expr) & self.kernel.all_inames()
        domain = self.kernel.get_inames_domain(iname_deps)

        assumption_non_param = isl.BasicSet.from_params(self.kernel.assumptions)
        assumptions, domain = isl.align_two(assumption_non_param, domain)
        domain = domain & assumptions

        from loopy.isl_helpers import is_nonnegative
        num_nonneg = is_nonnegative(expr.numerator, domain)
        den_nonneg = is_nonnegative(expr.denominator, domain)

        result_dtype = self.infer_type(expr)
        suffix = result_dtype.numpy_dtype.type.__name__

        def seen_func(name):
            from loopy.codegen import SeenFunction
            self.codegen_state.seen_functions.add(
                    SeenFunction(
                        name, "%s_%s" % (name, suffix),
                        (result_dtype, result_dtype)))

        if den_nonneg:
            if num_nonneg:
                return op_func(
                        self.rec(expr.numerator, type_context),
                        self.rec(expr.denominator, type_context))
            else:
                seen_func("%s_pos_b" % base_func_name)
                return var("%s_pos_b_%s" % (base_func_name, suffix))(
                        self.rec(expr.numerator, 'i'),
                        self.rec(expr.denominator, 'i'))
        else:
            seen_func(base_func_name)
            return var("%s_%s" % (base_func_name, suffix))(
                    self.rec(expr.numerator, 'i'),
                    self.rec(expr.denominator, 'i'))
Esempio n. 48
0
def test_precompute_confusing_subst_arguments(ctx_factory):
    ctx = ctx_factory()

    knl = lp.make_kernel(
        "{[i,j]: 0<=i<n and 0<=j<5}",
        """
        D(i):=a[i+1]-a[i]
        b[i,j] = D(j)
        """)

    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32))

    ref_knl = knl

    knl = lp.tag_inames(knl, dict(j="g.1"))
    knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")

    from loopy.symbolic import get_dependencies
    assert "i_inner" not in get_dependencies(knl.substitutions["D"].expression)
    knl = lp.precompute(knl, "D")

    lp.auto_test_vs_ref(
            ref_knl, ctx, knl,
            parameters=dict(n=12345))
Esempio n. 49
0
def substitute_into_domain(domain, param_name, expr, allowed_param_dims):
    """
    :arg allowed_deps: A :class:`list` of :class:`str` that are
    """
    import pymbolic.primitives as prim
    from loopy.symbolic import get_dependencies, isl_set_from_expr
    if param_name not in domain.get_var_dict():
        # param_name not in domain => domain will be unchanged
        return domain

    # {{{ rename 'param_name' to avoid namespace pollution with allowed_param_dims

    dt, pos = domain.get_var_dict()[param_name]
    domain = domain.set_dim_name(
        dt, pos,
        UniqueNameGenerator(set(allowed_param_dims))(param_name))

    # }}}

    for dep in get_dependencies(expr):
        if dep in allowed_param_dims:
            domain = domain.add_dims(isl.dim_type.param, 1)
            domain = domain.set_dim_name(isl.dim_type.param,
                                         domain.dim(isl.dim_type.param) - 1,
                                         dep)
        else:
            raise ValueError("Augmenting caller's domain "
                             f"with '{dep}' is not allowed.")

    set_ = isl_set_from_expr(
        domain.space, prim.Comparison(prim.Variable(param_name), "==", expr))

    bset, = set_.get_basic_sets()
    domain = domain & bset

    return domain.project_out(dt, pos, 1)
Esempio n. 50
0
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()):
    # FIXME: Does not understand subst rules for now
    if kernel.substitutions:
        from loopy.transform.subst import expand_subst
        kernel = expand_subst(kernel)

    if var_name in kernel.temporary_variables:
        var_descr = kernel.temporary_variables[var_name]
    elif var_name in kernel.arg_dict:
        var_descr = kernel.arg_dict[var_name]
    else:
        raise NameError("array '%s' was not found" % var_name)

    # {{{ check/normalize vary_by_axes

    if isinstance(vary_by_axes, str):
        vary_by_axes = vary_by_axes.split(",")

    from loopy.kernel.array import ArrayBase
    if isinstance(var_descr, ArrayBase):
        if var_descr.dim_names is not None:
            name_to_index = dict(
                    (name, idx)
                    for idx, name in enumerate(var_descr.dim_names))
        else:
            name_to_index = {}

        def map_ax_name_to_index(ax):
            if isinstance(ax, str):
                try:
                    return name_to_index[ax]
                except KeyError:
                    raise LoopyError("axis name '%s' not understood " % ax)
            else:
                return ax

        vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes]

        if (
                vary_by_axes
                and
                (min(vary_by_axes) < 0
                or
                max(vary_by_axes) > var_descr.num_user_axes())):
            raise LoopyError("vary_by_axes refers to out-of-bounds axis index")

    # }}}

    from pymbolic.mapper.substitutor import make_subst_func
    from pymbolic.primitives import (Sum, Product, is_zero,
            flattened_sum, flattened_product, Subscript, Variable)
    from loopy.symbolic import (get_dependencies, SubstitutionMapper,
            UnidirectionalUnifier)

    # {{{ common factor key list maintenance

    # list of (index_key, common factors found)
    common_factors = []

    def find_unifiable_cf_index(index_key):
        for i, (key, val) in enumerate(common_factors):
            unif = UnidirectionalUnifier(
                    lhs_mapping_candidates=get_dependencies(key))

            unif_result = unif(key, index_key)

            if unif_result:
                assert len(unif_result) == 1
                return i, unif_result[0]

        return None, None

    def extract_index_key(access_expr):
        if isinstance(access_expr, Variable):
            return ()

        elif isinstance(access_expr, Subscript):
            index = access_expr.index_tuple
            return tuple(index[ax] for ax in vary_by_axes)
        else:
            raise ValueError("unexpected type of access_expr")

    def is_assignee(insn):
        return any(
                lhs == var_name
                for lhs, sbscript in insn.assignees_and_indices())

    def iterate_as(cls, expr):
        if isinstance(expr, cls):
            for ch in expr.children:
                yield ch
        else:
            yield expr

    # }}}

    # {{{ find common factors

    from loopy.kernel.data import Assignment

    for insn in kernel.instructions:
        if not is_assignee(insn):
            continue

        if not isinstance(insn, Assignment):
            raise LoopyError("'%s' modified by non-expression instruction"
                    % var_name)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            # {{{ doesn't exist yet

            assert unif_result is None

            my_common_factors = None

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                "in RHS of instruction '%s'"
                                % (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                if my_common_factors is None:
                    my_common_factors = product_parts
                else:
                    my_common_factors = my_common_factors & product_parts

            if my_common_factors is not None:
                common_factors.append((index_key, my_common_factors))

            # }}}
        else:
            # {{{ match, filter existing common factors

            _, my_common_factors = common_factors[cf_index]

            unif_subst_map = SubstitutionMapper(
                    make_subst_func(unif_result.lmap))

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                "in RHS of instruction '%s'"
                                % (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                my_common_factors = set(
                        cf for cf in my_common_factors
                        if unif_subst_map(cf) in product_parts)

            common_factors[cf_index] = (index_key, my_common_factors)

            # }}}

    # }}}

    # {{{ remove common factors

    new_insns = []

    for insn in kernel.instructions:
        if not isinstance(insn, Assignment) or not is_assignee(insn):
            new_insns.append(insn)
            continue

        (_, index_key), = insn.assignees_and_indices()

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            new_insns.append(insn)
            continue

        _, my_common_factors = common_factors[cf_index]

        unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

        mapped_my_common_factors = set(
                unif_subst_map(cf)
                for cf in my_common_factors)

        new_sum_terms = []

        for term in iterate_as(Sum, rhs):
            if term == lhs:
                new_sum_terms.append(term)
                continue

            new_sum_terms.append(
                    flattened_product([
                        part
                        for part in iterate_as(Product, term)
                        if part not in mapped_my_common_factors
                        ]))

        new_insns.append(
                insn.copy(expression=flattened_sum(new_sum_terms)))

    # }}}

    # {{{ substitute common factors into usage sites

    def find_substitution(expr):
        if isinstance(expr, Subscript):
            v = expr.aggregate.name
        elif isinstance(expr, Variable):
            v = expr.name
        else:
            return expr

        if v != var_name:
            return expr

        index_key = extract_index_key(expr)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

        _, my_common_factors = common_factors[cf_index]

        if my_common_factors is not None:
            return flattened_product(
                    [unif_subst_map(cf) for cf in my_common_factors]
                    + [expr])
        else:
            return expr

    insns = new_insns
    new_insns = []

    subm = SubstitutionMapper(find_substitution)

    for insn in insns:
        if not isinstance(insn, Assignment) or is_assignee(insn):
            new_insns.append(insn)
            continue

        new_insns.append(insn.with_transformed_expressions(subm))

    # }}}

    return kernel.copy(instructions=new_insns)
Esempio n. 51
0
def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()):
    assert isinstance(kernel, LoopKernel)
    # FIXME: Does not understand subst rules for now
    if kernel.substitutions:
        from loopy.transform.subst import expand_subst
        kernel = expand_subst(kernel)

    if var_name in kernel.temporary_variables:
        var_descr = kernel.temporary_variables[var_name]
    elif var_name in kernel.arg_dict:
        var_descr = kernel.arg_dict[var_name]
    else:
        raise NameError("array '%s' was not found" % var_name)

    # {{{ check/normalize vary_by_axes

    if isinstance(vary_by_axes, str):
        vary_by_axes = vary_by_axes.split(",")

    from loopy.kernel.array import ArrayBase
    if isinstance(var_descr, ArrayBase):
        if var_descr.dim_names is not None:
            name_to_index = {
                name: idx
                for idx, name in enumerate(var_descr.dim_names)
            }
        else:
            name_to_index = {}

        def map_ax_name_to_index(ax):
            if isinstance(ax, str):
                try:
                    return name_to_index[ax]
                except KeyError:
                    raise LoopyError("axis name '%s' not understood " % ax)
            else:
                return ax

        vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes]

        if (vary_by_axes
                and (min(vary_by_axes) < 0
                     or max(vary_by_axes) > var_descr.num_user_axes())):
            raise LoopyError("vary_by_axes refers to out-of-bounds axis index")

    # }}}

    from pymbolic.mapper.substitutor import make_subst_func
    from pymbolic.primitives import (Sum, Product, is_zero, flattened_sum,
                                     flattened_product, Subscript, Variable)
    from loopy.symbolic import (get_dependencies, SubstitutionMapper,
                                UnidirectionalUnifier)

    # {{{ common factor key list maintenance

    # list of (index_key, common factors found)
    common_factors = []

    def find_unifiable_cf_index(index_key):
        for i, (key, _val) in enumerate(common_factors):
            unif = UnidirectionalUnifier(
                lhs_mapping_candidates=get_dependencies(key))

            unif_result = unif(key, index_key)

            if unif_result:
                assert len(unif_result) == 1
                return i, unif_result[0]

        return None, None

    def extract_index_key(access_expr):
        if isinstance(access_expr, Variable):
            return ()

        elif isinstance(access_expr, Subscript):
            index = access_expr.index_tuple
            return tuple(index[ax] for ax in vary_by_axes)
        else:
            raise ValueError("unexpected type of access_expr")

    def is_assignee(insn):
        return var_name in insn.assignee_var_names()

    def iterate_as(cls, expr):
        if isinstance(expr, cls):
            yield from expr.children
        else:
            yield expr

    # }}}

    # {{{ find common factors

    from loopy.kernel.data import Assignment

    for insn in kernel.instructions:
        if not is_assignee(insn):
            continue

        if not isinstance(insn, Assignment):
            raise LoopyError("'%s' modified by non-single-assignment" %
                             var_name)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            # {{{ doesn't exist yet

            assert unif_result is None

            my_common_factors = None

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                         "in RHS of instruction '%s'" %
                                         (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                if my_common_factors is None:
                    my_common_factors = product_parts
                else:
                    my_common_factors = my_common_factors & product_parts

            if my_common_factors is not None:
                common_factors.append((index_key, my_common_factors))

            # }}}
        else:
            # {{{ match, filter existing common factors

            _, my_common_factors = common_factors[cf_index]

            unif_subst_map = SubstitutionMapper(
                make_subst_func(unif_result.lmap))

            for term in iterate_as(Sum, rhs):
                if term == lhs:
                    continue

                for part in iterate_as(Product, term):
                    if var_name in get_dependencies(part):
                        raise LoopyError("unexpected dependency on '%s' "
                                         "in RHS of instruction '%s'" %
                                         (var_name, insn.id))

                product_parts = set(iterate_as(Product, term))

                my_common_factors = {
                    cf
                    for cf in my_common_factors
                    if unif_subst_map(cf) in product_parts
                }

            common_factors[cf_index] = (index_key, my_common_factors)

            # }}}

    # }}}

    common_factors = [(ik, cf) for ik, cf in common_factors if cf]

    if not common_factors:
        raise LoopyError("no common factors found")

    # {{{ remove common factors

    new_insns = []

    for insn in kernel.instructions:
        if not isinstance(insn, Assignment) or not is_assignee(insn):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(insn.assignee)

        lhs = insn.assignee
        rhs = insn.expression

        if is_zero(rhs):
            new_insns.append(insn)
            continue

        index_key = extract_index_key(lhs)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        if cf_index is None:
            new_insns.append(insn)
            continue

        _, my_common_factors = common_factors[cf_index]

        unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap))

        mapped_my_common_factors = {
            unif_subst_map(cf)
            for cf in my_common_factors
        }

        new_sum_terms = []

        for term in iterate_as(Sum, rhs):
            if term == lhs:
                new_sum_terms.append(term)
                continue

            new_sum_terms.append(
                flattened_product([
                    part for part in iterate_as(Product, term)
                    if part not in mapped_my_common_factors
                ]))

        new_insns.append(insn.copy(expression=flattened_sum(new_sum_terms)))

    # }}}

    # {{{ substitute common factors into usage sites

    def find_substitution(expr):
        if isinstance(expr, Subscript):
            v = expr.aggregate.name
        elif isinstance(expr, Variable):
            v = expr.name
        else:
            return expr

        if v != var_name:
            return expr

        index_key = extract_index_key(expr)
        cf_index, unif_result = find_unifiable_cf_index(index_key)

        unif_subst_map = SubstitutionMapper(make_subst_func(unif_result.lmap))

        _, my_common_factors = common_factors[cf_index]

        if my_common_factors is not None:
            return flattened_product(
                [unif_subst_map(cf) for cf in my_common_factors] + [expr])
        else:
            return expr

    insns = new_insns
    new_insns = []

    subm = SubstitutionMapper(find_substitution)

    for insn in insns:
        if not isinstance(insn, Assignment) or is_assignee(insn):
            new_insns.append(insn)
            continue

        new_insns.append(insn.with_transformed_expressions(subm))

    # }}}

    return kernel.copy(instructions=new_insns)
Esempio n. 52
0
def precompute(
        kernel,
        subst_use,
        sweep_inames=[],
        within=None,
        storage_axes=None,
        temporary_name=None,
        precompute_inames=None,
        precompute_outer_inames=None,
        storage_axis_to_tag={},

        # "None" is a valid value here, distinct from the default.
        default_tag=_not_provided,
        dtype=None,
        fetch_bounding_box=False,
        temporary_address_space=None,
        compute_insn_id=None,
        **kwargs):
    """Precompute the expression described in the substitution rule determined by
    *subst_use* and store it in a temporary array. A precomputation needs two
    things to operate, a list of *sweep_inames* (order irrelevant) and an
    ordered list of *storage_axes* (whose order will describe the axis ordering
    of the temporary array).

    :arg subst_use: Describes what to prefetch.

        The following objects may be given for *subst_use*:

        * The name of the substitution rule.

        * The tagged name ("name$tag") of the substitution rule.

        * A list of invocations of the substitution rule.
          This list of invocations, when swept across *sweep_inames*, then serves
          to define the footprint of the precomputation.

          Invocations may be tagged ("name$tag") to filter out a subset of the
          usage sites of the substitution rule. (Namely those usage sites that
          use the same tagged name.)

          Invocations may be given as a string or as a
          :class:`pymbolic.primitives.Expression` object.

          If only one invocation is to be given, then the only entry of the list
          may be given directly.

    If the list of invocations generating the footprint is not given,
    all (tag-matching, if desired) usage sites of the substitution rule
    are used to determine the footprint.

    The following cases can arise for each sweep axis:

    * The axis is an iname that occurs within arguments specified at
      usage sites of the substitution rule. This case is assumed covered
      by the storage axes provided for the argument.

    * The axis is an iname that occurs within the *value* of the rule, but not
      within its arguments. A new, dedicated storage axis is allocated for
      such an axis.

    :arg sweep_inames: A :class:`list` of inames to be swept.
        May also equivalently be a comma-separated string.
    :arg within: a stack match as understood by
        :func:`loopy.match.parse_stack_match`.
    :arg storage_axes: A :class:`list` of inames and/or rule argument
        names/indices to be used as storage axes.
        May also equivalently be a comma-separated string.
    :arg temporary_name:
        The temporary variable name to use for storing the precomputed data.
        If it does not exist, it will be created. If it does exist, its properties
        (such as size, type) are checked (and updated, if possible) to match
        its use.
    :arg precompute_inames:
        A tuple of inames to be used to carry out the precomputation.
        If the specified inames do not already exist, they will be
        created. If they do already exist, their loop domain is verified
        against the one required for this precomputation. This tuple may
        be shorter than the (provided or automatically found) *storage_axes*
        tuple, in which case names will be automatically created.
        May also equivalently be a comma-separated string.

    :arg precompute_outer_inames: A :class:`frozenset` of inames within which
        the compute instruction is nested. If *None*, make an educated guess.
        May also be specified as a comma-separated string.

    :arg default_tag: The :ref:`iname tag <iname-tags>` to be applied to the
        inames created to perform the precomputation. The current default will
        make them local axes and automatically split them to fit the work
        group size, but this default will disappear in favor of simply leaving them
        untagged in 2019. For 2018, a warning will be issued if no *default_tag* is
        specified.

    :arg compute_insn_id: The ID of the instruction generated to perform the
        precomputation.

    If `storage_axes` is not specified, it defaults to the arrangement
    `<direct sweep axes><arguments>` with the direct sweep axes being the
    slower-varying indices.

    Trivial storage axes (i.e. axes of length 1 with respect to the sweep) are
    eliminated.
    """

    # {{{ unify temporary_address_space / temporary_scope

    temporary_scope = kwargs.pop("temporary_scope", None)

    from loopy.kernel.data import AddressSpace
    if temporary_scope is not None:
        from warnings import warn
        warn(
            "temporary_scope is deprecated. Use temporary_address_space instead",
            DeprecationWarning,
            stacklevel=2)

        if temporary_address_space is not None:
            raise LoopyError(
                "may not specify both temporary_address_space and "
                "temporary_scope")

        temporary_address_space = temporary_scope

    del temporary_scope

    # }}}

    if kwargs:
        raise TypeError("unrecognized keyword arguments: %s" %
                        ", ".join(kwargs.keys()))

    # {{{ check, standardize arguments

    if isinstance(sweep_inames, str):
        sweep_inames = [iname.strip() for iname in sweep_inames.split(",")]

    for iname in sweep_inames:
        if iname not in kernel.all_inames():
            raise RuntimeError("sweep iname '%s' is not a known iname" % iname)

    sweep_inames = list(sweep_inames)
    sweep_inames_set = frozenset(sweep_inames)

    if isinstance(storage_axes, str):
        storage_axes = [ax.strip() for ax in storage_axes.split(",")]

    if isinstance(precompute_inames, str):
        precompute_inames = [
            iname.strip() for iname in precompute_inames.split(",")
        ]

    if isinstance(precompute_outer_inames, str):
        precompute_outer_inames = frozenset(
            iname.strip() for iname in precompute_outer_inames.split(","))

    if isinstance(subst_use, str):
        subst_use = [subst_use]

    footprint_generators = None

    subst_name = None
    subst_tag = None

    from pymbolic.primitives import Variable, Call
    from loopy.symbolic import parse, TaggedVariable

    for use in subst_use:
        if isinstance(use, str):
            use = parse(use)

        if isinstance(use, Call):
            if footprint_generators is None:
                footprint_generators = []

            footprint_generators.append(use)
            subst_name_as_expr = use.function
        else:
            subst_name_as_expr = use

        if isinstance(subst_name_as_expr, TaggedVariable):
            new_subst_name = subst_name_as_expr.name
            new_subst_tag = subst_name_as_expr.tag
        elif isinstance(subst_name_as_expr, Variable):
            new_subst_name = subst_name_as_expr.name
            new_subst_tag = None
        else:
            raise ValueError("unexpected type of subst_name")

        if (subst_name, subst_tag) == (None, None):
            subst_name, subst_tag = new_subst_name, new_subst_tag
        else:
            if (subst_name, subst_tag) != (new_subst_name, new_subst_tag):
                raise ValueError("not all uses in subst_use agree "
                                 "on rule name and tag")

    from loopy.match import parse_stack_match
    within = parse_stack_match(within)

    try:
        subst = kernel.substitutions[subst_name]
    except KeyError:
        raise LoopyError("substitution rule '%s' not found" % subst_name)

    c_subst_name = subst_name.replace(".", "_")

    # {{{ handle default_tag

    from loopy.transform.data import _not_provided \
            as transform_data_not_provided

    if default_tag is _not_provided or default_tag is transform_data_not_provided:
        # no need to warn for scalar precomputes
        if sweep_inames:
            from warnings import warn
            warn(
                "Not specifying default_tag is deprecated, and default_tag "
                "will become mandatory in 2019.x. "
                "Pass 'default_tag=\"l.auto\" to match the current default, "
                "or Pass 'default_tag=None to leave the loops untagged, which "
                "is the recommended behavior.",
                DeprecationWarning,
                stacklevel=(

                    # In this case, we came here through add_prefetch. Increase
                    # the stacklevel.
                    3 if default_tag is transform_data_not_provided else 2))

        default_tag = "l.auto"

    from loopy.kernel.data import parse_tag
    default_tag = parse_tag(default_tag)

    # }}}

    # }}}

    # {{{ process invocations in footprint generators, start access_descriptors

    if footprint_generators:
        from pymbolic.primitives import Variable, Call

        access_descriptors = []

        for fpg in footprint_generators:
            if isinstance(fpg, Variable):
                args = ()
            elif isinstance(fpg, Call):
                args = fpg.parameters
            else:
                raise ValueError("footprint generator must "
                                 "be substitution rule invocation")

            access_descriptors.append(
                RuleAccessDescriptor(identifier=access_descriptor_id(
                    args, None),
                                     args=args))

    # }}}

    # {{{ gather up invocations in kernel code, finish access_descriptors

    if not footprint_generators:
        rule_mapping_context = SubstitutionRuleMappingContext(
            kernel.substitutions, kernel.get_var_name_generator())
        invg = RuleInvocationGatherer(rule_mapping_context, kernel, subst_name,
                                      subst_tag, within)
        del rule_mapping_context

        import loopy as lp
        for insn in kernel.instructions:
            if isinstance(insn, lp.MultiAssignmentBase):
                for assignee in insn.assignees:
                    invg(assignee, kernel, insn)
                invg(insn.expression, kernel, insn)

        access_descriptors = invg.access_descriptors
        if not access_descriptors:
            raise RuntimeError("no invocations of '%s' found" % subst_name)

    # }}}

    # {{{ find inames used in arguments

    expanding_usage_arg_deps = set()

    for accdesc in access_descriptors:
        for arg in accdesc.args:
            expanding_usage_arg_deps.update(
                get_dependencies(arg) & kernel.all_inames())

    # }}}

    var_name_gen = kernel.get_var_name_generator()

    # {{{ use given / find new storage_axes

    # extra axes made necessary because they don't occur in the arguments
    extra_storage_axes = set(sweep_inames_set - expanding_usage_arg_deps)

    from loopy.symbolic import SubstitutionRuleExpander
    submap = SubstitutionRuleExpander(kernel.substitutions)

    value_inames = (get_dependencies(submap(subst.expression)) -
                    frozenset(subst.arguments)) & kernel.all_inames()
    if value_inames - expanding_usage_arg_deps < extra_storage_axes:
        raise RuntimeError("unreferenced sweep inames specified: " +
                           ", ".join(extra_storage_axes - value_inames -
                                     expanding_usage_arg_deps))

    new_iname_to_tag = {}

    if storage_axes is None:
        storage_axes = []

        # Add sweep_inames (in given--rather than arbitrary--order) to
        # storage_axes *if* they are part of extra_storage_axes.
        for iname in sweep_inames:
            if iname in extra_storage_axes:
                extra_storage_axes.remove(iname)
                storage_axes.append(iname)

        if extra_storage_axes:
            if (precompute_inames is not None
                    and len(storage_axes) < len(precompute_inames)):
                raise LoopyError(
                    "must specify a sufficient number of "
                    "storage_axes to uniquely determine the meaning "
                    "of the given precompute_inames. (%d storage_axes "
                    "needed)" % len(precompute_inames))
            storage_axes.extend(sorted(extra_storage_axes))

        storage_axes.extend(range(len(subst.arguments)))

    del extra_storage_axes

    prior_storage_axis_name_dict = {}

    storage_axis_names = []
    storage_axis_sources = []  # number for arg#, or iname

    # {{{ check for pre-existing precompute_inames

    if precompute_inames is not None:
        preexisting_precompute_inames = (set(precompute_inames)
                                         & kernel.all_inames())
    else:
        preexisting_precompute_inames = set()

    # }}}

    for i, saxis in enumerate(storage_axes):
        tag_lookup_saxis = saxis

        if saxis in subst.arguments:
            saxis = subst.arguments.index(saxis)

        storage_axis_sources.append(saxis)

        if isinstance(saxis, int):
            # argument index
            name = old_name = subst.arguments[saxis]
        else:
            old_name = saxis
            name = "%s_%s" % (c_subst_name, old_name)

        if (precompute_inames is not None and i < len(precompute_inames)
                and precompute_inames[i]):
            name = precompute_inames[i]
            tag_lookup_saxis = name
            if (name not in preexisting_precompute_inames
                    and var_name_gen.is_name_conflicting(name)):
                raise RuntimeError("new storage axis name '%s' "
                                   "conflicts with existing name" % name)
        else:
            name = var_name_gen(name)

        storage_axis_names.append(name)
        if name not in preexisting_precompute_inames:
            new_iname_to_tag[name] = storage_axis_to_tag.get(
                tag_lookup_saxis, default_tag)

        prior_storage_axis_name_dict[name] = old_name

    del storage_axis_to_tag
    del storage_axes
    del precompute_inames

    # }}}

    # {{{ fill out access_descriptors[...].storage_axis_exprs

    access_descriptors = [
        accdesc.copy(storage_axis_exprs=storage_axis_exprs(
            storage_axis_sources, accdesc.args))
        for accdesc in access_descriptors
    ]

    # }}}

    expanding_inames = sweep_inames_set | frozenset(expanding_usage_arg_deps)
    assert expanding_inames <= kernel.all_inames()

    if storage_axis_names:
        # {{{ find domain to be changed

        change_inames = expanding_inames | preexisting_precompute_inames

        from loopy.kernel.tools import DomainChanger
        domch = DomainChanger(kernel, change_inames)

        if domch.leaf_domain_index is not None:
            # If the sweep inames are at home in parent domains, then we'll add
            # fetches with loops over copies of these parent inames that will end
            # up being scheduled *within* loops over these parents.

            for iname in sweep_inames_set:
                if kernel.get_home_domain_index(
                        iname) != domch.leaf_domain_index:
                    raise RuntimeError(
                        "sweep iname '%s' is not 'at home' in the "
                        "sweep's leaf domain" % iname)

        # }}}

        abm = ArrayToBufferMap(kernel, domch.domain, sweep_inames,
                               access_descriptors, len(storage_axis_names))

        non1_storage_axis_names = []
        for i, saxis in enumerate(storage_axis_names):
            if abm.non1_storage_axis_flags[i]:
                non1_storage_axis_names.append(saxis)
            else:
                del new_iname_to_tag[saxis]

                if saxis in preexisting_precompute_inames:
                    raise LoopyError(
                        "precompute axis %d (1-based) was "
                        "eliminated as "
                        "having length 1 but also mapped to existing "
                        "iname '%s'" % (i + 1, saxis))

        mod_domain = domch.domain

        # {{{ modify the domain, taking into account preexisting inames

        # inames may already exist in mod_domain, add them primed to start
        primed_non1_saxis_names = [
            iname + "'" for iname in non1_storage_axis_names
        ]

        mod_domain = abm.augment_domain_with_sweep(
            domch.domain,
            primed_non1_saxis_names,
            boxify_sweep=fetch_bounding_box)

        check_domain = mod_domain

        for i, saxis in enumerate(non1_storage_axis_names):
            var_dict = mod_domain.get_var_dict(isl.dim_type.set)

            if saxis in preexisting_precompute_inames:
                # add equality constraint between existing and new variable

                dt, dim_idx = var_dict[saxis]
                saxis_aff = isl.Aff.var_on_domain(mod_domain.space, dt,
                                                  dim_idx)

                dt, dim_idx = var_dict[primed_non1_saxis_names[i]]
                new_var_aff = isl.Aff.var_on_domain(mod_domain.space, dt,
                                                    dim_idx)

                mod_domain = mod_domain.add_constraint(
                    isl.Constraint.equality_from_aff(new_var_aff - saxis_aff))

                # project out the new one
                mod_domain = mod_domain.project_out(dt, dim_idx, 1)

            else:
                # remove the prime from the new variable
                dt, dim_idx = var_dict[primed_non1_saxis_names[i]]
                mod_domain = mod_domain.set_dim_name(dt, dim_idx, saxis)

        def add_assumptions(d):
            assumption_non_param = isl.BasicSet.from_params(kernel.assumptions)
            assumptions, domain = isl.align_two(assumption_non_param, d)
            return assumptions & domain

        # {{{ check that we got the desired domain

        check_domain = add_assumptions(
            check_domain.project_out_except(primed_non1_saxis_names,
                                            [isl.dim_type.set]))

        mod_check_domain = add_assumptions(mod_domain)

        # re-add the prime from the new variable
        var_dict = mod_check_domain.get_var_dict(isl.dim_type.set)

        for saxis in non1_storage_axis_names:
            dt, dim_idx = var_dict[saxis]
            mod_check_domain = mod_check_domain.set_dim_name(
                dt, dim_idx, saxis + "'")

        mod_check_domain = mod_check_domain.project_out_except(
            primed_non1_saxis_names, [isl.dim_type.set])

        mod_check_domain, check_domain = isl.align_two(mod_check_domain,
                                                       check_domain)

        # The modified domain can't get bigger by adding constraints
        assert mod_check_domain <= check_domain

        if not check_domain <= mod_check_domain:
            print(check_domain)
            print(mod_check_domain)
            raise LoopyError("domain of preexisting inames does not match "
                             "domain needed for precompute")

        # }}}

        # {{{ check that we didn't shrink the original domain

        # project out the new names from the modified domain
        orig_domain_inames = list(domch.domain.get_var_dict(isl.dim_type.set))
        mod_check_domain = add_assumptions(
            mod_domain.project_out_except(orig_domain_inames,
                                          [isl.dim_type.set]))

        check_domain = add_assumptions(domch.domain)

        mod_check_domain, check_domain = isl.align_two(mod_check_domain,
                                                       check_domain)

        # The modified domain can't get bigger by adding constraints
        assert mod_check_domain <= check_domain

        if not check_domain <= mod_check_domain:
            print(check_domain)
            print(mod_check_domain)
            raise LoopyError(
                "original domain got shrunk by applying the precompute")

        # }}}

        # }}}

        new_kernel_domains = domch.get_domains_with(mod_domain)

    else:
        # leave kernel domains unchanged
        new_kernel_domains = kernel.domains

        non1_storage_axis_names = []
        abm = NoOpArrayToBufferMap()

    kernel = kernel.copy(domains=new_kernel_domains)

    # {{{ set up compute insn

    if temporary_name is None:
        temporary_name = var_name_gen(based_on=c_subst_name)

    assignee = var(temporary_name)

    if non1_storage_axis_names:
        assignee = assignee[tuple(
            var(iname) for iname in non1_storage_axis_names)]

    # {{{ process substitutions on compute instruction

    storage_axis_subst_dict = {}

    for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices):
        if arg_name in non1_storage_axis_names:
            arg = var(arg_name)
        else:
            arg = 0

        storage_axis_subst_dict[prior_storage_axis_name_dict.get(
            arg_name, arg_name)] = arg + bi

    rule_mapping_context = SubstitutionRuleMappingContext(
        kernel.substitutions, kernel.get_var_name_generator())

    from loopy.match import parse_stack_match
    expr_subst_map = RuleAwareSubstitutionMapper(
        rule_mapping_context,
        make_subst_func(storage_axis_subst_dict),
        within=parse_stack_match(None))

    compute_expression = expr_subst_map(subst.expression, kernel, None)

    # }}}

    from loopy.kernel.data import Assignment
    if compute_insn_id is None:
        compute_insn_id = kernel.make_unique_instruction_id(
            based_on=c_subst_name)

    compute_insn = Assignment(
        id=compute_insn_id,
        assignee=assignee,
        expression=compute_expression,
        # within_inames determined below
    )
    compute_dep_id = compute_insn_id
    added_compute_insns = [compute_insn]

    if temporary_address_space == AddressSpace.GLOBAL:
        barrier_insn_id = kernel.make_unique_instruction_id(
            based_on=c_subst_name + "_barrier")
        from loopy.kernel.instruction import BarrierInstruction
        barrier_insn = BarrierInstruction(id=barrier_insn_id,
                                          depends_on=frozenset(
                                              [compute_insn_id]),
                                          synchronization_kind="global",
                                          mem_kind="global")
        compute_dep_id = barrier_insn_id

        added_compute_insns.append(barrier_insn)

    # }}}

    # {{{ substitute rule into expressions in kernel (if within footprint)

    from loopy.symbolic import SubstitutionRuleExpander
    expander = SubstitutionRuleExpander(kernel.substitutions)

    invr = RuleInvocationReplacer(rule_mapping_context,
                                  subst_name,
                                  subst_tag,
                                  within,
                                  access_descriptors,
                                  abm,
                                  storage_axis_names,
                                  storage_axis_sources,
                                  non1_storage_axis_names,
                                  temporary_name,
                                  compute_insn_id,
                                  compute_dep_id,
                                  compute_read_variables=get_dependencies(
                                      expander(compute_expression)))

    kernel = invr.map_kernel(kernel)
    kernel = kernel.copy(instructions=added_compute_insns +
                         kernel.instructions)
    kernel = rule_mapping_context.finish_kernel(kernel)

    # }}}

    # {{{ add dependencies to compute insn

    kernel = kernel.copy(instructions=[
        insn.copy(depends_on=frozenset(invr.compute_insn_depends_on)) if insn.
        id == compute_insn_id else insn for insn in kernel.instructions
    ])

    # }}}

    # {{{ propagate storage iname subst to dependencies of compute instructions

    from loopy.kernel.tools import find_recursive_dependencies
    compute_deps = find_recursive_dependencies(kernel,
                                               frozenset([compute_insn_id]))

    # FIXME: Need to verify that there are no outside dependencies
    # on compute_deps

    prior_storage_axis_names = frozenset(storage_axis_subst_dict)

    new_insns = []
    for insn in kernel.instructions:
        if (insn.id in compute_deps
                and insn.within_inames & prior_storage_axis_names):
            insn = (insn.with_transformed_expressions(
                lambda expr: expr_subst_map(expr, kernel, insn)).copy(
                    within_inames=frozenset(
                        storage_axis_subst_dict.get(iname, var(iname)).name
                        for iname in insn.within_inames)))

            new_insns.append(insn)
        else:
            new_insns.append(insn)

    kernel = kernel.copy(instructions=new_insns)

    # }}}

    # {{{ determine inames for compute insn

    if precompute_outer_inames is None:
        from loopy.kernel.tools import guess_iname_deps_based_on_var_use
        precompute_outer_inames = (
            frozenset(non1_storage_axis_names)
            | frozenset((expanding_usage_arg_deps | value_inames) -
                        sweep_inames_set)
            | guess_iname_deps_based_on_var_use(kernel, compute_insn))
    else:
        if not isinstance(precompute_outer_inames, frozenset):
            raise TypeError("precompute_outer_inames must be a frozenset")

        precompute_outer_inames = precompute_outer_inames \
                | frozenset(non1_storage_axis_names)

    kernel = kernel.copy(instructions=[
        insn.copy(within_inames=precompute_outer_inames) if insn.id ==
        compute_insn_id else insn for insn in kernel.instructions
    ])

    # }}}

    # {{{ set up temp variable

    import loopy as lp
    if dtype is not None:
        dtype = np.dtype(dtype)

    if temporary_address_space is None:
        temporary_address_space = lp.auto

    new_temp_shape = tuple(abm.non1_storage_shape)

    new_temporary_variables = kernel.temporary_variables.copy()
    if temporary_name not in new_temporary_variables:
        temp_var = lp.TemporaryVariable(
            name=temporary_name,
            dtype=dtype,
            base_indices=(0, ) * len(new_temp_shape),
            shape=tuple(abm.non1_storage_shape),
            address_space=temporary_address_space,
            dim_names=tuple(non1_storage_axis_names))

    else:
        temp_var = new_temporary_variables[temporary_name]

        # {{{ check and adapt existing temporary

        if temp_var.dtype is lp.auto:
            pass
        elif temp_var.dtype is not lp.auto and dtype is lp.auto:
            dtype = temp_var.dtype
        elif temp_var.dtype is not lp.auto and dtype is not lp.auto:
            if temp_var.dtype != dtype:
                raise LoopyError("Existing and new dtype of temporary '%s' "
                                 "do not match (existing: %s, new: %s)" %
                                 (temporary_name, temp_var.dtype, dtype))

        temp_var = temp_var.copy(dtype=dtype)

        if len(temp_var.shape) != len(new_temp_shape):
            raise LoopyError(
                "Existing and new temporary '%s' do not "
                "have matching number of dimensions ('%d' vs. '%d') " %
                (temporary_name, len(temp_var.shape), len(new_temp_shape)))

        if temp_var.base_indices != (0, ) * len(new_temp_shape):
            raise LoopyError(
                "Existing and new temporary '%s' do not "
                "have matching number of dimensions ('%d' vs. '%d') " %
                (temporary_name, len(temp_var.shape), len(new_temp_shape)))

        new_temp_shape = tuple(
            max(i, ex_i) for i, ex_i in zip(new_temp_shape, temp_var.shape))

        temp_var = temp_var.copy(shape=new_temp_shape)

        if temporary_address_space == temp_var.address_space:
            pass
        elif temporary_address_space is lp.auto:
            temporary_address_space = temp_var.address_space
        elif temp_var.address_space is lp.auto:
            pass
        else:
            raise LoopyError("Existing and new temporary '%s' do not "
                             "have matching scopes (existing: %s, new: %s)" %
                             (temporary_name,
                              AddressSpace.stringify(temp_var.address_space),
                              AddressSpace.stringify(temporary_address_space)))

        temp_var = temp_var.copy(address_space=temporary_address_space)

        # }}}

    new_temporary_variables[temporary_name] = temp_var

    kernel = kernel.copy(temporary_variables=new_temporary_variables)

    # }}}

    from loopy import tag_inames
    kernel = tag_inames(kernel, new_iname_to_tag)

    from loopy.kernel.data import AutoFitLocalIndexTag, filter_iname_tags_by_type

    if filter_iname_tags_by_type(new_iname_to_tag.values(),
                                 AutoFitLocalIndexTag):
        from loopy.kernel.tools import assign_automatic_axes
        kernel = assign_automatic_axes(kernel)

    return kernel
Esempio n. 53
0
def precompute(kernel, subst_use, sweep_inames=[], within=None,
        storage_axes=None, temporary_name=None, precompute_inames=None,
        storage_axis_to_tag={}, default_tag="l.auto", dtype=None,
        fetch_bounding_box=False, temporary_is_local=None,
        compute_insn_id=None):
    """Precompute the expression described in the substitution rule determined by
    *subst_use* and store it in a temporary array. A precomputation needs two
    things to operate, a list of *sweep_inames* (order irrelevant) and an
    ordered list of *storage_axes* (whose order will describe the axis ordering
    of the temporary array).

    :arg subst_use: Describes what to prefetch.

    The following objects may be given for *subst_use*:

    * The name of the substitution rule.

    * The tagged name ("name$tag") of the substitution rule.

    * A list of invocations of the substitution rule.
      This list of invocations, when swept across *sweep_inames*, then serves
      to define the footprint of the precomputation.

      Invocations may be tagged ("name$tag") to filter out a subset of the
      usage sites of the substitution rule. (Namely those usage sites that
      use the same tagged name.)

      Invocations may be given as a string or as a
      :class:`pymbolic.primitives.Expression` object.

      If only one invocation is to be given, then the only entry of the list
      may be given directly.

    If the list of invocations generating the footprint is not given,
    all (tag-matching, if desired) usage sites of the substitution rule
    are used to determine the footprint.

    The following cases can arise for each sweep axis:

    * The axis is an iname that occurs within arguments specified at
      usage sites of the substitution rule. This case is assumed covered
      by the storage axes provided for the argument.

    * The axis is an iname that occurs within the *value* of the rule, but not
      within its arguments. A new, dedicated storage axis is allocated for
      such an axis.

    :arg sweep_inames: A :class:`list` of inames and/or rule argument
        names to be swept.
        May also equivalently be a comma-separated string.
    :arg storage_axes: A :class:`list` of inames and/or rule argument
        names/indices to be used as storage axes.
        May also equivalently be a comma-separated string.
    :arg within: a stack match as understood by
        :func:`loopy.context_matching.parse_stack_match`.
    :arg temporary_name:
        The temporary variable name to use for storing the precomputed data.
        If it does not exist, it will be created. If it does exist, its properties
        (such as size, type) are checked (and updated, if possible) to match
        its use.
    :arg precompute_inames:
        A tuple of inames to be used to carry out the precomputation.
        If the specified inames do not already exist, they will be
        created. If they do already exist, their loop domain is verified
        against the one required for this precomputation. This tuple may
        be shorter than the (provided or automatically found) *storage_axes*
        tuple, in which case names will be automatically created.
        May also equivalently be a comma-separated string.

    :arg compute_insn_id: The ID of the instruction performing the precomputation.

    If `storage_axes` is not specified, it defaults to the arrangement
    `<direct sweep axes><arguments>` with the direct sweep axes being the
    slower-varying indices.

    Trivial storage axes (i.e. axes of length 1 with respect to the sweep) are
    eliminated.
    """

    # {{{ check, standardize arguments

    if isinstance(sweep_inames, str):
        sweep_inames = [iname.strip() for iname in sweep_inames.split(",")]

    for iname in sweep_inames:
        if iname not in kernel.all_inames():
            raise RuntimeError("sweep iname '%s' is not a known iname"
                    % iname)

    sweep_inames = list(sweep_inames)
    sweep_inames_set = frozenset(sweep_inames)

    if isinstance(storage_axes, str):
        storage_axes = [ax.strip() for ax in storage_axes.split(",")]

    if isinstance(precompute_inames, str):
        precompute_inames = [iname.strip() for iname in precompute_inames.split(",")]

    if isinstance(subst_use, str):
        subst_use = [subst_use]

    footprint_generators = None

    subst_name = None
    subst_tag = None

    from pymbolic.primitives import Variable, Call
    from loopy.symbolic import parse, TaggedVariable

    for use in subst_use:
        if isinstance(use, str):
            use = parse(use)

        if isinstance(use, Call):
            if footprint_generators is None:
                footprint_generators = []

            footprint_generators.append(use)
            subst_name_as_expr = use.function
        else:
            subst_name_as_expr = use

        if isinstance(subst_name_as_expr, TaggedVariable):
            new_subst_name = subst_name_as_expr.name
            new_subst_tag = subst_name_as_expr.tag
        elif isinstance(subst_name_as_expr, Variable):
            new_subst_name = subst_name_as_expr.name
            new_subst_tag = None
        else:
            raise ValueError("unexpected type of subst_name")

        if (subst_name, subst_tag) == (None, None):
            subst_name, subst_tag = new_subst_name, new_subst_tag
        else:
            if (subst_name, subst_tag) != (new_subst_name, new_subst_tag):
                raise ValueError("not all uses in subst_use agree "
                        "on rule name and tag")

    from loopy.context_matching import parse_stack_match
    within = parse_stack_match(within)

    from loopy.kernel.data import parse_tag
    default_tag = parse_tag(default_tag)

    subst = kernel.substitutions[subst_name]
    c_subst_name = subst_name.replace(".", "_")

    # }}}

    # {{{ process invocations in footprint generators, start access_descriptors

    if footprint_generators:
        from pymbolic.primitives import Variable, Call

        access_descriptors = []

        for fpg in footprint_generators:
            if isinstance(fpg, Variable):
                args = ()
            elif isinstance(fpg, Call):
                args = fpg.parameters
            else:
                raise ValueError("footprint generator must "
                        "be substitution rule invocation")

            access_descriptors.append(
                    RuleAccessDescriptor(
                        identifier=access_descriptor_id(args, None),
                        args=args
                        ))

    # }}}

    # {{{ gather up invocations in kernel code, finish access_descriptors

    if not footprint_generators:
        rule_mapping_context = SubstitutionRuleMappingContext(
                kernel.substitutions, kernel.get_var_name_generator())
        invg = RuleInvocationGatherer(
                rule_mapping_context, kernel, subst_name, subst_tag, within)
        del rule_mapping_context

        import loopy as lp
        for insn in kernel.instructions:
            if isinstance(insn, lp.Assignment):
                invg(insn.assignee, kernel, insn)
                invg(insn.expression, kernel, insn)

        access_descriptors = invg.access_descriptors
        if not access_descriptors:
            raise RuntimeError("no invocations of '%s' found" % subst_name)

    # }}}

    # {{{ find inames used in arguments

    expanding_usage_arg_deps = set()

    for accdesc in access_descriptors:
        for arg in accdesc.args:
            expanding_usage_arg_deps.update(
                    get_dependencies(arg) & kernel.all_inames())

    # }}}

    var_name_gen = kernel.get_var_name_generator()

    # {{{ use given / find new storage_axes

    # extra axes made necessary because they don't occur in the arguments
    extra_storage_axes = set(sweep_inames_set - expanding_usage_arg_deps)

    from loopy.symbolic import SubstitutionRuleExpander
    submap = SubstitutionRuleExpander(kernel.substitutions)

    value_inames = get_dependencies(
            submap(subst.expression)
            ) & kernel.all_inames()
    if value_inames - expanding_usage_arg_deps < extra_storage_axes:
        raise RuntimeError("unreferenced sweep inames specified: "
                + ", ".join(extra_storage_axes
                    - value_inames - expanding_usage_arg_deps))

    new_iname_to_tag = {}

    if storage_axes is None:
        storage_axes = []

        # Add sweep_inames (in given--rather than arbitrary--order) to
        # storage_axes *if* they are part of extra_storage_axes.
        for iname in sweep_inames:
            if iname in extra_storage_axes:
                extra_storage_axes.remove(iname)
                storage_axes.append(iname)

        if extra_storage_axes:
            if (precompute_inames is not None
                    and len(storage_axes) < len(precompute_inames)):
                raise LoopyError("must specify a sufficient number of "
                        "storage_axes to uniquely determine the meaning "
                        "of the given precompute_inames. (%d storage_axes "
                        "needed)" % len(precompute_inames))
            storage_axes.extend(sorted(extra_storage_axes))

        storage_axes.extend(range(len(subst.arguments)))

    del extra_storage_axes

    prior_storage_axis_name_dict = {}

    storage_axis_names = []
    storage_axis_sources = []  # number for arg#, or iname

    # {{{ check for pre-existing precompute_inames

    if precompute_inames is not None:
        preexisting_precompute_inames = (
                set(precompute_inames) & kernel.all_inames())
    else:
        preexisting_precompute_inames = set()

    # }}}

    for i, saxis in enumerate(storage_axes):
        tag_lookup_saxis = saxis

        if saxis in subst.arguments:
            saxis = subst.arguments.index(saxis)

        storage_axis_sources.append(saxis)

        if isinstance(saxis, int):
            # argument index
            name = old_name = subst.arguments[saxis]
        else:
            old_name = saxis
            name = "%s_%s" % (c_subst_name, old_name)

        if (precompute_inames is not None
                and i < len(precompute_inames)
                and precompute_inames[i]):
            name = precompute_inames[i]
            tag_lookup_saxis = name
            if (name not in preexisting_precompute_inames
                    and var_name_gen.is_name_conflicting(name)):
                raise RuntimeError("new storage axis name '%s' "
                        "conflicts with existing name" % name)
        else:
            name = var_name_gen(name)

        storage_axis_names.append(name)
        if name not in preexisting_precompute_inames:
            new_iname_to_tag[name] = storage_axis_to_tag.get(
                    tag_lookup_saxis, default_tag)

        prior_storage_axis_name_dict[name] = old_name

    del storage_axis_to_tag
    del storage_axes
    del precompute_inames

    # }}}

    # {{{ fill out access_descriptors[...].storage_axis_exprs

    access_descriptors = [
            accdesc.copy(
                storage_axis_exprs=storage_axis_exprs(
                    storage_axis_sources, accdesc.args))
            for accdesc in access_descriptors]

    # }}}

    expanding_inames = sweep_inames_set | frozenset(expanding_usage_arg_deps)
    assert expanding_inames <= kernel.all_inames()

    if storage_axis_names:
        # {{{ find domain to be changed

        change_inames = expanding_inames | preexisting_precompute_inames

        from loopy.kernel.tools import DomainChanger
        domch = DomainChanger(kernel, change_inames)

        if domch.leaf_domain_index is not None:
            # If the sweep inames are at home in parent domains, then we'll add
            # fetches with loops over copies of these parent inames that will end
            # up being scheduled *within* loops over these parents.

            for iname in sweep_inames_set:
                if kernel.get_home_domain_index(iname) != domch.leaf_domain_index:
                    raise RuntimeError("sweep iname '%s' is not 'at home' in the "
                            "sweep's leaf domain" % iname)

        # }}}

        abm = ArrayToBufferMap(kernel, domch.domain, sweep_inames,
                access_descriptors, len(storage_axis_names))

        non1_storage_axis_names = []
        for i, saxis in enumerate(storage_axis_names):
            if abm.non1_storage_axis_flags[i]:
                non1_storage_axis_names.append(saxis)
            else:
                del new_iname_to_tag[saxis]

                if saxis in preexisting_precompute_inames:
                    raise LoopyError("precompute axis %d (1-based) was "
                            "eliminated as "
                            "having length 1 but also mapped to existing "
                            "iname '%s'" % (i+1, saxis))

        mod_domain = domch.domain

        # {{{ modify the domain, taking into account preexisting inames

        # inames may already exist in mod_domain, add them primed to start
        primed_non1_saxis_names = [
                iname+"'" for iname in non1_storage_axis_names]

        mod_domain = abm.augment_domain_with_sweep(
            domch.domain, primed_non1_saxis_names,
            boxify_sweep=fetch_bounding_box)

        check_domain = mod_domain

        for i, saxis in enumerate(non1_storage_axis_names):
            var_dict = mod_domain.get_var_dict(isl.dim_type.set)

            if saxis in preexisting_precompute_inames:
                # add equality constraint between existing and new variable

                dt, dim_idx = var_dict[saxis]
                saxis_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx)

                dt, dim_idx = var_dict[primed_non1_saxis_names[i]]
                new_var_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx)

                mod_domain = mod_domain.add_constraint(
                        isl.Constraint.equality_from_aff(new_var_aff - saxis_aff))

                # project out the new one
                mod_domain = mod_domain.project_out(dt, dim_idx, 1)

            else:
                # remove the prime from the new variable
                dt, dim_idx = var_dict[primed_non1_saxis_names[i]]
                mod_domain = mod_domain.set_dim_name(dt, dim_idx, saxis)

        # {{{ check that we got the desired domain

        check_domain = check_domain.project_out_except(
                primed_non1_saxis_names, [isl.dim_type.set])

        mod_check_domain = mod_domain

        # re-add the prime from the new variable
        var_dict = mod_check_domain.get_var_dict(isl.dim_type.set)

        for saxis in non1_storage_axis_names:
            dt, dim_idx = var_dict[saxis]
            mod_check_domain = mod_check_domain.set_dim_name(dt, dim_idx, saxis+"'")

        mod_check_domain = mod_check_domain.project_out_except(
                primed_non1_saxis_names, [isl.dim_type.set])

        mod_check_domain, check_domain = isl.align_two(
                mod_check_domain, check_domain)

        # The modified domain can't get bigger by adding constraints
        assert mod_check_domain <= check_domain

        if not check_domain <= mod_check_domain:
            print(check_domain)
            print(mod_check_domain)
            raise LoopyError("domain of preexisting inames does not match "
                    "domain needed for precompute")

        # }}}

        # {{{ check that we didn't shrink the original domain

        # project out the new names from the modified domain
        orig_domain_inames = list(domch.domain.get_var_dict(isl.dim_type.set))
        mod_check_domain = mod_domain.project_out_except(
                orig_domain_inames, [isl.dim_type.set])

        check_domain = domch.domain

        mod_check_domain, check_domain = isl.align_two(
                mod_check_domain, check_domain)

        # The modified domain can't get bigger by adding constraints
        assert mod_check_domain <= check_domain

        if not check_domain <= mod_check_domain:
            print(check_domain)
            print(mod_check_domain)
            raise LoopyError("original domain got shrunk by applying the precompute")

        # }}}

        # }}}

        new_kernel_domains = domch.get_domains_with(mod_domain)

    else:
        # leave kernel domains unchanged
        new_kernel_domains = kernel.domains

        non1_storage_axis_names = []
        abm = NoOpArrayToBufferMap()

    kernel = kernel.copy(domains=new_kernel_domains)

    # {{{ set up compute insn

    if temporary_name is None:
        temporary_name = var_name_gen(based_on=c_subst_name)

    assignee = var(temporary_name)

    if non1_storage_axis_names:
        assignee = assignee.index(
                tuple(var(iname) for iname in non1_storage_axis_names))

    # {{{ process substitutions on compute instruction

    storage_axis_subst_dict = {}

    for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices):
        if arg_name in non1_storage_axis_names:
            arg = var(arg_name)
        else:
            arg = 0

        storage_axis_subst_dict[
                prior_storage_axis_name_dict.get(arg_name, arg_name)] = arg+bi

    rule_mapping_context = SubstitutionRuleMappingContext(
            kernel.substitutions, kernel.get_var_name_generator())

    from loopy.context_matching import parse_stack_match
    expr_subst_map = RuleAwareSubstitutionMapper(
            rule_mapping_context,
            make_subst_func(storage_axis_subst_dict),
            within=parse_stack_match(None))

    compute_expression = expr_subst_map(subst.expression, kernel, None)

    # }}}

    from loopy.kernel.data import Assignment
    if compute_insn_id is None:
        compute_insn_id = kernel.make_unique_instruction_id(based_on=c_subst_name)

    compute_insn = Assignment(
            id=compute_insn_id,
            assignee=assignee,
            expression=compute_expression)

    # }}}

    # {{{ substitute rule into expressions in kernel (if within footprint)

    invr = RuleInvocationReplacer(rule_mapping_context,
            subst_name, subst_tag, within,
            access_descriptors, abm,
            storage_axis_names, storage_axis_sources,
            non1_storage_axis_names,
            temporary_name, compute_insn_id)

    kernel = invr.map_kernel(kernel)
    kernel = kernel.copy(
            instructions=[compute_insn] + kernel.instructions)
    kernel = rule_mapping_context.finish_kernel(kernel)

    # }}}

    # {{{ set up temp variable

    import loopy as lp
    if dtype is None:
        dtype = lp.auto
    else:
        dtype = np.dtype(dtype)

    import loopy as lp

    if temporary_is_local is None:
        temporary_is_local = lp.auto

    new_temp_shape = tuple(abm.non1_storage_shape)

    new_temporary_variables = kernel.temporary_variables.copy()
    if temporary_name not in new_temporary_variables:
        temp_var = lp.TemporaryVariable(
                name=temporary_name,
                dtype=dtype,
                base_indices=(0,)*len(new_temp_shape),
                shape=tuple(abm.non1_storage_shape),
                is_local=temporary_is_local)

    else:
        temp_var = new_temporary_variables[temporary_name]

        # {{{ check and adapt existing temporary

        if temp_var.dtype is lp.auto:
            pass
        elif temp_var.dtype is not lp.auto and dtype is lp.auto:
            dtype = temp_var.dtype
        elif temp_var.dtype is not lp.auto and dtype is not lp.auto:
            if temp_var.dtype != dtype:
                raise LoopyError("Existing and new dtype of temporary '%s' "
                        "do not match (existing: %s, new: %s)"
                        % (temporary_name, temp_var.dtype, dtype))

        temp_var = temp_var.copy(dtype=dtype)

        if len(temp_var.shape) != len(new_temp_shape):
            raise LoopyError("Existing and new temporary '%s' do not "
                    "have matching number of dimensions "
                    % (temporary_name,
                        len(temp_var.shape), len(new_temp_shape)))

        if temp_var.base_indices != (0,) * len(new_temp_shape):
            raise LoopyError("Existing and new temporary '%s' do not "
                    "have matching number of dimensions "
                    % (temporary_name,
                        len(temp_var.shape), len(new_temp_shape)))

        new_temp_shape = tuple(
                max(i, ex_i)
                for i, ex_i in zip(new_temp_shape, temp_var.shape))

        temp_var = temp_var.copy(shape=new_temp_shape)

        if temporary_is_local == temp_var.is_local:
            pass
        elif temporary_is_local is lp.auto:
            temporary_is_local = temp_var.is_local
        elif temp_var.is_local is lp.auto:
            pass
        else:
            raise LoopyError("Existing and new temporary '%s' do not "
                    "have matching values of 'is_local'"
                    % (temporary_name,
                        temp_var.is_local, temporary_is_local))

        temp_var = temp_var.copy(is_local=temporary_is_local)

        # }}}

    new_temporary_variables[temporary_name] = temp_var

    kernel = kernel.copy(
            temporary_variables=new_temporary_variables)

    # }}}

    from loopy import tag_inames
    kernel = tag_inames(kernel, new_iname_to_tag)

    from loopy.kernel.data import AutoFitLocalIndexTag
    has_automatic_axes = any(
            isinstance(tag, AutoFitLocalIndexTag)
            for tag in new_iname_to_tag.values())

    if has_automatic_axes:
        from loopy.kernel.tools import assign_automatic_axes
        kernel = assign_automatic_axes(kernel)

    return kernel
Esempio n. 54
0
File: data.py Progetto: shwina/loopy
def _process_footprint_subscripts(kernel, rule_name, sweep_inames,
        footprint_subscripts, arg):
    """Track applied iname rewrites, deal with slice specifiers ':'."""

    name_gen = kernel.get_var_name_generator()

    from pymbolic.primitives import Variable

    if footprint_subscripts is None:
        return kernel, rule_name, sweep_inames, []

    if not isinstance(footprint_subscripts, (list, tuple)):
        footprint_subscripts = [footprint_subscripts]

    inames_to_be_removed = []

    new_footprint_subscripts = []
    for fsub in footprint_subscripts:
        if isinstance(fsub, str):
            from loopy.symbolic import parse
            fsub = parse(fsub)

        if not isinstance(fsub, tuple):
            fsub = (fsub,)

        if len(fsub) != arg.num_user_axes():
            raise ValueError("sweep index '%s' has the wrong number of dimensions"
                    % str(fsub))

        for subst_map in kernel.applied_iname_rewrites:
            from loopy.symbolic import SubstitutionMapper
            from pymbolic.mapper.substitutor import make_subst_func
            fsub = SubstitutionMapper(make_subst_func(subst_map))(fsub)

        from loopy.symbolic import get_dependencies
        fsub_dependencies = get_dependencies(fsub)

        new_fsub = []
        for axis_nr, fsub_axis in enumerate(fsub):
            from pymbolic.primitives import Slice
            if isinstance(fsub_axis, Slice):
                if fsub_axis.children != (None,):
                    raise NotImplementedError("add_prefetch only "
                            "supports full slices")

                axis_name = name_gen(
                        based_on="%s_fetch_axis_%d" % (arg.name, axis_nr))

                kernel = _add_kernel_axis(kernel, axis_name, 0, arg.shape[axis_nr],
                        frozenset(sweep_inames) | fsub_dependencies)
                sweep_inames = sweep_inames + [axis_name]

                inames_to_be_removed.append(axis_name)
                new_fsub.append(Variable(axis_name))

            else:
                new_fsub.append(fsub_axis)

        new_footprint_subscripts.append(tuple(new_fsub))
        del new_fsub

    footprint_subscripts = new_footprint_subscripts
    del new_footprint_subscripts

    subst_use = [Variable(rule_name)(*si) for si in footprint_subscripts]
    return kernel, subst_use, sweep_inames, inames_to_be_removed
Esempio n. 55
0
File: data.py Progetto: shwina/loopy
 def tolerant_get_deps(expr):
     if expr is None or expr is lp.auto:
         return set()
     return get_dependencies(expr)
Esempio n. 56
0
def mark_local_temporaries(kernel):
    logger.debug("%s: mark local temporaries" % kernel.name)

    new_temp_vars = {}
    from loopy.kernel.data import LocalIndexTagBase
    import loopy as lp

    writers = kernel.writer_map()

    from loopy.symbolic import get_dependencies

    for temp_var in six.itervalues(kernel.temporary_variables):
        # Only fill out for variables that do not yet know if they're
        # local. (I.e. those generated by implicit temporary generation.)

        if temp_var.is_local is not lp.auto:
            new_temp_vars[temp_var.name] = temp_var
            continue

        my_writers = writers.get(temp_var.name, [])

        wants_to_be_local_per_insn = []
        for insn_id in my_writers:
            insn = kernel.id_to_insn[insn_id]

            # A write race will emerge if:
            #
            # - the variable is local
            #   and
            # - the instruction is run across more inames (locally) parallel
            #   than are reflected in the assignee indices.

            locparallel_compute_inames = set(iname
                    for iname in kernel.insn_inames(insn_id)
                    if isinstance(kernel.iname_to_tag.get(iname), LocalIndexTagBase))

            locparallel_assignee_inames = set(iname
                    for _, assignee_indices in insn.assignees_and_indices()
                    for iname in get_dependencies(assignee_indices)
                        & kernel.all_inames()
                    if isinstance(kernel.iname_to_tag.get(iname), LocalIndexTagBase))

            assert locparallel_assignee_inames <= locparallel_compute_inames

            if (locparallel_assignee_inames != locparallel_compute_inames
                    and bool(locparallel_assignee_inames)):
                warn(kernel, "write_race_local(%s)" % insn_id,
                        "instruction '%s' looks invalid: "
                        "it assigns to indices based on local IDs, but "
                        "its temporary '%s' cannot be made local because "
                        "a write race across the iname(s) '%s' would emerge. "
                        "(Do you need to add an extra iname to your prefetch?)"
                        % (insn_id, temp_var.name, ", ".join(
                            locparallel_compute_inames
                            - locparallel_assignee_inames)),
                        WriteRaceConditionWarning)

            wants_to_be_local_per_insn.append(
                    locparallel_assignee_inames == locparallel_compute_inames

                    # doesn't want to be local if there aren't any
                    # parallel inames:
                    and bool(locparallel_compute_inames))

        if not wants_to_be_local_per_insn:
            warn(kernel, "temp_to_write(%s)" % temp_var.name,
                    "temporary variable '%s' never written, eliminating"
                    % temp_var.name, LoopyAdvisory)

            continue

        is_local = any(wants_to_be_local_per_insn)

        from pytools import all
        if not all(wtbl == is_local for wtbl in wants_to_be_local_per_insn):
            raise LoopyError("not all instructions agree on whether "
                    "temporary '%s' should be in local memory" % temp_var.name)

        new_temp_vars[temp_var.name] = temp_var.copy(is_local=is_local)

    return kernel.copy(temporary_variables=new_temp_vars)
Esempio n. 57
0
    def map_subscript(self, expr):
        WalkMapper.map_subscript(self, expr)

        from pymbolic.primitives import Variable
        assert isinstance(expr.aggregate, Variable)

        shape = None
        var_name = expr.aggregate.name
        if var_name in self.kernel.arg_dict:
            arg = self.kernel.arg_dict[var_name]
            shape = arg.shape
        elif var_name in self.kernel.temporary_variables:
            tv = self.kernel.temporary_variables[var_name]
            shape = tv.shape

        if shape is not None:
            subscript = expr.index

            if not isinstance(subscript, tuple):
                subscript = (subscript,)

            from loopy.symbolic import get_dependencies, get_access_range

            available_vars = set(self.domain.get_var_dict())
            shape_deps = set()
            for shape_axis in shape:
                if shape_axis is not None:
                    shape_deps.update(get_dependencies(shape_axis))

            if not (get_dependencies(subscript) <= available_vars
                    and shape_deps <= available_vars):
                return

            if len(subscript) != len(shape):
                raise LoopyError("subscript to '%s' in '%s' has the wrong "
                        "number of indices (got: %d, expected: %d)" % (
                            expr.aggregate.name, expr,
                            len(subscript), len(shape)))

            try:
                access_range = get_access_range(self.domain, subscript,
                        self.kernel.assumptions)
            except isl.Error:
                # Likely: index was non-linear, nothing we can do.
                return
            except TypeError:
                # Likely: index was non-linear, nothing we can do.
                return

            shape_domain = isl.BasicSet.universe(access_range.get_space())
            for idim in range(len(subscript)):
                shape_axis = shape[idim]

                if shape_axis is not None:
                    from loopy.isl_helpers import make_slab
                    slab = make_slab(
                            shape_domain.get_space(), (dim_type.in_, idim),
                            0, shape_axis)

                    shape_domain = shape_domain.intersect(slab)

            if not access_range.is_subset(shape_domain):
                raise LoopyError("'%s' in instruction '%s' "
                        "accesses out-of-bounds array element"
                        % (expr, self.insn_id))
Esempio n. 58
0
    def map_Do(self, node):
        scope = self.scope_stack[-1]

        if not node.loopcontrol:
            raise NotImplementedError("unbounded do loop")

        loop_var, loop_bounds = node.loopcontrol.split("=")
        loop_var = loop_var.strip()

        iname_dtype = scope.get_type(loop_var)
        if self.index_dtype is None:
            self.index_dtype = iname_dtype
        else:
            if self.index_dtype != iname_dtype:
                raise LoopyError("type of '%s' (%s) does not agree with prior "
                        "index type (%s)"
                        % (loop_var, iname_dtype, self.index_dtype))

        scope.use_name(loop_var)
        loop_bounds = self.parse_expr(
                node,
                loop_bounds, min_precedence=self.expr_parser._PREC_FUNC_ARGS)

        if len(loop_bounds) == 2:
            start, stop = loop_bounds
            step = 1
        elif len(loop_bounds) == 3:
            start, stop, step = loop_bounds
        else:
            raise RuntimeError("loop bounds not understood: %s"
                    % node.loopcontrol)

        if step != 1:
            raise NotImplementedError(
                    "do loops with non-unit stride")

        if not isinstance(step, int):
            raise TranslationError(
                    "non-constant steps not supported: %s" % step)

        from loopy.symbolic import get_dependencies
        loop_bound_deps = (
                get_dependencies(start)
                | get_dependencies(stop)
                | get_dependencies(step))

        # {{{ find a usable loopy-side loop name

        loopy_loop_var = loop_var
        loop_var_suffix = None
        while True:
            already_used = False
            for iset in scope.index_sets:
                if loopy_loop_var in iset.get_var_dict(dim_type.set):
                    already_used = True
                    break

            if not already_used:
                break

            if loop_var_suffix is None:
                loop_var_suffix = 0

            loop_var_suffix += 1
            loopy_loop_var = loop_var + "_%d" % loop_var_suffix

        # }}}

        space = isl.Space.create_from_names(isl.DEFAULT_CONTEXT,
                set=[loopy_loop_var], params=list(loop_bound_deps))

        from loopy.isl_helpers import iname_rel_aff
        from loopy.symbolic import aff_from_expr
        index_set = (
                isl.BasicSet.universe(space)
                .add_constraint(
                    isl.Constraint.inequality_from_aff(
                        iname_rel_aff(space,
                            loopy_loop_var, ">=",
                            aff_from_expr(space, 0))))
                .add_constraint(
                    isl.Constraint.inequality_from_aff(
                        iname_rel_aff(space,
                            loopy_loop_var, "<=",
                            aff_from_expr(space, stop-start)))))

        from pymbolic import var
        scope.active_iname_aliases[loop_var] = \
                var(loopy_loop_var) + start
        scope.active_loopy_inames.add(loopy_loop_var)

        scope.index_sets.append(index_set)

        self.block_nest.append("do")

        for c in node.content:
            self.rec(c)

        del scope.active_iname_aliases[loop_var]
        scope.active_loopy_inames.remove(loopy_loop_var)
Esempio n. 59
0
def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
        store_expression=None, within=None, default_tag="l.auto",
        temporary_scope=None, temporary_is_local=None,
        fetch_bounding_box=False):
    """
    :arg init_expression: Either *None* (indicating the prior value of the buffered
        array should be read) or an expression optionally involving the
        variable 'base' (which references the associated location in the array
        being buffered).
    :arg store_expression: Either *None*, *False*, or an expression involving
        variables 'base' and 'buffer' (without array indices).
        (*None* indicates that a default storage instruction should be used,
        *False* indicates that no storing of the temporary should occur
        at all.)
    """

    # {{{ unify temporary_scope / temporary_is_local

    from loopy.kernel.data import temp_var_scope
    if temporary_is_local is not None:
        from warnings import warn
        warn("temporary_is_local is deprecated. Use temporary_scope instead",
                DeprecationWarning, stacklevel=2)

        if temporary_scope is not None:
            raise LoopyError("may not specify both temporary_is_local and "
                    "temporary_scope")

        if temporary_is_local:
            temporary_scope = temp_var_scope.LOCAL
        else:
            temporary_scope = temp_var_scope.PRIVATE

    del temporary_is_local

    # }}}

    # {{{ process arguments

    if isinstance(init_expression, str):
        from loopy.symbolic import parse
        init_expression = parse(init_expression)

    if isinstance(store_expression, str):
        from loopy.symbolic import parse
        store_expression = parse(store_expression)

    if isinstance(buffer_inames, str):
        buffer_inames = [s.strip()
                for s in buffer_inames.split(",") if s.strip()]

    for iname in buffer_inames:
        if iname not in kernel.all_inames():
            raise RuntimeError("sweep iname '%s' is not a known iname"
                    % iname)

    buffer_inames = list(buffer_inames)
    buffer_inames_set = frozenset(buffer_inames)

    from loopy.match import parse_stack_match
    within = parse_stack_match(within)

    if var_name in kernel.arg_dict:
        var_descr = kernel.arg_dict[var_name]
    elif var_name in kernel.temporary_variables:
        var_descr = kernel.temporary_variables[var_name]
    else:
        raise ValueError("variable '%s' not found" % var_name)

    from loopy.kernel.data import ArrayBase
    if isinstance(var_descr, ArrayBase):
        var_shape = var_descr.shape
    else:
        var_shape = ()

    if temporary_scope is None:
        import loopy as lp
        temporary_scope = lp.auto

    # }}}

    # {{{ caching

    from loopy import CACHING_ENABLED

    from loopy.preprocess import prepare_for_caching
    key_kernel = prepare_for_caching(kernel)
    cache_key = (key_kernel, var_name, tuple(buffer_inames),
            PymbolicExpressionHashWrapper(init_expression),
            PymbolicExpressionHashWrapper(store_expression), within,
            default_tag, temporary_scope, fetch_bounding_box)

    if CACHING_ENABLED:
        try:
            result = buffer_array_cache[cache_key]
            logger.info("%s: buffer_array cache hit" % kernel.name)
            return result
        except KeyError:
            pass

    # }}}

    var_name_gen = kernel.get_var_name_generator()
    within_inames = set()

    access_descriptors = []
    for insn in kernel.instructions:
        if not within(kernel, insn.id, ()):
            continue

        from pymbolic.primitives import Variable, Subscript
        from loopy.symbolic import LinearSubscript

        for assignee in insn.assignees:
            if isinstance(assignee, Variable):
                assignee_name = assignee.name
                index = ()

            elif isinstance(assignee, Subscript):
                assignee_name = assignee.aggregate.name
                index = assignee.index_tuple

            elif isinstance(assignee, LinearSubscript):
                if assignee.aggregate.name == var_name:
                    raise LoopyError("buffer_array may not be applied in the "
                            "presence of linear write indexing into '%s'" % var_name)

            else:
                raise LoopyError("invalid lvalue '%s'" % assignee)

            if assignee_name == var_name:
                within_inames.update(
                        (get_dependencies(index) & kernel.all_inames())
                        - buffer_inames_set)
                access_descriptors.append(
                        AccessDescriptor(
                            identifier=insn.id,
                            storage_axis_exprs=index))

    # {{{ find fetch/store inames

    init_inames = []
    store_inames = []
    new_iname_to_tag = {}

    for i in range(len(var_shape)):
        dim_name = str(i)
        if isinstance(var_descr, ArrayBase) and var_descr.dim_names is not None:
            dim_name = var_descr.dim_names[i]

        init_iname = var_name_gen("%s_init_%s" % (var_name, dim_name))
        store_iname = var_name_gen("%s_store_%s" % (var_name, dim_name))

        new_iname_to_tag[init_iname] = default_tag
        new_iname_to_tag[store_iname] = default_tag

        init_inames.append(init_iname)
        store_inames.append(store_iname)

    # }}}

    # {{{ modify loop domain

    non1_init_inames = []
    non1_store_inames = []

    if var_shape:
        # {{{ find domain to be changed

        from loopy.kernel.tools import DomainChanger
        domch = DomainChanger(kernel, buffer_inames_set | within_inames)

        if domch.leaf_domain_index is not None:
            # If the sweep inames are at home in parent domains, then we'll add
            # fetches with loops over copies of these parent inames that will end
            # up being scheduled *within* loops over these parents.

            for iname in buffer_inames_set:
                if kernel.get_home_domain_index(iname) != domch.leaf_domain_index:
                    raise RuntimeError("buffer iname '%s' is not 'at home' in the "
                            "sweep's leaf domain" % iname)

        # }}}

        abm = ArrayToBufferMap(kernel, domch.domain, buffer_inames,
                access_descriptors, len(var_shape))

        for i in range(len(var_shape)):
            if abm.non1_storage_axis_flags[i]:
                non1_init_inames.append(init_inames[i])
                non1_store_inames.append(store_inames[i])
            else:
                del new_iname_to_tag[init_inames[i]]
                del new_iname_to_tag[store_inames[i]]

        new_domain = domch.domain
        new_domain = abm.augment_domain_with_sweep(
                    new_domain, non1_init_inames,
                    boxify_sweep=fetch_bounding_box)
        new_domain = abm.augment_domain_with_sweep(
                    new_domain, non1_store_inames,
                    boxify_sweep=fetch_bounding_box)
        new_kernel_domains = domch.get_domains_with(new_domain)
        del new_domain

    else:
        # leave kernel domains unchanged
        new_kernel_domains = kernel.domains

        abm = NoOpArrayToBufferMap()

    # }}}

    # {{{ set up temp variable

    import loopy as lp

    buf_var_name = var_name_gen(based_on=var_name+"_buf")

    new_temporary_variables = kernel.temporary_variables.copy()
    temp_var = lp.TemporaryVariable(
            name=buf_var_name,
            dtype=var_descr.dtype,
            base_indices=(0,)*len(abm.non1_storage_shape),
            shape=tuple(abm.non1_storage_shape),
            scope=temporary_scope)

    new_temporary_variables[buf_var_name] = temp_var

    # }}}

    new_insns = []

    buf_var = var(buf_var_name)

    # {{{ generate init instruction

    buf_var_init = buf_var
    if non1_init_inames:
        buf_var_init = buf_var_init.index(
                tuple(var(iname) for iname in non1_init_inames))

    init_base = var(var_name)

    init_subscript = []
    init_iname_idx = 0
    if var_shape:
        for i in range(len(var_shape)):
            ax_subscript = abm.storage_base_indices[i]
            if abm.non1_storage_axis_flags[i]:
                ax_subscript += var(non1_init_inames[init_iname_idx])
                init_iname_idx += 1
            init_subscript.append(ax_subscript)

    if init_subscript:
        init_base = init_base.index(tuple(init_subscript))

    if init_expression is None:
        init_expression = init_base
    else:
        init_expression = init_expression
        init_expression = SubstitutionMapper(
                make_subst_func({
                    "base": init_base,
                    }))(init_expression)

    init_insn_id = kernel.make_unique_instruction_id(based_on="init_"+var_name)
    from loopy.kernel.data import Assignment
    init_instruction = Assignment(id=init_insn_id,
                assignee=buf_var_init,
                expression=init_expression,
                forced_iname_deps=(
                    frozenset(within_inames)
                    | frozenset(non1_init_inames)),
                depends_on=frozenset(),
                depends_on_is_final=True)

    # }}}

    rule_mapping_context = SubstitutionRuleMappingContext(
            kernel.substitutions, kernel.get_var_name_generator())
    aar = ArrayAccessReplacer(rule_mapping_context, var_name,
            within, abm, buf_var)
    kernel = rule_mapping_context.finish_kernel(aar.map_kernel(kernel))

    did_write = False
    for insn_id in aar.modified_insn_ids:
        insn = kernel.id_to_insn[insn_id]
        if buf_var_name in insn.assignee_var_names():
            did_write = True

    # {{{ add init_insn_id to depends_on

    new_insns = []

    def none_to_empty_set(s):
        if s is None:
            return frozenset()
        else:
            return s

    for insn in kernel.instructions:
        if insn.id in aar.modified_insn_ids:
            new_insns.append(
                    insn.copy(
                        depends_on=(
                            none_to_empty_set(insn.depends_on)
                            | frozenset([init_insn_id]))))
        else:
            new_insns.append(insn)

    # }}}

    # {{{ generate store instruction

    buf_var_store = buf_var
    if non1_store_inames:
        buf_var_store = buf_var_store.index(
                tuple(var(iname) for iname in non1_store_inames))

    store_subscript = []
    store_iname_idx = 0
    if var_shape:
        for i in range(len(var_shape)):
            ax_subscript = abm.storage_base_indices[i]
            if abm.non1_storage_axis_flags[i]:
                ax_subscript += var(non1_store_inames[store_iname_idx])
                store_iname_idx += 1
            store_subscript.append(ax_subscript)

    store_target = var(var_name)
    if store_subscript:
        store_target = store_target.index(tuple(store_subscript))

    if store_expression is None:
        store_expression = buf_var_store
    else:
        store_expression = SubstitutionMapper(
                make_subst_func({
                    "base": store_target,
                    "buffer": buf_var_store,
                    }))(store_expression)

    if store_expression is not False:
        from loopy.kernel.data import Assignment
        store_instruction = Assignment(
                    id=kernel.make_unique_instruction_id(based_on="store_"+var_name),
                    depends_on=frozenset(aar.modified_insn_ids),
                    no_sync_with=frozenset([init_insn_id]),
                    assignee=store_target,
                    expression=store_expression,
                    forced_iname_deps=(
                        frozenset(within_inames)
                        | frozenset(non1_store_inames)))
    else:
        did_write = False

    # }}}

    new_insns.append(init_instruction)
    if did_write:
        new_insns.append(store_instruction)
    else:
        for iname in store_inames:
            del new_iname_to_tag[iname]

    kernel = kernel.copy(
            domains=new_kernel_domains,
            instructions=new_insns,
            temporary_variables=new_temporary_variables)

    from loopy import tag_inames
    kernel = tag_inames(kernel, new_iname_to_tag)

    from loopy.kernel.tools import assign_automatic_axes
    kernel = assign_automatic_axes(kernel)

    if CACHING_ENABLED:
        from loopy.preprocess import prepare_for_caching
        buffer_array_cache[cache_key] = prepare_for_caching(kernel)

    return kernel
Esempio n. 60
0
def extract_subst(kernel, subst_name, template, parameters=()):
    """
    :arg subst_name: The name of the substitution rule to be created.
    :arg template: Unification template expression.
    :arg parameters: An iterable of parameters used in
        *template*, or a comma-separated string of the same.

    All targeted subexpressions must match ('unify with') *template*
    The template may contain '*' wildcards that will have to match exactly across all
    unifications.
    """

    if isinstance(template, str):
        from pymbolic import parse
        template = parse(template)

    if isinstance(parameters, str):
        parameters = tuple(
                s.strip() for s in parameters.split(","))

    var_name_gen = kernel.get_var_name_generator()

    # {{{ replace any wildcards in template with new variables

    def get_unique_var_name():
        based_on = subst_name+"_wc"

        result = var_name_gen(based_on)
        return result

    from loopy.symbolic import WildcardToUniqueVariableMapper
    wc_map = WildcardToUniqueVariableMapper(get_unique_var_name)
    template = wc_map(template)

    # }}}

    # {{{ deal with iname deps of template that are not independent_inames

    # (We call these 'matching_vars', because they have to match exactly in
    # every CSE. As above, they might need to be renamed to make them unique
    # within the kernel.)

    matching_vars = []
    old_to_new = {}

    for iname in (get_dependencies(template)
            - set(parameters)
            - kernel.non_iname_variable_names()):
        if iname in kernel.all_inames():
            # need to rename to be unique
            new_iname = var_name_gen(iname)
            old_to_new[iname] = var(new_iname)
            matching_vars.append(new_iname)
        else:
            matching_vars.append(iname)

    if old_to_new:
        template = (
                SubstitutionMapper(make_subst_func(old_to_new))
                (template))

    # }}}

    # {{{ gather up expressions

    expr_descriptors = []

    from loopy.symbolic import UnidirectionalUnifier
    unif = UnidirectionalUnifier(
            lhs_mapping_candidates=set(parameters) | set(matching_vars))

    def gather_exprs(expr, mapper):
        urecs = unif(template, expr)

        if urecs:
            if len(urecs) > 1:
                raise RuntimeError("ambiguous unification of '%s' with template '%s'"
                        % (expr, template))

            urec, = urecs

            expr_descriptors.append(
                    ExprDescriptor(
                        insn=insn,
                        expr=expr,
                        unif_var_dict=dict((lhs.name, rhs)
                            for lhs, rhs in urec.equations)))
        else:
            mapper.fallback_mapper(expr)
            # can't nest, don't recurse

    from loopy.symbolic import (
            CallbackMapper, WalkMapper, IdentityMapper)
    dfmapper = CallbackMapper(gather_exprs, WalkMapper())

    for insn in kernel.instructions:
        dfmapper(insn.assignees)
        dfmapper(insn.expression)

    for sr in six.itervalues(kernel.substitutions):
        dfmapper(sr.expression)

    # }}}

    if not expr_descriptors:
        raise RuntimeError("no expressions matching '%s'" % template)

    # {{{ substitute rule into instructions

    def replace_exprs(expr, mapper):
        found = False
        for exprd in expr_descriptors:
            if expr is exprd.expr:
                found = True
                break

        if not found:
            return mapper.fallback_mapper(expr)

        args = [exprd.unif_var_dict[arg_name]
                for arg_name in parameters]

        result = var(subst_name)
        if args:
            result = result(*args)

        return result
        # can't nest, don't recurse

    cbmapper = CallbackMapper(replace_exprs, IdentityMapper())

    new_insns = []

    for insn in kernel.instructions:
        new_insns.append(insn.with_transformed_expressions(cbmapper))

    from loopy.kernel.data import SubstitutionRule
    new_substs = {
            subst_name: SubstitutionRule(
                name=subst_name,
                arguments=tuple(parameters),
                expression=template,
                )}

    for subst in six.itervalues(kernel.substitutions):
        new_substs[subst.name] = subst.copy(
                expression=cbmapper(subst.expression))

    # }}}

    return kernel.copy(
            instructions=new_insns,
            substitutions=new_substs)