def thread_binding( begin: PrimExpr, end: PrimExpr = None, thread: str = None, annotations: Optional[Mapping[str, Object]] = None, ): if thread is None: if isinstance( end, str ): # handle case like thread_binding(128, "threadIdx.x") thread = end end = None else: raise ValueError( "Thread cannot be None for thread_binding") if end is None: end = begin begin = 0 thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread) self.create_loop_info( begin, end, ForKind.THREAD_BINDING, thread_binding=thread_iter_var, annotations=annotations, )
def thread_binding( begin: PrimExpr, end: PrimExpr, thread: str, annotations: Optional[Mapping[str, Object]] = None, ): thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread) self.create_loop_info( begin, end, ForKind.THREAD_BINDING, thread_binding=thread_iter_var, annotations=annotations, )
def launch_thread(env_var, extent, span): extent = tvm.runtime.convert(extent, span=span) return tvm.tir.AttrStmt( IterVar( None, env_var, getattr(IterVar, "ThreadIndex"), self.context.func_var_env_dict[env_var], span=span, ), "thread_extent", extent, self.body, span=span, )
def launch_thread(env_var, extent, span): extent = tvm.runtime.convert(extent, span=span) thread_id = self.context.func_var_env_dict[env_var] attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent" return tvm.tir.AttrStmt( IterVar( (0, extent), env_var, getattr(IterVar, "ThreadIndex"), thread_id, span=span, ), attr_key, extent, self.body, span=span, )
def block(axes=None, name_hint: str = "", span: Optional[Span] = None): assert ( self.node and self.context and self.body ), "call 'exit_scope' before 'enter_scope'" block_info = self.context.block_info_stack[-1] if axes is None: axes = [] if len(axes) != len(self.block_vars): self.context.report_error( "Inconsistent number of block vars, " + f"there are {len(axes)} axes but {len(self.block_vars)} block vars. " + "The number of block vars should match the number of axes.", self.node.span, ) block_iters: List[IterVar] = [] for i, axis in enumerate(axes): axis = tvm.runtime.convert(axis) if isinstance(axis, tvm.tir.PrimExpr): block_var_dom = Range.from_min_extent(0, axis) block_iters.append(IterVar(block_var_dom, self.block_vars[i], 0)) elif isinstance(axis, Range): block_iters.append(IterVar(axis, self.block_vars[i], 0)) elif isinstance(axis, IterVar): block_iters.append(IterVar(axis.dom, self.block_vars[i], axis.iter_type)) else: self.context.report_error( "Invalid argument of tir.block(), " + f"expected PrimExpr, Range or IterVar, but got {type(axis)}", self.node.span, ) # create block read/write regions reads: List[BufferRegion] = ( [buffer_slice_to_region(read) for read in block_info.reads] if block_info.reads else [] ) writes: List[BufferRegion] = ( [buffer_slice_to_region(write) for write in block_info.writes] if block_info.writes else [] ) region_detect_mask: int = (block_info.reads is None) | ( (block_info.writes is None) << 1 ) annotations = {} if block_info.annotations is None else block_info.annotations if region_detect_mask != 0: annotations["tir.script_parsing_detect_access"] = region_detect_mask inner = tvm.tir.Block( block_iters, reads, writes, name_hint, self.body, block_info.init, block_info.alloc_buffers, block_info.match_buffers, annotations, span, ) # create block var iter binding values: List[PrimExpr] if not block_info.iter_bindings: values = self.context.loop_stack[-2].copy() if len(block_iters) == 0: # It is an opaque block without any bindings values = [] elif len(values) == 0: values = [tvm.tir.const(float("nan"), dtype="float32")] * len(block_iters) elif len(values) != len(block_iters): self.context.report_error( "Number of block iter var and outer loop nesting mismatch, " + f"{len(block_iters)} block iter vars but {len(values)} loops", self.node.span, ) else: for block_var in self.block_vars: if block_var not in block_info.iter_bindings: self.context.report_error( "Missing block iter var binding for " + block_var.name, self.node.span, ) values = [block_info.iter_bindings[block_var] for block_var in self.block_vars] predicate = ( tvm.tir.const(True, "bool") if block_info.predicate is None else block_info.predicate ) body = tvm.tir.BlockRealize(values, predicate, inner, span) return body