def preprocess_kernel(kernel, device=None): if device is not None: from warnings import warn warn("passing 'device' to preprocess_kernel() is deprecated", DeprecationWarning, stacklevel=2) from loopy.kernel import kernel_state if kernel.state != kernel_state.INITIAL: raise LoopyError("cannot re-preprocess an already preprocessed " "kernel") # {{{ cache retrieval from loopy import CACHING_ENABLED if CACHING_ENABLED: input_kernel = kernel try: result = preprocess_cache[kernel] logger.info("%s: preprocess cache hit" % kernel.name) return result except KeyError: pass # }}} logger.info("%s: preprocess start" % kernel.name) from loopy.subst import expand_subst kernel = expand_subst(kernel) # Ordering restriction: # Type inference doesn't handle substitutions. Get them out of the # way. kernel = infer_unknown_types(kernel, expect_completion=False) kernel = add_default_dependencies(kernel) # Ordering restrictions: # # - realize_reduction must happen after type inference because it needs # to be able to determine the types of the reduced expressions. # # - realize_reduction must happen after default dependencies are added # because it manipulates the insn_deps field, which could prevent # defaults from being applied. kernel = realize_reduction(kernel) # Ordering restriction: # duplicate_private_temporaries_for_ilp because reduction accumulators # need to be duplicated by this. kernel = duplicate_private_temporaries_for_ilp_and_vec(kernel) kernel = mark_local_temporaries(kernel) kernel = assign_automatic_axes(kernel) kernel = find_boostability(kernel) kernel = limit_boostability(kernel) kernel = kernel.target.preprocess(kernel) logger.info("%s: preprocess done" % kernel.name) kernel = kernel.copy( state=kernel_state.PREPROCESSED) # {{{ prepare for caching # PicklableDtype instances for example need to know the target they're working # towards in order to pickle and unpickle them. This is the first pass that # uses caching, so we need to be ready to pickle. This means propagating # this target information. if CACHING_ENABLED: input_kernel = prepare_for_caching(input_kernel) kernel = prepare_for_caching(kernel) # }}} if CACHING_ENABLED: preprocess_cache[input_kernel] = kernel return kernel
def infer_unknown_types(kernel, expect_completion=False): """Infer types on temporaries and arguments.""" logger.debug("%s: infer types" % kernel.name) def debug(s): logger.debug("%s: %s" % (kernel.name, s)) unexpanded_kernel = kernel if kernel.substitutions: from loopy.subst import expand_subst kernel = expand_subst(kernel) new_temp_vars = kernel.temporary_variables.copy() new_arg_dict = kernel.arg_dict.copy() # {{{ fill queue # queue contains temporary variables queue = [] import loopy as lp for tv in six.itervalues(kernel.temporary_variables): if tv.dtype is lp.auto: queue.append(tv) for arg in kernel.args: if arg.dtype is None: queue.append(arg) # }}} from loopy.expression import TypeInferenceMapper type_inf_mapper = TypeInferenceMapper(kernel, _DictUnionView([ new_temp_vars, new_arg_dict ])) from loopy.symbolic import SubstitutionRuleExpander subst_expander = SubstitutionRuleExpander(kernel.substitutions) # {{{ work on type inference queue from loopy.kernel.data import TemporaryVariable, KernelArgument failed_names = set() while queue: item = queue.pop(0) debug("inferring type for %s %s" % (type(item).__name__, item.name)) result, symbols_with_unavailable_types = \ _infer_var_type(kernel, item.name, type_inf_mapper, subst_expander) failed = result is None if not failed: debug(" success: %s" % result) if isinstance(item, TemporaryVariable): new_temp_vars[item.name] = item.copy(dtype=result) elif isinstance(item, KernelArgument): new_arg_dict[item.name] = item.copy(dtype=result) else: raise LoopyError("unexpected item type in type inference") else: debug(" failure") if failed: if item.name in failed_names: # this item has failed before, give up. advice = "" if symbols_with_unavailable_types: advice += ( " (need type of '%s'--check for missing arguments)" % ", ".join(symbols_with_unavailable_types)) if expect_completion: raise LoopyError( "could not determine type of '%s'%s" % (item.name, advice)) else: # We're done here. break # remember that this item failed failed_names.add(item.name) queue_names = set(qi.name for qi in queue) if queue_names == failed_names: # We did what we could... print(queue_names, failed_names, item.name) assert not expect_completion break # can't infer type yet, put back into queue queue.append(item) else: # we've made progress, reset failure markers failed_names = set() # }}} return unexpanded_kernel.copy( temporary_variables=new_temp_vars, args=[new_arg_dict[arg.name] for arg in kernel.args], )
def find_all_insn_inames(kernel): logger.debug("%s: find_all_insn_inames: start" % kernel.name) writer_map = kernel.writer_map() insn_id_to_inames = {} insn_assignee_inames = {} all_read_deps = {} all_write_deps = {} from loopy.subst import expand_subst kernel = expand_subst(kernel) for insn in kernel.instructions: all_read_deps[insn.id] = read_deps = insn.read_dependency_names() all_write_deps[insn.id] = write_deps = insn.write_dependency_names() deps = read_deps | write_deps if insn.forced_iname_deps_is_final: iname_deps = insn.forced_iname_deps else: iname_deps = ( deps & kernel.all_inames() | insn.forced_iname_deps) assert isinstance(read_deps, frozenset), type(insn) assert isinstance(write_deps, frozenset), type(insn) assert isinstance(iname_deps, frozenset), type(insn) logger.debug("%s: find_all_insn_inames: %s (init): %s - " "read deps: %s - write deps: %s" % ( kernel.name, insn.id, ", ".join(sorted(iname_deps)), ", ".join(sorted(read_deps)), ", ".join(sorted(write_deps)), )) insn_id_to_inames[insn.id] = iname_deps insn_assignee_inames[insn.id] = write_deps & kernel.all_inames() written_vars = kernel.get_written_variables() # fixed point iteration until all iname dep sets have converged # Why is fixed point iteration necessary here? Consider the following # scenario: # # z = expr(iname) # y = expr(z) # x = expr(y) # # x clearly has a dependency on iname, but this is not found until that # dependency has propagated all the way up. Doing this recursively is # not guaranteed to terminate because of circular dependencies. while True: did_something = False for insn in kernel.instructions: if insn.forced_iname_deps_is_final: continue # {{{ depdency-based propagation # For all variables that insn depends on, find the intersection # of iname deps of all writers, and add those to insn's # dependencies. for tv_name in (all_read_deps[insn.id] & written_vars): implicit_inames = None for writer_id in writer_map[tv_name]: writer_implicit_inames = ( insn_id_to_inames[writer_id] - insn_assignee_inames[writer_id]) if implicit_inames is None: implicit_inames = writer_implicit_inames else: implicit_inames = (implicit_inames & writer_implicit_inames) inames_old = insn_id_to_inames[insn.id] inames_new = (inames_old | implicit_inames) \ - insn.reduction_inames() insn_id_to_inames[insn.id] = inames_new if inames_new != inames_old: did_something = True logger.debug("%s: find_all_insn_inames: %s -> %s (dep-based)" % ( kernel.name, insn.id, ", ".join(sorted(inames_new)))) # }}} # {{{ domain-based propagation inames_old = insn_id_to_inames[insn.id] inames_new = set(insn_id_to_inames[insn.id]) for iname in inames_old: home_domain = kernel.domains[kernel.get_home_domain_index(iname)] for par in home_domain.get_var_names(dim_type.param): # Add all inames occurring in parameters of domains that my # current inames refer to. if par in kernel.all_inames(): inames_new.add(par) # If something writes the bounds of a loop in which I'm # sitting, I had better be in the inames that the writer is # in. if par in kernel.temporary_variables: for writer_id in writer_map.get(par, []): inames_new.update(insn_id_to_inames[writer_id]) if inames_new != inames_old: did_something = True insn_id_to_inames[insn.id] = frozenset(inames_new) logger.debug("%s: find_all_insn_inames: %s -> %s (domain-based)" % ( kernel.name, insn.id, ", ".join(sorted(inames_new)))) # }}} if not did_something: break logger.debug("%s: find_all_insn_inames: done" % kernel.name) for v in six.itervalues(insn_id_to_inames): assert isinstance(v, frozenset) return insn_id_to_inames