def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within): super(RuleInvocationGatherer, self).__init__(rule_mapping_context) from loopy.symbolic import SubstitutionRuleExpander self.subst_expander = SubstitutionRuleExpander(kernel.substitutions) self.kernel = kernel self.subst_name = subst_name self.subst_tag = subst_tag self.within = within self.access_descriptors = []
def precompute( kernel, subst_use, sweep_inames=[], within=None, storage_axes=None, temporary_name=None, precompute_inames=None, precompute_outer_inames=None, storage_axis_to_tag={}, # "None" is a valid value here, distinct from the default. default_tag=_not_provided, dtype=None, fetch_bounding_box=False, temporary_address_space=None, compute_insn_id=None, **kwargs): """Precompute the expression described in the substitution rule determined by *subst_use* and store it in a temporary array. A precomputation needs two things to operate, a list of *sweep_inames* (order irrelevant) and an ordered list of *storage_axes* (whose order will describe the axis ordering of the temporary array). :arg subst_use: Describes what to prefetch. The following objects may be given for *subst_use*: * The name of the substitution rule. * The tagged name ("name$tag") of the substitution rule. * A list of invocations of the substitution rule. This list of invocations, when swept across *sweep_inames*, then serves to define the footprint of the precomputation. Invocations may be tagged ("name$tag") to filter out a subset of the usage sites of the substitution rule. (Namely those usage sites that use the same tagged name.) Invocations may be given as a string or as a :class:`pymbolic.primitives.Expression` object. If only one invocation is to be given, then the only entry of the list may be given directly. If the list of invocations generating the footprint is not given, all (tag-matching, if desired) usage sites of the substitution rule are used to determine the footprint. The following cases can arise for each sweep axis: * The axis is an iname that occurs within arguments specified at usage sites of the substitution rule. This case is assumed covered by the storage axes provided for the argument. * The axis is an iname that occurs within the *value* of the rule, but not within its arguments. A new, dedicated storage axis is allocated for such an axis. :arg sweep_inames: A :class:`list` of inames to be swept. May also equivalently be a comma-separated string. :arg within: a stack match as understood by :func:`loopy.match.parse_stack_match`. :arg storage_axes: A :class:`list` of inames and/or rule argument names/indices to be used as storage axes. May also equivalently be a comma-separated string. :arg temporary_name: The temporary variable name to use for storing the precomputed data. If it does not exist, it will be created. If it does exist, its properties (such as size, type) are checked (and updated, if possible) to match its use. :arg precompute_inames: A tuple of inames to be used to carry out the precomputation. If the specified inames do not already exist, they will be created. If they do already exist, their loop domain is verified against the one required for this precomputation. This tuple may be shorter than the (provided or automatically found) *storage_axes* tuple, in which case names will be automatically created. May also equivalently be a comma-separated string. :arg precompute_outer_inames: A :class:`frozenset` of inames within which the compute instruction is nested. If *None*, make an educated guess. May also be specified as a comma-separated string. :arg default_tag: The :ref:`iname tag <iname-tags>` to be applied to the inames created to perform the precomputation. The current default will make them local axes and automatically split them to fit the work group size, but this default will disappear in favor of simply leaving them untagged in 2019. For 2018, a warning will be issued if no *default_tag* is specified. :arg compute_insn_id: The ID of the instruction generated to perform the precomputation. If `storage_axes` is not specified, it defaults to the arrangement `<direct sweep axes><arguments>` with the direct sweep axes being the slower-varying indices. Trivial storage axes (i.e. axes of length 1 with respect to the sweep) are eliminated. """ # {{{ unify temporary_address_space / temporary_scope temporary_scope = kwargs.pop("temporary_scope", None) from loopy.kernel.data import AddressSpace if temporary_scope is not None: from warnings import warn warn( "temporary_scope is deprecated. Use temporary_address_space instead", DeprecationWarning, stacklevel=2) if temporary_address_space is not None: raise LoopyError( "may not specify both temporary_address_space and " "temporary_scope") temporary_address_space = temporary_scope del temporary_scope # }}} if kwargs: raise TypeError("unrecognized keyword arguments: %s" % ", ".join(kwargs.keys())) # {{{ check, standardize arguments if isinstance(sweep_inames, str): sweep_inames = [iname.strip() for iname in sweep_inames.split(",")] for iname in sweep_inames: if iname not in kernel.all_inames(): raise RuntimeError("sweep iname '%s' is not a known iname" % iname) sweep_inames = list(sweep_inames) sweep_inames_set = frozenset(sweep_inames) if isinstance(storage_axes, str): storage_axes = [ax.strip() for ax in storage_axes.split(",")] if isinstance(precompute_inames, str): precompute_inames = [ iname.strip() for iname in precompute_inames.split(",") ] if isinstance(precompute_outer_inames, str): precompute_outer_inames = frozenset( iname.strip() for iname in precompute_outer_inames.split(",")) if isinstance(subst_use, str): subst_use = [subst_use] footprint_generators = None subst_name = None subst_tag = None from pymbolic.primitives import Variable, Call from loopy.symbolic import parse, TaggedVariable for use in subst_use: if isinstance(use, str): use = parse(use) if isinstance(use, Call): if footprint_generators is None: footprint_generators = [] footprint_generators.append(use) subst_name_as_expr = use.function else: subst_name_as_expr = use if isinstance(subst_name_as_expr, TaggedVariable): new_subst_name = subst_name_as_expr.name new_subst_tag = subst_name_as_expr.tag elif isinstance(subst_name_as_expr, Variable): new_subst_name = subst_name_as_expr.name new_subst_tag = None else: raise ValueError("unexpected type of subst_name") if (subst_name, subst_tag) == (None, None): subst_name, subst_tag = new_subst_name, new_subst_tag else: if (subst_name, subst_tag) != (new_subst_name, new_subst_tag): raise ValueError("not all uses in subst_use agree " "on rule name and tag") from loopy.match import parse_stack_match within = parse_stack_match(within) try: subst = kernel.substitutions[subst_name] except KeyError: raise LoopyError("substitution rule '%s' not found" % subst_name) c_subst_name = subst_name.replace(".", "_") # {{{ handle default_tag from loopy.transform.data import _not_provided \ as transform_data_not_provided if default_tag is _not_provided or default_tag is transform_data_not_provided: # no need to warn for scalar precomputes if sweep_inames: from warnings import warn warn( "Not specifying default_tag is deprecated, and default_tag " "will become mandatory in 2019.x. " "Pass 'default_tag=\"l.auto\" to match the current default, " "or Pass 'default_tag=None to leave the loops untagged, which " "is the recommended behavior.", DeprecationWarning, stacklevel=( # In this case, we came here through add_prefetch. Increase # the stacklevel. 3 if default_tag is transform_data_not_provided else 2)) default_tag = "l.auto" from loopy.kernel.data import parse_tag default_tag = parse_tag(default_tag) # }}} # }}} # {{{ process invocations in footprint generators, start access_descriptors if footprint_generators: from pymbolic.primitives import Variable, Call access_descriptors = [] for fpg in footprint_generators: if isinstance(fpg, Variable): args = () elif isinstance(fpg, Call): args = fpg.parameters else: raise ValueError("footprint generator must " "be substitution rule invocation") access_descriptors.append( RuleAccessDescriptor(identifier=access_descriptor_id( args, None), args=args)) # }}} # {{{ gather up invocations in kernel code, finish access_descriptors if not footprint_generators: rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) invg = RuleInvocationGatherer(rule_mapping_context, kernel, subst_name, subst_tag, within) del rule_mapping_context import loopy as lp for insn in kernel.instructions: if isinstance(insn, lp.MultiAssignmentBase): for assignee in insn.assignees: invg(assignee, kernel, insn) invg(insn.expression, kernel, insn) access_descriptors = invg.access_descriptors if not access_descriptors: raise RuntimeError("no invocations of '%s' found" % subst_name) # }}} # {{{ find inames used in arguments expanding_usage_arg_deps = set() for accdesc in access_descriptors: for arg in accdesc.args: expanding_usage_arg_deps.update( get_dependencies(arg) & kernel.all_inames()) # }}} var_name_gen = kernel.get_var_name_generator() # {{{ use given / find new storage_axes # extra axes made necessary because they don't occur in the arguments extra_storage_axes = set(sweep_inames_set - expanding_usage_arg_deps) from loopy.symbolic import SubstitutionRuleExpander submap = SubstitutionRuleExpander(kernel.substitutions) value_inames = (get_dependencies(submap(subst.expression)) - frozenset(subst.arguments)) & kernel.all_inames() if value_inames - expanding_usage_arg_deps < extra_storage_axes: raise RuntimeError("unreferenced sweep inames specified: " + ", ".join(extra_storage_axes - value_inames - expanding_usage_arg_deps)) new_iname_to_tag = {} if storage_axes is None: storage_axes = [] # Add sweep_inames (in given--rather than arbitrary--order) to # storage_axes *if* they are part of extra_storage_axes. for iname in sweep_inames: if iname in extra_storage_axes: extra_storage_axes.remove(iname) storage_axes.append(iname) if extra_storage_axes: if (precompute_inames is not None and len(storage_axes) < len(precompute_inames)): raise LoopyError( "must specify a sufficient number of " "storage_axes to uniquely determine the meaning " "of the given precompute_inames. (%d storage_axes " "needed)" % len(precompute_inames)) storage_axes.extend(sorted(extra_storage_axes)) storage_axes.extend(range(len(subst.arguments))) del extra_storage_axes prior_storage_axis_name_dict = {} storage_axis_names = [] storage_axis_sources = [] # number for arg#, or iname # {{{ check for pre-existing precompute_inames if precompute_inames is not None: preexisting_precompute_inames = (set(precompute_inames) & kernel.all_inames()) else: preexisting_precompute_inames = set() # }}} for i, saxis in enumerate(storage_axes): tag_lookup_saxis = saxis if saxis in subst.arguments: saxis = subst.arguments.index(saxis) storage_axis_sources.append(saxis) if isinstance(saxis, int): # argument index name = old_name = subst.arguments[saxis] else: old_name = saxis name = "%s_%s" % (c_subst_name, old_name) if (precompute_inames is not None and i < len(precompute_inames) and precompute_inames[i]): name = precompute_inames[i] tag_lookup_saxis = name if (name not in preexisting_precompute_inames and var_name_gen.is_name_conflicting(name)): raise RuntimeError("new storage axis name '%s' " "conflicts with existing name" % name) else: name = var_name_gen(name) storage_axis_names.append(name) if name not in preexisting_precompute_inames: new_iname_to_tag[name] = storage_axis_to_tag.get( tag_lookup_saxis, default_tag) prior_storage_axis_name_dict[name] = old_name del storage_axis_to_tag del storage_axes del precompute_inames # }}} # {{{ fill out access_descriptors[...].storage_axis_exprs access_descriptors = [ accdesc.copy(storage_axis_exprs=storage_axis_exprs( storage_axis_sources, accdesc.args)) for accdesc in access_descriptors ] # }}} expanding_inames = sweep_inames_set | frozenset(expanding_usage_arg_deps) assert expanding_inames <= kernel.all_inames() if storage_axis_names: # {{{ find domain to be changed change_inames = expanding_inames | preexisting_precompute_inames from loopy.kernel.tools import DomainChanger domch = DomainChanger(kernel, change_inames) if domch.leaf_domain_index is not None: # If the sweep inames are at home in parent domains, then we'll add # fetches with loops over copies of these parent inames that will end # up being scheduled *within* loops over these parents. for iname in sweep_inames_set: if kernel.get_home_domain_index( iname) != domch.leaf_domain_index: raise RuntimeError( "sweep iname '%s' is not 'at home' in the " "sweep's leaf domain" % iname) # }}} abm = ArrayToBufferMap(kernel, domch.domain, sweep_inames, access_descriptors, len(storage_axis_names)) non1_storage_axis_names = [] for i, saxis in enumerate(storage_axis_names): if abm.non1_storage_axis_flags[i]: non1_storage_axis_names.append(saxis) else: del new_iname_to_tag[saxis] if saxis in preexisting_precompute_inames: raise LoopyError( "precompute axis %d (1-based) was " "eliminated as " "having length 1 but also mapped to existing " "iname '%s'" % (i + 1, saxis)) mod_domain = domch.domain # {{{ modify the domain, taking into account preexisting inames # inames may already exist in mod_domain, add them primed to start primed_non1_saxis_names = [ iname + "'" for iname in non1_storage_axis_names ] mod_domain = abm.augment_domain_with_sweep( domch.domain, primed_non1_saxis_names, boxify_sweep=fetch_bounding_box) check_domain = mod_domain for i, saxis in enumerate(non1_storage_axis_names): var_dict = mod_domain.get_var_dict(isl.dim_type.set) if saxis in preexisting_precompute_inames: # add equality constraint between existing and new variable dt, dim_idx = var_dict[saxis] saxis_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx) dt, dim_idx = var_dict[primed_non1_saxis_names[i]] new_var_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx) mod_domain = mod_domain.add_constraint( isl.Constraint.equality_from_aff(new_var_aff - saxis_aff)) # project out the new one mod_domain = mod_domain.project_out(dt, dim_idx, 1) else: # remove the prime from the new variable dt, dim_idx = var_dict[primed_non1_saxis_names[i]] mod_domain = mod_domain.set_dim_name(dt, dim_idx, saxis) def add_assumptions(d): assumption_non_param = isl.BasicSet.from_params(kernel.assumptions) assumptions, domain = isl.align_two(assumption_non_param, d) return assumptions & domain # {{{ check that we got the desired domain check_domain = add_assumptions( check_domain.project_out_except(primed_non1_saxis_names, [isl.dim_type.set])) mod_check_domain = add_assumptions(mod_domain) # re-add the prime from the new variable var_dict = mod_check_domain.get_var_dict(isl.dim_type.set) for saxis in non1_storage_axis_names: dt, dim_idx = var_dict[saxis] mod_check_domain = mod_check_domain.set_dim_name( dt, dim_idx, saxis + "'") mod_check_domain = mod_check_domain.project_out_except( primed_non1_saxis_names, [isl.dim_type.set]) mod_check_domain, check_domain = isl.align_two(mod_check_domain, check_domain) # The modified domain can't get bigger by adding constraints assert mod_check_domain <= check_domain if not check_domain <= mod_check_domain: print(check_domain) print(mod_check_domain) raise LoopyError("domain of preexisting inames does not match " "domain needed for precompute") # }}} # {{{ check that we didn't shrink the original domain # project out the new names from the modified domain orig_domain_inames = list(domch.domain.get_var_dict(isl.dim_type.set)) mod_check_domain = add_assumptions( mod_domain.project_out_except(orig_domain_inames, [isl.dim_type.set])) check_domain = add_assumptions(domch.domain) mod_check_domain, check_domain = isl.align_two(mod_check_domain, check_domain) # The modified domain can't get bigger by adding constraints assert mod_check_domain <= check_domain if not check_domain <= mod_check_domain: print(check_domain) print(mod_check_domain) raise LoopyError( "original domain got shrunk by applying the precompute") # }}} # }}} new_kernel_domains = domch.get_domains_with(mod_domain) else: # leave kernel domains unchanged new_kernel_domains = kernel.domains non1_storage_axis_names = [] abm = NoOpArrayToBufferMap() kernel = kernel.copy(domains=new_kernel_domains) # {{{ set up compute insn if temporary_name is None: temporary_name = var_name_gen(based_on=c_subst_name) assignee = var(temporary_name) if non1_storage_axis_names: assignee = assignee[tuple( var(iname) for iname in non1_storage_axis_names)] # {{{ process substitutions on compute instruction storage_axis_subst_dict = {} for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices): if arg_name in non1_storage_axis_names: arg = var(arg_name) else: arg = 0 storage_axis_subst_dict[prior_storage_axis_name_dict.get( arg_name, arg_name)] = arg + bi rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) from loopy.match import parse_stack_match expr_subst_map = RuleAwareSubstitutionMapper( rule_mapping_context, make_subst_func(storage_axis_subst_dict), within=parse_stack_match(None)) compute_expression = expr_subst_map(subst.expression, kernel, None) # }}} from loopy.kernel.data import Assignment if compute_insn_id is None: compute_insn_id = kernel.make_unique_instruction_id( based_on=c_subst_name) compute_insn = Assignment( id=compute_insn_id, assignee=assignee, expression=compute_expression, # within_inames determined below ) compute_dep_id = compute_insn_id added_compute_insns = [compute_insn] if temporary_address_space == AddressSpace.GLOBAL: barrier_insn_id = kernel.make_unique_instruction_id( based_on=c_subst_name + "_barrier") from loopy.kernel.instruction import BarrierInstruction barrier_insn = BarrierInstruction(id=barrier_insn_id, depends_on=frozenset( [compute_insn_id]), synchronization_kind="global", mem_kind="global") compute_dep_id = barrier_insn_id added_compute_insns.append(barrier_insn) # }}} # {{{ substitute rule into expressions in kernel (if within footprint) from loopy.symbolic import SubstitutionRuleExpander expander = SubstitutionRuleExpander(kernel.substitutions) invr = RuleInvocationReplacer(rule_mapping_context, subst_name, subst_tag, within, access_descriptors, abm, storage_axis_names, storage_axis_sources, non1_storage_axis_names, temporary_name, compute_insn_id, compute_dep_id, compute_read_variables=get_dependencies( expander(compute_expression))) kernel = invr.map_kernel(kernel) kernel = kernel.copy(instructions=added_compute_insns + kernel.instructions) kernel = rule_mapping_context.finish_kernel(kernel) # }}} # {{{ add dependencies to compute insn kernel = kernel.copy(instructions=[ insn.copy(depends_on=frozenset(invr.compute_insn_depends_on)) if insn. id == compute_insn_id else insn for insn in kernel.instructions ]) # }}} # {{{ propagate storage iname subst to dependencies of compute instructions from loopy.kernel.tools import find_recursive_dependencies compute_deps = find_recursive_dependencies(kernel, frozenset([compute_insn_id])) # FIXME: Need to verify that there are no outside dependencies # on compute_deps prior_storage_axis_names = frozenset(storage_axis_subst_dict) new_insns = [] for insn in kernel.instructions: if (insn.id in compute_deps and insn.within_inames & prior_storage_axis_names): insn = (insn.with_transformed_expressions( lambda expr: expr_subst_map(expr, kernel, insn)).copy( within_inames=frozenset( storage_axis_subst_dict.get(iname, var(iname)).name for iname in insn.within_inames))) new_insns.append(insn) else: new_insns.append(insn) kernel = kernel.copy(instructions=new_insns) # }}} # {{{ determine inames for compute insn if precompute_outer_inames is None: from loopy.kernel.tools import guess_iname_deps_based_on_var_use precompute_outer_inames = ( frozenset(non1_storage_axis_names) | frozenset((expanding_usage_arg_deps | value_inames) - sweep_inames_set) | guess_iname_deps_based_on_var_use(kernel, compute_insn)) else: if not isinstance(precompute_outer_inames, frozenset): raise TypeError("precompute_outer_inames must be a frozenset") precompute_outer_inames = precompute_outer_inames \ | frozenset(non1_storage_axis_names) kernel = kernel.copy(instructions=[ insn.copy(within_inames=precompute_outer_inames) if insn.id == compute_insn_id else insn for insn in kernel.instructions ]) # }}} # {{{ set up temp variable import loopy as lp if dtype is not None: dtype = np.dtype(dtype) if temporary_address_space is None: temporary_address_space = lp.auto new_temp_shape = tuple(abm.non1_storage_shape) new_temporary_variables = kernel.temporary_variables.copy() if temporary_name not in new_temporary_variables: temp_var = lp.TemporaryVariable( name=temporary_name, dtype=dtype, base_indices=(0, ) * len(new_temp_shape), shape=tuple(abm.non1_storage_shape), address_space=temporary_address_space, dim_names=tuple(non1_storage_axis_names)) else: temp_var = new_temporary_variables[temporary_name] # {{{ check and adapt existing temporary if temp_var.dtype is lp.auto: pass elif temp_var.dtype is not lp.auto and dtype is lp.auto: dtype = temp_var.dtype elif temp_var.dtype is not lp.auto and dtype is not lp.auto: if temp_var.dtype != dtype: raise LoopyError("Existing and new dtype of temporary '%s' " "do not match (existing: %s, new: %s)" % (temporary_name, temp_var.dtype, dtype)) temp_var = temp_var.copy(dtype=dtype) if len(temp_var.shape) != len(new_temp_shape): raise LoopyError( "Existing and new temporary '%s' do not " "have matching number of dimensions ('%d' vs. '%d') " % (temporary_name, len(temp_var.shape), len(new_temp_shape))) if temp_var.base_indices != (0, ) * len(new_temp_shape): raise LoopyError( "Existing and new temporary '%s' do not " "have matching number of dimensions ('%d' vs. '%d') " % (temporary_name, len(temp_var.shape), len(new_temp_shape))) new_temp_shape = tuple( max(i, ex_i) for i, ex_i in zip(new_temp_shape, temp_var.shape)) temp_var = temp_var.copy(shape=new_temp_shape) if temporary_address_space == temp_var.address_space: pass elif temporary_address_space is lp.auto: temporary_address_space = temp_var.address_space elif temp_var.address_space is lp.auto: pass else: raise LoopyError("Existing and new temporary '%s' do not " "have matching scopes (existing: %s, new: %s)" % (temporary_name, AddressSpace.stringify(temp_var.address_space), AddressSpace.stringify(temporary_address_space))) temp_var = temp_var.copy(address_space=temporary_address_space) # }}} new_temporary_variables[temporary_name] = temp_var kernel = kernel.copy(temporary_variables=new_temporary_variables) # }}} from loopy import tag_inames kernel = tag_inames(kernel, new_iname_to_tag) from loopy.kernel.data import AutoFitLocalIndexTag, filter_iname_tags_by_type if filter_iname_tags_by_type(new_iname_to_tag.values(), AutoFitLocalIndexTag): from loopy.kernel.tools import assign_automatic_axes kernel = assign_automatic_axes(kernel) return kernel
def infer_unknown_types(kernel, expect_completion=False): """Infer types on temporaries and arguments.""" logger.debug("%s: infer types" % kernel.name) from functools import partial debug = partial(_debug, kernel) import time start_time = time.time() unexpanded_kernel = kernel if kernel.substitutions: from loopy.transform.subst import expand_subst kernel = expand_subst(kernel) new_temp_vars = kernel.temporary_variables.copy() new_arg_dict = kernel.arg_dict.copy() # {{{ find names_with_unknown_types # contains both arguments and temporaries names_for_type_inference = [] import loopy as lp for tv in six.itervalues(kernel.temporary_variables): if tv.dtype is lp.auto: names_for_type_inference.append(tv.name) for arg in kernel.args: if arg.dtype is None: names_for_type_inference.append(arg.name) # }}} logger.debug("finding types for {count:d} names".format( count=len(names_for_type_inference))) writer_map = kernel.writer_map() dep_graph = dict(( written_var, set(read_var for insn_id in writer_map.get(written_var, []) for read_var in kernel.id_to_insn[insn_id].read_dependency_names() if read_var in names_for_type_inference)) for written_var in names_for_type_inference) from loopy.tools import compute_sccs # To speed up processing, we sort the variables by computing the SCCs of the # type dependency graph. Each SCC represents a set of variables whose types # mutually depend on themselves. The SCCs are returned and processed in # topological order. sccs = compute_sccs(dep_graph) item_lookup = _DictUnionView([new_temp_vars, new_arg_dict]) type_inf_mapper = TypeInferenceMapper(kernel, item_lookup) from loopy.symbolic import SubstitutionRuleExpander subst_expander = SubstitutionRuleExpander(kernel.substitutions) # {{{ work on type inference queue from loopy.kernel.data import TemporaryVariable, KernelArgument for var_chain in sccs: changed_during_last_queue_run = False queue = var_chain[:] failed_names = set() while queue or changed_during_last_queue_run: if not queue and changed_during_last_queue_run: changed_during_last_queue_run = False # Optimization: If there's a single variable in the SCC without # a self-referential dependency, then the type is known after a # single iteration (we don't need to look at the expressions # again). if len(var_chain) == 1: single_var, = var_chain if single_var not in dep_graph[single_var]: break queue = var_chain[:] name = queue.pop(0) item = item_lookup[name] debug("inferring type for %s %s", type(item).__name__, item.name) result, symbols_with_unavailable_types = (_infer_var_type( kernel, item.name, type_inf_mapper, subst_expander)) failed = not result if not failed: new_dtype, = result debug(" success: %s", new_dtype) if new_dtype != item.dtype: debug(" changed from: %s", item.dtype) changed_during_last_queue_run = True if isinstance(item, TemporaryVariable): new_temp_vars[name] = item.copy(dtype=new_dtype) elif isinstance(item, KernelArgument): new_arg_dict[name] = item.copy(dtype=new_dtype) else: raise LoopyError( "unexpected item type in type inference") else: debug(" failure") if failed: if item.name in failed_names: # this item has failed before, give up. advice = "" if symbols_with_unavailable_types: advice += ( " (need type of '%s'--check for missing arguments)" % ", ".join(symbols_with_unavailable_types)) if expect_completion: raise LoopyError("could not determine type of '%s'%s" % (item.name, advice)) else: # We're done here. break # remember that this item failed failed_names.add(item.name) if set(queue) == failed_names: # We did what we could... print(queue, failed_names, item.name) assert not expect_completion break # can't infer type yet, put back into queue queue.append(name) else: # we've made progress, reset failure markers failed_names = set() # }}} end_time = time.time() logger.debug("type inference took {dur:.2f} seconds".format(dur=end_time - start_time)) return unexpanded_kernel.copy( temporary_variables=new_temp_vars, args=[new_arg_dict[arg.name] for arg in kernel.args], )
def guess_var_shape(kernel, var_name): from loopy.symbolic import SubstitutionRuleExpander, AccessRangeMapper armap = AccessRangeMapper(kernel, var_name) submap = SubstitutionRuleExpander(kernel.substitutions) def run_through_armap(expr): armap(submap(expr), kernel.insn_inames(insn)) return expr try: for insn in kernel.instructions: insn.with_transformed_expressions(run_through_armap) except TypeError as e: from traceback import print_exc print_exc() raise LoopyError( "Failed to (automatically, as requested) find " "shape/strides for variable '%s'. " "Specifying the shape manually should get rid of this. " "The following error occurred: %s" % (var_name, str(e))) if armap.access_range is None: if armap.bad_subscripts: from loopy.symbolic import LinearSubscript if any( isinstance(sub, LinearSubscript) for sub in armap.bad_subscripts): raise LoopyError( "cannot determine access range for '%s': " "linear subscript(s) in '%s'" % (var_name, ", ".join(str(i) for i in armap.bad_subscripts))) n_axes_in_subscripts = set( len(sub.index_tuple) for sub in armap.bad_subscripts) if len(n_axes_in_subscripts) != 1: raise RuntimeError("subscripts of '%s' with differing " "numbers of axes were found" % var_name) n_axes, = n_axes_in_subscripts if n_axes == 1: # Leave shape undetermined--we can live with that for 1D. shape = (None, ) else: raise LoopyError( "cannot determine access range for '%s': " "undetermined index in subscript(s) '%s'" % (var_name, ", ".join(str(i) for i in armap.bad_subscripts))) else: # no subscripts found, let's call it a scalar shape = () else: from loopy.isl_helpers import static_max_of_pw_aff from loopy.symbolic import pw_aff_to_expr shape = [] for i in range(armap.access_range.dim(dim_type.set)): try: shape.append( pw_aff_to_expr( static_max_of_pw_aff(kernel.cache_manager.dim_max( armap.access_range, i) + 1, constants_only=False))) except: print("While trying to find shape axis %d of " "variable '%s', the following " "exception occurred:" % (i, var_name), file=sys.stderr) print("*** ADVICE: You may need to manually specify the " "shape of argument '%s'." % (var_name), file=sys.stderr) raise shape = tuple(shape) return shape
def infer_unknown_types(kernel, expect_completion=False): """Infer types on temporaries and arguments.""" logger.debug("%s: infer types" % kernel.name) def debug(s): logger.debug("%s: %s" % (kernel.name, s)) unexpanded_kernel = kernel if kernel.substitutions: from loopy.transform.subst import expand_subst kernel = expand_subst(kernel) new_temp_vars = kernel.temporary_variables.copy() new_arg_dict = kernel.arg_dict.copy() # {{{ fill queue # queue contains temporary variables queue = [] import loopy as lp for tv in six.itervalues(kernel.temporary_variables): if tv.dtype is lp.auto: queue.append(tv) for arg in kernel.args: if arg.dtype is None: queue.append(arg) # }}} from loopy.expression import TypeInferenceMapper type_inf_mapper = TypeInferenceMapper(kernel, _DictUnionView([ new_temp_vars, new_arg_dict ])) from loopy.symbolic import SubstitutionRuleExpander subst_expander = SubstitutionRuleExpander(kernel.substitutions) # {{{ work on type inference queue from loopy.kernel.data import TemporaryVariable, KernelArgument failed_names = set() while queue: item = queue.pop(0) debug("inferring type for %s %s" % (type(item).__name__, item.name)) result, symbols_with_unavailable_types = \ _infer_var_type(kernel, item.name, type_inf_mapper, subst_expander) failed = result is None if not failed: debug(" success: %s" % result) if isinstance(item, TemporaryVariable): new_temp_vars[item.name] = item.copy(dtype=result) elif isinstance(item, KernelArgument): new_arg_dict[item.name] = item.copy(dtype=result) else: raise LoopyError("unexpected item type in type inference") else: debug(" failure") if failed: if item.name in failed_names: # this item has failed before, give up. advice = "" if symbols_with_unavailable_types: advice += ( " (need type of '%s'--check for missing arguments)" % ", ".join(symbols_with_unavailable_types)) if expect_completion: raise LoopyError( "could not determine type of '%s'%s" % (item.name, advice)) else: # We're done here. break # remember that this item failed failed_names.add(item.name) queue_names = set(qi.name for qi in queue) if queue_names == failed_names: # We did what we could... print(queue_names, failed_names, item.name) assert not expect_completion break # can't infer type yet, put back into queue queue.append(item) else: # we've made progress, reset failure markers failed_names = set() # }}} return unexpanded_kernel.copy( temporary_variables=new_temp_vars, args=[new_arg_dict[arg.name] for arg in kernel.args], )
def infer_unknown_types_for_a_single_kernel(kernel, clbl_inf_ctx): """Infer types on temporaries and arguments.""" logger.debug("%s: infer types" % kernel.name) from functools import partial debug = partial(_debug, kernel) import time start_time = time.time() unexpanded_kernel = kernel if kernel.substitutions: from loopy.transform.subst import expand_subst kernel = expand_subst(kernel) new_temp_vars = kernel.temporary_variables.copy() new_arg_dict = kernel.arg_dict.copy() # {{{ find names_with_unknown_types # contains both arguments and temporaries names_for_type_inference = [] import loopy as lp for tv in kernel.temporary_variables.values(): assert tv.dtype is not lp.auto if tv.dtype is None: names_for_type_inference.append(tv.name) for arg in kernel.args: assert arg.dtype is not lp.auto if arg.dtype is None: names_for_type_inference.append(arg.name) # }}} logger.debug("finding types for {count:d} names".format( count=len(names_for_type_inference))) writer_map = kernel.writer_map() dep_graph = { written_var: { read_var for insn_id in writer_map.get(written_var, []) for read_var in kernel.id_to_insn[insn_id].read_dependency_names() if read_var in names_for_type_inference } for written_var in names_for_type_inference } from pytools.graph import compute_sccs # To speed up processing, we sort the variables by computing the SCCs of the # type dependency graph. Each SCC represents a set of variables whose types # mutually depend on themselves. The SCCs are returned and processed in # topological order. sccs = compute_sccs(dep_graph) item_lookup = _DictUnionView([new_temp_vars, new_arg_dict]) type_inf_mapper = TypeInferenceMapper(kernel, clbl_inf_ctx, item_lookup) from loopy.symbolic import SubstitutionRuleExpander subst_expander = SubstitutionRuleExpander(kernel.substitutions) # {{{ work on type inference queue from loopy.kernel.data import TemporaryVariable, KernelArgument old_calls_to_new_calls = {} for var_chain in sccs: changed_during_last_queue_run = False var_queue = var_chain[:] failed_names = set() while var_queue or changed_during_last_queue_run: if not var_queue and changed_during_last_queue_run: changed_during_last_queue_run = False # Optimization: If there's a single variable in the SCC without # a self-referential dependency, then the type is known after a # single iteration (we don't need to look at the expressions # again). if len(var_chain) == 1: single_var, = var_chain if single_var not in dep_graph[single_var]: break var_queue = var_chain[:] name = var_queue.pop(0) item = item_lookup[name] debug("inferring type for %s %s", type(item).__name__, item.name) try: (result, symbols_with_unknown_types, new_old_calls_to_new_calls, clbl_inf_ctx) = (_infer_var_type(kernel, item.name, type_inf_mapper, subst_expander)) except DependencyTypeInferenceFailure: result = () symbols_with_unknown_types = () type_inf_mapper = type_inf_mapper.copy(clbl_inf_ctx=clbl_inf_ctx) if result: new_dtype, = result debug(" success: %s", new_dtype) if new_dtype != item.dtype: debug(" changed from: %s", item.dtype) changed_during_last_queue_run = True if isinstance(item, TemporaryVariable): new_temp_vars[name] = item.copy(dtype=new_dtype) elif isinstance(item, KernelArgument): new_arg_dict[name] = item.copy(dtype=new_dtype) else: raise LoopyError( "unexpected item type in type inference") old_calls_to_new_calls.update(new_old_calls_to_new_calls) # we've made progress, reset failure markers failed_names = set() else: debug(" failure") if item.name in failed_names: # this item has failed before, give up. advice = "" if symbols_with_unknown_types: advice += ( " (need type of '%s'--check for missing arguments)" % ", ".join(symbols_with_unknown_types)) debug("could not determine type of '%s'%s" % (item.name, advice)) # We're done here break # remember that this item failed failed_names.add(item.name) if set(var_queue) == failed_names: # We did what we could... print(var_queue, failed_names, item.name) break # can't infer type yet, put back into var_queue var_queue.append(name) # }}} # {{{ check if insn missed during type inference def _instruction_missed_during_inference(insn): for assignee in insn.assignees: if isinstance(assignee, Lookup): assignee = assignee.aggregate if isinstance(assignee, Variable): if assignee.name in kernel.arg_dict: if kernel.arg_dict[assignee.name].dtype is None: return False else: assert assignee.name in kernel.temporary_variables if kernel.temporary_variables[assignee.name].dtype is None: return False elif isinstance(assignee, (Subscript, LinearSubscript)): if assignee.aggregate.name in kernel.arg_dict: if kernel.arg_dict[assignee.aggregate.name].dtype is None: return False else: assert assignee.aggregate.name in kernel.temporary_variables if kernel.temporary_variables[ assignee.aggregate.name].dtype is None: return False else: assert isinstance(assignee, SubArrayRef) if assignee.subscript.aggregate.name in kernel.arg_dict: if kernel.arg_dict[ assignee.subscript.aggregate.name].dtype is None: return False else: assert assignee.subscript.aggregate.name in ( kernel.temporary_variables) if kernel.temporary_variables[ assignee.subscript.aggregate.name] is None: return False return True # }}} for insn in kernel.instructions: if isinstance(insn, lp.MultiAssignmentBase): # just a dummy run over the expression, to pass over all the # functions if _instruction_missed_during_inference(insn): type_inf_mapper(insn.expression, return_tuple=len(insn.assignees) != 1, return_dtype_set=True) elif isinstance(insn, (_DataObliviousInstruction, lp.CInstruction)): pass else: raise NotImplementedError("Unknown instructions type %s." % (type(insn).__name__)) clbl_inf_ctx = type_inf_mapper.clbl_inf_ctx old_calls_to_new_calls.update(type_inf_mapper.old_calls_to_new_calls) end_time = time.time() logger.debug("type inference took {dur:.2f} seconds".format(dur=end_time - start_time)) pre_type_specialized_knl = unexpanded_kernel.copy( temporary_variables=new_temp_vars, args=[new_arg_dict[arg.name] for arg in kernel.args], ) type_specialized_kernel = change_names_of_pymbolic_calls( pre_type_specialized_knl, old_calls_to_new_calls) return type_specialized_kernel, clbl_inf_ctx
def change_names_of_pymbolic_calls(kernel, pymbolic_calls_to_new_names): """ Returns a copy of *kernel* with the names of pymbolic calls changed according to the mapping given by *pymbolic_calls_new_names*. :arg pymbolic_calls_to_new_names: A mapping from instances of :class:`pymbolic.primitives.Call` to :class:`str`. **Example: ** - Given a *kernel* -- .. code:: ------------------------------------------------------------- KERNEL: loopy_kernel ------------------------------------------------------------- ARGUMENTS: x: type: <auto/runtime>, shape: (10), dim_tags: (N0:stride:1) y: type: <auto/runtime>, shape: (10), dim_tags: (N0:stride:1) ------------------------------------------------------------- DOMAINS: { [i] : 0 <= i <= 9 } ------------------------------------------------------------- INAME IMPLEMENTATION TAGS: i: None ------------------------------------------------------------- INSTRUCTIONS: for i y[i] = ResolvedFunction('sin')(x[i]) end i ------------------------------------------------------------- - And given a *pymbolic_calls_to_new_names* -- .. code:: {Call(ResolvedFunction(Variable('sin')), (Subscript(Variable('x'), Variable('i')),))": 'sin_1'} - The following *kernel* is returned -- .. code:: ------------------------------------------------------------- KERNEL: loopy_kernel ------------------------------------------------------------- ARGUMENTS: x: type: <auto/runtime>, shape: (10), dim_tags: (N0:stride:1) y: type: <auto/runtime>, shape: (10), dim_tags: (N0:stride:1) ------------------------------------------------------------- DOMAINS: { [i] : 0 <= i <= 9 } ------------------------------------------------------------- INAME IMPLEMENTATION TAGS: i: None ------------------------------------------------------------- INSTRUCTIONS: for i y[i] = ResolvedFunction('sin_1')(x[i]) end i ------------------------------------------------------------- """ rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) subst_expander = SubstitutionRuleExpander(kernel.substitutions) name_changer = FunctionNameChanger(rule_mapping_context, pymbolic_calls_to_new_names, subst_expander) return rule_mapping_context.finish_kernel(name_changer.map_kernel(kernel))