예제 #1
0
 def thread_binding(
     begin: PrimExpr,
     end: PrimExpr = None,
     thread: str = None,
     annotations: Optional[Mapping[str, Object]] = None,
 ):
     if thread is None:
         if isinstance(
                 end, str
         ):  # handle case like thread_binding(128, "threadIdx.x")
             thread = end
             end = None
         else:
             raise ValueError(
                 "Thread cannot be None for thread_binding")
     if end is None:
         end = begin
         begin = 0
     thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread)
     self.create_loop_info(
         begin,
         end,
         ForKind.THREAD_BINDING,
         thread_binding=thread_iter_var,
         annotations=annotations,
     )
예제 #2
0
 def thread_binding(
     begin: PrimExpr,
     end: PrimExpr,
     thread: str,
     annotations: Optional[Mapping[str, Object]] = None,
 ):
     thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread)
     self.create_loop_info(
         begin,
         end,
         ForKind.THREAD_BINDING,
         thread_binding=thread_iter_var,
         annotations=annotations,
     )
예제 #3
0
 def launch_thread(env_var, extent, span):
     extent = tvm.runtime.convert(extent, span=span)
     return tvm.tir.AttrStmt(
         IterVar(
             None,
             env_var,
             getattr(IterVar, "ThreadIndex"),
             self.context.func_var_env_dict[env_var],
             span=span,
         ),
         "thread_extent",
         extent,
         self.body,
         span=span,
     )
예제 #4
0
 def launch_thread(env_var, extent, span):
     extent = tvm.runtime.convert(extent, span=span)
     thread_id = self.context.func_var_env_dict[env_var]
     attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent"
     return tvm.tir.AttrStmt(
         IterVar(
             (0, extent),
             env_var,
             getattr(IterVar, "ThreadIndex"),
             thread_id,
             span=span,
         ),
         attr_key,
         extent,
         self.body,
         span=span,
     )
예제 #5
0
        def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
            assert (
                self.node and self.context and self.body
            ), "call 'exit_scope' before 'enter_scope'"
            block_info = self.context.block_info_stack[-1]
            if axes is None:
                axes = []
            if len(axes) != len(self.block_vars):
                self.context.report_error(
                    "Inconsistent number of block vars, "
                    + f"there are {len(axes)} axes but {len(self.block_vars)} block vars. "
                    + "The number of block vars should match the number of axes.",
                    self.node.span,
                )
            block_iters: List[IterVar] = []
            for i, axis in enumerate(axes):
                axis = tvm.runtime.convert(axis)
                if isinstance(axis, tvm.tir.PrimExpr):
                    block_var_dom = Range.from_min_extent(0, axis)
                    block_iters.append(IterVar(block_var_dom, self.block_vars[i], 0))
                elif isinstance(axis, Range):
                    block_iters.append(IterVar(axis, self.block_vars[i], 0))
                elif isinstance(axis, IterVar):
                    block_iters.append(IterVar(axis.dom, self.block_vars[i], axis.iter_type))
                else:
                    self.context.report_error(
                        "Invalid argument of tir.block(), "
                        + f"expected PrimExpr, Range or IterVar, but got {type(axis)}",
                        self.node.span,
                    )

            # create block read/write regions

            reads: List[BufferRegion] = (
                [buffer_slice_to_region(read) for read in block_info.reads]
                if block_info.reads
                else []
            )
            writes: List[BufferRegion] = (
                [buffer_slice_to_region(write) for write in block_info.writes]
                if block_info.writes
                else []
            )

            region_detect_mask: int = (block_info.reads is None) | (
                (block_info.writes is None) << 1
            )
            annotations = {} if block_info.annotations is None else block_info.annotations
            if region_detect_mask != 0:
                annotations["tir.script_parsing_detect_access"] = region_detect_mask
            inner = tvm.tir.Block(
                block_iters,
                reads,
                writes,
                name_hint,
                self.body,
                block_info.init,
                block_info.alloc_buffers,
                block_info.match_buffers,
                annotations,
                span,
            )
            # create block var iter binding
            values: List[PrimExpr]
            if not block_info.iter_bindings:
                values = self.context.loop_stack[-2].copy()
                if len(block_iters) == 0:
                    # It is an opaque block without any bindings
                    values = []
                elif len(values) == 0:
                    values = [tvm.tir.const(float("nan"), dtype="float32")] * len(block_iters)
                elif len(values) != len(block_iters):
                    self.context.report_error(
                        "Number of block iter var and outer loop nesting mismatch, "
                        + f"{len(block_iters)} block iter vars but {len(values)} loops",
                        self.node.span,
                    )
            else:
                for block_var in self.block_vars:
                    if block_var not in block_info.iter_bindings:
                        self.context.report_error(
                            "Missing block iter var binding for " + block_var.name,
                            self.node.span,
                        )
                values = [block_info.iter_bindings[block_var] for block_var in self.block_vars]
            predicate = (
                tvm.tir.const(True, "bool")
                if block_info.predicate is None
                else block_info.predicate
            )
            body = tvm.tir.BlockRealize(values, predicate, inner, span)
            return body