Ejemplo n.º 1
0
    def visit_FunctionDef(self, node):
        """FunctionDef visitor
        AST abstract grammar:
            FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list,
                        expr? returns, string? type_comment)
            arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs,
                         expr* kw_defaults, arg? kwarg, expr* defaults)
            arg = (identifier arg, expr? annotation, string? type_comment)
        """

        self.init_function_parsing_env()
        # add parameters of function
        for arg in node.args.args:
            arg_var = tvm.te.var(arg.arg, self.parse_type(arg.annotation))
            self.scope_emitter.update_symbol(arg.arg, arg_var)
            self.params.append(arg_var)

        # visit the body of function
        self.scope_emitter.node_stack[-1].extend(reversed(node.body))

        # fetch the body and return a tir.PrimFunc
        func = tvm.tir.PrimFunc(
            self.params,
            self.get_body(),
            ret_type=self.parse_type(node.returns),
            buffer_map=self.buffer_map,
            attrs=tvm.ir.make_node("DictAttrs", **self.dict_attr),
        )
        self.functions[GlobalVar(node.name)] = func
        return func
Ejemplo n.º 2
0
def test_global_var_supply_from_none():
    var_supply = GlobalVarSupply()
    global_var = GlobalVar("test")
    var_supply.reserve_global(global_var)

    assert structural_equal(var_supply.unique_global_for("test"), global_var)
    assert not structural_equal(var_supply.fresh_global("test"), global_var)
Ejemplo n.º 3
0
def test_global_var_supply_from_name_supply():
    name_supply = NameSupply("prefix")
    var_supply = GlobalVarSupply(name_supply)
    global_var = GlobalVar("test")
    var_supply.reserve_global(global_var)

    assert structural_equal(var_supply.unique_global_for("test", False),
                            global_var)
    assert not structural_equal(var_supply.unique_global_for("test"),
                                global_var)
Ejemplo n.º 4
0
def test_global_var_supply_from_ir_mod():
    x = relay.var("x")
    y = relay.var("y")
    mod = tvm.IRModule()
    global_var = GlobalVar("test")
    mod[global_var] = relay.Function([x, y], relay.add(x, y))
    var_supply = GlobalVarSupply(mod)

    second_global_var = var_supply.fresh_global("test", False)

    assert structural_equal(var_supply.unique_global_for("test", False),
                            global_var)
    assert not structural_equal(var_supply.unique_global_for("test"),
                                global_var)
    assert not structural_equal(second_global_var, global_var)
Ejemplo n.º 5
0
    def transform_Class(self, node):
        """Class definition visitor.

        A class can have multiple function definitions and a single
        :code:`__tvm_meta__` statement. Each class corresponds to a single
        :code:`IRModule`.

        Example
        -------
        .. code-block:: python

            @tvm.script.ir_module
            class MyClass:
                __tvm_meta__ = {}
                def A():
                    T.evaluate(0)
        """
        if len(node.assignments) == 1:
            if not (
                len(node.assignments[0].lhs) == 1
                and isinstance(node.assignments[0].lhs[0], ast.Var)
                and node.assignments[0].lhs[0].id.name == "__tvm_meta__"
            ):
                self.report_error(
                    "The only top level assignments allowed are `__tvm_meta__ = ...`",
                    node.assignments[0].span,
                )
            self.init_meta(
                MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context)
            )
        elif len(node.assignments) > 1:
            self.report_error(
                "Only a single top level `__tvm_meta__` is allowed",
                ast.Span.union([x.span for x in node.assignments[1:]]),
            )

        return IRModule(
            {GlobalVar(name): self.transform(func) for name, func in node.funcs.items()}
        )