Exemple #1
0
def from_source(
        input_func: Union[str, Callable],
        tir_prefix: Optional[List[str]] = None) -> Union[PrimFunc, IRModule]:
    """Parse function or string into PrimFunc or IRModule.

    If possible, pass the TVM script in as a function so that line numbers and
    filename will be accurate.

    Parameters
    ----------
    input_module : Union[str, Callable]
        The python function to be parsed.

    tir_prefix : Optional[List[str]]
        The tir prefix list. Only works for str input, default by "tir" and "T".

    Returns
    -------
    output : Union[Function, Module]
        The Function or Module in IR.
    """
    if isinstance(input_func, str):
        tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
        return to_ast(input_func, TVMDiagnosticCtx(),
                      TVMScriptParser(0, tir_prefix))
    elif inspect.isfunction(input_func):
        _, start_line = inspect.getsourcelines(input_func)
        env: Dict[str, Any] = input_func.__globals__
        namespace = [key for key in env.keys() if env[key] == tir]
        parser = TVMScriptParser(start_line, namespace)
        result = to_ast(input_func, TVMDiagnosticCtx(), parser)
        return result
    else:
        raise TypeError("Only function definitions are supported.")
Exemple #2
0
def to_ast(program: Any) -> Any:
    diag_ctx = synr.PrinterDiagnosticContext()
    transformer = None
    res = synr.to_ast(program, diag_ctx, transformer)
    if isinstance(res, str):
        raise (RuntimeError(res))
    return res
Exemple #3
0
def from_source(src):
    """Parse function or string into TIR.

    If possible, pass the TVM script in as a function so that line numbers and
    filename will be accurate.

    Parameters
    ----------
    src : [str, function, class]
        Pruned source of original script

    Returns
    -------
    functions : PrimFunc or IRModule
        The PrimFunc or IRModule in IR.
    """
    if isinstance(src, str):
        start_line = 0
    else:
        _, start_line = inspect.getsourcelines(src)
    parser = TVMScriptParser(start_line)
    return to_ast(src, TVMDiagnosticCtx(), parser)
Exemple #4
0
def to_ast_err(program: Any) -> Any:
    diag_ctx = ErrorAccumulator()
    transformer = None
    return synr.to_ast(program, diag_ctx, transformer)