Example #1
0
 def get_tensor_access(
     self, stmt_isl_repr
 ) -> Tuple[Dict[str, Dict[str, isl.union_map]], Dict[str, Set[Tensor]]]:
     isl_mapping = defaultdict(
         lambda: defaultdict(lambda: isl.union_map('{}')))
     vanilla = defaultdict(set)
     record = self.record[-1]
     for t, tensor, ind in record:
         ind = ', '.join(map(str, ind))
         new_map = isl.union_map(
             f'{{ {stmt_isl_repr} -> {tensor.name}[{ind}] }}')
         isl_mapping[t][tensor.name] = isl_mapping[t][tensor.name].union(
             new_map)
         vanilla[t].add(tensor)
     return isl_mapping, vanilla
Example #2
0
 def __init__(self, expansion=None, contraction=None, **kwargs):
     super().__init__(**kwargs)
     if expansion is None or isinstance(expansion, str):
         expansion = isl.union_map(expansion or '{}')
     if contraction is None or isinstance(contraction, str):
         contraction = isl.union_pw_multi_aff(contraction or '{}')
     self.expansion: isl.union_map = expansion
     self.contraction: isl.union_pw_multi_aff = contraction
Example #3
0
 def get_tagged_must_kills(self):
     return isl.union_map(ctx=self.ctx,
         ptr=pet.pet_scop_get_tagged_must_kills(self.ptr))
Example #4
0
 def get_tagged_may_writes(self):
     return isl.union_map(ctx=self.ctx,
         ptr=pet.pet_scop_get_tagged_may_writes(self.ptr))
Example #5
0
 def get_may_reads(self):
     return isl.union_map(ctx=self.ctx,
         ptr=pet.pet_scop_get_may_reads(self.ptr))
Example #6
0
 def get_tagged_must_kills(self):
     return isl.union_map(ctx=self.ctx,
                          ptr=pet.pet_scop_get_tagged_must_kills(self.ptr))
Example #7
0
 def get_tagged_may_writes(self):
     return isl.union_map(ctx=self.ctx,
                          ptr=pet.pet_scop_get_tagged_may_writes(self.ptr))
Example #8
0
 def get_may_reads(self):
     return isl.union_map(ctx=self.ctx,
                          ptr=pet.pet_scop_get_may_reads(self.ptr))
Example #9
0
 def __init__(self, extension=None, **kwargs):
     super().__init__(**kwargs)
     if extension is None or isinstance(extension, str):
         extension = isl.union_map(extension or '{}')
     self.extension: isl.union_map = extension
Example #10
0
def cuda_find_sharable_tensors(tree,
                               statements,
                               tensors,
                               max_shared_memory=None):
    node = tree.root

    while node and isinstance(node, (NodeWithSingleChild, MarkNode)):
        if isinstance(node, MarkNode) and 'threadIdx' in node.mark:
            break
        node = node.child
    assert isinstance(node, MarkNode) and 'threadIdx' in node.mark

    if max_shared_memory is None:
        max_shared_memory = cuda_settings['max_shared_memory']

    prefix = node.to_isl().prefix_schedule_relation()
    prefix = prefix.intersect_domain(tree.domain()).reverse()

    tensor_access = dict()
    tensor_stmts = defaultdict(
        lambda: defaultdict(lambda: isl.union_map('{}')))
    tensor_access_types = defaultdict(set)
    for _, stmt in statements.items():
        stmt_tensor_access, _ = stmt.get_access(tensors)
        for k in ('read', 'write'):
            for name, access in stmt_tensor_access[k].items():
                tensor_access_types[name].add(k)
                new_map = prefix.apply_range(access)
                assert new_map.isa_map()
                new_map = isl.map.from_union_map(new_map)
                if name in tensor_access:
                    new_map = new_map.add_map(tensor_access[name])
                    assert new_map.isa_map()
                    new_map = isl.map.from_union_map(new_map)
                tensor_access[name] = new_map
                tensor_stmts[k][name] = tensor_stmts[k][name].union(
                    access.reverse())

    access_count = defaultdict(int)
    for name in tensor_access:
        for access_type in ('read', 'write'):
            ts_maps = tensor_stmts[access_type][name].intersect_range(
                tree.domain())
            stmt_maps = tensor_access[name].apply_range(ts_maps)
            stmts = list()
            stmt_maps.foreach_map(stmts.append)
            for stmt in stmts:
                box = stmt.range_simple_fixed_box_hull()
                box_size, _ = structure_unnamed_fixed_box(box)
                s = stmt.range()
                strides = [int(str(s.stride(i))) for i in range(len(box_size))]
                total = reduce(int.__mul__,
                               [-(-i // j) for i, j in zip(box_size, strides)])
                access_count[name] += total

    usages = []
    for name in tensor_access:
        box = tensor_access[name].range_simple_fixed_box_hull()
        box_size, offset = structure_unnamed_fixed_box(box)
        s = tensor_access[name].range()
        strides = [int(str(s.stride(i))) for i in range(len(box_size))]
        usages.append(
            BlockTensorUsage(tensors[name], box_size, strides, offset,
                             tensor_access_types[name]))

    usages.sort(key=lambda x: x.size_in_bytes)

    res = []
    shared_total_usage = 0
    for i in usages:
        name = i.origin.name
        bytes_usage = i.size_in_bytes
        if bytes_usage * 8 >= access_count[name]:
            continue
        if bytes_usage + shared_total_usage > max_shared_memory:
            break
        shared_total_usage += bytes_usage
        res.append(i)

    return res