コード例 #1
0
ファイル: utils.py プロジェクト: lfengad/incubator-tvm
def get_global_func_on_rpc_session(
    session: RPCSession,
    name: str,
    extra_error_msg: Optional[str] = None,
) -> PackedFunc:
    """Get a PackedFunc from the global registry from an RPCSession.

    Parameters
    ----------
    session : RPCSession
        The RPCSession to be retrieved from
    name : str
        The name of the PackedFunc
    extra_error_msg : Optional[str]
        Extra information to provide in the error message

    Returns
    -------
    result : PackedFunc
        The result
    """
    try:
        result = session.get_function(name)
    except AttributeError as error:
        error_msg = f'Unable to find function "{name}" on the remote RPC server.'
        if extra_error_msg:
            error_msg = f"{error_msg} {extra_error_msg}"
        raise AttributeError(error_msg) from error
    return result
コード例 #2
0
ファイル: rpc_runner.py プロジェクト: junrushao1994/tvm
def default_upload_module(
    session: RPCSession,
    local_path: str,
    remote_path: str,
) -> Module:
    """Default function to upload the module

    Parameters
    ----------
    session: RPCSession
        The session to upload the module
    local_path: str
        The local path of the module
    remote_path: str
        The remote path to place the module

    Returns
    -------
    rt_mod : Module
        The runtime module
    """
    session.upload(local_path, remote_path)
    rt_mod: Module = session.load_module(remote_path)
    return rt_mod