Exemplo n.º 1
0
def simplify_via_aff(expr):
    from loopy.symbolic import aff_from_expr, aff_to_expr, get_dependencies
    deps = get_dependencies(expr)
    return aff_to_expr(
        aff_from_expr(
            isl.Space.create_from_names(isl.DEFAULT_CONTEXT, list(deps)),
            expr))
Exemplo n.º 2
0
def augment_domain_for_temporary_promotion(
        kernel, domain, promoted_temporary, mode, name_gen):
    """
    Add new axes to the domain corresponding to the dimensions of
    `promoted_temporary`.
    """
    import islpy as isl

    orig_temporary = promoted_temporary.orig_temporary
    orig_dim = domain.dim(isl.dim_type.set)
    dims_to_insert = len(orig_temporary.shape)

    iname_to_tag = {}

    # Add dimension-dependent inames.
    dim_inames = []

    domain = domain.add(isl.dim_type.set, dims_to_insert)
    for t_idx in range(len(orig_temporary.shape)):
        new_iname = name_gen("{name}_{mode}_dim_{dim}".
            format(name=orig_temporary.name,
                   mode=mode,
                   dim=orig_dim + t_idx))
        domain = domain.set_dim_name(
            isl.dim_type.set, orig_dim + t_idx, new_iname)
        #from loopy.kernel.data import auto
        #iname_to_tag[new_iname] = auto
        dim_inames.append(new_iname)

        # Add size information.
        aff = isl.affs_from_space(domain.space)
        domain &= aff[0].le_set(aff[new_iname])
        size = orig_temporary.shape[t_idx]
        from loopy.symbolic import aff_from_expr
        domain &= aff[new_iname].le_set(aff_from_expr(domain.space, size))

    hw_inames = []

    # Add hardware inames duplicates.
    for t_idx, hw_iname in enumerate(promoted_temporary.hw_inames):
        new_iname = name_gen("{name}_{mode}_hw_dim_{dim}".
            format(name=orig_temporary.name,
                   mode=mode,
                   dim=t_idx))
        hw_inames.append(new_iname)
        iname_to_tag[new_iname] = kernel.iname_to_tag[hw_iname]

    from loopy.isl_helpers import duplicate_axes
    domain = duplicate_axes(
        domain, promoted_temporary.hw_inames, hw_inames)

    # The operations on the domain above return a Set object, but the
    # underlying domain should be expressible as a single BasicSet.
    domain_list = domain.get_basic_set_list()
    assert domain_list.n_basic_set() == 1
    domain = domain_list.get_basic_set(0)
    return domain, hw_inames, dim_inames, iname_to_tag
Exemplo n.º 3
0
def augment_domain_for_temporary_promotion(kernel, domain, promoted_temporary,
                                           mode, name_gen):
    """
    Add new axes to the domain corresponding to the dimensions of
    `promoted_temporary`.
    """
    import islpy as isl

    orig_temporary = promoted_temporary.orig_temporary
    orig_dim = domain.dim(isl.dim_type.set)
    dims_to_insert = len(orig_temporary.shape)

    iname_to_tag = {}

    # Add dimension-dependent inames.
    dim_inames = []

    domain = domain.add(isl.dim_type.set, dims_to_insert)
    for t_idx in range(len(orig_temporary.shape)):
        new_iname = name_gen("{name}_{mode}_dim_{dim}".format(
            name=orig_temporary.name, mode=mode, dim=t_idx))
        domain = domain.set_dim_name(isl.dim_type.set, orig_dim + t_idx,
                                     new_iname)
        if orig_temporary.is_local:
            # If the temporary is has local scope, then loads / stores can be
            # done in parallel.
            from loopy.kernel.data import AutoFitLocalIndexTag
            iname_to_tag[new_iname] = AutoFitLocalIndexTag()

        dim_inames.append(new_iname)

        # Add size information.
        aff = isl.affs_from_space(domain.space)
        domain &= aff[0].le_set(aff[new_iname])
        size = orig_temporary.shape[t_idx]
        from loopy.symbolic import aff_from_expr
        domain &= aff[new_iname].lt_set(aff_from_expr(domain.space, size))

    hw_inames = []

    # Add hardware inames duplicates.
    for t_idx, hw_iname in enumerate(promoted_temporary.hw_inames):
        new_iname = name_gen("{name}_{mode}_hw_dim_{dim}".format(
            name=orig_temporary.name, mode=mode, dim=t_idx))
        hw_inames.append(new_iname)
        iname_to_tag[new_iname] = kernel.iname_to_tag[hw_iname]

    from loopy.isl_helpers import duplicate_axes
    domain = duplicate_axes(domain, promoted_temporary.hw_inames, hw_inames)

    # The operations on the domain above return a Set object, but the
    # underlying domain should be expressible as a single BasicSet.
    domain_list = domain.get_basic_set_list()
    assert domain_list.n_basic_set() == 1
    domain = domain_list.get_basic_set(0)
    return domain, hw_inames, dim_inames, iname_to_tag
Exemplo n.º 4
0
def is_nonnegative(expr, over_set):
    space = over_set.get_space()
    from loopy.symbolic import aff_from_expr
    try:
        aff = aff_from_expr(space, -expr - 1)
    except Exception:
        return None
    expr_neg_set = isl.BasicSet.universe(space).add_constraint(
        isl.Constraint.inequality_from_aff(aff))

    return over_set.intersect(expr_neg_set).is_empty()
Exemplo n.º 5
0
def is_nonnegative(expr, over_set):
    space = over_set.get_space()
    from loopy.symbolic import aff_from_expr
    try:
        aff = aff_from_expr(space, -expr-1)
    except:
        return None
    expr_neg_set = isl.BasicSet.universe(space).add_constraint(
            isl.Constraint.inequality_from_aff(aff))

    return over_set.intersect(expr_neg_set).is_empty()
Exemplo n.º 6
0
def make_slab(space, iname, start, stop):
    zero = isl.Aff.zero_on_domain(space)

    if isinstance(start, (isl.Aff, isl.PwAff)):
        start, zero = isl.align_two(pw_aff_to_aff(start), zero)
    if isinstance(stop, (isl.Aff, isl.PwAff)):
        stop, zero = isl.align_two(pw_aff_to_aff(stop), zero)

    space = zero.get_domain_space()

    from pymbolic.primitives import Expression
    from loopy.symbolic import aff_from_expr
    if isinstance(start, Expression):
        start = aff_from_expr(space, start)
    if isinstance(stop, Expression):
        stop = aff_from_expr(space, stop)

    if isinstance(start, int):
        start = zero + start
    if isinstance(stop, int):
        stop = zero + stop

    if isinstance(iname, str):
        iname_dt, iname_idx = zero.get_space().get_var_dict()[iname]
    else:
        iname_dt, iname_idx = iname

    iname_aff = zero.add_coefficient_val(iname_dt, iname_idx, 1)

    result = (isl.BasicSet.universe(space)
            # start <= iname
            .add_constraint(isl.Constraint.inequality_from_aff(
                iname_aff - start))
            # iname < stop
            .add_constraint(isl.Constraint.inequality_from_aff(
                stop-1 - iname_aff)))

    return result
Exemplo n.º 7
0
def make_slab(space, iname, start, stop):
    zero = isl.Aff.zero_on_domain(space)

    if isinstance(start, (isl.Aff, isl.PwAff)):
        start, zero = isl.align_two(pw_aff_to_aff(start), zero)
    if isinstance(stop, (isl.Aff, isl.PwAff)):
        stop, zero = isl.align_two(pw_aff_to_aff(stop), zero)

    space = zero.get_domain_space()

    from pymbolic.primitives import Expression
    from loopy.symbolic import aff_from_expr
    if isinstance(start, Expression):
        start = aff_from_expr(space, start)
    if isinstance(stop, Expression):
        stop = aff_from_expr(space, stop)

    if isinstance(start, int):
        start = zero + start
    if isinstance(stop, int):
        stop = zero + stop

    if isinstance(iname, str):
        iname_dt, iname_idx = zero.get_space().get_var_dict()[iname]
    else:
        iname_dt, iname_idx = iname

    iname_aff = zero.add_coefficient_val(iname_dt, iname_idx, 1)

    result = (
        isl.BasicSet.universe(space)
        # start <= iname
        .add_constraint(isl.Constraint.inequality_from_aff(iname_aff - start))
        # iname < stop
        .add_constraint(
            isl.Constraint.inequality_from_aff(stop - 1 - iname_aff)))

    return result
Exemplo n.º 8
0
def solve_constraints(
    variables: Iterable[str],
    parameters: Iterable[str],
    constraints: Sequence[Tuple[ScalarExpression, ScalarExpression]],
) -> Mapping[str, ScalarExpression]:
    """
    :arg variables: Names of the variables to solve for
    :arg parameters: Names of the parameters that to express that are allowed
        to be a part of the solution expressions.
    :arg constraints: A :class:`list` of constraints. Each constraint is
        represented as a tuple ``(lhs, rhs)``, that corresponds to the
        constraint ``lhs = rhs``. ``lhs`` and ``rhs`` are quasi-affine
        expressions in *variables* and *constraints*.
    :returns: A mapping from variable name in *variables* to
        :class:`ScalarExpression` obtained after solving for them.
    """
    from loopy.symbolic import aff_from_expr

    space = isl.Space.create_from_names(isl.DEFAULT_CONTEXT,
                                        set=variables,
                                        params=parameters)

    shape_inference_bset = isl.BasicSet.universe(space)

    for lhs, rhs in constraints:
        # type-ignored reason: no "(-)" support for Number
        aff = aff_from_expr(space, lhs - rhs)  # type: ignore

        shape_inference_bset = (shape_inference_bset.add_constraint(
            isl.Constraint.equality_from_aff(aff)))

    if shape_inference_bset.is_empty():
        raise ShapeInferenceFailure

    solution = {}

    # {{{ get the value of each unknown variable

    for idim in range(shape_inference_bset.dim(isl.dim_type.set)):
        arg_name = shape_inference_bset.get_dim_name(isl.dim_type.set, idim)
        solved_val = _get_val_in_bset(shape_inference_bset, idim)
        solution[arg_name] = solved_val

    # }}}

    return solution
Exemplo n.º 9
0
def build_per_access_storage_to_domain_map(storage_axis_exprs, domain,
        storage_axis_names,
        prime_sweep_inames):

    map_space = domain.space
    stor_dim = len(storage_axis_names)
    rn = map_space.dim(dim_type.out)

    map_space = map_space.add_dims(dim_type.in_, stor_dim)
    for i, saxis in enumerate(storage_axis_names):
        # arg names are initially primed, to be replaced with unprimed
        # base-0 versions below

        map_space = map_space.set_dim_name(dim_type.in_, i, saxis+"'")

    # map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn)

    set_space = map_space.move_dims(
            dim_type.out, rn,
            dim_type.in_, 0, stor_dim).range()

    # set_space: [domain](dup_sweep_index)[dup_sweep](rn)[stor_axes']

    stor2sweep = None

    from loopy.symbolic import aff_from_expr

    for saxis, sa_expr in zip(storage_axis_names, storage_axis_exprs):
        cns = isl.Constraint.equality_from_aff(
                aff_from_expr(set_space,
                    var(saxis+"'") - prime_sweep_inames(sa_expr)))

        cns_map = isl.BasicMap.from_constraint(cns)
        if stor2sweep is None:
            stor2sweep = cns_map
        else:
            stor2sweep = stor2sweep.intersect(cns_map)

    if stor2sweep is not None:
        stor2sweep = stor2sweep.move_dims(
                dim_type.in_, 0,
                dim_type.out, rn, stor_dim)

    # stor2sweep is back in map_space
    return stor2sweep
Exemplo n.º 10
0
def build_per_access_storage_to_domain_map(storage_axis_exprs, domain,
        storage_axis_names,
        prime_sweep_inames):

    map_space = domain.space
    stor_dim = len(storage_axis_names)
    rn = map_space.dim(dim_type.out)

    map_space = map_space.add_dims(dim_type.in_, stor_dim)
    for i, saxis in enumerate(storage_axis_names):
        # arg names are initially primed, to be replaced with unprimed
        # base-0 versions below

        map_space = map_space.set_dim_name(dim_type.in_, i, saxis+"'")

    # map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn)

    set_space = map_space.move_dims(
            dim_type.out, rn,
            dim_type.in_, 0, stor_dim).range()

    # set_space: [domain](dup_sweep_index)[dup_sweep](rn)[stor_axes']

    stor2sweep = None

    from loopy.symbolic import aff_from_expr

    for saxis, sa_expr in zip(storage_axis_names, storage_axis_exprs):
        cns = isl.Constraint.equality_from_aff(
                aff_from_expr(set_space,
                    var(saxis+"'") - prime_sweep_inames(sa_expr)))

        cns_map = isl.BasicMap.from_constraint(cns)
        if stor2sweep is None:
            stor2sweep = cns_map
        else:
            stor2sweep = stor2sweep.intersect(cns_map)

    if stor2sweep is not None:
        stor2sweep = stor2sweep.move_dims(
                dim_type.in_, 0,
                dim_type.out, rn, stor_dim)

    # stor2sweep is back in map_space
    return stor2sweep
Exemplo n.º 11
0
def subst_into_pwaff(new_space, pwaff, subst_dict):
    """
    Returns an instance of :class:`islpy.PwAff` with substitutions from
    *subst_dict* substituted into *pwaff*.

    :arg pwaff: an instance of :class:`islpy.PwAff`
    :arg subst_dict: a mapping from parameters of *pwaff* to
        :class:`pymbolic.primitives.Expression` made up of terms comprising the
        parameters of *new_space*. The expression must be affine in the param
        dims of *new_space*.
    """
    from pymbolic.mapper.substitutor import (SubstitutionMapper,
                                             make_subst_func)
    from loopy.symbolic import aff_from_expr, aff_to_expr
    from functools import reduce

    i_begin_subst_space = pwaff.dim(dim_type.param)
    pwaff, subst_domain, subst_dict = get_param_subst_domain(
        new_space, pwaff, subst_dict)
    subst_mapper = SubstitutionMapper(make_subst_func(subst_dict))
    pwaffs = []

    for valid_set, qpoly in pwaff.get_pieces():
        valid_set = valid_set & subst_domain
        if valid_set.plain_is_empty():
            continue

        valid_set = valid_set.project_out(dim_type.param, 0,
                                          i_begin_subst_space)
        aff = aff_from_expr(valid_set.space, subst_mapper(aff_to_expr(qpoly)))

        pwaffs.append(isl.PwAff.alloc(valid_set, aff))

    if not pwaffs:
        raise ValueError("no pieces of PwAff survived the substitution")

    return reduce(lambda pwaff1, pwaff2: pwaff1.union_add(pwaff2),
                  pwaffs).coalesce()
Exemplo n.º 12
0
    def __init__(self, kernel, domain, sweep_inames, access_descriptors,
                 storage_axis_count):
        self.kernel = kernel
        self.sweep_inames = sweep_inames

        storage_axis_names = self.storage_axis_names = [
            "_loopy_storage_%d" % i for i in range(storage_axis_count)
        ]

        # {{{ duplicate sweep inames

        # The duplication is necessary, otherwise the storage fetch
        # inames remain weirdly tied to the original sweep inames.

        self.primed_sweep_inames = [psin + "'" for psin in sweep_inames]

        from loopy.isl_helpers import duplicate_axes
        dup_sweep_index = domain.space.dim(dim_type.out)
        domain_dup_sweep = duplicate_axes(domain, sweep_inames,
                                          self.primed_sweep_inames)

        self.prime_sweep_inames = SubstitutionMapper(
            make_subst_func({
                sin: var(psin)
                for sin, psin in zip(sweep_inames, self.primed_sweep_inames)
            }))

        # # }}}

        self.stor2sweep = build_global_storage_to_sweep_map(
            kernel, access_descriptors, domain_dup_sweep, dup_sweep_index,
            storage_axis_names, sweep_inames, self.primed_sweep_inames,
            self.prime_sweep_inames)

        storage_base_indices, storage_shape = compute_bounds(
            kernel, domain, self.stor2sweep, self.primed_sweep_inames,
            storage_axis_names)

        # compute augmented domain

        # {{{ filter out unit-length dimensions

        non1_storage_axis_flags = []
        non1_storage_shape = []

        for saxis_len in storage_shape:
            has_length_non1 = saxis_len != 1

            non1_storage_axis_flags.append(has_length_non1)

            if has_length_non1:
                non1_storage_shape.append(saxis_len)

        # }}}

        # {{{ subtract off the base indices
        # add the new, base-0 indices as new in dimensions

        sp = self.stor2sweep.get_space()
        stor_idx = sp.dim(dim_type.out)

        n_stor = storage_axis_count
        nn1_stor = len(non1_storage_shape)

        aug_domain = self.stor2sweep.move_dims(dim_type.out, stor_idx,
                                               dim_type.in_, 0,
                                               n_stor).range()

        # aug_domain space now:
        # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes']

        aug_domain = aug_domain.insert_dims(dim_type.set, stor_idx, nn1_stor)

        inew = 0
        for i, name in enumerate(storage_axis_names):
            if non1_storage_axis_flags[i]:
                aug_domain = aug_domain.set_dim_name(dim_type.set,
                                                     stor_idx + inew, name)
                inew += 1

        # aug_domain space now:
        # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes'][n1_stor_axes]

        from loopy.symbolic import aff_from_expr
        for saxis, bi, s in zip(storage_axis_names, storage_base_indices,
                                storage_shape):
            if s != 1:
                cns = isl.Constraint.equality_from_aff(
                    aff_from_expr(aug_domain.get_space(),
                                  var(saxis) - (var(saxis + "'") - bi)))

                aug_domain = aug_domain.add_constraint(cns)

        # }}}

        # eliminate (primed) storage axes with non-zero base indices
        aug_domain = aug_domain.project_out(dim_type.set, stor_idx + nn1_stor,
                                            n_stor)

        # eliminate duplicated sweep_inames
        nsweep = len(sweep_inames)
        aug_domain = aug_domain.project_out(dim_type.set, dup_sweep_index,
                                            nsweep)

        self.non1_storage_axis_flags = non1_storage_axis_flags
        self.aug_domain = aug_domain
        self.storage_base_indices = storage_base_indices
        self.non1_storage_shape = non1_storage_shape
Exemplo n.º 13
0
    def map_Do(self, node):
        scope = self.scope_stack[-1]

        if not node.loopcontrol:
            raise NotImplementedError("unbounded do loop")

        loop_var, loop_bounds = node.loopcontrol.split("=")
        loop_var = loop_var.strip()

        iname_dtype = scope.get_type(loop_var)
        if self.index_dtype is None:
            self.index_dtype = iname_dtype
        else:
            if self.index_dtype != iname_dtype:
                raise LoopyError("type of '%s' (%s) does not agree with prior "
                        "index type (%s)"
                        % (loop_var, iname_dtype, self.index_dtype))

        scope.use_name(loop_var)
        loop_bounds = self.parse_expr(
                node,
                loop_bounds, min_precedence=self.expr_parser._PREC_FUNC_ARGS)

        if len(loop_bounds) == 2:
            start, stop = loop_bounds
            step = 1
        elif len(loop_bounds) == 3:
            start, stop, step = loop_bounds
        else:
            raise RuntimeError("loop bounds not understood: %s"
                    % node.loopcontrol)

        if step != 1:
            raise NotImplementedError(
                    "do loops with non-unit stride")

        if not isinstance(step, int):
            raise TranslationError(
                    "non-constant steps not supported: %s" % step)

        from loopy.symbolic import get_dependencies
        loop_bound_deps = (
                get_dependencies(start)
                | get_dependencies(stop)
                | get_dependencies(step))

        # {{{ find a usable loopy-side loop name

        loopy_loop_var = loop_var
        loop_var_suffix = None
        while True:
            already_used = False
            for iset in scope.index_sets:
                if loopy_loop_var in iset.get_var_dict(dim_type.set):
                    already_used = True
                    break

            if not already_used:
                break

            if loop_var_suffix is None:
                loop_var_suffix = 0

            loop_var_suffix += 1
            loopy_loop_var = loop_var + "_%d" % loop_var_suffix

        loopy_loop_var = intern(loopy_loop_var)

        # }}}

        space = isl.Space.create_from_names(isl.DEFAULT_CONTEXT,
                set=[loopy_loop_var], params=list(loop_bound_deps))

        from loopy.isl_helpers import iname_rel_aff
        from loopy.symbolic import aff_from_expr
        index_set = (
                isl.BasicSet.universe(space)
                .add_constraint(
                    isl.Constraint.inequality_from_aff(
                        iname_rel_aff(space,
                            loopy_loop_var, ">=",
                            aff_from_expr(space, 0))))
                .add_constraint(
                    isl.Constraint.inequality_from_aff(
                        iname_rel_aff(space,
                            loopy_loop_var, "<=",
                            aff_from_expr(space, stop-start)))))

        from pymbolic import var
        scope.active_iname_aliases[loop_var] = \
                var(loopy_loop_var) + start
        scope.active_loopy_inames.add(loopy_loop_var)

        scope.index_sets.append(index_set)

        self.block_nest.append("do")

        for c in node.content:
            self.rec(c)

        del scope.active_iname_aliases[loop_var]
        scope.active_loopy_inames.remove(loopy_loop_var)
Exemplo n.º 14
0
    def map_Do(self, node):
        scope = self.scope_stack[-1]

        if not node.loopcontrol:
            raise NotImplementedError("unbounded do loop")

        loop_var, loop_bounds = node.loopcontrol.split("=")
        loop_var = loop_var.strip()

        iname_dtype = scope.get_type(loop_var)
        if self.index_dtype is None:
            self.index_dtype = iname_dtype
        else:
            if self.index_dtype != iname_dtype:
                raise LoopyError("type of '%s' (%s) does not agree with prior "
                        "index type (%s)"
                        % (loop_var, iname_dtype, self.index_dtype))

        scope.use_name(loop_var)
        loop_bounds = self.parse_expr(
                node,
                loop_bounds, min_precedence=self.expr_parser._PREC_FUNC_ARGS)

        if len(loop_bounds) == 2:
            start, stop = loop_bounds
            step = 1
        elif len(loop_bounds) == 3:
            start, stop, step = loop_bounds
        else:
            raise RuntimeError("loop bounds not understood: %s"
                    % node.loopcontrol)

        if step != 1:
            raise NotImplementedError(
                    "do loops with non-unit stride")

        if not isinstance(step, int):
            raise TranslationError(
                    "non-constant steps not supported: %s" % step)

        from loopy.symbolic import get_dependencies
        loop_bound_deps = (
                get_dependencies(start)
                | get_dependencies(stop)
                | get_dependencies(step))

        # {{{ find a usable loopy-side loop name

        loopy_loop_var = loop_var
        loop_var_suffix = None
        while True:
            already_used = False
            for iset in scope.index_sets:
                if loopy_loop_var in iset.get_var_dict(dim_type.set):
                    already_used = True
                    break

            if not already_used:
                break

            if loop_var_suffix is None:
                loop_var_suffix = 0

            loop_var_suffix += 1
            loopy_loop_var = loop_var + "_%d" % loop_var_suffix

        # }}}

        space = isl.Space.create_from_names(isl.DEFAULT_CONTEXT,
                set=[loopy_loop_var], params=list(loop_bound_deps))

        from loopy.isl_helpers import iname_rel_aff
        from loopy.symbolic import aff_from_expr
        index_set = (
                isl.BasicSet.universe(space)
                .add_constraint(
                    isl.Constraint.inequality_from_aff(
                        iname_rel_aff(space,
                            loopy_loop_var, ">=",
                            aff_from_expr(space, 0))))
                .add_constraint(
                    isl.Constraint.inequality_from_aff(
                        iname_rel_aff(space,
                            loopy_loop_var, "<=",
                            aff_from_expr(space, stop-start)))))

        from pymbolic import var
        scope.active_iname_aliases[loop_var] = \
                var(loopy_loop_var) + start
        scope.active_loopy_inames.add(loopy_loop_var)

        scope.index_sets.append(index_set)

        self.block_nest.append("do")

        for c in node.content:
            self.rec(c)

        del scope.active_iname_aliases[loop_var]
        scope.active_loopy_inames.remove(loopy_loop_var)
Exemplo n.º 15
0
def domain_for_shape(
    dim_names: Tuple[str, ...],
    shape: ShapeType,
    reductions: Dict[str, Tuple[ScalarExpression, ScalarExpression]],
) -> isl.BasicSet:  # noqa
    """Create an :class:`islpy.BasicSet` that expresses an appropriate index domain
    for an array of (potentially symbolic) shape *shape* having reduction
    dimensions *reductions*.

    :param dim_names: A tuple of strings, the names of the axes. These become set
        dimensions in the returned domain.

    :param shape: A tuple of constant or quasi-affine :mod:`pymbolic`
        expressions. The variables in these expressions become parameter
        dimensions in the returned set.  Must have the same length as
        *dim_names*.

    :arg reductions: A map from reduction inames to (lower, upper) bounds
        (as half-open integer ranges). The variables in the bounds become
        parameter dimensions in the returned set.
    """
    assert len(dim_names) == len(shape)

    # Collect parameters.
    param_names_set: Set[str] = set()
    for sdep in map(scalar_expr.get_dependencies, shape):
        param_names_set |= sdep

    for bounds in reductions.values():
        for sdep in map(scalar_expr.get_dependencies, bounds):
            # FIXME: Assumes that reduction bounds are not data-dependent.
            param_names_set |= sdep

    set_names = sorted(tuple(dim_names) + tuple(reductions))
    param_names = sorted(param_names_set)

    # Build domain.
    dom = isl.BasicSet.universe(
        isl.Space.create_from_names(isl.DEFAULT_CONTEXT,
                                    set=set_names,
                                    params=param_names))

    # Add constraints.
    from loopy.symbolic import aff_from_expr
    affs = isl.affs_from_space(dom.space)

    for iname, dim in zip(dim_names, shape):
        dom &= affs[0].le_set(affs[iname])
        dom &= affs[iname].lt_set(aff_from_expr(dom.space, dim))

    for iname, (left, right) in reductions.items():
        dom &= aff_from_expr(dom.space, left).le_set(affs[iname])
        dom &= affs[iname].lt_set(aff_from_expr(dom.space, right))

    doms = dom.get_basic_sets()

    if len(doms) == 0:
        # empty set
        dom = isl.BasicSet.empty(dom.get_space())
    else:
        dom, = doms

    return dom
Exemplo n.º 16
0
def affine_map_inames(kernel, old_inames, new_inames, equations):
    """Return a new *kernel* where the affine transform
    specified by *equations* has been applied to the inames.

    :arg old_inames: A list of inames to be replaced by affine transforms
        of their values.
        May also be a string of comma-separated inames.

    :arg new_inames: A list of new inames that are not yet used in *kernel*,
        but have their values established in terms of *old_inames* by
        *equations*.
        May also be a string of comma-separated inames.
    :arg equations: A list of equations estabilishing a relationship
        between *old_inames* and *new_inames*. Each equation may be
        a tuple ``(lhs, rhs)`` of expressions or a string, with left and
        right hand side of the equation separated by ``=``.
    """

    # {{{ check and parse arguments

    if isinstance(new_inames, str):
        new_inames = new_inames.split(",")
        new_inames = [iname.strip() for iname in new_inames]
    if isinstance(old_inames, str):
        old_inames = old_inames.split(",")
        old_inames = [iname.strip() for iname in old_inames]
    if isinstance(equations, str):
        equations = [equations]

    import re
    eqn_re = re.compile(r"^([^=]+)=([^=]+)$")

    def parse_equation(eqn):
        if isinstance(eqn, str):
            eqn_match = eqn_re.match(eqn)
            if not eqn_match:
                raise ValueError("invalid equation: %s" % eqn)

            from loopy.symbolic import parse
            lhs = parse(eqn_match.group(1))
            rhs = parse(eqn_match.group(2))
            return (lhs, rhs)
        elif isinstance(eqn, tuple):
            if len(eqn) != 2:
                raise ValueError("unexpected length of equation tuple, "
                                 "got %d, should be 2" % len(eqn))
            return eqn
        else:
            raise ValueError("unexpected type of equation"
                             "got %d, should be string or tuple" %
                             type(eqn).__name__)

    equations = [parse_equation(eqn) for eqn in equations]

    all_vars = kernel.all_variable_names()
    for iname in new_inames:
        if iname in all_vars:
            raise LoopyError("new iname '%s' is already used in kernel" %
                             iname)

    for iname in old_inames:
        if iname not in kernel.all_inames():
            raise LoopyError("old iname '%s' not known" % iname)

    # }}}

    # {{{ substitute iname use

    from pymbolic.algorithm import solve_affine_equations_for
    old_inames_to_expr = solve_affine_equations_for(old_inames, equations)

    subst_dict = dict((v.name, expr) for v, expr in old_inames_to_expr.items())

    var_name_gen = kernel.get_var_name_generator()

    from pymbolic.mapper.substitutor import make_subst_func
    from loopy.match import parse_stack_match

    rule_mapping_context = SubstitutionRuleMappingContext(
        kernel.substitutions, var_name_gen)
    old_to_new = RuleAwareSubstitutionMapper(rule_mapping_context,
                                             make_subst_func(subst_dict),
                                             within=parse_stack_match(None))

    kernel = (rule_mapping_context.finish_kernel(
        old_to_new.map_kernel(kernel)).copy(
            applied_iname_rewrites=kernel.applied_iname_rewrites +
            [subst_dict]))

    # }}}

    # {{{ change domains

    new_inames_set = frozenset(new_inames)
    old_inames_set = frozenset(old_inames)

    new_domains = []
    for idom, dom in enumerate(kernel.domains):
        dom_var_dict = dom.get_var_dict()
        old_iname_overlap = [
            iname for iname in old_inames if iname in dom_var_dict
        ]

        if not old_iname_overlap:
            new_domains.append(dom)
            continue

        from loopy.symbolic import get_dependencies
        dom_new_inames = set()
        dom_old_inames = set()

        # mapping for new inames to dim_types
        new_iname_dim_types = {}

        dom_equations = []
        for iname in old_iname_overlap:
            for ieqn, (lhs, rhs) in enumerate(equations):
                eqn_deps = get_dependencies(lhs) | get_dependencies(rhs)
                if iname in eqn_deps:
                    dom_new_inames.update(eqn_deps & new_inames_set)
                    dom_old_inames.update(eqn_deps & old_inames_set)

                if dom_old_inames:
                    dom_equations.append((lhs, rhs))

                this_eqn_old_iname_dim_types = set(dom_var_dict[old_iname][0]
                                                   for old_iname in eqn_deps
                                                   & old_inames_set)

                if this_eqn_old_iname_dim_types:
                    if len(this_eqn_old_iname_dim_types) > 1:
                        raise ValueError(
                            "inames '%s' (from equation %d (0-based)) "
                            "in domain %d (0-based) are not "
                            "of a uniform dim_type" %
                            (", ".join(eqn_deps & old_inames_set), ieqn, idom))

                    this_eqn_new_iname_dim_type, = this_eqn_old_iname_dim_types

                    for new_iname in eqn_deps & new_inames_set:
                        if new_iname in new_iname_dim_types:
                            if (this_eqn_new_iname_dim_type !=
                                    new_iname_dim_types[new_iname]):
                                raise ValueError(
                                    "dim_type disagreement for "
                                    "iname '%s' (from equation %d (0-based)) "
                                    "in domain %d (0-based)" %
                                    (new_iname, ieqn, idom))
                        else:
                            new_iname_dim_types[new_iname] = \
                                    this_eqn_new_iname_dim_type

        if not dom_old_inames <= set(dom_var_dict):
            raise ValueError(
                "domain %d (0-based) does not know about "
                "all old inames (specifically '%s') needed to define new inames"
                % (idom, ", ".join(dom_old_inames - set(dom_var_dict))))

        # add inames to domain with correct dim_types
        dom_new_inames = list(dom_new_inames)
        for iname in dom_new_inames:
            dt = new_iname_dim_types[iname]
            iname_idx = dom.dim(dt)
            dom = dom.add_dims(dt, 1)
            dom = dom.set_dim_name(dt, iname_idx, iname)

        # add equations
        from loopy.symbolic import aff_from_expr
        for lhs, rhs in dom_equations:
            dom = dom.add_constraint(
                isl.Constraint.equality_from_aff(
                    aff_from_expr(dom.space, rhs - lhs)))

        # project out old inames
        for iname in dom_old_inames:
            dt, idx = dom.get_var_dict()[iname]
            dom = dom.project_out(dt, idx, 1)

        new_domains.append(dom)

    # }}}

    # {{{ switch iname refs in instructions

    def fix_iname_set(insn_id, inames):
        if old_inames_set <= inames:
            return (inames - old_inames_set) | new_inames_set
        elif old_inames_set & inames:
            raise LoopyError(
                "instruction '%s' uses only a part (%s), not all, "
                "of the old inames" %
                (insn_id, ", ".join(old_inames_set & inames)))
        else:
            return inames

    new_instructions = [
        insn.copy(within_inames=fix_iname_set(insn.id, insn.within_inames))
        for insn in kernel.instructions
    ]

    # }}}

    return kernel.copy(domains=new_domains, instructions=new_instructions)
Exemplo n.º 17
0
def simplify_via_aff(expr):
    from loopy.symbolic import aff_from_expr, aff_to_expr, get_dependencies
    deps = get_dependencies(expr)
    return aff_to_expr(aff_from_expr(
        isl.Space.create_from_names(isl.DEFAULT_CONTEXT, list(deps)),
        expr))
Exemplo n.º 18
0
def affine_map_inames(kernel, old_inames, new_inames, equations):
    """Return a new *kernel* where the affine transform
    specified by *equations* has been applied to the inames.

    :arg old_inames: A list of inames to be replaced by affine transforms
        of their values.
        May also be a string of comma-separated inames.

    :arg new_inames: A list of new inames that are not yet used in *kernel*,
        but have their values established in terms of *old_inames* by
        *equations*.
        May also be a string of comma-separated inames.
    :arg equations: A list of equations estabilishing a relationship
        between *old_inames* and *new_inames*. Each equation may be
        a tuple ``(lhs, rhs)`` of expressions or a string, with left and
        right hand side of the equation separated by ``=``.
    """

    # {{{ check and parse arguments

    if isinstance(new_inames, str):
        new_inames = new_inames.split(",")
        new_inames = [iname.strip() for iname in new_inames]
    if isinstance(old_inames, str):
        old_inames = old_inames.split(",")
        old_inames = [iname.strip() for iname in old_inames]
    if isinstance(equations, str):
        equations = [equations]

    import re
    eqn_re = re.compile(r"^([^=]+)=([^=]+)$")

    def parse_equation(eqn):
        if isinstance(eqn, str):
            eqn_match = eqn_re.match(eqn)
            if not eqn_match:
                raise ValueError("invalid equation: %s" % eqn)

            from loopy.symbolic import parse
            lhs = parse(eqn_match.group(1))
            rhs = parse(eqn_match.group(2))
            return (lhs, rhs)
        elif isinstance(eqn, tuple):
            if len(eqn) != 2:
                raise ValueError("unexpected length of equation tuple, "
                        "got %d, should be 2" % len(eqn))
            return eqn
        else:
            raise ValueError("unexpected type of equation"
                    "got %d, should be string or tuple"
                    % type(eqn).__name__)

    equations = [parse_equation(eqn) for eqn in equations]

    all_vars = kernel.all_variable_names()
    for iname in new_inames:
        if iname in all_vars:
            raise LoopyError("new iname '%s' is already used in kernel"
                    % iname)

    for iname in old_inames:
        if iname not in kernel.all_inames():
            raise LoopyError("old iname '%s' not known" % iname)

    # }}}

    # {{{ substitute iname use

    from pymbolic.algorithm import solve_affine_equations_for
    old_inames_to_expr = solve_affine_equations_for(old_inames, equations)

    subst_dict = dict(
            (v.name, expr)
            for v, expr in old_inames_to_expr.items())

    var_name_gen = kernel.get_var_name_generator()

    from pymbolic.mapper.substitutor import make_subst_func
    from loopy.context_matching import parse_stack_match

    rule_mapping_context = SubstitutionRuleMappingContext(
            kernel.substitutions, var_name_gen)
    old_to_new = RuleAwareSubstitutionMapper(rule_mapping_context,
            make_subst_func(subst_dict), within=parse_stack_match(None))

    kernel = (
            rule_mapping_context.finish_kernel(
                old_to_new.map_kernel(kernel))
            .copy(
                applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict]
                ))

    # }}}

    # {{{ change domains

    new_inames_set = set(new_inames)
    old_inames_set = set(old_inames)

    new_domains = []
    for idom, dom in enumerate(kernel.domains):
        dom_var_dict = dom.get_var_dict()
        old_iname_overlap = [
                iname
                for iname in old_inames
                if iname in dom_var_dict]

        if not old_iname_overlap:
            new_domains.append(dom)
            continue

        from loopy.symbolic import get_dependencies
        dom_new_inames = set()
        dom_old_inames = set()

        # mapping for new inames to dim_types
        new_iname_dim_types = {}

        dom_equations = []
        for iname in old_iname_overlap:
            for ieqn, (lhs, rhs) in enumerate(equations):
                eqn_deps = get_dependencies(lhs) | get_dependencies(rhs)
                if iname in eqn_deps:
                    dom_new_inames.update(eqn_deps & new_inames_set)
                    dom_old_inames.update(eqn_deps & old_inames_set)

                if dom_old_inames:
                    dom_equations.append((lhs, rhs))

                this_eqn_old_iname_dim_types = set(
                        dom_var_dict[old_iname][0]
                        for old_iname in eqn_deps & old_inames_set)

                if this_eqn_old_iname_dim_types:
                    if len(this_eqn_old_iname_dim_types) > 1:
                        raise ValueError("inames '%s' (from equation %d (0-based)) "
                                "in domain %d (0-based) are not "
                                "of a uniform dim_type"
                                % (", ".join(eqn_deps & old_inames_set), ieqn, idom))

                    this_eqn_new_iname_dim_type, = this_eqn_old_iname_dim_types

                    for new_iname in eqn_deps & new_inames_set:
                        if new_iname in new_iname_dim_types:
                            if (this_eqn_new_iname_dim_type
                                    != new_iname_dim_types[new_iname]):
                                raise ValueError("dim_type disagreement for "
                                        "iname '%s' (from equation %d (0-based)) "
                                        "in domain %d (0-based)"
                                        % (new_iname, ieqn, idom))
                        else:
                            new_iname_dim_types[new_iname] = \
                                    this_eqn_new_iname_dim_type

        if not dom_old_inames <= set(dom_var_dict):
            raise ValueError("domain %d (0-based) does not know about "
                    "all old inames (specifically '%s') needed to define new inames"
                    % (idom, ", ".join(dom_old_inames - set(dom_var_dict))))

        # add inames to domain with correct dim_types
        dom_new_inames = list(dom_new_inames)
        for iname in dom_new_inames:
            dt = new_iname_dim_types[iname]
            iname_idx = dom.dim(dt)
            dom = dom.add_dims(dt, 1)
            dom = dom.set_dim_name(dt, iname_idx, iname)

        # add equations
        from loopy.symbolic import aff_from_expr
        for lhs, rhs in dom_equations:
            dom = dom.add_constraint(
                    isl.Constraint.equality_from_aff(
                        aff_from_expr(dom.space, rhs - lhs)))

        # project out old inames
        for iname in dom_old_inames:
            dt, idx = dom.get_var_dict()[iname]
            dom = dom.project_out(dt, idx, 1)

        new_domains.append(dom)

    # }}}

    return kernel.copy(domains=new_domains)
Exemplo n.º 19
0
    def __init__(self, kernel, domain, sweep_inames, access_descriptors,
            storage_axis_count):
        self.kernel = kernel
        self.sweep_inames = sweep_inames

        storage_axis_names = self.storage_axis_names = [
                "_loopy_storage_%d" % i for i in range(storage_axis_count)]

        # {{{ duplicate sweep inames

        # The duplication is necessary, otherwise the storage fetch
        # inames remain weirdly tied to the original sweep inames.

        self.primed_sweep_inames = [psin+"'" for psin in sweep_inames]

        from loopy.isl_helpers import duplicate_axes
        dup_sweep_index = domain.space.dim(dim_type.out)
        domain_dup_sweep = duplicate_axes(
                domain, sweep_inames,
                self.primed_sweep_inames)

        self.prime_sweep_inames = SubstitutionMapper(make_subst_func(
            dict((sin, var(psin))
                for sin, psin in zip(sweep_inames, self.primed_sweep_inames))))

        # # }}}

        self.stor2sweep = build_global_storage_to_sweep_map(
                kernel, access_descriptors,
                domain_dup_sweep, dup_sweep_index,
                storage_axis_names,
                sweep_inames, self.primed_sweep_inames, self.prime_sweep_inames)

        storage_base_indices, storage_shape = compute_bounds(
                kernel, domain, self.stor2sweep, self.primed_sweep_inames,
                storage_axis_names)

        # compute augmented domain

        # {{{ filter out unit-length dimensions

        non1_storage_axis_flags = []
        non1_storage_shape = []

        for saxis, bi, l in zip(
                storage_axis_names, storage_base_indices, storage_shape):
            has_length_non1 = l != 1

            non1_storage_axis_flags.append(has_length_non1)

            if has_length_non1:
                non1_storage_shape.append(l)

        # }}}

        # {{{ subtract off the base indices
        # add the new, base-0 indices as new in dimensions

        sp = self.stor2sweep.get_space()
        stor_idx = sp.dim(dim_type.out)

        n_stor = storage_axis_count
        nn1_stor = len(non1_storage_shape)

        aug_domain = self.stor2sweep.move_dims(
                dim_type.out, stor_idx,
                dim_type.in_, 0,
                n_stor).range()

        # aug_domain space now:
        # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes']

        aug_domain = aug_domain.insert_dims(dim_type.set, stor_idx, nn1_stor)

        inew = 0
        for i, name in enumerate(storage_axis_names):
            if non1_storage_axis_flags[i]:
                aug_domain = aug_domain.set_dim_name(
                        dim_type.set, stor_idx + inew, name)
                inew += 1

        # aug_domain space now:
        # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes'][n1_stor_axes]

        from loopy.symbolic import aff_from_expr
        for saxis, bi, s in zip(storage_axis_names, storage_base_indices,
                storage_shape):
            if s != 1:
                cns = isl.Constraint.equality_from_aff(
                        aff_from_expr(aug_domain.get_space(),
                            var(saxis) - (var(saxis+"'") - bi)))

                aug_domain = aug_domain.add_constraint(cns)

        # }}}

        # eliminate (primed) storage axes with non-zero base indices
        aug_domain = aug_domain.project_out(dim_type.set, stor_idx+nn1_stor, n_stor)

        # eliminate duplicated sweep_inames
        nsweep = len(sweep_inames)
        aug_domain = aug_domain.project_out(dim_type.set, dup_sweep_index, nsweep)

        self.non1_storage_axis_flags = non1_storage_axis_flags
        self.aug_domain = aug_domain
        self.storage_base_indices = storage_base_indices
        self.non1_storage_shape = non1_storage_shape
Exemplo n.º 20
0
    def augment_domain_for_save_or_reload(self,
            domain, promoted_temporary, mode, subkernel):
        """
        Add new axes to the domain corresponding to the dimensions of
        `promoted_temporary`. These axes will be used in the save/
        reload stage. These get prefixed onto the already existing axes.
        """
        assert mode in ("save", "reload")
        import islpy as isl

        orig_temporary = (
                self.kernel.temporary_variables[
                    promoted_temporary.orig_temporary_name])
        orig_dim = domain.dim(isl.dim_type.set)

        # Tags for newly added inames
        iname_to_tags = {}

        from loopy.symbolic import aff_from_expr

        # FIXME: Restrict size of new inames to access footprint.

        # Add dimension-dependent inames.
        dim_inames = []
        domain = domain.add_dims(isl.dim_type.set,
                            len(promoted_temporary.non_hw_dims)
                            + len(promoted_temporary.hw_dims))

        for dim_idx, dim_size in enumerate(promoted_temporary.non_hw_dims):
            new_iname = self.insn_name_gen("{name}_{mode}_axis_{dim}_{sk}".
                format(name=orig_temporary.name,
                       mode=mode,
                       dim=dim_idx,
                       sk=subkernel))
            domain = domain.set_dim_name(
                isl.dim_type.set, orig_dim + dim_idx, new_iname)

            if orig_temporary.address_space == AddressSpace.LOCAL:
                # If the temporary has local scope, then loads / stores can
                # be done in parallel.
                from loopy.kernel.data import AutoFitLocalIndexTag
                iname_to_tags[new_iname] = frozenset([AutoFitLocalIndexTag()])

            dim_inames.append(new_iname)

            # Add size information.
            aff = isl.affs_from_space(domain.space)
            domain &= aff[0].le_set(aff[new_iname])
            domain &= aff[new_iname].lt_set(aff_from_expr(domain.space, dim_size))

        dim_offset = orig_dim + len(promoted_temporary.non_hw_dims)

        hw_inames = []
        # Add hardware dims.
        for hw_iname_idx, (hw_tag, dim) in enumerate(
                zip(promoted_temporary.hw_tags, promoted_temporary.hw_dims)):
            new_iname = self.insn_name_gen("{name}_{mode}_hw_dim_{dim}_{sk}".
                format(name=orig_temporary.name,
                       mode=mode,
                       dim=hw_iname_idx,
                       sk=subkernel))
            domain = domain.set_dim_name(
                isl.dim_type.set, dim_offset + hw_iname_idx, new_iname)

            aff = isl.affs_from_space(domain.space)
            domain = (domain
                &
                aff[0].le_set(aff[new_iname])
                &
                aff[new_iname].lt_set(aff_from_expr(domain.space, dim)))

            self.updated_iname_to_tags[new_iname] = frozenset([hw_tag])
            hw_inames.append(new_iname)

        # The operations on the domain above return a Set object, but the
        # underlying domain should be expressible as a single BasicSet.
        domain_list = domain.get_basic_set_list()
        assert domain_list.n_basic_set() == 1
        domain = domain_list.get_basic_set(0)
        return domain, hw_inames, dim_inames, iname_to_tags
Exemplo n.º 21
0
    def augment_domain_for_save_or_reload(self, domain, promoted_temporary,
                                          mode, subkernel):
        """
        Add new axes to the domain corresponding to the dimensions of
        `promoted_temporary`. These axes will be used in the save/
        reload stage. These get prefixed onto the already existing axes.
        """
        assert mode in ("save", "reload")
        import islpy as isl

        orig_temporary = (self.kernel.temporary_variables[
            promoted_temporary.orig_temporary_name])
        orig_dim = domain.dim(isl.dim_type.set)

        # Tags for newly added inames
        iname_to_tag = {}

        from loopy.symbolic import aff_from_expr

        # FIXME: Restrict size of new inames to access footprint.

        # Add dimension-dependent inames.
        dim_inames = []
        domain = domain.add(
            isl.dim_type.set,
            len(promoted_temporary.non_hw_dims) +
            len(promoted_temporary.hw_dims))

        for dim_idx, dim_size in enumerate(promoted_temporary.non_hw_dims):
            new_iname = self.insn_name_gen(
                "{name}_{mode}_axis_{dim}_{sk}".format(
                    name=orig_temporary.name,
                    mode=mode,
                    dim=dim_idx,
                    sk=subkernel))
            domain = domain.set_dim_name(isl.dim_type.set, orig_dim + dim_idx,
                                         new_iname)

            if orig_temporary.is_local:
                # If the temporary has local scope, then loads / stores can
                # be done in parallel.
                from loopy.kernel.data import AutoFitLocalIndexTag
                iname_to_tag[new_iname] = AutoFitLocalIndexTag()

            dim_inames.append(new_iname)

            # Add size information.
            aff = isl.affs_from_space(domain.space)
            domain &= aff[0].le_set(aff[new_iname])
            domain &= aff[new_iname].lt_set(
                aff_from_expr(domain.space, dim_size))

        dim_offset = orig_dim + len(promoted_temporary.non_hw_dims)

        hw_inames = []
        # Add hardware dims.
        for hw_iname_idx, (hw_tag, dim) in enumerate(
                zip(promoted_temporary.hw_tags, promoted_temporary.hw_dims)):
            new_iname = self.insn_name_gen(
                "{name}_{mode}_hw_dim_{dim}_{sk}".format(
                    name=orig_temporary.name,
                    mode=mode,
                    dim=hw_iname_idx,
                    sk=subkernel))
            domain = domain.set_dim_name(isl.dim_type.set,
                                         dim_offset + hw_iname_idx, new_iname)

            aff = isl.affs_from_space(domain.space)
            domain = (
                domain
                & aff[0].le_set(aff[new_iname])
                & aff[new_iname].lt_set(aff_from_expr(domain.space, dim)))

            self.updated_iname_to_tag[new_iname] = hw_tag
            hw_inames.append(new_iname)

        # The operations on the domain above return a Set object, but the
        # underlying domain should be expressible as a single BasicSet.
        domain_list = domain.get_basic_set_list()
        assert domain_list.n_basic_set() == 1
        domain = domain_list.get_basic_set(0)
        return domain, hw_inames, dim_inames, iname_to_tag
Exemplo n.º 22
0
def make_slab(space, iname, start, stop, iname_multiplier=1):
    """
    Returns an instance of :class:`islpy._isl.BasicSet`, which satisfies the
    constraint ``start <= iname_multiplier*iname < stop``.

    :arg space: An instance of :class:`islpy._isl.Space`.

    :arg iname:

        Either an instance of :class:`str` as a name of the ``iname`` or a
        tuple of ``(iname_dt, iname_dx)`` indicating the *iname* in the space.

    :arg start:

        An instance of :class:`int`  or an instance of
        :class:`islpy._isl.Aff` indicating the lower bound of
        ``iname_multiplier*iname``(inclusive).

    :arg stop:

        An instance of :class:`int`  or an instance of
        :class:`islpy._isl.Aff` indicating the upper bound of
        ``iname_multiplier*iname``.

    :arg iname_multiplier:

        A strictly positive :class:`int` denoting *iname*'s coefficient in the
        above inequality expression.
    """
    zero = isl.Aff.zero_on_domain(space)

    if isinstance(start, (isl.Aff, isl.PwAff)):
        start, zero = isl.align_two(pw_aff_to_aff(start), zero)
    if isinstance(stop, (isl.Aff, isl.PwAff)):
        stop, zero = isl.align_two(pw_aff_to_aff(stop), zero)

    space = zero.get_domain_space()

    from pymbolic.primitives import Expression
    from loopy.symbolic import aff_from_expr
    if isinstance(start, Expression):
        start = aff_from_expr(space, start)
    if isinstance(stop, Expression):
        stop = aff_from_expr(space, stop)

    if isinstance(start, int):
        start = zero + start
    if isinstance(stop, int):
        stop = zero + stop

    if isinstance(iname, str):
        iname_dt, iname_idx = zero.get_space().get_var_dict()[iname]
    else:
        iname_dt, iname_idx = iname

    iname_aff = zero.add_coefficient_val(iname_dt, iname_idx, 1)

    if iname_multiplier > 0:
        result = (
            isl.BasicSet.universe(space)
            # start <= iname_multiplier*iname
            .add_constraint(
                isl.Constraint.inequality_from_aff(iname_multiplier *
                                                   iname_aff - start))
            # iname_multiplier*iname < stop
            .add_constraint(
                isl.Constraint.inequality_from_aff(stop - 1 -
                                                   iname_multiplier *
                                                   iname_aff)))
    else:
        raise LoopyError("iname_multiplier must be strictly positive")

    return result