Beispiel #1
0
    def __getitem__(self, indices):
        from ..arith import Analyzer  # pylint: disable=import-outside-toplevel
        from .expr import BufferLoad, Ramp  # pylint: disable=import-outside-toplevel
        from .stmt import BufferRegion  # pylint: disable=import-outside-toplevel

        if not isinstance(indices, (tuple, list)):
            indices = [indices]
        if any(isinstance(index, slice) and index.step is None for index in indices):
            region = []
            analyzer = Analyzer()
            for index in indices:
                if isinstance(index, slice):
                    region.append(
                        Range.from_min_extent(
                            index.start, analyzer.simplify(index.stop - index.start)
                        )
                    )
                else:
                    region.append(Range.from_min_extent(index, 1))
            return BufferRegion(self, region)
        else:
            analyzer = Analyzer()
            expr_indices = []
            for index in indices:
                if isinstance(index, slice):
                    lanes = analyzer.simplify(
                        (index.stop - index.start + index.step - 1) // index.step
                    )
                    if lanes == 1:
                        expr_indices.append(index.start)
                    else:
                        expr_indices.append(Ramp(index.start, index.step, int(lanes)))
                else:
                    expr_indices.append(index)
            return BufferLoad(self, expr_indices)
def test_block_access_region_detector():
    block = func.body.block.body.block
    alloc_buffers = func.body.block.alloc_buffers
    buffer_var_map = {buf.data: buf for buf in alloc_buffers}
    ret = tir.analysis.get_block_access_region(block, buffer_var_map)

    tvm.ir.assert_structural_equal(block.reads, ret[0])
    tvm.ir.assert_structural_equal(block.writes, ret[1])
    D = alloc_buffers[-1]
    tvm.ir.assert_structural_equal(
        [tvm.tir.BufferRegion(D, [Range(0, 128), Range(0, 128)])], ret[2])
Beispiel #3
0
def _get_region(tslice):
    region = []
    for idx in tslice.indices:
        if isinstance(idx, slice):
            assert idx.step is None
            region.append(Range(idx.start, idx.stop))
        else:
            if isinstance(idx, tvm.tir.IterVar):
                begin = idx.var
            else:
                begin = idx
            region.append(Range.make_by_min_extent(begin, 1))
    return region
Beispiel #4
0
def test_complete_matmul_original():
    func = matmul_original
    A, B, C = [func.buffer_map[x] for x in func.params]

    block1 = func.body.block.body.body.body[0].block
    assert isinstance(block1, tvm.tir.Block)
    vi, vj = [x.var for x in block1.iter_vars]
    access_C = tvm.tir.BufferRegion(
        C,
        [Range.from_min_extent(vi * 4, 4),
         Range.from_min_extent(vj * 4, 4)])
    tvm.ir.assert_structural_equal(block1.reads, [])
    tvm.ir.assert_structural_equal(block1.writes, [access_C])

    block2 = func.body.block.body.body.body[1].body.block
    assert isinstance(block2, tvm.tir.Block)
    vi, vj, vk = [x.var for x in block2.iter_vars]
    access_A = tvm.tir.BufferRegion(
        A,
        [Range.from_min_extent(vi * 4, 4),
         Range.from_min_extent(vk * 4, 4)])
    access_B = tvm.tir.BufferRegion(
        B,
        [Range.from_min_extent(vj * 4, 4),
         Range.from_min_extent(vk * 4, 4)])
    access_C = tvm.tir.BufferRegion(
        C,
        [Range.from_min_extent(vi * 4, 4),
         Range.from_min_extent(vj * 4, 4)])
    tvm.ir.assert_structural_equal(block2.reads,
                                   [access_C, access_A, access_B])
    tvm.ir.assert_structural_equal(block2.writes, [access_C])
Beispiel #5
0
def test_complete_with_root():
    func = elementwise_with_root
    A, B, C = [func.buffer_map[x] for x in func.params]

    block1 = func.body.block.body[0].body.body.block
    assert isinstance(block1, tvm.tir.Block)
    vi, vj = [x.var for x in block1.iter_vars]

    tvm.ir.assert_structural_equal(
        block1.reads,
        [tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])],
    )
    tvm.ir.assert_structural_equal(
        block1.writes,
        [tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])],
    )

    block2 = func.body.block.body[1].body.body.block
    assert isinstance(block2, tvm.tir.Block)
    vi, vj = [x.var for x in block2.iter_vars]
    tvm.ir.assert_structural_equal(
        block2.reads,
        [tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])],
    )
    tvm.ir.assert_structural_equal(
        block2.writes,
        [tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])],
    )
Beispiel #6
0
    def wrap_up_realize(self, node, body):
        """Wrap up all the variables which will no longer be used"""
        to_pop = []
        for key, val in self.usage.items():
            _, level, _ = val
            if key not in self.symbols:
                # don't realize the symbols that are never visited
                continue
            if level != node:
                continue
            _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)

            ty, entry = self.symbols[key]  # pylint: disable=invalid-name
            if ty in [Symbol.Input, Symbol.OutputBuffer]:
                continue
            if "Buffer" in ty.name:
                _buf = entry
                _scope = "global" if ty is Symbol.BufferVar else ty.name[:-6].lower()
                to_pop.append(key)
            else:
                continue

            if _scope == "global":
                body = self.wrap_up_binds(body)

            _domain = [Range.from_min_extent(0, i) for i in _buf.shape]
            _dtype = _buf.dtype
            _true = tvm.runtime.convert(True)
            body = tvm.tir.ProducerRealize(_buf, _domain, _true, body, tvm.runtime.convert(_scope))

        for elem in to_pop:
            self.symbols.pop(elem)

        return body
        def realize(buffer_slice: BufferSlice,
                    scope: str,
                    condition: bool = True,
                    span: bool = None):
            assert self.context, "call 'exit_scope' before 'enter_scope'"
            buffer: Buffer = buffer_slice.buffer
            bounds: List[Range] = []
            for s in buffer_slice.slices:
                min: Union[PrimExpr, int] = s.start
                extent: Union[PrimExpr,
                              int] = 1 if s.stop is None else s.stop - s.start
                if isinstance(extent, PrimExpr):
                    extent = self.context.analyzer.simplify(extent)
                bounds.append(Range.from_min_extent(min, extent, span=s.span))

            scope = tvm.runtime.convert(scope, span=span)
            return tvm.tir.AttrStmt(
                buffer,
                "realize_scope",
                scope,
                tvm.tir.BufferRealize(buffer,
                                      bounds,
                                      condition,
                                      self.body,
                                      span=span),
                span=span,
            )
Beispiel #8
0
def buffer_slice_to_region(
        buffer_slice: BufferSlice,
        analyzer: Optional[Analyzer] = None) -> BufferRegion:
    """Construct BufferRegion from BufferSlice

    Parameters
    ----------
    buffer_slice : BufferSlice
        The input BufferSlice

    analyzer : Optional[tvm.arith.Analyzer]
        The analyzer for simplifying. If not provided, the method will construct a new one

    Returns
    -------
    buffer_region : BufferRegion
        The constructed BufferRegion.
    """
    region: List[Range] = []
    for s in buffer_slice.slices:
        start = s.start if isinstance(s.start, PrimExpr) else IntImm(
            "int32", s.start)
        extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start
        if not analyzer:
            analyzer = Analyzer()
        if isinstance(extent, PrimExpr):
            extent = analyzer.simplify(extent)
        region.append(Range.from_min_extent(start, extent, span=s.span))
    return BufferRegion(buffer_slice.buffer, region)
Beispiel #9
0
    def as_buffer_region(self,
                         analyzer: Optional[Analyzer] = None) -> BufferRegion:
        """Construct BufferRegion from BufferSlice

        Parameters
        ----------
        analyzer : Optional[tvm.arith.Analyzer]
            The analyzer for simplifying. If not provided, the method will construct a new one

        Returns
        -------
        buffer_region : BufferRegion
            The constructed BufferRegion.
        """
        region: List[Range] = []
        for s in self.slices:
            start = s.start if isinstance(s.start, PrimExpr) else IntImm(
                "int32", s.start)
            extent = IntImm(start.dtype,
                            1) if s.stop is None else s.stop - s.start
            if not analyzer:
                analyzer = Analyzer()
            if isinstance(extent, PrimExpr):
                extent = analyzer.simplify(extent)
            if s.step != 1:
                self.report_error(
                    "BufferRegion do not support non-trivial stride", s.span)
            region.append(Range.from_min_extent(start, extent, span=s.span))
        return BufferRegion(self.buffer, region)
Beispiel #10
0
    def enter_scope(
        self,
        node: synr.ast.Node,
        context: ContextMaintainer,
        arg_list: List[Any],
        span: synr.ast.Span,
    ):
        assert isinstance(
            node, synr.ast.For
        ), f"ForScopeHandler expected synr.ast.For but got {type(node)}"

        loop_var_names = list()
        spans = list()
        if isinstance(node.lhs, synr.ast.Var):
            loop_var_names.append(node.lhs.id.name)
            spans.append(tvm_span_from_synr(node.lhs.id.span))
        elif isinstance(node.lhs, list):
            for elt in node.lhs:
                if not isinstance(elt, synr.ast.Var):
                    context.report_error(
                        f"Invalid loop var. Expected a var, but got {type(elt)}",
                        elt.span)
                loop_var_names.append(elt.id.name)
                spans.append(tvm_span_from_synr(elt.id.span))
        else:
            context.report_error(
                f"Invalid loop var. Expected var or list of vars as lhs, but got {type(node.lhs)}",
                span,
            )

        self.node = node
        self.context = context
        # collect loop infos by calling self.func
        call_with_error_reporting(context.report_error, span, self.func,
                                  *arg_list)
        if len(loop_var_names) != len(self.loop_info):
            self.context.report_error(
                f"Inconsistent number of vars and loops, got {len(loop_var_names)} "
                + f"vs {len(self.loop_info)}",
                self.node.span,
            )
        # generate loop vars
        self.loop_vars = []
        for name, lv_span, li in zip(loop_var_names, spans, self.loop_info):
            if not li.begin.dtype.startswith("int"):
                raise NotImplementedError(
                    f"Unsupported dtype in loop begin: {li.begin.dtype}")
            if not li.extent.dtype.startswith("int"):
                raise NotImplementedError(
                    f"Unsupported dtype in loop extent: {li.extent.dtype}")
            dtype = "int64" if "int64" in [li.begin.dtype, li.extent.dtype
                                           ] else "int32"
            self.loop_vars.append(tvm.te.var(name, dtype=dtype, span=lv_span))

        for loop_var, loop_info in zip(self.loop_vars, self.loop_info):
            context.update_symbol(loop_var.name, loop_var, node)
            context.loop_stack[loop_var] = Range.from_min_extent(
                loop_info.begin, loop_info.extent)
Beispiel #11
0
def test_complete_matmul():
    func = matmul
    A, B, C = [func.buffer_map[x] for x in func.params]

    block = func.body.block.body.body.body.body.block
    assert isinstance(block, tvm.tir.Block)
    vi, vj, vk = [x.var for x in block.iter_vars]
    access_A = tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vk, 1)])
    access_B = tir.BufferRegion(B, [Range.from_min_extent(vj, 1), Range.from_min_extent(vk, 1)])
    access_C = tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])
    tvm.ir.assert_structural_equal(block.reads, [access_C, access_A, access_B])
    tvm.ir.assert_structural_equal(block.writes, [access_C])
Beispiel #12
0
    def non_surjective_inverse(
            self, shape: List[Union[Range,
                                    PrimExpr]]) -> Tuple["IndexMap", PrimExpr]:
        """Return the inverse of the map

        Can be applied to transformations that introduce padding.

        Parameters
        ----------
        shape: List[Union[Range,PrimExpr]]

            The region over which the inverse should be determined.
            Used for determining the predicate.

        Returns
        -------
        result : Tuple[IndexMap, PrimExpr]

            The inverse, and a predicate for which the inverse maps to
            a valid index in the input range.

        Examples
        --------

        .. code-block:: python

            index_map = IndexMap.from_func(lambda i: [i//4, i%4])
            inverse_map, predicate = index_map.non_surjective_inverse([14])
            assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k])
            print(predicate) # Prints "(axis0==3) && (axis2 >= 2)"
        """

        shape = [
            dim if isinstance(dim, Range) else Range(0, dim) for dim in shape
        ]
        return _ffi_api.IndexMapNonSurjectiveInverse(self, shape)
Beispiel #13
0
    def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap":
        """Return the inverse of the map

        Throws an error if the function is not bijective.

        Parameters
        ----------
        shape: List[Union[Range,PrimExpr]]

            The region over which the inverse should be determined.
            Used for validating that the mapping is bijective over
            this range.

        Returns
        -------
        inverse : IndexMap

            The inverse
        """

        shape = [
            dim if isinstance(dim, Range) else Range(0, dim) for dim in shape
        ]
        return _ffi_api.IndexMapInverse(self, shape)
Beispiel #14
0
        def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
            assert (
                self.node and self.context and self.body
            ), "call 'exit_scope' before 'enter_scope'"
            block_info = self.context.block_info_stack[-1]
            if axes is None:
                axes = []
            if len(axes) != len(self.block_vars):
                self.context.report_error(
                    "Inconsistent number of block vars, "
                    + f"there are {len(axes)} axes but {len(self.block_vars)} block vars. "
                    + "The number of block vars should match the number of axes.",
                    self.node.span,
                )
            block_iters: List[IterVar] = []
            for i, axis in enumerate(axes):
                axis = tvm.runtime.convert(axis)
                if isinstance(axis, tvm.tir.PrimExpr):
                    block_var_dom = Range.from_min_extent(0, axis)
                    block_iters.append(IterVar(block_var_dom, self.block_vars[i], 0))
                elif isinstance(axis, Range):
                    block_iters.append(IterVar(axis, self.block_vars[i], 0))
                elif isinstance(axis, IterVar):
                    block_iters.append(IterVar(axis.dom, self.block_vars[i], axis.iter_type))
                else:
                    self.context.report_error(
                        "Invalid argument of tir.block(), "
                        + f"expected PrimExpr, Range or IterVar, but got {type(axis)}",
                        self.node.span,
                    )

            # create block read/write regions

            reads: List[BufferRegion] = (
                [buffer_slice_to_region(read) for read in block_info.reads]
                if block_info.reads
                else []
            )
            writes: List[BufferRegion] = (
                [buffer_slice_to_region(write) for write in block_info.writes]
                if block_info.writes
                else []
            )

            region_detect_mask: int = (block_info.reads is None) | (
                (block_info.writes is None) << 1
            )
            annotations = {} if block_info.annotations is None else block_info.annotations
            if region_detect_mask != 0:
                annotations["tir.script_parsing_detect_access"] = region_detect_mask
            inner = tvm.tir.Block(
                block_iters,
                reads,
                writes,
                name_hint,
                self.body,
                block_info.init,
                block_info.alloc_buffers,
                block_info.match_buffers,
                annotations,
                span,
            )
            # create block var iter binding
            values: List[PrimExpr]
            if not block_info.iter_bindings:
                values = self.context.loop_stack[-2].copy()
                if len(block_iters) == 0:
                    # It is an opaque block without any bindings
                    values = []
                elif len(values) == 0:
                    values = [tvm.tir.const(float("nan"), dtype="float32")] * len(block_iters)
                elif len(values) != len(block_iters):
                    self.context.report_error(
                        "Number of block iter var and outer loop nesting mismatch, "
                        + f"{len(block_iters)} block iter vars but {len(values)} loops",
                        self.node.span,
                    )
            else:
                for block_var in self.block_vars:
                    if block_var not in block_info.iter_bindings:
                        self.context.report_error(
                            "Missing block iter var binding for " + block_var.name,
                            self.node.span,
                        )
                values = [block_info.iter_bindings[block_var] for block_var in self.block_vars]
            predicate = (
                tvm.tir.const(True, "bool")
                if block_info.predicate is None
                else block_info.predicate
            )
            body = tvm.tir.BlockRealize(values, predicate, inner, span)
            return body