def get_device(libmod, device): """Parse and validate all the device(s). Parameters ---------- libmod : tvm.runtime.Module The module of the corresponding function device : Device or list of Device Returns ------- device : list of Device num_rpc_dev : Number of rpc devices device_type_id : List of device type and device id """ if isinstance(device, Device): device = [device] elif not isinstance(device, (list, tuple)): raise ValueError( "dev has to be the type of Device or a list of Device") for cur_dev in device: if not isinstance(cur_dev, Device): raise ValueError( "dev has to be the type of Device or a list of Device") # device_type_id[0], device_type_id[1] are used as the primary/fallback # device type and id. All other ones are used as device for # heterogeneous execution. num_rpc_dev = 0 device_type_id = [] for cur_dev in device: device_type = cur_dev.device_type if device_type >= rpc_base.RPC_SESS_MASK: assert libmod.type_key == "rpc" assert _rpc_ffi_api.SessTableIndex( libmod) == cur_dev._rpc_sess._tbl_index num_rpc_dev += 1 device_type = cur_dev.device_type % rpc_base.RPC_SESS_MASK device_type_id.append(device_type) device_type_id.append(cur_dev.device_id) if 0 < num_rpc_dev < len(device): raise ValueError("Either all or none of the devices should be rpc.") return device, num_rpc_dev, device_type_id
def get_device_ctx(libmod, ctx): """Parse and validate all the device context(s). Parameters ---------- libmod : tvm.runtime.Module The module of the corresponding function ctx : TVMContext or list of TVMContext Returns ------- ctx : list of TVMContext num_rpc_ctx : Number of rpc contexts device_type_id : List of device type and device id """ if isinstance(ctx, TVMContext): ctx = [ctx] elif not isinstance(ctx, (list, tuple)): raise ValueError("ctx has to be the type of TVMContext or a list of " "TVMCTVMContext") for cur_ctx in ctx: if not isinstance(cur_ctx, TVMContext): raise ValueError("ctx has to be the type of TVMContext or a list " "of TVMContext") # device_type_id[0], device_type_id[1] are used as the primary/fallback # context type and id. All other ones are used as device context for # heterogeneous execution. num_rpc_ctx = 0 device_type_id = [] for cur_ctx in ctx: device_type = cur_ctx.device_type if device_type >= rpc_base.RPC_SESS_MASK: assert libmod.type_key == "rpc" assert _rpc_ffi_api.SessTableIndex( libmod) == cur_ctx._rpc_sess._tbl_index num_rpc_ctx += 1 device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK device_type_id.append(device_type) device_type_id.append(cur_ctx.device_id) if 0 < num_rpc_ctx < len(ctx): raise ValueError("Either all or none of the contexts should be rpc.") return ctx, num_rpc_ctx, device_type_id