示例#1
0
def is_tf32_env():
    """
    The environment variable NVIDIA_TF32_OVERRIDE=0 will override any defaults
    or programmatic configuration of NVIDIA libraries, and consequently,
    cuBLAS will not accelerate FP32 computations with TF32 tensor cores.
    """
    global _tf32_enabled
    if _tf32_enabled is None:
        _tf32_enabled = False
        if (torch.cuda.is_available()
                and not version_leq(f"{torch.version.cuda}", "10.100")
                and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0"
                and torch.cuda.device_count() > 0  # at least 11.0
            ):
            try:
                # with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result
                g_gpu = torch.Generator(device="cuda")
                g_gpu.manual_seed(2147483647)
                a_full = torch.randn(1024,
                                     1024,
                                     dtype=torch.double,
                                     device="cuda",
                                     generator=g_gpu)
                b_full = torch.randn(1024,
                                     1024,
                                     dtype=torch.double,
                                     device="cuda",
                                     generator=g_gpu)
                _tf32_enabled = (
                    a_full.float() @ b_full.float() -
                    a_full @ b_full).abs().max().item() > 0.001  # 0.1713
            except BaseException:
                pass
        print(f"tf32 enabled: {_tf32_enabled}")
    return _tf32_enabled
示例#2
0
def is_module_ver_at_least(module, version):
    """Determine if a module's version is at least equal to the given value.

    Args:
        module: imported module's name, e.g., `np` or `torch`.
        version: required version, given as a tuple, e.g., `(1, 8, 0)`.
    Returns:
        `True` if module is the given version or newer.
    """
    test_ver = ".".join(map(str, version))
    return module.__version__ != test_ver and version_leq(test_ver, module.__version__)
示例#3
0
 def __init__(self, pytorch_version_tuple):
     self.max_version = pytorch_version_tuple
     test_ver = ".".join(map(str, self.max_version))
     self.version_too_new = version_leq(test_ver, torch.__version__)
示例#4
0
def deprecated(
    since: Optional[str] = None, removed: Optional[str] = None, msg_suffix: str = "", version_val: str = __version__
):
    """
    Marks a function or class as deprecated. If `since` is given this should be a version at or earlier than the
    current version and states at what version of the definition was marked as deprecated. If `removed` is given
    this can be any version and marks when the definition was removed.

    When the decorated definition is called, that is when the function is called or the class instantiated,
    a `DeprecationWarning` is issued if `since` is given and the current version is at or later than that given.
    a `DeprecatedError` exception is instead raised if `removed` is given and the current version is at or later
    than that, or if neither `since` nor `removed` is provided.

    The relevant docstring of the deprecating function should also be updated accordingly,
    using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`.
    https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded

    Args:
        since: version at which the definition was marked deprecated but not removed.
        removed: version at which the definition was removed and no longer usable.
        msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead.
        version_val: (used for testing) version to compare since and removed against, default is MONAI version.

    Returns:
        Decorated definition which warns or raises exception when used
    """

    # if version_val.startswith("0+"):
    #     # version unknown, set version_val to a large value (assuming the latest version)
    #     version_val = "100"
    if since is not None and removed is not None and not version_leq(since, removed):
        raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.")
    is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)
    if is_not_yet_deprecated:
        # smaller than `since`, do nothing
        return lambda obj: obj

    if since is None and removed is None:
        # raise a DeprecatedError directly
        is_removed = True
        is_deprecated = True
    else:
        # compare the numbers
        is_deprecated = since is not None and version_leq(since, version_val)
        is_removed = removed is not None and version_leq(removed, version_val)

    def _decorator(obj):
        is_func = isinstance(obj, FunctionType)
        call_obj = obj if is_func else obj.__init__

        msg_prefix = f"{'Function' if is_func else 'Class'} `{obj.__name__}`"

        if is_removed:
            msg_infix = f"was removed in version {removed}."
        elif is_deprecated:
            msg_infix = f"has been deprecated since version {since}."
            if removed is not None:
                msg_infix += f" It will be removed in version {removed}."
        else:
            msg_infix = "has been deprecated."

        msg = f"{msg_prefix} {msg_infix} {msg_suffix}".strip()

        @wraps(call_obj)
        def _wrapper(*args, **kwargs):
            if is_removed:
                raise DeprecatedError(msg)
            if is_deprecated:
                warn_deprecated(obj, msg)

            return call_obj(*args, **kwargs)

        if is_func:
            return _wrapper
        obj.__init__ = _wrapper
        return obj

    return _decorator
示例#5
0
def deprecated_arg(
    name,
    since: Optional[str] = None,
    removed: Optional[str] = None,
    msg_suffix: str = "",
    version_val: str = __version__,
    new_name: Optional[str] = None,
):
    """
    Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as
    described in the `deprecated` decorator.

    When the decorated definition is called, that is when the function is called or the class instantiated with args,
    a `DeprecationWarning` is issued if `since` is given and the current version is at or later than that given.
    a `DeprecatedError` exception is instead raised if `removed` is given and the current version is at or later
    than that, or if neither `since` nor `removed` is provided.

    The relevant docstring of the deprecating function should also be updated accordingly,
    using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`.
    https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded

    In the current implementation type annotations are not preserved.


    Args:
        name: name of position or keyword argument to mark as deprecated.
        since: version at which the argument was marked deprecated but not removed.
        removed: version at which the argument was removed and no longer usable.
        msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead.
        version_val: (used for testing) version to compare since and removed against, default is MONAI version.
        new_name: name of position or keyword argument to replace the deprecated argument.

    Returns:
        Decorated callable which warns or raises exception when deprecated argument used.
    """

    if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit():
        # version unknown, set version_val to a large value (assuming the latest version)
        version_val = "100"
    if since is not None and removed is not None and not version_leq(since, removed):
        raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.")
    is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)
    if is_not_yet_deprecated:
        # smaller than `since`, do nothing
        return lambda obj: obj
    if since is None and removed is None:
        # raise a DeprecatedError directly
        is_removed = True
        is_deprecated = True
    else:
        # compare the numbers
        is_deprecated = since is not None and version_leq(since, version_val)
        is_removed = removed is not None and version_leq(removed, version_val)

    def _decorator(func):
        argname = f"{func.__name__}_{name}"

        msg_prefix = f"Argument `{name}`"

        if is_removed:
            msg_infix = f"was removed in version {removed}."
        elif is_deprecated:
            msg_infix = f"has been deprecated since version {since}."
            if removed is not None:
                msg_infix += f" It will be removed in version {removed}."
        else:
            msg_infix = "has been deprecated."

        msg = f"{msg_prefix} {msg_infix} {msg_suffix}".strip()

        sig = inspect.signature(func)

        @wraps(func)
        def _wrapper(*args, **kwargs):
            if new_name is not None and name in kwargs and new_name not in kwargs:
                # replace the deprecated arg "name" with "new_name"
                # if name is specified and new_name is not specified
                kwargs[new_name] = kwargs[name]
                try:
                    sig.bind(*args, **kwargs).arguments
                except TypeError:
                    # multiple values for new_name using both args and kwargs
                    kwargs.pop(new_name, None)
            binding = sig.bind(*args, **kwargs).arguments

            positional_found = name in binding
            kw_found = "kwargs" in binding and name in binding["kwargs"]

            if positional_found or kw_found:
                if is_removed:
                    raise DeprecatedError(msg)
                if is_deprecated:
                    warn_deprecated(argname, msg)

            return func(*args, **kwargs)

        return _wrapper

    return _decorator