def expand_subst(kernel, within=None): logger.debug("%s: expand subst" % kernel.name) from loopy.symbolic import RuleAwareSubstitutionRuleExpander from loopy.context_matching import parse_stack_match rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) submap = RuleAwareSubstitutionRuleExpander( rule_mapping_context, kernel.substitutions, parse_stack_match(within)) return rule_mapping_context.finish_kernel(submap.map_kernel(kernel))
def _split_reduction(kernel, inames, direction, within=None): if direction not in ["in", "out"]: raise ValueError("invalid value for 'direction': %s" % direction) if isinstance(inames, str): inames = inames.split(",") inames = set(inames) from loopy.context_matching import parse_stack_match within = parse_stack_match(within) rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) rsplit = _ReductionSplitter(rule_mapping_context, within, inames, direction) return rule_mapping_context.finish_kernel( rsplit.map_kernel(kernel))
def precompute(kernel, subst_use, sweep_inames=[], within=None, storage_axes=None, temporary_name=None, precompute_inames=None, storage_axis_to_tag={}, default_tag="l.auto", dtype=None, fetch_bounding_box=False, temporary_is_local=None, compute_insn_id=None): """Precompute the expression described in the substitution rule determined by *subst_use* and store it in a temporary array. A precomputation needs two things to operate, a list of *sweep_inames* (order irrelevant) and an ordered list of *storage_axes* (whose order will describe the axis ordering of the temporary array). :arg subst_use: Describes what to prefetch. The following objects may be given for *subst_use*: * The name of the substitution rule. * The tagged name ("name$tag") of the substitution rule. * A list of invocations of the substitution rule. This list of invocations, when swept across *sweep_inames*, then serves to define the footprint of the precomputation. Invocations may be tagged ("name$tag") to filter out a subset of the usage sites of the substitution rule. (Namely those usage sites that use the same tagged name.) Invocations may be given as a string or as a :class:`pymbolic.primitives.Expression` object. If only one invocation is to be given, then the only entry of the list may be given directly. If the list of invocations generating the footprint is not given, all (tag-matching, if desired) usage sites of the substitution rule are used to determine the footprint. The following cases can arise for each sweep axis: * The axis is an iname that occurs within arguments specified at usage sites of the substitution rule. This case is assumed covered by the storage axes provided for the argument. * The axis is an iname that occurs within the *value* of the rule, but not within its arguments. A new, dedicated storage axis is allocated for such an axis. :arg sweep_inames: A :class:`list` of inames and/or rule argument names to be swept. May also equivalently be a comma-separated string. :arg storage_axes: A :class:`list` of inames and/or rule argument names/indices to be used as storage axes. May also equivalently be a comma-separated string. :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. :arg temporary_name: The temporary variable name to use for storing the precomputed data. If it does not exist, it will be created. If it does exist, its properties (such as size, type) are checked (and updated, if possible) to match its use. :arg precompute_inames: A tuple of inames to be used to carry out the precomputation. If the specified inames do not already exist, they will be created. If they do already exist, their loop domain is verified against the one required for this precomputation. This tuple may be shorter than the (provided or automatically found) *storage_axes* tuple, in which case names will be automatically created. May also equivalently be a comma-separated string. :arg compute_insn_id: The ID of the instruction performing the precomputation. If `storage_axes` is not specified, it defaults to the arrangement `<direct sweep axes><arguments>` with the direct sweep axes being the slower-varying indices. Trivial storage axes (i.e. axes of length 1 with respect to the sweep) are eliminated. """ # {{{ check, standardize arguments if isinstance(sweep_inames, str): sweep_inames = [iname.strip() for iname in sweep_inames.split(",")] for iname in sweep_inames: if iname not in kernel.all_inames(): raise RuntimeError("sweep iname '%s' is not a known iname" % iname) sweep_inames = list(sweep_inames) sweep_inames_set = frozenset(sweep_inames) if isinstance(storage_axes, str): storage_axes = [ax.strip() for ax in storage_axes.split(",")] if isinstance(precompute_inames, str): precompute_inames = [iname.strip() for iname in precompute_inames.split(",")] if isinstance(subst_use, str): subst_use = [subst_use] footprint_generators = None subst_name = None subst_tag = None from pymbolic.primitives import Variable, Call from loopy.symbolic import parse, TaggedVariable for use in subst_use: if isinstance(use, str): use = parse(use) if isinstance(use, Call): if footprint_generators is None: footprint_generators = [] footprint_generators.append(use) subst_name_as_expr = use.function else: subst_name_as_expr = use if isinstance(subst_name_as_expr, TaggedVariable): new_subst_name = subst_name_as_expr.name new_subst_tag = subst_name_as_expr.tag elif isinstance(subst_name_as_expr, Variable): new_subst_name = subst_name_as_expr.name new_subst_tag = None else: raise ValueError("unexpected type of subst_name") if (subst_name, subst_tag) == (None, None): subst_name, subst_tag = new_subst_name, new_subst_tag else: if (subst_name, subst_tag) != (new_subst_name, new_subst_tag): raise ValueError("not all uses in subst_use agree " "on rule name and tag") from loopy.context_matching import parse_stack_match within = parse_stack_match(within) from loopy.kernel.data import parse_tag default_tag = parse_tag(default_tag) subst = kernel.substitutions[subst_name] c_subst_name = subst_name.replace(".", "_") # }}} # {{{ process invocations in footprint generators, start access_descriptors if footprint_generators: from pymbolic.primitives import Variable, Call access_descriptors = [] for fpg in footprint_generators: if isinstance(fpg, Variable): args = () elif isinstance(fpg, Call): args = fpg.parameters else: raise ValueError("footprint generator must " "be substitution rule invocation") access_descriptors.append( RuleAccessDescriptor( identifier=access_descriptor_id(args, None), args=args )) # }}} # {{{ gather up invocations in kernel code, finish access_descriptors if not footprint_generators: rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) invg = RuleInvocationGatherer( rule_mapping_context, kernel, subst_name, subst_tag, within) del rule_mapping_context import loopy as lp for insn in kernel.instructions: if isinstance(insn, lp.Assignment): invg(insn.assignee, kernel, insn) invg(insn.expression, kernel, insn) access_descriptors = invg.access_descriptors if not access_descriptors: raise RuntimeError("no invocations of '%s' found" % subst_name) # }}} # {{{ find inames used in arguments expanding_usage_arg_deps = set() for accdesc in access_descriptors: for arg in accdesc.args: expanding_usage_arg_deps.update( get_dependencies(arg) & kernel.all_inames()) # }}} var_name_gen = kernel.get_var_name_generator() # {{{ use given / find new storage_axes # extra axes made necessary because they don't occur in the arguments extra_storage_axes = set(sweep_inames_set - expanding_usage_arg_deps) from loopy.symbolic import SubstitutionRuleExpander submap = SubstitutionRuleExpander(kernel.substitutions) value_inames = get_dependencies( submap(subst.expression) ) & kernel.all_inames() if value_inames - expanding_usage_arg_deps < extra_storage_axes: raise RuntimeError("unreferenced sweep inames specified: " + ", ".join(extra_storage_axes - value_inames - expanding_usage_arg_deps)) new_iname_to_tag = {} if storage_axes is None: storage_axes = [] # Add sweep_inames (in given--rather than arbitrary--order) to # storage_axes *if* they are part of extra_storage_axes. for iname in sweep_inames: if iname in extra_storage_axes: extra_storage_axes.remove(iname) storage_axes.append(iname) if extra_storage_axes: if (precompute_inames is not None and len(storage_axes) < len(precompute_inames)): raise LoopyError("must specify a sufficient number of " "storage_axes to uniquely determine the meaning " "of the given precompute_inames. (%d storage_axes " "needed)" % len(precompute_inames)) storage_axes.extend(sorted(extra_storage_axes)) storage_axes.extend(range(len(subst.arguments))) del extra_storage_axes prior_storage_axis_name_dict = {} storage_axis_names = [] storage_axis_sources = [] # number for arg#, or iname # {{{ check for pre-existing precompute_inames if precompute_inames is not None: preexisting_precompute_inames = ( set(precompute_inames) & kernel.all_inames()) else: preexisting_precompute_inames = set() # }}} for i, saxis in enumerate(storage_axes): tag_lookup_saxis = saxis if saxis in subst.arguments: saxis = subst.arguments.index(saxis) storage_axis_sources.append(saxis) if isinstance(saxis, int): # argument index name = old_name = subst.arguments[saxis] else: old_name = saxis name = "%s_%s" % (c_subst_name, old_name) if (precompute_inames is not None and i < len(precompute_inames) and precompute_inames[i]): name = precompute_inames[i] tag_lookup_saxis = name if (name not in preexisting_precompute_inames and var_name_gen.is_name_conflicting(name)): raise RuntimeError("new storage axis name '%s' " "conflicts with existing name" % name) else: name = var_name_gen(name) storage_axis_names.append(name) if name not in preexisting_precompute_inames: new_iname_to_tag[name] = storage_axis_to_tag.get( tag_lookup_saxis, default_tag) prior_storage_axis_name_dict[name] = old_name del storage_axis_to_tag del storage_axes del precompute_inames # }}} # {{{ fill out access_descriptors[...].storage_axis_exprs access_descriptors = [ accdesc.copy( storage_axis_exprs=storage_axis_exprs( storage_axis_sources, accdesc.args)) for accdesc in access_descriptors] # }}} expanding_inames = sweep_inames_set | frozenset(expanding_usage_arg_deps) assert expanding_inames <= kernel.all_inames() if storage_axis_names: # {{{ find domain to be changed change_inames = expanding_inames | preexisting_precompute_inames from loopy.kernel.tools import DomainChanger domch = DomainChanger(kernel, change_inames) if domch.leaf_domain_index is not None: # If the sweep inames are at home in parent domains, then we'll add # fetches with loops over copies of these parent inames that will end # up being scheduled *within* loops over these parents. for iname in sweep_inames_set: if kernel.get_home_domain_index(iname) != domch.leaf_domain_index: raise RuntimeError("sweep iname '%s' is not 'at home' in the " "sweep's leaf domain" % iname) # }}} abm = ArrayToBufferMap(kernel, domch.domain, sweep_inames, access_descriptors, len(storage_axis_names)) non1_storage_axis_names = [] for i, saxis in enumerate(storage_axis_names): if abm.non1_storage_axis_flags[i]: non1_storage_axis_names.append(saxis) else: del new_iname_to_tag[saxis] if saxis in preexisting_precompute_inames: raise LoopyError("precompute axis %d (1-based) was " "eliminated as " "having length 1 but also mapped to existing " "iname '%s'" % (i+1, saxis)) mod_domain = domch.domain # {{{ modify the domain, taking into account preexisting inames # inames may already exist in mod_domain, add them primed to start primed_non1_saxis_names = [ iname+"'" for iname in non1_storage_axis_names] mod_domain = abm.augment_domain_with_sweep( domch.domain, primed_non1_saxis_names, boxify_sweep=fetch_bounding_box) check_domain = mod_domain for i, saxis in enumerate(non1_storage_axis_names): var_dict = mod_domain.get_var_dict(isl.dim_type.set) if saxis in preexisting_precompute_inames: # add equality constraint between existing and new variable dt, dim_idx = var_dict[saxis] saxis_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx) dt, dim_idx = var_dict[primed_non1_saxis_names[i]] new_var_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx) mod_domain = mod_domain.add_constraint( isl.Constraint.equality_from_aff(new_var_aff - saxis_aff)) # project out the new one mod_domain = mod_domain.project_out(dt, dim_idx, 1) else: # remove the prime from the new variable dt, dim_idx = var_dict[primed_non1_saxis_names[i]] mod_domain = mod_domain.set_dim_name(dt, dim_idx, saxis) # {{{ check that we got the desired domain check_domain = check_domain.project_out_except( primed_non1_saxis_names, [isl.dim_type.set]) mod_check_domain = mod_domain # re-add the prime from the new variable var_dict = mod_check_domain.get_var_dict(isl.dim_type.set) for saxis in non1_storage_axis_names: dt, dim_idx = var_dict[saxis] mod_check_domain = mod_check_domain.set_dim_name(dt, dim_idx, saxis+"'") mod_check_domain = mod_check_domain.project_out_except( primed_non1_saxis_names, [isl.dim_type.set]) mod_check_domain, check_domain = isl.align_two( mod_check_domain, check_domain) # The modified domain can't get bigger by adding constraints assert mod_check_domain <= check_domain if not check_domain <= mod_check_domain: print(check_domain) print(mod_check_domain) raise LoopyError("domain of preexisting inames does not match " "domain needed for precompute") # }}} # {{{ check that we didn't shrink the original domain # project out the new names from the modified domain orig_domain_inames = list(domch.domain.get_var_dict(isl.dim_type.set)) mod_check_domain = mod_domain.project_out_except( orig_domain_inames, [isl.dim_type.set]) check_domain = domch.domain mod_check_domain, check_domain = isl.align_two( mod_check_domain, check_domain) # The modified domain can't get bigger by adding constraints assert mod_check_domain <= check_domain if not check_domain <= mod_check_domain: print(check_domain) print(mod_check_domain) raise LoopyError("original domain got shrunk by applying the precompute") # }}} # }}} new_kernel_domains = domch.get_domains_with(mod_domain) else: # leave kernel domains unchanged new_kernel_domains = kernel.domains non1_storage_axis_names = [] abm = NoOpArrayToBufferMap() kernel = kernel.copy(domains=new_kernel_domains) # {{{ set up compute insn if temporary_name is None: temporary_name = var_name_gen(based_on=c_subst_name) assignee = var(temporary_name) if non1_storage_axis_names: assignee = assignee.index( tuple(var(iname) for iname in non1_storage_axis_names)) # {{{ process substitutions on compute instruction storage_axis_subst_dict = {} for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices): if arg_name in non1_storage_axis_names: arg = var(arg_name) else: arg = 0 storage_axis_subst_dict[ prior_storage_axis_name_dict.get(arg_name, arg_name)] = arg+bi rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) from loopy.context_matching import parse_stack_match expr_subst_map = RuleAwareSubstitutionMapper( rule_mapping_context, make_subst_func(storage_axis_subst_dict), within=parse_stack_match(None)) compute_expression = expr_subst_map(subst.expression, kernel, None) # }}} from loopy.kernel.data import Assignment if compute_insn_id is None: compute_insn_id = kernel.make_unique_instruction_id(based_on=c_subst_name) compute_insn = Assignment( id=compute_insn_id, assignee=assignee, expression=compute_expression) # }}} # {{{ substitute rule into expressions in kernel (if within footprint) invr = RuleInvocationReplacer(rule_mapping_context, subst_name, subst_tag, within, access_descriptors, abm, storage_axis_names, storage_axis_sources, non1_storage_axis_names, temporary_name, compute_insn_id) kernel = invr.map_kernel(kernel) kernel = kernel.copy( instructions=[compute_insn] + kernel.instructions) kernel = rule_mapping_context.finish_kernel(kernel) # }}} # {{{ set up temp variable import loopy as lp if dtype is None: dtype = lp.auto else: dtype = np.dtype(dtype) import loopy as lp if temporary_is_local is None: temporary_is_local = lp.auto new_temp_shape = tuple(abm.non1_storage_shape) new_temporary_variables = kernel.temporary_variables.copy() if temporary_name not in new_temporary_variables: temp_var = lp.TemporaryVariable( name=temporary_name, dtype=dtype, base_indices=(0,)*len(new_temp_shape), shape=tuple(abm.non1_storage_shape), is_local=temporary_is_local) else: temp_var = new_temporary_variables[temporary_name] # {{{ check and adapt existing temporary if temp_var.dtype is lp.auto: pass elif temp_var.dtype is not lp.auto and dtype is lp.auto: dtype = temp_var.dtype elif temp_var.dtype is not lp.auto and dtype is not lp.auto: if temp_var.dtype != dtype: raise LoopyError("Existing and new dtype of temporary '%s' " "do not match (existing: %s, new: %s)" % (temporary_name, temp_var.dtype, dtype)) temp_var = temp_var.copy(dtype=dtype) if len(temp_var.shape) != len(new_temp_shape): raise LoopyError("Existing and new temporary '%s' do not " "have matching number of dimensions " % (temporary_name, len(temp_var.shape), len(new_temp_shape))) if temp_var.base_indices != (0,) * len(new_temp_shape): raise LoopyError("Existing and new temporary '%s' do not " "have matching number of dimensions " % (temporary_name, len(temp_var.shape), len(new_temp_shape))) new_temp_shape = tuple( max(i, ex_i) for i, ex_i in zip(new_temp_shape, temp_var.shape)) temp_var = temp_var.copy(shape=new_temp_shape) if temporary_is_local == temp_var.is_local: pass elif temporary_is_local is lp.auto: temporary_is_local = temp_var.is_local elif temp_var.is_local is lp.auto: pass else: raise LoopyError("Existing and new temporary '%s' do not " "have matching values of 'is_local'" % (temporary_name, temp_var.is_local, temporary_is_local)) temp_var = temp_var.copy(is_local=temporary_is_local) # }}} new_temporary_variables[temporary_name] = temp_var kernel = kernel.copy( temporary_variables=new_temporary_variables) # }}} from loopy import tag_inames kernel = tag_inames(kernel, new_iname_to_tag) from loopy.kernel.data import AutoFitLocalIndexTag has_automatic_axes = any( isinstance(tag, AutoFitLocalIndexTag) for tag in new_iname_to_tag.values()) if has_automatic_axes: from loopy.kernel.tools import assign_automatic_axes kernel = assign_automatic_axes(kernel) return kernel
def temporary_to_subst(kernel, temp_name, extra_arguments=(), within=None): """Extract an assignment to a temporary variable as a :ref:`substituiton-rule`. The temporary may be an array, in which case the array indices will become arguments to the substitution rule. :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. This operation will change all usage sites of *temp_name* matched by *within*. If there are further usage sites of *temp_name*, then the original assignment to *temp_name* as well as the temporary variable is left in place. """ if isinstance(extra_arguments, str): extra_arguments = tuple(s.strip() for s in extra_arguments.split(",")) # {{{ establish the relevant definition of temp_name for each usage site dep_kernel = expand_subst(kernel) from loopy.preprocess import add_default_dependencies dep_kernel = add_default_dependencies(dep_kernel) id_to_insn = dep_kernel.id_to_insn def get_relevant_definition_insn_id(usage_insn_id): insn = id_to_insn[usage_insn_id] def_id = set() for dep_id in insn.insn_deps: dep_insn = id_to_insn[dep_id] if temp_name in dep_insn.write_dependency_names(): if temp_name in dep_insn.read_dependency_names(): raise LoopyError("instruction '%s' both reads *and* " "writes '%s'--cannot transcribe to substitution " "rule" % (dep_id, temp_name)) def_id.add(dep_id) else: rec_result = get_relevant_definition_insn_id(dep_id) if rec_result is not None: def_id.add(rec_result) if len(def_id) > 1: raise LoopyError("more than one write to '%s' found in " "depdendencies of '%s'--definition cannot be resolved " "(writer instructions ids: %s)" % (temp_name, usage_insn_id, ", ".join(def_id))) if not def_id: return None else: def_id, = def_id return def_id usage_to_definition = {} for insn in kernel.instructions: if temp_name not in insn.read_dependency_names(): continue def_id = get_relevant_definition_insn_id(insn.id) if def_id is None: raise LoopyError("no write to '%s' found in dependency tree " "of '%s'--definition cannot be resolved" % (temp_name, insn.id)) usage_to_definition[insn.id] = def_id definition_insn_ids = set() for insn in kernel.instructions: if temp_name in insn.write_dependency_names(): definition_insn_ids.add(insn.id) # }}} from loopy.context_matching import parse_stack_match within = parse_stack_match(within) rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) tts = TemporaryToSubstChanger(rule_mapping_context, temp_name, definition_insn_ids, usage_to_definition, extra_arguments, within) kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel)) from loopy.kernel.data import SubstitutionRule # {{{ create new substitution rules new_substs = kernel.substitutions.copy() for def_id, subst_name in six.iteritems(tts.definition_insn_id_to_subst_name): def_insn = kernel.id_to_insn[def_id] (_, indices), = def_insn.assignees_and_indices() arguments = [] from pymbolic.primitives import Variable for i in indices: if not isinstance(i, Variable): raise LoopyError("In defining instruction '%s': " "asignee index '%s' is not a plain variable. " "Perhaps use loopy.affine_map_inames() " "to perform substitution." % (def_id, i)) arguments.append(i.name) new_substs[subst_name] = SubstitutionRule( name=subst_name, arguments=tuple(arguments) + extra_arguments, expression=def_insn.expression) # }}} # {{{ delete temporary variable if possible new_temp_vars = kernel.temporary_variables if not any(six.itervalues(tts.saw_unmatched_usage_sites)): # All usage sites matched--they're now substitution rules. # We can get rid of the variable. new_temp_vars = new_temp_vars.copy() del new_temp_vars[temp_name] # }}} import loopy as lp kernel = lp.remove_instructions( kernel, set( insn_id for insn_id, still_used in six.iteritems( tts.saw_unmatched_usage_sites) if not still_used)) return kernel.copy( substitutions=new_substs, temporary_variables=new_temp_vars, )
def buffer_array( kernel, var_name, buffer_inames, init_expression=None, store_expression=None, within=None, default_tag="l.auto", temporary_is_local=None, fetch_bounding_box=False, ): """ :arg init_expression: Either *None* (indicating the prior value of the buffered array should be read) or an expression optionally involving the variable 'base' (which references the associated location in the array being buffered). :arg store_expression: Either *None* or an expression involving variables 'base' and 'buffer' (without array indices). """ # {{{ process arguments if isinstance(init_expression, str): from loopy.symbolic import parse init_expression = parse(init_expression) if isinstance(store_expression, str): from loopy.symbolic import parse store_expression = parse(store_expression) if isinstance(buffer_inames, str): buffer_inames = [s.strip() for s in buffer_inames.split(",") if s.strip()] for iname in buffer_inames: if iname not in kernel.all_inames(): raise RuntimeError("sweep iname '%s' is not a known iname" % iname) buffer_inames = list(buffer_inames) buffer_inames_set = frozenset(buffer_inames) from loopy.context_matching import parse_stack_match within = parse_stack_match(within) if var_name in kernel.arg_dict: var_descr = kernel.arg_dict[var_name] elif var_name in kernel.temporary_variables: var_descr = kernel.temporary_variables[var_name] else: raise ValueError("variable '%s' not found" % var_name) from loopy.kernel.data import ArrayBase if isinstance(var_descr, ArrayBase): var_shape = var_descr.shape else: var_shape = () if temporary_is_local is None: import loopy as lp temporary_is_local = lp.auto # }}} var_name_gen = kernel.get_var_name_generator() within_inames = set() access_descriptors = [] for insn in kernel.instructions: if not within(kernel, insn.id, ()): continue for assignee, index in insn.assignees_and_indices(): if assignee == var_name: within_inames.update((get_dependencies(index) & kernel.all_inames()) - buffer_inames_set) access_descriptors.append(AccessDescriptor(identifier=insn.id, storage_axis_exprs=index)) # {{{ find fetch/store inames init_inames = [] store_inames = [] new_iname_to_tag = {} for i in range(len(var_shape)): init_iname = var_name_gen("%s_init_%d" % (var_name, i)) store_iname = var_name_gen("%s_store_%d" % (var_name, i)) new_iname_to_tag[init_iname] = default_tag new_iname_to_tag[store_iname] = default_tag init_inames.append(init_iname) store_inames.append(store_iname) # }}} # {{{ modify loop domain non1_init_inames = [] non1_store_inames = [] if var_shape: # {{{ find domain to be changed from loopy.kernel.tools import DomainChanger domch = DomainChanger(kernel, buffer_inames_set | within_inames) if domch.leaf_domain_index is not None: # If the sweep inames are at home in parent domains, then we'll add # fetches with loops over copies of these parent inames that will end # up being scheduled *within* loops over these parents. for iname in buffer_inames_set: if kernel.get_home_domain_index(iname) != domch.leaf_domain_index: raise RuntimeError("buffer iname '%s' is not 'at home' in the " "sweep's leaf domain" % iname) # }}} abm = ArrayToBufferMap(kernel, domch.domain, buffer_inames, access_descriptors, len(var_shape)) for i in range(len(var_shape)): if abm.non1_storage_axis_flags[i]: non1_init_inames.append(init_inames[i]) non1_store_inames.append(store_inames[i]) else: del new_iname_to_tag[init_inames[i]] del new_iname_to_tag[store_inames[i]] new_domain = domch.domain new_domain = abm.augment_domain_with_sweep(new_domain, non1_init_inames, boxify_sweep=fetch_bounding_box) new_domain = abm.augment_domain_with_sweep(new_domain, non1_store_inames, boxify_sweep=fetch_bounding_box) new_kernel_domains = domch.get_domains_with(new_domain) del new_domain else: # leave kernel domains unchanged new_kernel_domains = kernel.domains abm = NoOpArrayToBufferMap() # }}} # {{{ set up temp variable import loopy as lp buf_var_name = var_name_gen(based_on=var_name + "_buf") new_temporary_variables = kernel.temporary_variables.copy() temp_var = lp.TemporaryVariable( name=buf_var_name, dtype=var_descr.dtype, base_indices=(0,) * len(abm.non1_storage_shape), shape=tuple(abm.non1_storage_shape), is_local=temporary_is_local, ) new_temporary_variables[buf_var_name] = temp_var # }}} new_insns = [] buf_var = var(buf_var_name) # {{{ generate init instruction buf_var_init = buf_var if non1_init_inames: buf_var_init = buf_var_init.index(tuple(var(iname) for iname in non1_init_inames)) init_base = var(var_name) init_subscript = [] init_iname_idx = 0 if var_shape: for i in range(len(var_shape)): ax_subscript = abm.storage_base_indices[i] if abm.non1_storage_axis_flags[i]: ax_subscript += var(non1_init_inames[init_iname_idx]) init_iname_idx += 1 init_subscript.append(ax_subscript) if init_subscript: init_base = init_base.index(tuple(init_subscript)) if init_expression is None: init_expression = init_base else: init_expression = init_expression init_expression = SubstitutionMapper(make_subst_func({"base": init_base}))(init_expression) init_insn_id = kernel.make_unique_instruction_id(based_on="init_" + var_name) from loopy.kernel.data import ExpressionInstruction init_instruction = ExpressionInstruction( id=init_insn_id, assignee=buf_var_init, expression=init_expression, forced_iname_deps=frozenset(within_inames), insn_deps=frozenset(), insn_deps_is_final=True, ) # }}} rule_mapping_context = SubstitutionRuleMappingContext(kernel.substitutions, kernel.get_var_name_generator()) aar = ArrayAccessReplacer(rule_mapping_context, var_name, within, abm, buf_var) kernel = rule_mapping_context.finish_kernel(aar.map_kernel(kernel)) did_write = False for insn_id in aar.modified_insn_ids: insn = kernel.id_to_insn[insn_id] if any(assignee_name == buf_var_name for assignee_name, _ in insn.assignees_and_indices()): did_write = True # {{{ add init_insn_id to insn_deps new_insns = [] def none_to_empty_set(s): if s is None: return frozenset() else: return s for insn in kernel.instructions: if insn.id in aar.modified_insn_ids: new_insns.append(insn.copy(insn_deps=(none_to_empty_set(insn.insn_deps) | frozenset([init_insn_id])))) else: new_insns.append(insn) # }}} # {{{ generate store instruction buf_var_store = buf_var if non1_store_inames: buf_var_store = buf_var_store.index(tuple(var(iname) for iname in non1_store_inames)) store_subscript = [] store_iname_idx = 0 if var_shape: for i in range(len(var_shape)): ax_subscript = abm.storage_base_indices[i] if abm.non1_storage_axis_flags[i]: ax_subscript += var(non1_store_inames[store_iname_idx]) store_iname_idx += 1 store_subscript.append(ax_subscript) store_target = var(var_name) if store_subscript: store_target = store_target.index(tuple(store_subscript)) if store_expression is None: store_expression = buf_var_store else: store_expression = SubstitutionMapper(make_subst_func({"base": store_target, "buffer": buf_var_store}))( store_expression ) from loopy.kernel.data import ExpressionInstruction store_instruction = ExpressionInstruction( id=kernel.make_unique_instruction_id(based_on="store_" + var_name), insn_deps=frozenset(aar.modified_insn_ids), assignee=store_target, expression=store_expression, forced_iname_deps=frozenset(within_inames), ) # }}} new_insns.append(init_instruction) if did_write: new_insns.append(store_instruction) kernel = kernel.copy( domains=new_kernel_domains, instructions=new_insns, temporary_variables=new_temporary_variables ) from loopy import tag_inames kernel = tag_inames(kernel, new_iname_to_tag) return kernel
def _fuse_two_kernels(knla, knlb): from loopy.kernel import kernel_state if knla.state != kernel_state.INITIAL or knlb.state != kernel_state.INITIAL: raise LoopyError("can only fuse kernels in INITIAL state") # {{{ fuse domains new_domains = knla.domains[:] for dom_b in knlb.domains: i_fuse = _find_fusable_loop_domain_index(dom_b, new_domains) if i_fuse is None: new_domains.append(dom_b) else: dom_a = new_domains[i_fuse] dom_a, dom_b = isl.align_two(dom_a, dom_b) shared_inames = list( set(dom_a.get_var_dict(dim_type.set)) & set(dom_b.get_var_dict(dim_type.set))) dom_a_s = dom_a.project_out_except(shared_inames, [dim_type.set]) dom_b_s = dom_a.project_out_except(shared_inames, [dim_type.set]) if not (dom_a_s <= dom_b_s and dom_b_s <= dom_a_s): raise LoopyError("kernels do not agree on domain of " "inames '%s'" % (",".join(shared_inames))) new_domain = dom_a & dom_b new_domains[i_fuse] = new_domain # }}} vng = knla.get_var_name_generator() b_var_renames = {} # {{{ fuse args new_args = knla.args[:] for b_arg in knlb.args: if b_arg.name not in knla.arg_dict: new_arg_name = vng(b_arg.name) if new_arg_name != b_arg.name: b_var_renames[b_arg.name] = var(new_arg_name) new_args.append(b_arg.copy(name=new_arg_name)) else: if b_arg != knla.arg_dict[b_arg.name]: raise LoopyError( "argument '%s' has inconsistent definition between " "the two kernels being merged" % b_arg.name) # }}} # {{{ fuse temporaries new_temporaries = knla.temporary_variables.copy() for b_name, b_tv in six.iteritems(knlb.temporary_variables): assert b_name == b_tv.name new_tv_name = vng(b_name) if new_tv_name != b_name: b_var_renames[b_name] = var(new_tv_name) assert new_tv_name not in new_temporaries new_temporaries[new_tv_name] = b_tv.copy(name=new_tv_name) # }}} # {{{ apply renames in kernel b from loopy.symbolic import ( SubstitutionRuleMappingContext, RuleAwareSubstitutionMapper) from pymbolic.mapper.substitutor import make_subst_func from loopy.context_matching import parse_stack_match srmc = SubstitutionRuleMappingContext( knlb.substitutions, knlb.get_var_name_generator()) subst_map = RuleAwareSubstitutionMapper( srmc, make_subst_func(b_var_renames), within=parse_stack_match(None)) knlb = subst_map.map_kernel(knlb) # }}} # {{{ fuse instructions new_instructions = knla.instructions[:] from pytools import UniqueNameGenerator insn_id_gen = UniqueNameGenerator( set([insna.id for insna in new_instructions])) knl_b_instructions = [] old_b_id_to_new_b_id = {} for insnb in knlb.instructions: old_id = insnb.id new_id = insn_id_gen(old_id) old_b_id_to_new_b_id[old_id] = new_id knl_b_instructions.append( insnb.copy(id=new_id)) for insnb in knl_b_instructions: new_instructions.append( insnb.copy( insn_deps=frozenset( old_b_id_to_new_b_id[dep_id] for dep_id in insnb.insn_deps))) # }}} # {{{ fuse assumptions assump_a = knla.assumptions assump_b = knlb.assumptions assump_a, assump_b = isl.align_two(assump_a, assump_b) shared_param_names = list( set(dom_a.get_var_dict(dim_type.set)) & set(dom_b.get_var_dict(dim_type.set))) assump_a_s = assump_a.project_out_except(shared_param_names, [dim_type.param]) assump_b_s = assump_a.project_out_except(shared_param_names, [dim_type.param]) if not (assump_a_s <= assump_b_s and assump_b_s <= assump_a_s): raise LoopyError("assumptions do not agree on kernels to be merged") new_assumptions = (assump_a & assump_b).params() # }}} from loopy.kernel import LoopKernel return LoopKernel( domains=new_domains, instructions=new_instructions, args=new_args, name="%s_and_%s" % (knla.name, knlb.name), preambles=_ordered_merge_lists(knla.preambles, knlb.preambles), preamble_generators=_ordered_merge_lists( knla.preamble_generators, knlb.preamble_generators), assumptions=new_assumptions, local_sizes=_merge_dicts( "local size", knla.local_sizes, knlb.local_sizes), temporary_variables=new_temporaries, iname_to_tag=_merge_dicts( "iname-to-tag mapping", knla.iname_to_tag, knlb.iname_to_tag), substitutions=_merge_dicts( "substitution", knla.substitutions, knlb.substitutions), function_manglers=_ordered_merge_lists( knla.function_manglers, knlb.function_manglers), symbol_manglers=_ordered_merge_lists( knla.symbol_manglers, knlb.symbol_manglers), iname_slab_increments=_merge_dicts( "iname slab increment", knla.iname_slab_increments, knlb.iname_slab_increments), loop_priority=_ordered_merge_lists( knla.loop_priority, knlb.loop_priority), silenced_warnings=_ordered_merge_lists( knla.silenced_warnings, knlb.silenced_warnings), applied_iname_rewrites=_ordered_merge_lists( knla.applied_iname_rewrites, knlb.applied_iname_rewrites), index_dtype=_merge_values( "index dtype", knla.index_dtype, knlb.index_dtype), target=_merge_values( "target", knla.target, knlb.target), options=knla.options)
def link_inames(knl, inames, new_iname, within=None, tag=None): # {{{ normalize arguments if isinstance(inames, str): inames = inames.split(",") var_name_gen = knl.get_var_name_generator() new_iname = var_name_gen(new_iname) # }}} # {{{ ensure that each iname is used at most once in each instruction inames_set = set(inames) if 0: # FIXME! for insn in knl.instructions: insn_inames = knl.insn_inames(insn.id) | insn.reduction_inames() if len(insn_inames & inames_set) > 1: raise LoopyError("To-be-linked inames '%s' are used in " "instruction '%s'. No more than one such iname can " "be used in one instruction." % (", ".join(insn_inames & inames_set), insn.id)) # }}} from loopy.kernel.tools import DomainChanger domch = DomainChanger(knl, tuple(inames)) # {{{ ensure that projections are identical unrelated_dom_inames = list( set(domch.domain.get_var_names(dim_type.set)) - inames_set) domain = domch.domain # move all inames to be linked to end to prevent shuffly confusion for iname in inames: dt, index = domain.get_var_dict()[iname] assert dt == dim_type.set # move to tail of param dim_type domain = domain.move_dims( dim_type.param, domain.dim(dim_type.param), dt, index, 1) # move to tail of set dim_type domain = domain.move_dims( dim_type.set, domain.dim(dim_type.set), dim_type.param, domain.dim(dim_type.param)-1, 1) projections = [ domch.domain.project_out_except( unrelated_dom_inames + [iname], [dim_type.set]) for iname in inames] all_equal = True first_proj = projections[0] for proj in projections[1:]: all_equal = all_equal and (proj <= first_proj and first_proj <= proj) if not all_equal: raise LoopyError("Inames cannot be linked because their domain " "constraints are not the same.") del domain # messed up for testing, do not use # }}} # change the domain from loopy.isl_helpers import duplicate_axes knl = knl.copy( domains=domch.get_domains_with( duplicate_axes(domch.domain, [inames[0]], [new_iname]))) # {{{ change the code from pymbolic import var subst_dict = dict((iname, var(new_iname)) for iname in inames) from loopy.context_matching import parse_stack_match within = parse_stack_match(within) from pymbolic.mapper.substitutor import make_subst_func rule_mapping_context = SubstitutionRuleMappingContext( knl.substitutions, var_name_gen) ijoin = RuleAwareSubstitutionMapper(rule_mapping_context, make_subst_func(subst_dict), within) knl = rule_mapping_context.finish_kernel( ijoin.map_kernel(knl)) # }}} knl = remove_unused_inames(knl, inames) if tag is not None: knl = tag_inames(knl, {new_iname: tag}) return knl
def affine_map_inames(kernel, old_inames, new_inames, equations): """Return a new *kernel* where the affine transform specified by *equations* has been applied to the inames. :arg old_inames: A list of inames to be replaced by affine transforms of their values. May also be a string of comma-separated inames. :arg new_inames: A list of new inames that are not yet used in *kernel*, but have their values established in terms of *old_inames* by *equations*. May also be a string of comma-separated inames. :arg equations: A list of equations estabilishing a relationship between *old_inames* and *new_inames*. Each equation may be a tuple ``(lhs, rhs)`` of expressions or a string, with left and right hand side of the equation separated by ``=``. """ # {{{ check and parse arguments if isinstance(new_inames, str): new_inames = new_inames.split(",") new_inames = [iname.strip() for iname in new_inames] if isinstance(old_inames, str): old_inames = old_inames.split(",") old_inames = [iname.strip() for iname in old_inames] if isinstance(equations, str): equations = [equations] import re eqn_re = re.compile(r"^([^=]+)=([^=]+)$") def parse_equation(eqn): if isinstance(eqn, str): eqn_match = eqn_re.match(eqn) if not eqn_match: raise ValueError("invalid equation: %s" % eqn) from loopy.symbolic import parse lhs = parse(eqn_match.group(1)) rhs = parse(eqn_match.group(2)) return (lhs, rhs) elif isinstance(eqn, tuple): if len(eqn) != 2: raise ValueError("unexpected length of equation tuple, " "got %d, should be 2" % len(eqn)) return eqn else: raise ValueError("unexpected type of equation" "got %d, should be string or tuple" % type(eqn).__name__) equations = [parse_equation(eqn) for eqn in equations] all_vars = kernel.all_variable_names() for iname in new_inames: if iname in all_vars: raise LoopyError("new iname '%s' is already used in kernel" % iname) for iname in old_inames: if iname not in kernel.all_inames(): raise LoopyError("old iname '%s' not known" % iname) # }}} # {{{ substitute iname use from pymbolic.algorithm import solve_affine_equations_for old_inames_to_expr = solve_affine_equations_for(old_inames, equations) subst_dict = dict( (v.name, expr) for v, expr in old_inames_to_expr.items()) var_name_gen = kernel.get_var_name_generator() from pymbolic.mapper.substitutor import make_subst_func from loopy.context_matching import parse_stack_match rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, var_name_gen) old_to_new = RuleAwareSubstitutionMapper(rule_mapping_context, make_subst_func(subst_dict), within=parse_stack_match(None)) kernel = ( rule_mapping_context.finish_kernel( old_to_new.map_kernel(kernel)) .copy( applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict] )) # }}} # {{{ change domains new_inames_set = set(new_inames) old_inames_set = set(old_inames) new_domains = [] for idom, dom in enumerate(kernel.domains): dom_var_dict = dom.get_var_dict() old_iname_overlap = [ iname for iname in old_inames if iname in dom_var_dict] if not old_iname_overlap: new_domains.append(dom) continue from loopy.symbolic import get_dependencies dom_new_inames = set() dom_old_inames = set() # mapping for new inames to dim_types new_iname_dim_types = {} dom_equations = [] for iname in old_iname_overlap: for ieqn, (lhs, rhs) in enumerate(equations): eqn_deps = get_dependencies(lhs) | get_dependencies(rhs) if iname in eqn_deps: dom_new_inames.update(eqn_deps & new_inames_set) dom_old_inames.update(eqn_deps & old_inames_set) if dom_old_inames: dom_equations.append((lhs, rhs)) this_eqn_old_iname_dim_types = set( dom_var_dict[old_iname][0] for old_iname in eqn_deps & old_inames_set) if this_eqn_old_iname_dim_types: if len(this_eqn_old_iname_dim_types) > 1: raise ValueError("inames '%s' (from equation %d (0-based)) " "in domain %d (0-based) are not " "of a uniform dim_type" % (", ".join(eqn_deps & old_inames_set), ieqn, idom)) this_eqn_new_iname_dim_type, = this_eqn_old_iname_dim_types for new_iname in eqn_deps & new_inames_set: if new_iname in new_iname_dim_types: if (this_eqn_new_iname_dim_type != new_iname_dim_types[new_iname]): raise ValueError("dim_type disagreement for " "iname '%s' (from equation %d (0-based)) " "in domain %d (0-based)" % (new_iname, ieqn, idom)) else: new_iname_dim_types[new_iname] = \ this_eqn_new_iname_dim_type if not dom_old_inames <= set(dom_var_dict): raise ValueError("domain %d (0-based) does not know about " "all old inames (specifically '%s') needed to define new inames" % (idom, ", ".join(dom_old_inames - set(dom_var_dict)))) # add inames to domain with correct dim_types dom_new_inames = list(dom_new_inames) for iname in dom_new_inames: dt = new_iname_dim_types[iname] iname_idx = dom.dim(dt) dom = dom.add_dims(dt, 1) dom = dom.set_dim_name(dt, iname_idx, iname) # add equations from loopy.symbolic import aff_from_expr for lhs, rhs in dom_equations: dom = dom.add_constraint( isl.Constraint.equality_from_aff( aff_from_expr(dom.space, rhs - lhs))) # project out old inames for iname in dom_old_inames: dt, idx = dom.get_var_dict()[iname] dom = dom.project_out(dt, idx, 1) new_domains.append(dom) # }}} return kernel.copy(domains=new_domains)
def rename_iname(knl, old_iname, new_iname, existing_ok=False, within=None): """ :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. :arg existing_ok: execute even if *new_iname* already exists """ var_name_gen = knl.get_var_name_generator() does_exist = var_name_gen.is_name_conflicting(new_iname) if does_exist and not existing_ok: raise ValueError("iname '%s' conflicts with an existing identifier" "--cannot rename" % new_iname) if does_exist: # {{{ check that the domains match up dom = knl.get_inames_domain(frozenset((old_iname, new_iname))) var_dict = dom.get_var_dict() _, old_idx = var_dict[old_iname] _, new_idx = var_dict[new_iname] par_idx = dom.dim(dim_type.param) dom_old = dom.move_dims( dim_type.param, par_idx, dim_type.set, old_idx, 1) dom_old = dom_old.move_dims( dim_type.set, dom_old.dim(dim_type.set), dim_type.param, par_idx, 1) dom_old = dom_old.project_out( dim_type.set, new_idx if new_idx < old_idx else new_idx - 1, 1) par_idx = dom.dim(dim_type.param) dom_new = dom.move_dims( dim_type.param, par_idx, dim_type.set, new_idx, 1) dom_new = dom_new.move_dims( dim_type.set, dom_new.dim(dim_type.set), dim_type.param, par_idx, 1) dom_new = dom_new.project_out( dim_type.set, old_idx if old_idx < new_idx else old_idx - 1, 1) if not (dom_old <= dom_new and dom_new <= dom_old): raise LoopyError( "inames {old} and {new} do not iterate over the same domain" .format(old=old_iname, new=new_iname)) # }}} from pymbolic import var subst_dict = {old_iname: var(new_iname)} from loopy.context_matching import parse_stack_match within = parse_stack_match(within) from pymbolic.mapper.substitutor import make_subst_func rule_mapping_context = SubstitutionRuleMappingContext( knl.substitutions, var_name_gen) ijoin = RuleAwareSubstitutionMapper(rule_mapping_context, make_subst_func(subst_dict), within) knl = rule_mapping_context.finish_kernel( ijoin.map_kernel(knl)) new_instructions = [] for insn in knl.instructions: if (old_iname in insn.forced_iname_deps and within(knl, insn, ())): insn = insn.copy( forced_iname_deps=( (insn.forced_iname_deps - frozenset([old_iname])) | frozenset([new_iname]))) new_instructions.append(insn) knl = knl.copy(instructions=new_instructions) else: knl = duplicate_inames( knl, [old_iname], within=within, new_inames=[new_iname]) knl = remove_unused_inames(knl, [old_iname]) return knl
def duplicate_inames(knl, inames, within, new_inames=None, suffix=None, tags={}): """ :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. """ # {{{ normalize arguments, find unique new_inames if isinstance(inames, str): inames = [iname.strip() for iname in inames.split(",")] if isinstance(new_inames, str): new_inames = [iname.strip() for iname in new_inames.split(",")] from loopy.context_matching import parse_stack_match within = parse_stack_match(within) if new_inames is None: new_inames = [None] * len(inames) if len(new_inames) != len(inames): raise ValueError("new_inames must have the same number of entries as inames") name_gen = knl.get_var_name_generator() for i, iname in enumerate(inames): new_iname = new_inames[i] if new_iname is None: new_iname = iname if suffix is not None: new_iname += suffix new_iname = name_gen(new_iname) else: if name_gen.is_name_conflicting(new_iname): raise ValueError("new iname '%s' conflicts with existing names" % new_iname) name_gen.add_name(new_iname) new_inames[i] = new_iname # }}} # {{{ duplicate the inames for old_iname, new_iname in zip(inames, new_inames): from loopy.kernel.tools import DomainChanger domch = DomainChanger(knl, frozenset([old_iname])) from loopy.isl_helpers import duplicate_axes knl = knl.copy( domains=domch.get_domains_with( duplicate_axes(domch.domain, [old_iname], [new_iname]))) # }}} # {{{ change the inames in the code rule_mapping_context = SubstitutionRuleMappingContext( knl.substitutions, name_gen) indup = _InameDuplicator(rule_mapping_context, old_to_new=dict(list(zip(inames, new_inames))), within=within) knl = rule_mapping_context.finish_kernel( indup.map_kernel(knl)) # }}} # {{{ realize tags for old_iname, new_iname in zip(inames, new_inames): new_tag = tags.get(old_iname) if new_tag is not None: knl = tag_inames(knl, {new_iname: new_tag}) # }}} return knl
def join_inames(kernel, inames, new_iname=None, tag=None, within=None): """ :arg inames: fastest varying last :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. """ # now fastest varying first inames = inames[::-1] if new_iname is None: new_iname = kernel.get_var_name_generator()("_and_".join(inames)) from loopy.kernel.tools import DomainChanger domch = DomainChanger(kernel, frozenset(inames)) for iname in inames: if kernel.get_home_domain_index(iname) != domch.leaf_domain_index: raise LoopyError("iname '%s' is not 'at home' in the " "join's leaf domain" % iname) new_domain = domch.domain new_dim_idx = new_domain.dim(dim_type.set) new_domain = new_domain.add_dims(dim_type.set, 1) new_domain = new_domain.set_dim_name(dim_type.set, new_dim_idx, new_iname) joint_aff = zero = isl.Aff.zero_on_domain(new_domain.space) subst_dict = {} base_divisor = 1 from pymbolic import var for i, iname in enumerate(inames): iname_dt, iname_idx = zero.get_space().get_var_dict()[iname] iname_aff = zero.add_coefficient_val(iname_dt, iname_idx, 1) joint_aff = joint_aff + base_divisor*iname_aff bounds = kernel.get_iname_bounds(iname, constants_only=True) from loopy.isl_helpers import ( static_max_of_pw_aff, static_value_of_pw_aff) from loopy.symbolic import pw_aff_to_expr length = int(pw_aff_to_expr( static_max_of_pw_aff(bounds.size, constants_only=True))) try: lower_bound_aff = static_value_of_pw_aff( bounds.lower_bound_pw_aff.coalesce(), constants_only=False) except Exception as e: raise type(e)("while finding lower bound of '%s': " % iname) my_val = var(new_iname) // base_divisor if i+1 < len(inames): my_val %= length my_val += pw_aff_to_expr(lower_bound_aff) subst_dict[iname] = my_val base_divisor *= length from loopy.isl_helpers import iname_rel_aff new_domain = new_domain.add_constraint( isl.Constraint.equality_from_aff( iname_rel_aff(new_domain.get_space(), new_iname, "==", joint_aff))) for i, iname in enumerate(inames): iname_to_dim = new_domain.get_space().get_var_dict() iname_dt, iname_idx = iname_to_dim[iname] if within is None: new_domain = new_domain.project_out(iname_dt, iname_idx, 1) def subst_forced_iname_deps(fid): result = set() for iname in fid: if iname in inames: result.add(new_iname) else: result.add(iname) return frozenset(result) new_insns = [ insn.copy( forced_iname_deps=subst_forced_iname_deps(insn.forced_iname_deps)) for insn in kernel.instructions] kernel = (kernel .copy( instructions=new_insns, domains=domch.get_domains_with(new_domain), applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict] )) from loopy.context_matching import parse_stack_match within = parse_stack_match(within) from pymbolic.mapper.substitutor import make_subst_func rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) ijoin = _InameJoiner(rule_mapping_context, within, make_subst_func(subst_dict), inames, new_iname) kernel = rule_mapping_context.finish_kernel( ijoin.map_kernel(kernel)) if tag is not None: kernel = tag_inames(kernel, {new_iname: tag}) return kernel
def split_iname(kernel, split_iname, inner_length, outer_iname=None, inner_iname=None, outer_tag=None, inner_tag=None, slabs=(0, 0), do_tagged_check=True, within=None): """ :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. """ existing_tag = kernel.iname_to_tag.get(split_iname) from loopy.kernel.data import ForceSequentialTag if do_tagged_check and ( existing_tag is not None and not isinstance(existing_tag, ForceSequentialTag)): raise LoopyError("cannot split already tagged iname '%s'" % split_iname) if split_iname not in kernel.all_inames(): raise ValueError("cannot split loop for unknown variable '%s'" % split_iname) applied_iname_rewrites = kernel.applied_iname_rewrites[:] vng = kernel.get_var_name_generator() if outer_iname is None: outer_iname = vng(split_iname+"_outer") if inner_iname is None: inner_iname = vng(split_iname+"_inner") def process_set(s): var_dict = s.get_var_dict() if split_iname not in var_dict: return s orig_dim_type, _ = var_dict[split_iname] outer_var_nr = s.dim(orig_dim_type) inner_var_nr = s.dim(orig_dim_type)+1 s = s.add_dims(orig_dim_type, 2) s = s.set_dim_name(orig_dim_type, outer_var_nr, outer_iname) s = s.set_dim_name(orig_dim_type, inner_var_nr, inner_iname) from loopy.isl_helpers import make_slab space = s.get_space() inner_constraint_set = ( make_slab(space, inner_iname, 0, inner_length) # name = inner + length*outer .add_constraint(isl.Constraint.eq_from_names( space, { split_iname: 1, inner_iname: -1, outer_iname: -inner_length}))) name_dim_type, name_idx = space.get_var_dict()[split_iname] s = s.intersect(inner_constraint_set) if within is None: s = s.project_out(name_dim_type, name_idx, 1) return s new_domains = [process_set(dom) for dom in kernel.domains] from pymbolic import var inner = var(inner_iname) outer = var(outer_iname) new_loop_index = inner + outer*inner_length subst_map = {var(split_iname): new_loop_index} applied_iname_rewrites.append(subst_map) # {{{ update forced_iname deps new_insns = [] for insn in kernel.instructions: if split_iname in insn.forced_iname_deps: new_forced_iname_deps = ( (insn.forced_iname_deps.copy() - frozenset([split_iname])) | frozenset([outer_iname, inner_iname])) else: new_forced_iname_deps = insn.forced_iname_deps insn = insn.copy( forced_iname_deps=new_forced_iname_deps) new_insns.append(insn) # }}} iname_slab_increments = kernel.iname_slab_increments.copy() iname_slab_increments[outer_iname] = slabs new_loop_priority = [] for prio_iname in kernel.loop_priority: if prio_iname == split_iname: new_loop_priority.append(outer_iname) new_loop_priority.append(inner_iname) else: new_loop_priority.append(prio_iname) kernel = kernel.copy( domains=new_domains, iname_slab_increments=iname_slab_increments, instructions=new_insns, applied_iname_rewrites=applied_iname_rewrites, loop_priority=new_loop_priority) from loopy.context_matching import parse_stack_match within = parse_stack_match(within) rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) ins = _InameSplitter(rule_mapping_context, within, split_iname, outer_iname, inner_iname, new_loop_index) kernel = ins.map_kernel(kernel) kernel = rule_mapping_context.finish_kernel(kernel) if existing_tag is not None: kernel = tag_inames(kernel, {outer_iname: existing_tag, inner_iname: existing_tag}) return tag_inames(kernel, {outer_iname: outer_tag, inner_iname: inner_tag})
def _fix_parameter(kernel, name, value): def process_set(s): var_dict = s.get_var_dict() try: dt, idx = var_dict[name] except KeyError: return s value_aff = isl.Aff.zero_on_domain(s.space) + value from loopy.isl_helpers import iname_rel_aff name_equal_value_aff = iname_rel_aff(s.space, name, "==", value_aff) s = (s .add_constraint( isl.Constraint.equality_from_aff(name_equal_value_aff)) .project_out(dt, idx, 1)) return s new_domains = [process_set(dom) for dom in kernel.domains] from pymbolic.mapper.substitutor import make_subst_func subst_func = make_subst_func({name: value}) from loopy.symbolic import SubstitutionMapper, PartialEvaluationMapper subst_map = SubstitutionMapper(subst_func) ev_map = PartialEvaluationMapper() def map_expr(expr): return ev_map(subst_map(expr)) from loopy.kernel.array import ArrayBase new_args = [] for arg in kernel.args: if arg.name == name: # remove from argument list continue if not isinstance(arg, ArrayBase): new_args.append(arg) else: new_args.append(arg.map_exprs(map_expr)) new_temp_vars = {} for tv in six.itervalues(kernel.temporary_variables): new_temp_vars[tv.name] = tv.map_exprs(map_expr) from loopy.context_matching import parse_stack_match within = parse_stack_match(None) rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) esubst_map = RuleAwareSubstitutionMapper( rule_mapping_context, subst_func, within=within) return ( rule_mapping_context.finish_kernel( esubst_map.map_kernel(kernel)) .copy( domains=new_domains, args=new_args, temporary_variables=new_temp_vars, assumptions=process_set(kernel.assumptions), ))