Exemple #1
0
def test_nvidia_device(idx: int):
    from py3nvml import py3nvml as nvml

    handle = nvml.nvmlDeviceGetHandleByIndex(idx)

    pciInfo = nvml.nvmlDeviceGetPciInfo(handle)

    brands = {
        nvml.NVML_BRAND_UNKNOWN: "Unknown",
        nvml.NVML_BRAND_QUADRO: "Quadro",
        nvml.NVML_BRAND_TESLA: "Tesla",
        nvml.NVML_BRAND_NVS: "NVS",
        nvml.NVML_BRAND_GRID: "Grid",
        nvml.NVML_BRAND_GEFORCE: "GeForce"
    }

    inspect(
        idx=idx,
        # id=pciInfo.busId,
        # uuid=nvml.nvmlDeviceGetUUID(handle),
        name=nvml.nvmlDeviceGetName(handle),
        # brand=brands[nvml.nvmlDeviceGetBrand(handle)],
        # multi_gpu=nvml.nvmlDeviceGetMultiGpuBoard(handle),
        # pcie_link=nvml.nvmlDeviceGetCurrPcieLinkWidth(handle),
        fan=nvml.nvmlDeviceGetFanSpeed(handle),
        # power=nvml.nvmlDeviceGetPowerState(handle),
        mem_total=nvml.nvmlDeviceGetMemoryInfo(handle).total,
        mem_used=nvml.nvmlDeviceGetMemoryInfo(handle).used,
        util_gpu=nvml.nvmlDeviceGetUtilizationRates(handle).gpu,
        # util_mem=nvml.nvmlDeviceGetUtilizationRates(handle).memory,
        temp=nvml.nvmlDeviceGetTemperature(handle, nvml.NVML_TEMPERATURE_GPU),
        power=nvml.nvmlDeviceGetPowerUsage(handle),
        power_limit=nvml.nvmlDeviceGetPowerManagementLimit(handle),

        # display=nvml.nvmlDeviceGetDisplayMode(handle),
        display_active=nvml.nvmlDeviceGetDisplayActive(handle),
    )

    logger.log()

    procs = nvml.nvmlDeviceGetGraphicsRunningProcesses(handle)
    for p in procs:
        inspect(name=nvml.nvmlSystemGetProcessName(p.pid),
                pid=p.pid,
                mem=p.usedGpuMemory)

    procs = nvml.nvmlDeviceGetComputeRunningProcesses(handle)
    for p in procs:
        inspect(name=nvml.nvmlSystemGetProcessName(p.pid),
                pid=p.pid,
                mem=p.usedGpuMemory)

    logger.log()
Exemple #2
0
def get_mem(device_handle):
    """Get GPU device memory consumption in percent."""
    try:
        memory_info = pynvml.nvmlDeviceGetMemoryInfo(device_handle)
        return memory_info.used * 100.0 / memory_info.total
    except pynvml.NVMLError:
        return None
Exemple #3
0
def getCUDAEnvironment():
    """ Get the CUDA runtime environment parameters (number of cards etc.). """

    rdict = dict()
    rdict['first_available_device_index'] = None
    rdict['device_count'] = 0

    try:
        nvml.nvmlInit()
        rdict['device_count'] = nvml.nvmlDeviceGetCount()

    except Exception:
        print(
            'WARNING: At least one of (py3nvml.nvml, CUDA) is not available. Will continue without GPU.'
        )
        return rdict

    for i in range(rdict['device_count']):
        memory_info = nvml.nvmlDeviceGetMemoryInfo(
            nvml.nvmlDeviceGetHandleByIndex(i))
        memory_usage_percentage = memory_info.used / memory_info.total

        if memory_usage_percentage <= 0.1:
            rdict['first_available_device_index'] = i
            break

    nvml.nvmlShutdown()

    return rdict
Exemple #4
0
 def inference_speed_memory(self, batch_size, seq_length):
     # input_ids = np.random.randint(0, self.vocab_size, (batch_size, seq_length))
     key = jax.random.PRNGKey(0)
     input_ids = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size)
     @jax.jit
     def ref_step():
         out = self.model(input_ids=input_ids)
         return out[0]
     if jax.local_devices()[0].platform == 'gpu':
         nvml.nvmlInit()
         ref_step().block_until_ready()
         handle = nvml.nvmlDeviceGetHandleByIndex(0)
         meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
         max_bytes_in_use = meminfo.used
         memory = Memory(max_bytes_in_use)
         # shutdown nvml
         nvml.nvmlShutdown()
     else:
         memory = None
     timeit.repeat("ref_step().block_until_ready()", repeat=1, number=2,globals=locals())
     if self.jit:
         runtimes = timeit.repeat("ref_step().block_until_ready()", repeat=self.repeat,number=3,globals=locals())
     else:
         with jax.disable_jit():
             runtimes = timeit.repeat("ref_step().block_until_ready()",repeat=self.repeat,number=3,globals=locals())
     return float(np.min(runtimes)/3.0), memory
 def _get_framebuffer_memory_stats(gpu):
     mem_info = pynvml.nvmlDeviceGetMemoryInfo(gpu)
     return {
         'memory_fb_total_bytes': mem_info.total,
         'memory_fb_used_bytes': mem_info.used,
         'memory_fb_free_bytes': (mem_info.total - mem_info.used)
     }
Exemple #6
0
 def __init__(self, handle, cpu_to_node):
     node = None
     # TODO: use number of CPU cores to determine cpuset size
     # This is very hacky at the moment
     affinity = pynvml.nvmlDeviceGetCpuAffinity(handle, 1)
     n_cpus = max(cpu_to_node.keys()) + 1
     for j in range(n_cpus):
         if affinity[0] & (1 << j):
             cur_node = cpu_to_node[j]
             if node is not None and node != cur_node:
                 node = -1  # Sentinel to indicate unknown affinity
             else:
                 node = cur_node
     if node == -1:
         node = None
     self.node = node
     self.mem = pynvml.nvmlDeviceGetMemoryInfo(handle).total
     self.name = pynvml.nvmlDeviceGetName(handle)
     # NVML doesn't report compute capability, so we need CUDA
     pci_bus_id = pynvml.nvmlDeviceGetPciInfo(handle).busId
     # In Python 3 pci_bus_id is bytes but pycuda wants str
     if not isinstance(pci_bus_id, str):
         pci_bus_id = pci_bus_id.decode('ascii')
     cuda_device = pycuda.driver.Device(pci_bus_id)
     self.compute_capability = cuda_device.compute_capability()
     self.device_attributes = {}
     self.uuid = pynvml.nvmlDeviceGetUUID(handle)
     for key, value in cuda_device.get_attributes().items():
         if isinstance(value, (int, float, str)):
             # Some of the attributes use Boost.Python's enum, which is
             # derived from int but which leads to invalid JSON when passed
             # to json.dumps.
             if isinstance(value, int) and type(value) != int:
                 value = str(value)
             self.device_attributes[str(key)] = value
def gpu_profile(frame, event, arg):
    # it is _about to_ execute (!)
    global last_tensor_sizes
    global lineno, func_name, filename, module_name

    if event == 'line':
        try:
            # about _previous_ line (!)
            if lineno is not None:
                py3nvml.nvmlInit()
                handle = py3nvml.nvmlDeviceGetHandleByIndex(
                    int(os.environ['GPU_DEBUG']))
                meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle)
                line = linecache.getline(filename, lineno)
                where_str = module_name + ' ' + func_name + ':' + str(lineno)

                with open(gpu_profile_fn, 'a+') as f:
                    f.write(f"{where_str:<50}"
                            f":{meminfo.used/1024**2:<7.1f}Mb "
                            f"{line.rstrip()}\n")

                    if print_tensor_sizes is True:
                        for tensor in get_tensors():
                            if not hasattr(tensor, 'dbg_alloc_where'):
                                tensor.dbg_alloc_where = where_str
                        new_tensor_sizes = {(type(x), tuple(x.size()),
                                             x.dbg_alloc_where)
                                            for x in get_tensors()}
                        for t, s, loc in new_tensor_sizes - last_tensor_sizes:
                            f.write(f'+ {loc:<50} {str(s):<20} {str(t):<10}\n')
                        for t, s, loc in last_tensor_sizes - new_tensor_sizes:
                            f.write(f'- {loc:<50} {str(s):<20} {str(t):<10}\n')
                        last_tensor_sizes = new_tensor_sizes
                py3nvml.nvmlShutdown()

            # save details about line _to be_ executed
            lineno = None

            func_name = frame.f_code.co_name
            filename = frame.f_globals["__file__"]
            if (filename.endswith(".pyc") or filename.endswith(".pyo")):
                filename = filename[:-1]
            module_name = frame.f_globals["__name__"]
            lineno = frame.f_lineno

            if 'gmwda-pytorch' not in os.path.dirname(
                    os.path.abspath(filename)):
                lineno = None  # skip current line evaluation

            if ('car_datasets' in filename or '_exec_config' in func_name
                    or 'gpu_profile' in module_name
                    or 'tee_stdout' in module_name):
                lineno = None  # skip current

            return gpu_profile

        except (KeyError, AttributeError) as e:
            print(e)

    return gpu_profile
Exemple #8
0
 def measure_gpu_usage(self):
     from py3nvml.py3nvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, \
                          nvmlDeviceGetMemoryInfo, nvmlDeviceGetName, nvmlShutdown, NVMLError
     max_gpu_usage = []
     gpu_name = []
     try:
         nvmlInit()
         deviceCount = nvmlDeviceGetCount()
         max_gpu_usage = [0 for i in range(deviceCount)]
         gpu_name = [
             nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i))
             for i in range(deviceCount)
         ]
         while True:
             for i in range(deviceCount):
                 info = nvmlDeviceGetMemoryInfo(
                     nvmlDeviceGetHandleByIndex(i))
                 max_gpu_usage[i] = max(max_gpu_usage[i],
                                        info.used / 1024**2)
             sleep(0.005)  # 5ms
             if not self.keep_measuring:
                 break
         nvmlShutdown()
         return [{
             "device_id": i,
             "name": gpu_name[i],
             "max_used_MB": max_gpu_usage[i]
         } for i in range(deviceCount)]
     except NVMLError as error:
         if not self.silent:
             self.logger.error(
                 "Error fetching GPU information using nvml: %s", error)
         return None
Exemple #9
0
    def get_gpu_info_by_nvml(self) -> Dict:
        """Get GPU info using nvml"""
        gpu_info_list = []
        driver_version = None
        try:
            nvmlInit()
            driver_version = nvmlSystemGetDriverVersion()
            deviceCount = nvmlDeviceGetCount()
            for i in range(deviceCount):
                handle = nvmlDeviceGetHandleByIndex(i)
                info = nvmlDeviceGetMemoryInfo(handle)
                gpu_info = {}
                gpu_info["memory_total"] = info.total
                gpu_info["memory_available"] = info.free
                gpu_info["name"] = nvmlDeviceGetName(handle)
                gpu_info_list.append(gpu_info)
            nvmlShutdown()
        except NVMLError as error:
            if not self.silent:
                self.logger.error(
                    "Error fetching GPU information using nvml: %s", error)
            return None

        result = {"driver_version": driver_version, "devices": gpu_info_list}

        if 'CUDA_VISIBLE_DEVICES' in environ:
            result["cuda_visible"] = environ['CUDA_VISIBLE_DEVICES']
        return result
 def read_top_card_memory_in_bytes():
     # pylint: disable=no-member
     # pylint incorrectly detects that function nvmlDeviceGetMemoryInfo returns str
     return self.__nvml_get_or_else(lambda: [
         nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(card_index))
         .total for card_index in range(nvmlDeviceGetCount())
     ],
                                    default=0)
    def environment_info(self):
        if self._environment_info is None:
            info = {}
            info["transformers_version"] = version
            info["framework"] = self.framework
            if self.framework == "PyTorch":
                info["use_torchscript"] = self.args.torchscript
            if self.framework == "TensorFlow":
                info["eager_mode"] = self.args.eager_mode
                info["use_xla"] = self.args.use_xla
            info["framework_version"] = self.framework_version
            info["python_version"] = platform.python_version()
            info["system"] = platform.system()
            info["cpu"] = platform.processor()
            info["architecture"] = platform.architecture()[0]
            info["date"] = datetime.date(datetime.now())
            info["time"] = datetime.time(datetime.now())
            info["fp16"] = self.args.fp16
            info["use_multiprocessing"] = self.args.do_multi_processing
            info["only_pretrain_model"] = self.args.only_pretrain_model

            if is_psutil_available():
                info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total)
            else:
                logger.warning(
                    "Psutil not installed, we won't log available CPU memory. "
                    "Install psutil (pip install psutil) to log available CPU memory."
                )
                info["cpu_ram_mb"] = "N/A"

            info["use_gpu"] = self.args.is_gpu
            if self.args.is_gpu:
                info["num_gpus"] = 1  # TODO(PVP) Currently only single GPU is supported
                if is_py3nvml_available():
                    nvml.nvmlInit()
                    handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
                    info["gpu"] = nvml.nvmlDeviceGetName(handle)
                    info["gpu_ram_mb"] = bytes_to_mega_bytes(nvml.nvmlDeviceGetMemoryInfo(handle).total)
                    info["gpu_power_watts"] = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000
                    info["gpu_performance_state"] = nvml.nvmlDeviceGetPerformanceState(handle)
                    nvml.nvmlShutdown()
                else:
                    logger.warning(
                        "py3nvml not installed, we won't log GPU memory usage. "
                        "Install py3nvml (pip install py3nvml) to log information about GPU."
                    )
                    info["gpu"] = "N/A"
                    info["gpu_ram_mb"] = "N/A"
                    info["gpu_power_watts"] = "N/A"
                    info["gpu_performance_state"] = "N/A"

            info["use_tpu"] = self.args.is_tpu
            # TODO(PVP): See if we can add more information about TPU
            # see: https://github.com/pytorch/xla/issues/2180

            self._environment_info = info
        return self._environment_info
def gpu_profile(frame, event):
    global last_meminfo_used, last_tensor_sizes
    global lineno, func_name, filename, module_name

    if event == 'line':
        try:
            if lineno:
                py3nvml.nvmlInit()
                handle = py3nvml.nvmlDeviceGetHandleByIndex(
                    int(os.environ["GPU_DEBUG"]))
                meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle)
                line = linecache.getline(filename, lineno)
                where_str = module_name + ' ' + func_name + ' ' + str(lineno)

                new_meminfo_used = meminfo.used
                mem_display = new_meminfo_used - last_meminfo_used if use_incremental else new_meminfo_used
                with open(gpu_profile_fn, "a+") as f:
                    f.write(f"{where_str:<50}"
                            f":{(mem_display) / 1024 ** 2:<7.1f}Mb "
                            f"{line.rstrip()}\n")

                    last_meminfo_used = new_meminfo_used
                    if print_tensor_sizes:
                        for tensor in get_tensors():
                            if not hasattr(tensor, 'dbg_alloc_where'):
                                tensor.dbg_alloc_where = where_str
                        new_tensor_sizes = {(type(x), tuple(x.size()),
                                             x.dbg_alloc_where)
                                            for x in get_tensors()}

                        for t, s, loc in new_tensor_sizes - last_tensor_sizes:
                            f.write(f'+ {loc:<50} {str(s):<20} {str(t):<10}\n')

                        for t, s, loc in last_tensor_sizes - new_tensor_sizes:
                            f.write(f'- {loc:<50} {str(s):<20} {str(t):<10}\n')

                        last_tensor_sizes = new_tensor_sizes
                py3nvml.nvmlShutdown()

            lineno = None

            func_name = frame.f_code.co_name
            filename = frame.f_globals["__file__"]
            module_name = frame.f_globals["__name__"]
            lineno = frame.f_lineno

            if 'Beta' not in os.path.dirname(os.path.abspath(filename)):
                lineno = None

            return gpu_profile

        except (KeyError, AttributeError):
            pass

    return gpu_profile
Exemple #13
0
def get_available_memory(device, clear_before=False):
    if not isinstance(device, torch.device):
        device = torch.device(device)
    if device.type == 'cpu':
        return psutil.virtual_memory().available
    if clear_before:
        torch.cuda.empty_cache()
    index = device.index if device.index else 0
    mem = py3nvml.nvmlDeviceGetMemoryInfo(iu._NVML_MAP[index])
    torch_mem = torch.cuda.memory_cached(device) - torch.cuda.memory_allocated(device)
    return mem.free + torch_mem
Exemple #14
0
        def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
            from py3nvml.py3nvml import (
                NVMLError,
                nvmlDeviceGetCount,
                nvmlDeviceGetHandleByIndex,
                nvmlDeviceGetMemoryInfo,
                nvmlDeviceGetName,
                nvmlInit,
                nvmlShutdown,
            )

            max_gpu_usage = []
            gpu_name = []
            try:
                nvmlInit()
                device_count = nvmlDeviceGetCount()
                if not isinstance(device_count, int):
                    logger.error(
                        f"nvmlDeviceGetCount result is not integer: {device_count}"
                    )
                    return None

                max_gpu_usage = [0 for i in range(device_count)]
                gpu_name = [
                    nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i))
                    for i in range(device_count)
                ]
                while True:
                    for i in range(device_count):
                        info = nvmlDeviceGetMemoryInfo(
                            nvmlDeviceGetHandleByIndex(i))
                        if isinstance(info, str):
                            logger.error(
                                f"nvmlDeviceGetMemoryInfo returns str: {info}")
                            return None
                        max_gpu_usage[i] = max(max_gpu_usage[i],
                                               info.used / 1024**2)
                    sleep(0.005)  # 5ms
                    if not self.keep_measuring:
                        break
                nvmlShutdown()
                return [{
                    "device_id": i,
                    "name": gpu_name[i],
                    "max_used_MB": max_gpu_usage[i],
                } for i in range(device_count)]
            except NVMLError as error:
                logger.error("Error fetching GPU information using nvml: %s",
                             error)
                return None
    def environment_info(self):
        if self._environment_info is None:
            info = {}
            info["gluonnlp_version"] = gluonnlp.__version__
            info["framework_version"] = mxnet.__version__
            info["python_version"] = platform.python_version()
            info["system"] = platform.system()
            info["cpu"] = platform.processor()
            info["architecture"] = platform.architecture()[0]
            info["date"] = datetime.date(datetime.now())
            info["time"] = datetime.time(datetime.now())
            info["fp16"] = self._use_fp16

            if is_psutil_available():
                info["cpu_ram_mb"] = bytes_to_mega_bytes(
                    psutil.virtual_memory().total)
            else:
                logger.warning(
                    "Psutil not installed, we won't log available CPU memory."
                    "Install psutil (pip install psutil) to log available CPU memory."
                )
                info["cpu_ram_mb"] = "N/A"

            info["use_gpu"] = self._use_gpu
            if self._use_gpu:
                info["num_gpus"] = 1
                if is_py3nvml_available():
                    nvml.nvmlInit()
                    handle = nvml.nvmlDeviceGetHandleByIndex(self._device_idx)
                    info["gpu"] = nvml.nvmlDeviceGetName(handle)
                    info["gpu_ram_mb"] = bytes_to_mega_bytes(
                        nvml.nvmlDeviceGetMemoryInfo(handle).total)
                    info[
                        "gpu_power_watts"] = nvml.nvmlDeviceGetPowerManagementLimit(
                            handle) / 1000
                    info[
                        "gpu_performance_state"] = nvml.nvmlDeviceGetPerformanceState(
                            handle)
                    nvml.nvmlShutdown()
                else:
                    logger.warning(
                        "py3nvml not installed, we won't log GPU memory usage. "
                        "Install py3nvml (pip install py3nvml) to log information about GPU."
                    )
                    info["gpu"] = "N/A"
                    info["gpu_ram_mb"] = "N/A"
                    info["gpu_power_watts"] = "N/A"
                    info["gpu_performance_state"] = "N/A"
            self._environment_info = info
        return self._environment_info
Exemple #16
0
    def get_gpu_stats(self):
        """
        Return some statistics for the gpu associated with handle.

        The statistics returned are:
        - used memory in MB
        - gpu utilization percentage
        - temperature in Celsius degrees
        """
        mem = nvmlDeviceGetMemoryInfo(self.handle)
        rates = nvmlDeviceGetUtilizationRates(self.handle)
        temp = nvmlDeviceGetTemperature(self.handle, NVML_TEMPERATURE_GPU)

        return (mem.used / 1024 / 1024, rates.gpu, temp)
Exemple #17
0
    def train_speed_memory(self, batch_size, seq_length):
        key = jax.random.PRNGKey(0)
        input_ids = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size)
        targets = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size)
        labels = jax.random.randint(key, (batch_size, seq_length), 0, 2)
        # input_ids = np.random.randint(0, self.vocab_size, (batch_size, seq_length))
        # targets = np.random.randint(0, self.vocab_size, (batch_size, seq_length))
        # labels = np.random.randint(0,2, (batch_size, seq_length))
        @jax.jit
        def train_step():

            def loss_fn(params):
                token_mask = jnp.where(labels > 0, 1.0, 0.0).astype(self.dtype)
                logits = self.model(input_ids=input_ids, train=True, params=params, dropout_rng=jax.random.PRNGKey(0))[0]
                loss, normalizing_factor = cross_entropy(logits,targets, token_mask)
                jax.profiler.save_device_memory_profile(f"memory/{workload[0]}_{workload[1]}_memory.prof", "gpu")
                return loss / normalizing_factor
            if self.fp16 and jax.local_devices()[0].platform == 'gpu':
                grad_fn = self.dynamic_scale.value_and_grad(loss_fn)
                dyn_scale, is_fin, loss, grad = grad_fn(self.model.params)
            else:
                grad_fn = jax.value_and_grad(loss_fn)
                loss, grad = grad_fn(self.model.params)
            return tree_flatten(grad)[0]


        if jax.local_devices()[0].platform == 'gpu':
            nvml.nvmlInit()
            train_step()
            handle = nvml.nvmlDeviceGetHandleByIndex(0)
            meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
            max_bytes_in_use = meminfo.used
            memory = Memory(max_bytes_in_use)
            # shutdown nvml
            nvml.nvmlShutdown()
        else:
            memory = None
        # timeit.repeat(train_step,repeat=1,number=2)
        timeit.repeat("for i in train_step():i.block_until_ready()", repeat=1, number=2,globals=locals())
        if self.jit:
            # runtimes = timeit.repeat(train_step,repeat=self.repeat,number=3)
            runtimes = timeit.repeat("for i in train_step():i.block_until_ready()", repeat=self.repeat, number=3,globals=locals())
        else:
            with jax.disable_jit():
                # runtimes = timeit.repeat(train_step, repeat=self.repeat, number=3)
                runtimes = timeit.repeat("for i in train_step():i.block_until_ready()", repeat=self.repeat, number=3,globals=locals())


        return float(np.min(runtimes)/3.0), memory
Exemple #18
0
def run_gpu_mem_counter(do_shutdown=False):
    # Sum used memory for all GPUs
    if not torch.cuda.is_available(): return 0
    if do_shutdown:
        py3nvml.nvmlInit()
    devices = list(range(py3nvml.nvmlDeviceGetCount())
                   )  #if gpus_to_trace is None else gpus_to_trace
    gpu_mem = 0
    for i in devices:
        handle = py3nvml.nvmlDeviceGetHandleByIndex(i)
        meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle)
        gpu_mem += meminfo.used
    if do_shutdown:
        py3nvml.nvmlShutdown()
    return gpu_mem
Exemple #19
0
    def get_device_memory(self, idx):
        """Get the memory information of device, unit: byte.

        Args:
            idx (int): device index.

        Return:
            used (float): the used device memory, None means failed to get the data.
            total (float): the total device memory, None means failed to get the data.
        """
        try:
            mem = nvml.nvmlDeviceGetMemoryInfo(self._device_handlers[idx])
        except Exception as err:
            logger.error('Get device memory failed: {}'.format(str(err)))
            return None, None
        return mem.used, mem.total
Exemple #20
0
    def _measure_memory(self, func: Callable[[],
                                             None]) -> [Memory, MemorySummary]:
        try:
            if self.args.trace_memory_line_by_line:
                trace = start_memory_tracing("transformers")

            if self.args.is_tpu:
                # tpu
                raise NotImplementedError(
                    "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `--no-memory` or `args.memory=False`"
                )
            elif self.args.is_gpu:
                if not is_py3nvml_available():
                    logger.warning(
                        "py3nvml not installed, we won't log GPU memory usage. "
                        "Install py3nvml (pip install py3nvml) to log information about GPU."
                    )
                    memory = "N/A"
                else:
                    logger.info(
                        "Measuring total GPU usage on GPU device. Make sure to not have additional processes running on the same GPU."
                    )
                    # init nvml
                    nvml.nvmlInit()
                    func()
                    handle = nvml.nvmlDeviceGetHandleByIndex(
                        self.args.device_idx)
                    meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
                    max_bytes_in_use = meminfo.used
                    memory = Memory(max_bytes_in_use)
                    # shutdown nvml
                    nvml.nvmlShutdown()
            else:
                # cpu
                memory_bytes = measure_peak_memory_cpu(func)
                memory = Memory(memory_bytes) if isinstance(
                    memory_bytes, int) else memory_bytes

            if self.args.trace_memory_line_by_line:
                summary = stop_memory_tracing(trace)
            else:
                summary = None

            return memory, summary
        except RuntimeError as e:
            self.print_fn(f"Doesn't fit on GPU. {e}")
            return "N/A", None
def memory_status(msg="", reset_max=True, sync=True):

    rank = smp.rank()
    tp_rank = smp.tp_rank()
    pp_rank = smp.pp_rank()
    rdp_rank = smp.rdp_rank()
    local_rank = smp.local_rank()

    if sync:
        torch.cuda.synchronize()

    if rdp_rank != 0:
        return

    if py3nvml != None:
        py3nvml.nvmlInit()
        handle = py3nvml.nvmlDeviceGetHandleByIndex(local_rank)
        info = py3nvml.nvmlDeviceGetMemoryInfo(handle)
        total_used = info.used / 1024**3
        total_used_str = f"Totally used GPU memory: {total_used}"
    else:
        total_used_str = ""

    alloced = torch.cuda.memory_allocated(device=local_rank)
    max_alloced = torch.cuda.max_memory_allocated(device=local_rank)
    cached = torch.cuda.memory_reserved(device=local_rank)
    max_cached = torch.cuda.max_memory_reserved(device=local_rank)

    # convert to GB for printing
    alloced /= 1024**3
    cached /= 1024**3
    max_alloced /= 1024**3
    max_cached /= 1024**3

    print(
        f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}',
        f'device={local_rank} '
        f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} '
        f'cache {cached:0.4f} max_cached {max_cached:0.4f} '
        f'{total_used_str}')
    if reset_max:
        torch.cuda.reset_max_memory_cached()
        torch.cuda.reset_max_memory_allocated()
    if py3nvml != None:
        py3nvml.nvmlShutdown()
Exemple #22
0
def gpustats():
    import py3nvml.py3nvml as pynvml

    if '__gpuhandler__' not in globals():
        globals()['__gpuhandler__'] = True
        pynvml.nvmlInit()

    usage = []
    util = []
    deviceCount = pynvml.nvmlDeviceGetCount()
    for i in range(deviceCount):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        usage.append(info.used / info.total)
        info = pynvml.nvmlDeviceGetUtilizationRates(handle)
        util.append(info.gpu / 100.)

    return {'maxmemusage': max(usage), 'maxutil': max(util)}
Exemple #23
0
def get_gpu_info() -> Optional[List[Dict[str, Any]]]:
    from py3nvml.py3nvml import (
        NVMLError,
        nvmlDeviceGetCount,
        nvmlDeviceGetHandleByIndex,
        nvmlDeviceGetMemoryInfo,
        nvmlDeviceGetName,
        nvmlInit,
        nvmlShutdown,
    )

    try:
        nvmlInit()
        result = []
        device_count = nvmlDeviceGetCount()
        if not isinstance(device_count, int):
            return None

        for i in range(device_count):
            info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
            if isinstance(info, str):
                return None
            result.append({
                "id":
                i,
                "name":
                nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)),
                "total":
                info.total,
                "free":
                info.free,
                "used":
                info.used,
            })
        nvmlShutdown()
        return result
    except NVMLError as error:
        print("Error fetching GPU information using nvml: %s", error)
        return None
Exemple #24
0
def _get_gpu_mem_used():
    handle = py3nvml.nvmlDeviceGetHandleByIndex(int(os.environ['GPU_DEBUG']))
    meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle)
    return meminfo.used / 1024**2
Exemple #25
0
    def _measure_memory(self, func: Callable[[],
                                             None]) -> [Memory, MemorySummary]:
        logger.info("Note that Tensorflow allocates more memory than"
                    "it might need to speed up computation."
                    "The memory reported here corresponds to the memory"
                    "reported by `nvidia-smi`, which can vary depending"
                    "on total available memory on the GPU that is used.")
        with self.args.strategy.scope():
            try:
                if self.args.trace_memory_line_by_line:
                    assert (
                        self.args.eager_mode
                    ), "`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory consumption line by line."
                    trace = start_memory_tracing("transformers")

                if self.args.is_tpu:
                    # tpu
                    raise NotImplementedError(
                        "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`"
                    )
                elif self.args.is_gpu:
                    # gpu
                    if not is_py3nvml_available():
                        logger.warning(
                            "py3nvml not installed, we won't log GPU memory usage. "
                            "Install py3nvml (pip install py3nvml) to log information about GPU."
                        )
                        memory = "N/A"
                    else:
                        logger.info(
                            "Measuring total GPU usage on GPU device. Make sure to not have additional processes running on the same GPU."
                        )
                        # init nvml
                        nvml.nvmlInit()
                        func()
                        handle = nvml.nvmlDeviceGetHandleByIndex(
                            self.args.device_idx)
                        meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
                        max_bytes_in_use = meminfo.used
                        memory = Memory(max_bytes_in_use)
                        # shutdown nvml
                        nvml.nvmlShutdown()
                else:
                    # cpu
                    if self.args.trace_memory_line_by_line:
                        logger.info(
                            "When enabling line by line tracing, the max peak memory for CPU is inaccurate in Tensorflow."
                        )
                        memory = None
                    else:
                        memory_bytes = measure_peak_memory_cpu(func)
                        memory = Memory(memory_bytes) if isinstance(
                            memory_bytes, int) else memory_bytes
                if self.args.trace_memory_line_by_line:
                    summary = stop_memory_tracing(trace)
                    if memory is None:
                        memory = summary.total
                else:
                    summary = None

                return memory, summary
            except ResourceExhaustedError as e:
                self.print_fn("Doesn't fit on GPU. {}".format(e))
                return "N/A", None
Exemple #26
0
def grab_gpus(num_gpus=1, gpu_select=None, gpu_fraction=1.0):
    """
    Checks for gpu availability and sets CUDA_VISIBLE_DEVICES as such.

    Note that this function does not do anything to 'reserve' gpus, it only
    limits what GPUS your program can see by altering the CUDA_VISIBLE_DEVICES
    variable. Other programs can still come along and snatch your gpu. This
    function is more about preventing **you** from stealing someone else's GPU.

    If more than 1 GPU is requested but the full amount are available, then it
    will set the CUDA_VISIBLE_DEVICES variable to see all the available GPUs.
    A warning is generated in this case.

    If one or more GPUs were requested and none were available, a Warning
    will be raised. Before raising it, the CUDA_VISIBLE_DEVICES will be set to a
    blank string. This means the calling function can ignore this warning and
    proceed if it chooses to only use the CPU, and it should still be protected
    against putting processes on a busy GPU.

    You can call this function with num_gpus=0 to blank out the
    CUDA_VISIBLE_DEVICES environment variable.

    Parameters
    ----------
    num_gpus : int
        How many gpus your job needs (optional)
    gpu_select : iterable
        A single int or an iterable of ints indicating gpu numbers to
        search through.  If left blank, will search through all gpus.
    gpu_fraction : float
        The fractional of a gpu memory that must be free for the script to see
        the gpu as free. Defaults to 1. Useful if someone has grabbed a tiny
        amount of memory on a gpu but isn't using it.

    Returns
    -------
    success : int
        Number of gpus 'grabbed'

    Raises
    ------
    RuntimeWarning
        If couldn't connect with NVIDIA drivers.
        If 1 or more gpus were requested and none were available.
    ValueError
        If the gpu_select option was not understood (can fix by leaving this
        field blank, providing an int or an iterable of ints).
    """
    # Set the visible devices to blank.
    os.environ['CUDA_VISIBLE_DEVICES'] = ""

    if num_gpus == 0:
        return 0

    # Try connect with NVIDIA drivers
    logger = logging.getLogger(__name__)
    try:
        py3nvml.nvmlInit()
    except:
        str_ = """Couldn't connect to nvml drivers. Check they are installed correctly.
                  Proceeding on cpu only..."""
        warnings.warn(str_, RuntimeWarning)
        logger.warn(str_)
        return 0

    numDevices = py3nvml.nvmlDeviceGetCount()
    gpu_free = [False] * numDevices

    # Flag which gpus we can check
    if gpu_select is None:
        gpu_check = [True] * 8
    else:
        gpu_check = [False] * 8
        try:
            gpu_check[gpu_select] = True
        except TypeError:
            try:
                for i in gpu_select:
                    gpu_check[i] = True
            except:
                raise ValueError(
                    '''Please provide an int or an iterable of ints
                    for gpu_select''')

    # Print out GPU device info. Useful for debugging.
    for i in range(numDevices):
        # If the gpu was specified, examine it
        if not gpu_check[i]:
            continue

        handle = py3nvml.nvmlDeviceGetHandleByIndex(i)
        info = py3nvml.nvmlDeviceGetMemoryInfo(handle)

        str_ = "GPU {}:\t".format(i) + \
               "Used Mem: {:>6}MB\t".format(info.used/(1024*1024)) + \
               "Total Mem: {:>6}MB".format(info.total/(1024*1024))
        logger.debug(str_)

    # Now check if any devices are suitable
    for i in range(numDevices):
        # If the gpu was specified, examine it
        if not gpu_check[i]:
            continue

        handle = py3nvml.nvmlDeviceGetHandleByIndex(i)
        info = py3nvml.nvmlDeviceGetMemoryInfo(handle)

        # Sometimes GPU has a few MB used when it is actually free
        if (info.free + 10) / info.total >= gpu_fraction:
            gpu_free[i] = True
        else:
            logger.info('GPU {} has processes on it. Skipping.'.format(i))

    py3nvml.nvmlShutdown()

    # Now check whether we can create the session
    if sum(gpu_free) == 0:
        warnings.warn("Could not find enough GPUs for your job",
                      RuntimeWarning)
        logger.warn(str_)
        return 0
    else:
        if sum(gpu_free) >= num_gpus:
            # only use the first num_gpus gpus. Hide the rest from greedy
            # tensorflow
            available_gpus = [i for i, x in enumerate(gpu_free) if x]
            use_gpus = ','.join(list(
                str(s) for s in available_gpus[:num_gpus]))
            logger.debug('{} Gpus found free'.format(sum(gpu_free)))
            logger.info('Using {}'.format(use_gpus))
            os.environ['CUDA_VISIBLE_DEVICES'] = use_gpus
            return num_gpus
        else:
            # use everything we can.
            s = "Only {} GPUs found but {}".format(sum(gpu_free), num_gpus) + \
                "requested. Allocating these and continuing."
            warnings.warn(s, RuntimeWarning)
            logger.warn(s)
            available_gpus = [i for i, x in enumerate(gpu_free) if x]
            use_gpus = ','.join(list(str(s) for s in available_gpus))
            logger.debug('{} Gpus found free'.format(sum(gpu_free)))
            logger.info('Using {}'.format(use_gpus))
            os.environ['CUDA_VISIBLE_DEVICES'] = use_gpus
            return sum(gpu_free)
    def traceit(frame, event, args):
        """
        Tracing method executed before running each line in a module or sub-module Record memory allocated in a list
        with debugging information
        """
        global _is_memory_tracing_enabled

        if not _is_memory_tracing_enabled:
            return traceit

        # Filter events
        if events_to_trace is not None:
            if isinstance(events_to_trace, str) and event != events_to_trace:
                return traceit
            elif isinstance(events_to_trace,
                            (list, tuple)) and event not in events_to_trace:
                return traceit

        if "__name__" not in frame.f_globals:
            return traceit

        # Filter modules
        name = frame.f_globals["__name__"]
        if not isinstance(name, str):
            return traceit
        else:
            # Filter whitelist of modules to trace
            if modules_to_trace is not None:
                if isinstance(modules_to_trace,
                              str) and modules_to_trace not in name:
                    return traceit
                elif isinstance(modules_to_trace, (list, tuple)) and all(
                        m not in name for m in modules_to_trace):
                    return traceit

            # Filter blacklist of modules not to trace
            if modules_not_to_trace is not None:
                if isinstance(modules_not_to_trace,
                              str) and modules_not_to_trace in name:
                    return traceit
                elif isinstance(modules_not_to_trace, (list, tuple)) and any(
                        m in name for m in modules_not_to_trace):
                    return traceit

        # Record current tracing state (file, location in file...)
        lineno = frame.f_lineno
        filename = frame.f_globals["__file__"]
        if filename.endswith(".pyc") or filename.endswith(".pyo"):
            filename = filename[:-1]
        line = linecache.getline(filename, lineno).rstrip()
        traced_state = Frame(filename, name, lineno, event, line)

        # Record current memory state (rss memory) and compute difference with previous memory state
        cpu_mem = 0
        if process is not None:
            mem = process.memory_info()
            cpu_mem = mem.rss

        gpu_mem = 0
        if log_gpu:
            # Clear GPU caches
            if is_torch_available():
                torch_empty_cache()
            if is_tf_available():
                tf_context.context()._clear_caches(
                )  # See https://github.com/tensorflow/tensorflow/issues/20218#issuecomment-416771802

            # Sum used memory for all GPUs
            nvml.nvmlInit()

            for i in devices:
                handle = nvml.nvmlDeviceGetHandleByIndex(i)
                meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
                gpu_mem += meminfo.used

            nvml.nvmlShutdown()

        mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem)
        memory_trace.append(mem_state)

        return traceit
    def environment_info(self):
        if self._environment_info is None:
            info = {}
            info["transformers_version"] = version
            info["framework"] = self.framework
            info["framework_version"] = self.framework_version
            info["python_version"] = platform.python_version()
            info["system"] = platform.system()
            info["cpu"] = platform.processor()
            info["architecture"] = platform.architecture()[0]
            info["date"] = datetime.date(datetime.now())
            info["time"] = datetime.time(datetime.now())

            try:
                import psutil
            except (ImportError):
                logger.warning(
                    "Psutil not installed, we won't log available CPU memory."
                    "Install psutil (pip install psutil) to log available CPU memory."
                )
                info["cpu_ram_mb"] = "N/A"
            else:
                info["cpu_ram_mb"] = bytes_to_mega_bytes(
                    psutil.virtual_memory().total)

            info["use_gpu"] = self.is_gpu
            if self.is_gpu:
                info["num_gpus"] = self.args.n_gpu
                try:
                    from py3nvml import py3nvml

                    py3nvml.nvmlInit()
                    handle = py3nvml.nvmlDeviceGetHandleByIndex(
                        self.args.device_idx)
                except ImportError:
                    logger.warning(
                        "py3nvml not installed, we won't log GPU memory usage. "
                        "Install py3nvml (pip install py3nvml) to log information about GPU."
                    )
                    info["gpu"] = "N/A"
                    info["gpu_ram_mb"] = "N/A"
                    info["gpu_power_watts"] = "N/A"
                    info["gpu_performance_state"] = "N/A"
                except (OSError, py3nvml.NVMLError):
                    logger.warning(
                        "Error while initializing comunication with GPU. "
                        "We won't log information about GPU.")
                    info["gpu"] = "N/A"
                    info["gpu_ram_mb"] = "N/A"
                    info["gpu_power_watts"] = "N/A"
                    info["gpu_performance_state"] = "N/A"
                    py3nvml.nvmlShutdown()
                else:
                    info["gpu"] = py3nvml.nvmlDeviceGetName(handle)
                    info["gpu_ram_mb"] = bytes_to_mega_bytes(
                        py3nvml.nvmlDeviceGetMemoryInfo(handle).total)
                    info[
                        "gpu_power_watts"] = py3nvml.nvmlDeviceGetPowerManagementLimit(
                            handle) / 1000
                    info[
                        "gpu_performance_state"] = py3nvml.nvmlDeviceGetPerformanceState(
                            handle)
                    py3nvml.nvmlShutdown()

            self._environment_info = info
        return self._environment_info
Exemple #29
0
def gpu_profile(frame, event, arg):
    # it is _about to_ execute (!)
    global last_tensor_sizes
    global last_meminfo_used
    global lineno, func_name, filename, module_name

    if event == "line":
        try:
            # about _previous_ line (!)
            if lineno is not None:
                py3nvml.nvmlInit()
                handle = py3nvml.nvmlDeviceGetHandleByIndex(
                    int(os.environ["GPU_DEBUG"]))
                meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle)
                line = linecache.getline(filename, lineno)
                where_str = module_name + " " + func_name + ":" + str(lineno)

                new_meminfo_used = meminfo.used
                mem_display = new_meminfo_used - last_meminfo_used if use_incremental else new_meminfo_used
                if abs(new_meminfo_used - last_meminfo_used) / 1024**2 > 256:
                    with open(gpu_profile_fn, "a+") as f:
                        f.write(f"{where_str:<50}"
                                f":{(mem_display)/1024**2:<7.1f}Mb "
                                f"{line.rstrip()}\n")

                        last_meminfo_used = new_meminfo_used
                        if print_tensor_sizes is True:
                            for tensor in get_tensors():
                                if not hasattr(tensor, "dbg_alloc_where"):
                                    tensor.dbg_alloc_where = where_str
                            new_tensor_sizes = {(type(x), tuple(x.size()),
                                                 x.dbg_alloc_where)
                                                for x in get_tensors()}
                            for t, s, loc in new_tensor_sizes - last_tensor_sizes:
                                f.write(
                                    f"+ {loc:<50} {str(s):<20} {str(t):<10}\n")
                            for t, s, loc in last_tensor_sizes - new_tensor_sizes:
                                f.write(
                                    f"- {loc:<50} {str(s):<20} {str(t):<10}\n")
                            last_tensor_sizes = new_tensor_sizes
                py3nvml.nvmlShutdown()

            # save details about line _to be_ executed
            lineno = None

            func_name = frame.f_code.co_name
            filename = frame.f_globals["__file__"]
            if filename.endswith(".pyc") or filename.endswith(".pyo"):
                filename = filename[:-1]
            module_name = frame.f_globals["__name__"]
            lineno = frame.f_lineno

            # only profile codes within the parent folder, otherwise there are too many function calls into other pytorch scripts
            # need to modify the key words below to suit your case.
            if "maua-stylegan2" not in os.path.dirname(
                    os.path.abspath(filename)):
                lineno = None  # skip current line evaluation

            if ("car_datasets" in filename or "_exec_config" in func_name
                    or "gpu_profile" in module_name
                    or "tee_stdout" in module_name or "PIL" in module_name):
                lineno = None  # skip othe unnecessary lines

            return gpu_profile

        except (KeyError, AttributeError):
            pass

    return gpu_profile
Exemple #30
0
    def _train_speed_memory(self, model_name: str, batch_size: int, sequence_length: int)\
            -> Tuple[float, Memory]:
        if self._use_fp16:
            from mxnet import amp
            amp.init()

        if self._use_gpu:
            ctx = mxnet.gpu()
        else:
            ctx = mxnet.cpu()
        model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name)
        cfg.defrost()
        cfg.MODEL.layout = self._layout
        if model_cls.__name__ not in ['BartModel']:
            cfg.MODEL.compute_layout = self._compute_layout
        cfg.freeze()
        if model_cls.__name__ in ['BartModel']:
            model = model_cls.from_cfg(cfg, extract_feature=True)
        else:
            model = model_cls.from_cfg(cfg)
        model.load_parameters(backbone_param_path, ctx=ctx)
        model.hybridize(static_alloc=True)
        vocab_size = cfg.MODEL.vocab_size
        if hasattr(cfg.MODEL, 'units'):
            out_units = cfg.MODEL.units
        else:
            out_units = cfg.MODEL.DECODER.units
        if self._layout == 'NT':
            input_ids = mxnet.np.random.randint(0, vocab_size, (batch_size, sequence_length),
                                                dtype=np.int32, ctx=ctx)
            token_types = mxnet.np.zeros((batch_size, sequence_length), dtype=np.int32, ctx=ctx)
            valid_length = mxnet.np.full((batch_size,), sequence_length,
                                         dtype=np.int32, ctx=ctx)
            contextual_embedding_ograd = mxnet.np.random.normal(
                0, 1, (batch_size, sequence_length, out_units),
                dtype=np.float32, ctx=ctx)
            pooled_out_ograd = mxnet.np.random.normal(
                0, 1, (batch_size, out_units), dtype=np.float32, ctx=ctx)
        elif self._layout == 'TN':
            input_ids = mxnet.np.random.randint(0, vocab_size, (sequence_length, batch_size),
                                                dtype=np.int32, ctx=ctx)
            token_types = mxnet.np.zeros((sequence_length, batch_size), dtype=np.int32, ctx=ctx)
            valid_length = mxnet.np.full((batch_size,), sequence_length,
                                         dtype=np.int32, ctx=ctx)
            contextual_embedding_ograd = mxnet.np.random.normal(
                0, 1, (sequence_length, batch_size, out_units),
                dtype=np.float32, ctx=ctx)
            pooled_out_ograd = mxnet.np.random.normal(0, 1, (batch_size, out_units),
                                                      dtype=np.float32,
                                                      ctx=ctx)
        else:
            raise NotImplementedError
        if model_cls.__name__ in ['BertModel', 'AlbertModel', 'ElectraModel', 'MobileBertModel']:
            def train_step():
                with mxnet.autograd.record():
                    contextual_embedding, pooled_out = model(input_ids, token_types, valid_length)
                    # We'd like to set the head gradient of
                    # contextual_embedding to contextual_embedding_ograd
                    # and the head gradient of pooled_out to pooled_out_ograd
                    # Thus, we simply doing two hadamard product and sum up the results.
                    fake_loss = mxnet.np.sum(contextual_embedding * contextual_embedding_ograd)\
                                + mxnet.np.sum(pooled_out * pooled_out_ograd)
                    fake_loss.backward()
                mxnet.npx.waitall()
        elif model_cls.__name__ in ['BartModel']:
            def train_step():
                with mxnet.autograd.record():
                    contextual_embedding, pooled_out = model(input_ids, valid_length,
                                                             input_ids, valid_length)
                    fake_loss = (contextual_embedding * contextual_embedding_ograd).sum() \
                                + (pooled_out * pooled_out_ograd).sum()
                    fake_loss.backward()
                mxnet.npx.waitall()
        else:
            raise NotImplementedError
        timeit.repeat(train_step, repeat=1, number=5)
        mxnet.npx.waitall()
        runtimes = timeit.repeat(train_step, repeat=self._repeat, number=3)
        mxnet.npx.waitall()
        ctx.empty_cache()
        mxnet.npx.waitall()
        # Profile memory
        if self._use_gpu:
            nvml.nvmlInit()
            train_step()
            mxnet.npx.waitall()
            handle = nvml.nvmlDeviceGetHandleByIndex(self._device_idx)
            meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
            max_bytes_in_use = meminfo.used
            memory = Memory(max_bytes_in_use)
            # shutdown nvml
            nvml.nvmlShutdown()
        else:
            # cpu
            memory_bytes = measure_peak_memory_cpu(train_step)
            memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
        return float(np.min(runtimes) / 3.0), memory