Example #1
0
def visit_list_to_block(visit, lst):
    """Visit and concatenate a list of Python IR nodes to HalideIR Block"""
    lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
    lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
    if not lst:
        return util.make_nop()
    return concat_list_to_block(lst)
Example #2
0
    def visit_For(self, node):
        iter_var, low, ext, for_type = self.visit(node.iter)
        _internal_assert(isinstance(node.target, ast.Name), \
                         "The loop iterator should be a variable!")

        _name = node.target.id

        if isinstance(for_type, tuple):
            low = _ir_pass.CanonicalSimplify(low)
            ext = _ir_pass.CanonicalSimplify(ext)
            _internal_assert(isinstance(low, _expr.ConstExpr) and
                             isinstance(ext, _expr.ConstExpr), \
                             "Const range should start from a const " + \
                             "and iterate const times")

            low, ext = low.value, ext.value
            if ext > 114514:
                logging.log(logging.CRITICAL, \
                            '[Warning] Are you sure to unroll a large loop in Python?')

            bodies = []
            for i in range(low, low + ext):
                self.add_symbol(_name, Symbol.ConstLoopVar, i)
                body = visit_list_to_block(self.visit, node.body)
                body = self.wrap_up_realize(node, body)
                bodies.append(body)
                self.symbols.pop(_name)
            return concat_list_to_block(bodies)

        if iter_var is None:
            _internal_assert(for_type is not None,
                             "The loop iterating function parse error!")
            offset = iter_var = _api.var(_name)
            if not _ir_pass.Equal(low, _api.const(0, 'int32')):
                offset = iter_var + low
            self.add_symbol(_name, Symbol.LoopVar, offset)
            _body = visit_list_to_block(self.visit, node.body)
        else:
            _internal_assert(for_type is None,
                             "The loop bind function parse error!")
            self.add_symbol(_name, Symbol.ThreadBind, iter_var)
            self.device += 1
            _body = visit_list_to_block(self.visit, node.body)
            self.device -= 1

        _body = self.wrap_up_realize(node, _body)

        if for_type is None:
            res = _body
        else:
            _internal_assert(not isinstance(for_type, tuple), \
                            "Micro expansion should be handled before!")
            res = tvm.tir.For(iter_var, _api.const(0, 'int32'), ext, for_type,
                              0, _body)

        self.symbols.pop(_name)
        return res
Example #3
0
def _range(annotation, args):
    """Handling TVM loop types"""
    n = args.__len__()
    if n == 1:
        low, ext = _api.const(0, dtype='int32'), args[0]
    else:
        _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
        low, ext = args[0], args[1]
    if not ir_pass.Equal(low, _api.const(0, dtype='int32')):
        ext = ext - low
    for_type = LOOP_INTRIN[annotation]
    iter_var = None
    return iter_var, low, ext, for_type
Example #4
0
    def add_symbol(self, key, ty, val): #pylint: disable=invalid-name
        """Add value to the symbol table context"""
        if key in self.symbols.keys():
            old = str(self.symbols[key])
            new = str((ty, val))
            _internal_assert(False,
                             "Name conflict in symbol table! [%s] %s -> %s" % (key, old, new))

        self.symbols[key] = ty, val

        if ty == Symbol.ThreadBind:
            if val.var.name not in self.binds.keys():
                self.binds[val.var.name] = val
                return
            val_ = self.binds[val.var.name]
            _internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent),
                             "Thread extents should be uniform!")
            self.symbols[key] = ty, val_