def augment_domain_for_temporary_promotion(kernel, domain, promoted_temporary, mode, name_gen): """ Add new axes to the domain corresponding to the dimensions of `promoted_temporary`. """ import islpy as isl orig_temporary = promoted_temporary.orig_temporary orig_dim = domain.dim(isl.dim_type.set) dims_to_insert = len(orig_temporary.shape) iname_to_tag = {} # Add dimension-dependent inames. dim_inames = [] domain = domain.add(isl.dim_type.set, dims_to_insert) for t_idx in range(len(orig_temporary.shape)): new_iname = name_gen("{name}_{mode}_dim_{dim}".format( name=orig_temporary.name, mode=mode, dim=t_idx)) domain = domain.set_dim_name(isl.dim_type.set, orig_dim + t_idx, new_iname) if orig_temporary.is_local: # If the temporary is has local scope, then loads / stores can be # done in parallel. from loopy.kernel.data import AutoFitLocalIndexTag iname_to_tag[new_iname] = AutoFitLocalIndexTag() dim_inames.append(new_iname) # Add size information. aff = isl.affs_from_space(domain.space) domain &= aff[0].le_set(aff[new_iname]) size = orig_temporary.shape[t_idx] from loopy.symbolic import aff_from_expr domain &= aff[new_iname].lt_set(aff_from_expr(domain.space, size)) hw_inames = [] # Add hardware inames duplicates. for t_idx, hw_iname in enumerate(promoted_temporary.hw_inames): new_iname = name_gen("{name}_{mode}_hw_dim_{dim}".format( name=orig_temporary.name, mode=mode, dim=t_idx)) hw_inames.append(new_iname) iname_to_tag[new_iname] = kernel.iname_to_tag[hw_iname] from loopy.isl_helpers import duplicate_axes domain = duplicate_axes(domain, promoted_temporary.hw_inames, hw_inames) # The operations on the domain above return a Set object, but the # underlying domain should be expressible as a single BasicSet. domain_list = domain.get_basic_set_list() assert domain_list.n_basic_set() == 1 domain = domain_list.get_basic_set(0) return domain, hw_inames, dim_inames, iname_to_tag
def augment_domain_for_save_or_reload(self, domain, promoted_temporary, mode, subkernel): """ Add new axes to the domain corresponding to the dimensions of `promoted_temporary`. These axes will be used in the save/ reload stage. These get prefixed onto the already existing axes. """ assert mode in ("save", "reload") import islpy as isl orig_temporary = (self.kernel.temporary_variables[ promoted_temporary.orig_temporary_name]) orig_dim = domain.dim(isl.dim_type.set) # Tags for newly added inames iname_to_tag = {} from loopy.symbolic import aff_from_expr # FIXME: Restrict size of new inames to access footprint. # Add dimension-dependent inames. dim_inames = [] domain = domain.add( isl.dim_type.set, len(promoted_temporary.non_hw_dims) + len(promoted_temporary.hw_dims)) for dim_idx, dim_size in enumerate(promoted_temporary.non_hw_dims): new_iname = self.insn_name_gen( "{name}_{mode}_axis_{dim}_{sk}".format( name=orig_temporary.name, mode=mode, dim=dim_idx, sk=subkernel)) domain = domain.set_dim_name(isl.dim_type.set, orig_dim + dim_idx, new_iname) if orig_temporary.is_local: # If the temporary has local scope, then loads / stores can # be done in parallel. from loopy.kernel.data import AutoFitLocalIndexTag iname_to_tag[new_iname] = AutoFitLocalIndexTag() dim_inames.append(new_iname) # Add size information. aff = isl.affs_from_space(domain.space) domain &= aff[0].le_set(aff[new_iname]) domain &= aff[new_iname].lt_set( aff_from_expr(domain.space, dim_size)) dim_offset = orig_dim + len(promoted_temporary.non_hw_dims) hw_inames = [] # Add hardware dims. for hw_iname_idx, (hw_tag, dim) in enumerate( zip(promoted_temporary.hw_tags, promoted_temporary.hw_dims)): new_iname = self.insn_name_gen( "{name}_{mode}_hw_dim_{dim}_{sk}".format( name=orig_temporary.name, mode=mode, dim=hw_iname_idx, sk=subkernel)) domain = domain.set_dim_name(isl.dim_type.set, dim_offset + hw_iname_idx, new_iname) aff = isl.affs_from_space(domain.space) domain = ( domain & aff[0].le_set(aff[new_iname]) & aff[new_iname].lt_set(aff_from_expr(domain.space, dim))) self.updated_iname_to_tag[new_iname] = hw_tag hw_inames.append(new_iname) # The operations on the domain above return a Set object, but the # underlying domain should be expressible as a single BasicSet. domain_list = domain.get_basic_set_list() assert domain_list.n_basic_set() == 1 domain = domain_list.get_basic_set(0) return domain, hw_inames, dim_inames, iname_to_tag