def get_device(libmod, device):
    """Parse and validate all the device(s).

    Parameters
    ----------
    libmod : tvm.runtime.Module
        The module of the corresponding function

    device : Device or list of Device

    Returns
    -------
    device : list of Device
    num_rpc_dev : Number of rpc devices
    device_type_id : List of device type and device id
    """

    if isinstance(device, Device):
        device = [device]
    elif not isinstance(device, (list, tuple)):
        raise ValueError(
            "dev has to be the type of Device or a list of Device")
    for cur_dev in device:
        if not isinstance(cur_dev, Device):
            raise ValueError(
                "dev has to be the type of Device or a list of Device")

    # device_type_id[0], device_type_id[1] are used as the primary/fallback
    # device type and id. All other ones are used as device for
    # heterogeneous execution.
    num_rpc_dev = 0
    device_type_id = []
    for cur_dev in device:
        device_type = cur_dev.device_type
        if device_type >= rpc_base.RPC_SESS_MASK:
            assert libmod.type_key == "rpc"
            assert _rpc_ffi_api.SessTableIndex(
                libmod) == cur_dev._rpc_sess._tbl_index
            num_rpc_dev += 1
            device_type = cur_dev.device_type % rpc_base.RPC_SESS_MASK
        device_type_id.append(device_type)
        device_type_id.append(cur_dev.device_id)

    if 0 < num_rpc_dev < len(device):
        raise ValueError("Either all or none of the devices should be rpc.")
    return device, num_rpc_dev, device_type_id
Exemple #2
0
def get_device_ctx(libmod, ctx):
    """Parse and validate all the device context(s).

    Parameters
    ----------
    libmod : tvm.runtime.Module
        The module of the corresponding function

    ctx : TVMContext or list of TVMContext

    Returns
    -------
    ctx : list of TVMContext
    num_rpc_ctx : Number of rpc contexts
    device_type_id : List of device type and device id
    """

    if isinstance(ctx, TVMContext):
        ctx = [ctx]
    elif not isinstance(ctx, (list, tuple)):
        raise ValueError("ctx has to be the type of TVMContext or a list of "
                         "TVMCTVMContext")
    for cur_ctx in ctx:
        if not isinstance(cur_ctx, TVMContext):
            raise ValueError("ctx has to be the type of TVMContext or a list "
                             "of TVMContext")

    # device_type_id[0], device_type_id[1] are used as the primary/fallback
    # context type and id. All other ones are used as device context for
    # heterogeneous execution.
    num_rpc_ctx = 0
    device_type_id = []
    for cur_ctx in ctx:
        device_type = cur_ctx.device_type
        if device_type >= rpc_base.RPC_SESS_MASK:
            assert libmod.type_key == "rpc"
            assert _rpc_ffi_api.SessTableIndex(
                libmod) == cur_ctx._rpc_sess._tbl_index
            num_rpc_ctx += 1
            device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK
        device_type_id.append(device_type)
        device_type_id.append(cur_ctx.device_id)

    if 0 < num_rpc_ctx < len(ctx):
        raise ValueError("Either all or none of the contexts should be rpc.")
    return ctx, num_rpc_ctx, device_type_id