def visit_FunctionDef(self, node): """FunctionDef visitor AST abstract grammar: FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list, expr? returns, string? type_comment) arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, expr* kw_defaults, arg? kwarg, expr* defaults) arg = (identifier arg, expr? annotation, string? type_comment) """ self.init_function_parsing_env() # add parameters of function for arg in node.args.args: arg_var = tvm.te.var(arg.arg, self.parse_type(arg.annotation)) self.scope_emitter.update_symbol(arg.arg, arg_var) self.params.append(arg_var) # visit the body of function self.scope_emitter.node_stack[-1].extend(reversed(node.body)) # fetch the body and return a tir.PrimFunc func = tvm.tir.PrimFunc( self.params, self.get_body(), ret_type=self.parse_type(node.returns), buffer_map=self.buffer_map, attrs=tvm.ir.make_node("DictAttrs", **self.dict_attr), ) self.functions[GlobalVar(node.name)] = func return func
def test_global_var_supply_from_none(): var_supply = GlobalVarSupply() global_var = GlobalVar("test") var_supply.reserve_global(global_var) assert structural_equal(var_supply.unique_global_for("test"), global_var) assert not structural_equal(var_supply.fresh_global("test"), global_var)
def test_global_var_supply_from_name_supply(): name_supply = NameSupply("prefix") var_supply = GlobalVarSupply(name_supply) global_var = GlobalVar("test") var_supply.reserve_global(global_var) assert structural_equal(var_supply.unique_global_for("test", False), global_var) assert not structural_equal(var_supply.unique_global_for("test"), global_var)
def test_global_var_supply_from_ir_mod(): x = relay.var("x") y = relay.var("y") mod = tvm.IRModule() global_var = GlobalVar("test") mod[global_var] = relay.Function([x, y], relay.add(x, y)) var_supply = GlobalVarSupply(mod) second_global_var = var_supply.fresh_global("test", False) assert structural_equal(var_supply.unique_global_for("test", False), global_var) assert not structural_equal(var_supply.unique_global_for("test"), global_var) assert not structural_equal(second_global_var, global_var)
def transform_Class(self, node): """Class definition visitor. A class can have multiple function definitions and a single :code:`__tvm_meta__` statement. Each class corresponds to a single :code:`IRModule`. Example ------- .. code-block:: python @tvm.script.ir_module class MyClass: __tvm_meta__ = {} def A(): T.evaluate(0) """ if len(node.assignments) == 1: if not ( len(node.assignments[0].lhs) == 1 and isinstance(node.assignments[0].lhs[0], ast.Var) and node.assignments[0].lhs[0].id.name == "__tvm_meta__" ): self.report_error( "The only top level assignments allowed are `__tvm_meta__ = ...`", node.assignments[0].span, ) self.init_meta( MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context) ) elif len(node.assignments) > 1: self.report_error( "Only a single top level `__tvm_meta__` is allowed", ast.Span.union([x.span for x in node.assignments[1:]]), ) return IRModule( {GlobalVar(name): self.transform(func) for name, func in node.funcs.items()} )