def map_subscript(self, expr): from pymbolic.primitives import Variable assert isinstance(expr.aggregate, Variable) name = expr.aggregate.name dims = self.scope.dim_map.get(name) if dims is None: return IdentityMapper.map_subscript(self, expr) subscript = expr.index if not isinstance(subscript, tuple): subscript = (subscript,) subscript = list(subscript) if len(dims) != len(subscript): raise TranslationError("inconsistent number of indices " "to '%s'" % name) for i in range(len(dims)): if len(dims[i]) == 2: # has a base index subscript[i] -= dims[i][0] elif len(dims[i]) == 1: # base index is 1 implicitly subscript[i] -= 1 return expr.aggregate[self.rec(tuple(subscript))]
def map_subscript(self, expr): try: new_idx = self.var_to_new_inames[expr.aggregate.name] except KeyError: return IdentityMapper.map_subscript(self, expr) else: index = expr.index if not isinstance(index, tuple): index = (index,) index = tuple(self.rec(i) for i in index) return expr.aggregate.index(index + new_idx)
def map_subscript(self, expr): try: new_idx = self.var_to_new_inames[expr.aggregate.name] except KeyError: return IdentityMapper.map_subscript(self, expr) else: index = expr.index if not isinstance(index, tuple): index = (index, ) index = tuple(self.rec(i) for i in index) return expr.aggregate.index(index + new_idx)
def map_subscript(self, expr): try: new_idx = self.var_to_new_inames[expr.aggregate.name] except KeyError: return IdentityMapper.map_subscript(self, expr) else: index = expr.index if not isinstance(index, tuple): index = (index, ) index = tuple(self.rec(i) for i in index) self.seen_priv_axis_inames.update(v.name for v in new_idx) new_idx = tuple(v - self.iname_to_lbound[v.name] for v in new_idx) return expr.aggregate.index(index + new_idx)
def extract_subst(kernel, subst_name, template, parameters=()): """ :arg subst_name: The name of the substitution rule to be created. :arg template: Unification template expression. :arg parameters: An iterable of parameters used in *template*, or a comma-separated string of the same. All targeted subexpressions must match ('unify with') *template* The template may contain '*' wildcards that will have to match exactly across all unifications. """ if isinstance(template, str): from pymbolic import parse template = parse(template) if isinstance(parameters, str): parameters = tuple(s.strip() for s in parameters.split(",")) var_name_gen = kernel.get_var_name_generator() # {{{ replace any wildcards in template with new variables def get_unique_var_name(): based_on = subst_name + "_wc" result = var_name_gen(based_on) return result from loopy.symbolic import WildcardToUniqueVariableMapper wc_map = WildcardToUniqueVariableMapper(get_unique_var_name) template = wc_map(template) # }}} # {{{ gather up expressions expr_descriptors = [] from loopy.symbolic import UnidirectionalUnifier unif = UnidirectionalUnifier(lhs_mapping_candidates=set(parameters)) def gather_exprs(expr, mapper): urecs = unif(template, expr) if urecs: if len(urecs) > 1: raise RuntimeError( "ambiguous unification of '%s' with template '%s'" % (expr, template)) urec, = urecs expr_descriptors.append( ExprDescriptor(insn=insn, expr=expr, unif_var_dict=dict( (lhs.name, rhs) for lhs, rhs in urec.equations))) else: mapper.fallback_mapper(expr) # can't nest, don't recurse from loopy.symbolic import (CallbackMapper, WalkMapper, IdentityMapper) dfmapper = CallbackMapper(gather_exprs, WalkMapper()) for insn in kernel.instructions: dfmapper(insn.assignees) dfmapper(insn.expression) for sr in six.itervalues(kernel.substitutions): dfmapper(sr.expression) # }}} if not expr_descriptors: raise RuntimeError("no expressions matching '%s'" % template) # {{{ substitute rule into instructions def replace_exprs(expr, mapper): found = False for exprd in expr_descriptors: if expr is exprd.expr: found = True break if not found: return mapper.fallback_mapper(expr) args = [exprd.unif_var_dict[arg_name] for arg_name in parameters] result = var(subst_name) if args: result = result(*args) return result # can't nest, don't recurse cbmapper = CallbackMapper(replace_exprs, IdentityMapper()) new_insns = [] for insn in kernel.instructions: new_insns.append(insn.with_transformed_expressions(cbmapper)) from loopy.kernel.data import SubstitutionRule new_substs = { subst_name: SubstitutionRule( name=subst_name, arguments=tuple(parameters), expression=template, ) } for subst in six.itervalues(kernel.substitutions): new_substs[subst.name] = subst.copy( expression=cbmapper(subst.expression)) # }}} return kernel.copy(instructions=new_insns, substitutions=new_substs)
def extract_subst(kernel, subst_name, template, parameters=(), within=None): """ :arg subst_name: The name of the substitution rule to be created. :arg template: Unification template expression. :arg parameters: An iterable of parameters used in *template*, or a comma-separated string of the same. :arg within: An instance of :class:`loopy.match.MatchExpressionBase` or :class:`str` as understood by :func:`loopy.match.parse_match`. All targeted subexpressions must match ('unify with') *template* The template may contain '*' wildcards that will have to match exactly across all unifications. """ if isinstance(kernel, TranslationUnit): kernel_names = [ i for i, clbl in kernel.callables_table.items() if isinstance(clbl, CallableKernel) ] if len(kernel_names) != 1: raise LoopyError() return kernel.with_kernel( extract_subst(kernel[kernel_names[0]], subst_name, template, parameters)) if isinstance(template, str): from pymbolic import parse template = parse(template) if isinstance(parameters, str): parameters = tuple(s.strip() for s in parameters.split(",")) from loopy.match import parse_match within = parse_match(within) var_name_gen = kernel.get_var_name_generator() # {{{ replace any wildcards in template with new variables def get_unique_var_name(): based_on = subst_name + "_wc" result = var_name_gen(based_on) return result from loopy.symbolic import WildcardToUniqueVariableMapper wc_map = WildcardToUniqueVariableMapper(get_unique_var_name) template = wc_map(template) # }}} # {{{ gather up expressions expr_descriptors = [] from loopy.symbolic import UnidirectionalUnifier unif = UnidirectionalUnifier(lhs_mapping_candidates=set(parameters)) def gather_exprs(expr, mapper): urecs = unif(template, expr) if urecs: if len(urecs) > 1: raise RuntimeError( "ambiguous unification of '%s' with template '%s'" % (expr, template)) urec, = urecs expr_descriptors.append( ExprDescriptor(insn=insn, expr=expr, unif_var_dict={ lhs.name: rhs for lhs, rhs in urec.equations })) else: mapper.fallback_mapper(expr) # can't nest, don't recurse from loopy.symbolic import (CallbackMapper, WalkMapper, IdentityMapper) dfmapper = CallbackMapper(gather_exprs, WalkMapper()) from loopy.kernel.instruction import MultiAssignmentBase for insn in kernel.instructions: if isinstance(insn, MultiAssignmentBase) and within(kernel, insn): dfmapper(insn.assignees) dfmapper(insn.expression) for sr in kernel.substitutions.values(): dfmapper(sr.expression) # }}} if not expr_descriptors: raise RuntimeError("no expressions matching '%s'" % template) # {{{ substitute rule into instructions def replace_exprs(expr, mapper): found = False for exprd in expr_descriptors: if expr is exprd.expr: found = True break if not found: return mapper.fallback_mapper(expr) args = [exprd.unif_var_dict[arg_name] for arg_name in parameters] result = var(subst_name) if args: result = result(*args) return result # can't nest, don't recurse cbmapper = CallbackMapper(replace_exprs, IdentityMapper()) new_insns = [] def transform_assignee(expr): # Assignment LHS's cannot be subst rules. Treat them # specially. import pymbolic.primitives as prim if isinstance(expr, tuple): return tuple(transform_assignee(expr_i) for expr_i in expr) elif isinstance(expr, prim.Subscript): return type(expr)(expr.aggregate, cbmapper(expr.index)) elif isinstance(expr, prim.Variable): return expr else: raise ValueError("assignment LHS not understood") for insn in kernel.instructions: if within(kernel, insn): new_insns.append( insn.with_transformed_expressions( cbmapper, assignee_f=transform_assignee)) else: new_insns.append(insn) from loopy.kernel.data import SubstitutionRule new_substs = { subst_name: SubstitutionRule( name=subst_name, arguments=tuple(parameters), expression=template, ) } for subst in kernel.substitutions.values(): new_substs[subst.name] = subst.copy( expression=cbmapper(subst.expression)) # }}} return kernel.copy(instructions=new_insns, substitutions=new_substs)
def map_subscript(self, expr): from pymbolic.primitives import Variable assert isinstance(expr.aggregate, Variable) name = expr.aggregate.name dims = self.scope.dim_map.get(name) if dims is None: return IdentityMapper.map_subscript(self, expr) subscript = expr.index if not isinstance(subscript, tuple): subscript = (subscript, ) if len(dims) != len(subscript): raise TranslationError("inconsistent number of indices " "to '%s'" % name) new_subscript = [] for i in range(len(dims)): if len(dims[i]) == 2: # has an explicit base index base_index, end_index = dims[i] elif len(dims[i]) == 1: base_index = 1 end_index, = dims[i] sub_i = subscript[i] if isinstance(sub_i, Slice): start = sub_i.start if start is None: start = base_index step = sub_i.step if step is None: step = 1 stop = sub_i.stop if stop is None: stop = end_index if step == 1: sub_i = Slice(( start - base_index, # FIXME This is only correct for unit strides stop - base_index + 1, step)) elif step == -1: sub_i = Slice(( start - base_index, # FIXME This is only correct for unit strides stop - base_index - 1, step)) else: # FIXME raise NotImplementedError("Fortran slice processing for " "non-unit strides") else: sub_i = sub_i - base_index new_subscript.append(sub_i) return expr.aggregate[self.rec(tuple(new_subscript))]