Beispiel #1
0
    def wrap_up_realize(self, node, body):
        """Wrap up all the variables which will no longer be used"""
        to_pop = []
        for key, val in self.usage.items():
            _, level, _ = val
            if key not in self.symbols:
                # don't realize the symbols that are never visited
                continue
            if level != node:
                continue
            _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)

            ty, entry = self.symbols[key] #pylint: disable=invalid-name
            if ty in [Symbol.Input, Symbol.OutputBuffer]:
                continue
            if 'Buffer' in ty.name:
                _buf = entry
                _scope = 'global' if ty is Symbol.BufferVar else ty.name[:-6].lower()
                to_pop.append(key)
            else:
                continue

            if _scope == 'global':
                body = self.wrap_up_binds(body)

            _domain = [Range.make_by_min_extent(0, i) for i in _buf.shape]
            _dtype = _buf.dtype
            _true = tvm.runtime.convert(True)
            body = tvm.tir.Realize(_buf.op, 0, _dtype, _domain, _true, body)
            body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', tvm.runtime.convert(_scope), body)

        for elem in to_pop:
            self.symbols.pop(elem)

        return body
Beispiel #2
0
def _get_region(tslice):
    region = []
    for idx in tslice.indices:
        if isinstance(idx, slice):
            assert idx.step is None
            region.append(Range(idx.start, idx.stop))
        else:
            if isinstance(idx, tvm.tir.IterVar):
                begin = idx.var
            else:
                begin = idx
            region.append(Range.make_by_min_extent(begin, 1))
    return region