def get_mapped_subsets_dicts(self, interval: Interval, section: dace.SDFG): in_subsets = dict() out_subsets = dict() section_origins: Dict[str, Tuple[int, int]] = dict() min_k_offsets: Dict[str, int] = dict() for he in (ln for ln, _ in section.all_nodes_recursive() if isinstance(ln, (HorizontalExecutionLibraryNode, VerticalLoopLibraryNode))): access_collection: AccessCollector.CartesianAccessCollection = get_access_collection( he) for name, offsets in access_collection.offsets().items(): off: Tuple[int, int, int] for off in offsets: origin = ( -off[0] - he.iteration_space.i_interval.start.offset, -off[1] - he.iteration_space.j_interval.start.offset, ) if name not in section_origins: section_origins[name] = origin if name not in min_k_offsets: min_k_offsets[name] = off[2] section_origins[name] = ( max(section_origins[name][0], origin[0]), max(section_origins[name][1], origin[1]), ) min_k_offsets[name] = min(min_k_offsets[name], off[2]) access_collection = get_access_collection(section) for name, section_origin in section_origins.items(): vl_origin = self.origins[name] shape = section.arrays[name].shape dimensions = array_dimensions(section.arrays[name]) subset_strs = [] idx = iter(range(3)) if dimensions[0]: subset_strs.append("{i:+d}:{i:+d}+({I})".format( i=vl_origin[0] - section_origin[0], I=shape[next(idx)], )) if dimensions[1]: subset_strs.append("{j:+d}:{j:+d}+({J})".format( j=vl_origin[1] - section_origin[1], J=shape[next(idx)])) if dimensions[2]: subset_strs.append( "k-({k_orig}){k:+d}:k-({k_orig}){k:+d}{K:+d}".format( k_orig=get_axis_bound_str(vl_origin[2], "__K"), k=min_k_offsets[name], K=shape[next(idx)], )) data_dims = shape[sum(dimensions):] subset_strs.extend([f"0:{d}" for d in data_dims]) subset_str = ",".join(subset_strs) if name in access_collection.read_fields(): in_subsets[name] = subset_str if name in access_collection.write_fields(): out_subsets[name] = subset_str return in_subsets, out_subsets
def get_ij_origins(self): origins: Dict[str, Tuple[int, int]] = {} for _, section in self.node.sections: for he in ( ln for ln, _ in section.all_nodes_recursive() if isinstance(ln, HorizontalExecutionLibraryNode) ): access_collection = get_access_collection(he) for name, offsets in access_collection.offsets().items(): off: Tuple[int, int] for off in offsets: origin = ( -off[0] - he.iteration_space.i_interval.start.offset, -off[1] - he.iteration_space.j_interval.start.offset, ) if name not in origins: origins[name] = origin origins[name] = ( max(origins[name][0], origin[0]), max(origins[name][1], origin[1]), ) return origins
def get_k_origins(self): k_origs: Dict[str, oir.AxisBound] = {} for interval, section in self.node.sections: access_collection = get_access_collection(section) for name, offsets in access_collection.offsets().items(): for off in offsets: k_level = oir.AxisBound( level=interval.start.level, offset=interval.start.offset + off[2] ) k_orig = min(k_origs.get(name, k_level), k_level) k_origs[name] = k_orig return k_origs
def get_origins(self): access_collection: AccessCollector.Result = get_access_collection(self.node) origins = dict() for name, offsets in access_collection.offsets().items(): origins[name] = access_collection.offsets()[name].pop() for off in offsets: origins[name] = ( min(origins[name][0], off[0]), min(origins[name][1], off[1]), min(origins[name][2], off[2]), ) origins[name] = ( -origins[name][0] - self.node.iteration_space.i_interval.start.offset, -origins[name][1] - self.node.iteration_space.j_interval.start.offset, -origins[name][2], ) return origins
def get_innermost_memlets(self): access_collection: AccessCollector.CartesianAccessCollection = get_access_collection( self.node) in_memlets = dict() for name, offsets in access_collection.read_offsets().items(): dimensions = array_dimensions(self.parent_sdfg.arrays[name]) data_dims = self.parent_sdfg.arrays[name].shape[sum(dimensions):] for off in offsets: subset_strs = [ f"{var}{self.origins[name][dim] + off[dim]:+d}" for dim, var in enumerate("ij0") if dimensions[dim] ] subset_strs.extend(f"0:{dim}" for dim in data_dims) acc_name = get_tasklet_symbol(name, off, is_target=False) in_memlets[acc_name] = dace.memlet.Memlet.simple( name, ",".join(subset_strs)) out_memlets = dict() for name in access_collection.write_fields(): dimensions = array_dimensions(self.parent_sdfg.arrays[name]) data_dims = self.parent_sdfg.arrays[name].shape[sum(dimensions):] subset_strs = [ f"{var}{self.origins[name][dim]:+d}" for dim, var in enumerate("ij0") if dimensions[dim] ] subset_strs.extend(f"0:{dim}" for dim in data_dims) acc_name = "__" + name out_memlets[acc_name] = dace.memlet.Memlet.simple( name, ",".join(subset_strs), dynamic=any( isinstance(stmt, oir.MaskStmt) for stmt in self.node.oir_node.body), ) return in_memlets, out_memlets