Exemplo n.º 1
0
def get_arch(device_id=0):
    compute_cap = 0
    try:
        import pynvml
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
        compute_cap = compute_cap[0] + compute_cap[1] / 10.
    except ModuleNotFoundError:
        print("NVML not found")
    return compute_cap
Exemplo n.º 2
0
def is_gds_supported(device_id=0):
    global is_gds_supported_var
    if is_gds_supported_var is not None:
        return is_gds_supported_var

    compute_cap = 0
    try:
        import pynvml
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
        compute_cap = compute_cap[0] + compute_cap[1] / 10.
    except ModuleNotFoundError:
        pass

    is_gds_supported_var = platform.processor(
    ) == "x86_64" and compute_cap >= 6.0
    return is_gds_supported_var
Exemplo n.º 3
0
def is_of_supported(device_id=0):
    global is_of_supported_var
    if is_of_supported_var is not None:
        return is_of_supported_var

    compute_cap = 0
    driver_version_major = 0
    try:
        import pynvml
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
        compute_cap = compute_cap[0] + compute_cap[1] / 10.
        driver_version = pynvml.nvmlSystemGetDriverVersion().decode('utf-8')
        driver_version_major = int(driver_version.split('.')[0])
    except ModuleNotFoundError:
        print("NVML not found")

    # there is an issue with OpticalFlow driver in R495 and newer on aarch64 platform
    is_of_supported_var = compute_cap >= 7.5 and (
        platform.machine() == "x86_64" or driver_version_major < 495)
    return is_of_supported_var