def is_tf32_env(): """ The environment variable NVIDIA_TF32_OVERRIDE=0 will override any defaults or programmatic configuration of NVIDIA libraries, and consequently, cuBLAS will not accelerate FP32 computations with TF32 tensor cores. """ global _tf32_enabled if _tf32_enabled is None: _tf32_enabled = False if (torch.cuda.is_available() and not version_leq(f"{torch.version.cuda}", "10.100") and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0" and torch.cuda.device_count() > 0 # at least 11.0 ): try: # with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result g_gpu = torch.Generator(device="cuda") g_gpu.manual_seed(2147483647) a_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu) b_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu) _tf32_enabled = ( a_full.float() @ b_full.float() - a_full @ b_full).abs().max().item() > 0.001 # 0.1713 except BaseException: pass print(f"tf32 enabled: {_tf32_enabled}") return _tf32_enabled
def is_module_ver_at_least(module, version): """Determine if a module's version is at least equal to the given value. Args: module: imported module's name, e.g., `np` or `torch`. version: required version, given as a tuple, e.g., `(1, 8, 0)`. Returns: `True` if module is the given version or newer. """ test_ver = ".".join(map(str, version)) return module.__version__ != test_ver and version_leq(test_ver, module.__version__)
def __init__(self, pytorch_version_tuple): self.max_version = pytorch_version_tuple test_ver = ".".join(map(str, self.max_version)) self.version_too_new = version_leq(test_ver, torch.__version__)
def deprecated( since: Optional[str] = None, removed: Optional[str] = None, msg_suffix: str = "", version_val: str = __version__ ): """ Marks a function or class as deprecated. If `since` is given this should be a version at or earlier than the current version and states at what version of the definition was marked as deprecated. If `removed` is given this can be any version and marks when the definition was removed. When the decorated definition is called, that is when the function is called or the class instantiated, a `DeprecationWarning` is issued if `since` is given and the current version is at or later than that given. a `DeprecatedError` exception is instead raised if `removed` is given and the current version is at or later than that, or if neither `since` nor `removed` is provided. The relevant docstring of the deprecating function should also be updated accordingly, using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`. https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded Args: since: version at which the definition was marked deprecated but not removed. removed: version at which the definition was removed and no longer usable. msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead. version_val: (used for testing) version to compare since and removed against, default is MONAI version. Returns: Decorated definition which warns or raises exception when used """ # if version_val.startswith("0+"): # # version unknown, set version_val to a large value (assuming the latest version) # version_val = "100" if since is not None and removed is not None and not version_leq(since, removed): raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) if is_not_yet_deprecated: # smaller than `since`, do nothing return lambda obj: obj if since is None and removed is None: # raise a DeprecatedError directly is_removed = True is_deprecated = True else: # compare the numbers is_deprecated = since is not None and version_leq(since, version_val) is_removed = removed is not None and version_leq(removed, version_val) def _decorator(obj): is_func = isinstance(obj, FunctionType) call_obj = obj if is_func else obj.__init__ msg_prefix = f"{'Function' if is_func else 'Class'} `{obj.__name__}`" if is_removed: msg_infix = f"was removed in version {removed}." elif is_deprecated: msg_infix = f"has been deprecated since version {since}." if removed is not None: msg_infix += f" It will be removed in version {removed}." else: msg_infix = "has been deprecated." msg = f"{msg_prefix} {msg_infix} {msg_suffix}".strip() @wraps(call_obj) def _wrapper(*args, **kwargs): if is_removed: raise DeprecatedError(msg) if is_deprecated: warn_deprecated(obj, msg) return call_obj(*args, **kwargs) if is_func: return _wrapper obj.__init__ = _wrapper return obj return _decorator
def deprecated_arg( name, since: Optional[str] = None, removed: Optional[str] = None, msg_suffix: str = "", version_val: str = __version__, new_name: Optional[str] = None, ): """ Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as described in the `deprecated` decorator. When the decorated definition is called, that is when the function is called or the class instantiated with args, a `DeprecationWarning` is issued if `since` is given and the current version is at or later than that given. a `DeprecatedError` exception is instead raised if `removed` is given and the current version is at or later than that, or if neither `since` nor `removed` is provided. The relevant docstring of the deprecating function should also be updated accordingly, using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`. https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded In the current implementation type annotations are not preserved. Args: name: name of position or keyword argument to mark as deprecated. since: version at which the argument was marked deprecated but not removed. removed: version at which the argument was removed and no longer usable. msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead. version_val: (used for testing) version to compare since and removed against, default is MONAI version. new_name: name of position or keyword argument to replace the deprecated argument. Returns: Decorated callable which warns or raises exception when deprecated argument used. """ if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit(): # version unknown, set version_val to a large value (assuming the latest version) version_val = "100" if since is not None and removed is not None and not version_leq(since, removed): raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) if is_not_yet_deprecated: # smaller than `since`, do nothing return lambda obj: obj if since is None and removed is None: # raise a DeprecatedError directly is_removed = True is_deprecated = True else: # compare the numbers is_deprecated = since is not None and version_leq(since, version_val) is_removed = removed is not None and version_leq(removed, version_val) def _decorator(func): argname = f"{func.__name__}_{name}" msg_prefix = f"Argument `{name}`" if is_removed: msg_infix = f"was removed in version {removed}." elif is_deprecated: msg_infix = f"has been deprecated since version {since}." if removed is not None: msg_infix += f" It will be removed in version {removed}." else: msg_infix = "has been deprecated." msg = f"{msg_prefix} {msg_infix} {msg_suffix}".strip() sig = inspect.signature(func) @wraps(func) def _wrapper(*args, **kwargs): if new_name is not None and name in kwargs and new_name not in kwargs: # replace the deprecated arg "name" with "new_name" # if name is specified and new_name is not specified kwargs[new_name] = kwargs[name] try: sig.bind(*args, **kwargs).arguments except TypeError: # multiple values for new_name using both args and kwargs kwargs.pop(new_name, None) binding = sig.bind(*args, **kwargs).arguments positional_found = name in binding kw_found = "kwargs" in binding and name in binding["kwargs"] if positional_found or kw_found: if is_removed: raise DeprecatedError(msg) if is_deprecated: warn_deprecated(argname, msg) return func(*args, **kwargs) return _wrapper return _decorator