def get_tensor_access( self, stmt_isl_repr ) -> Tuple[Dict[str, Dict[str, isl.union_map]], Dict[str, Set[Tensor]]]: isl_mapping = defaultdict( lambda: defaultdict(lambda: isl.union_map('{}'))) vanilla = defaultdict(set) record = self.record[-1] for t, tensor, ind in record: ind = ', '.join(map(str, ind)) new_map = isl.union_map( f'{{ {stmt_isl_repr} -> {tensor.name}[{ind}] }}') isl_mapping[t][tensor.name] = isl_mapping[t][tensor.name].union( new_map) vanilla[t].add(tensor) return isl_mapping, vanilla
def __init__(self, expansion=None, contraction=None, **kwargs): super().__init__(**kwargs) if expansion is None or isinstance(expansion, str): expansion = isl.union_map(expansion or '{}') if contraction is None or isinstance(contraction, str): contraction = isl.union_pw_multi_aff(contraction or '{}') self.expansion: isl.union_map = expansion self.contraction: isl.union_pw_multi_aff = contraction
def get_tagged_must_kills(self): return isl.union_map(ctx=self.ctx, ptr=pet.pet_scop_get_tagged_must_kills(self.ptr))
def get_tagged_may_writes(self): return isl.union_map(ctx=self.ctx, ptr=pet.pet_scop_get_tagged_may_writes(self.ptr))
def get_may_reads(self): return isl.union_map(ctx=self.ctx, ptr=pet.pet_scop_get_may_reads(self.ptr))
def __init__(self, extension=None, **kwargs): super().__init__(**kwargs) if extension is None or isinstance(extension, str): extension = isl.union_map(extension or '{}') self.extension: isl.union_map = extension
def cuda_find_sharable_tensors(tree, statements, tensors, max_shared_memory=None): node = tree.root while node and isinstance(node, (NodeWithSingleChild, MarkNode)): if isinstance(node, MarkNode) and 'threadIdx' in node.mark: break node = node.child assert isinstance(node, MarkNode) and 'threadIdx' in node.mark if max_shared_memory is None: max_shared_memory = cuda_settings['max_shared_memory'] prefix = node.to_isl().prefix_schedule_relation() prefix = prefix.intersect_domain(tree.domain()).reverse() tensor_access = dict() tensor_stmts = defaultdict( lambda: defaultdict(lambda: isl.union_map('{}'))) tensor_access_types = defaultdict(set) for _, stmt in statements.items(): stmt_tensor_access, _ = stmt.get_access(tensors) for k in ('read', 'write'): for name, access in stmt_tensor_access[k].items(): tensor_access_types[name].add(k) new_map = prefix.apply_range(access) assert new_map.isa_map() new_map = isl.map.from_union_map(new_map) if name in tensor_access: new_map = new_map.add_map(tensor_access[name]) assert new_map.isa_map() new_map = isl.map.from_union_map(new_map) tensor_access[name] = new_map tensor_stmts[k][name] = tensor_stmts[k][name].union( access.reverse()) access_count = defaultdict(int) for name in tensor_access: for access_type in ('read', 'write'): ts_maps = tensor_stmts[access_type][name].intersect_range( tree.domain()) stmt_maps = tensor_access[name].apply_range(ts_maps) stmts = list() stmt_maps.foreach_map(stmts.append) for stmt in stmts: box = stmt.range_simple_fixed_box_hull() box_size, _ = structure_unnamed_fixed_box(box) s = stmt.range() strides = [int(str(s.stride(i))) for i in range(len(box_size))] total = reduce(int.__mul__, [-(-i // j) for i, j in zip(box_size, strides)]) access_count[name] += total usages = [] for name in tensor_access: box = tensor_access[name].range_simple_fixed_box_hull() box_size, offset = structure_unnamed_fixed_box(box) s = tensor_access[name].range() strides = [int(str(s.stride(i))) for i in range(len(box_size))] usages.append( BlockTensorUsage(tensors[name], box_size, strides, offset, tensor_access_types[name])) usages.sort(key=lambda x: x.size_in_bytes) res = [] shared_total_usage = 0 for i in usages: name = i.origin.name bytes_usage = i.size_in_bytes if bytes_usage * 8 >= access_count[name]: continue if bytes_usage + shared_total_usage > max_shared_memory: break shared_total_usage += bytes_usage res.append(i) return res