Beispiel #1
0
def numba_cuda_is_supported(min_version: str) -> bool:
    """
    Tests if an appropriate version of numba is installed, and if it is,
    if cuda is supported properly within it.
    
    Args:
        min_version: The minimum version of numba that is required.

    Returns:
        bool, whether cuda is supported with this current installation or not.
    """
    module_available, msg = model_utils.check_lib_version(
        'numba', checked_version=min_version, operator=operator.ge)

    # If numba is not installed
    if module_available is None:
        return False

    # If numba version is installed and available
    if module_available is True:
        from numba import cuda

        # this method first arrived in 0.53, and that's the minimum version required
        if hasattr(cuda, 'is_supported_version'):
            return cuda.is_supported_version()
        else:
            # assume cuda is supported, but it may fail due to CUDA incompatibility
            return False

    else:
        return False
Beispiel #2
0
def numba_cpu_is_supported(min_version: str) -> bool:
    """
    Tests if an appropriate version of numba is installed.

    Args:
        min_version: The minimum version of numba that is required.

    Returns:
        bool, whether numba CPU supported with this current installation or not.
    """
    module_available, msg = model_utils.check_lib_version(
        'numba', checked_version=min_version, operator=operator.ge)

    # If numba is not installed
    if module_available is None:
        return False
    else:
        return True
Beispiel #3
0
def is_dali_supported(min_version: str, verbose: bool = False) -> bool:
    """
    Checks if DALI in installed, and version is >= min_verion.

    Args:
        min_version: A semver str that is the minimum requirement.
        verbose: Whether to log the installation instructions if DALI is not found.

    Returns:
        bool - whether DALI could be imported or not.
    """
    module_available, _ = model_utils.check_lib_version(
        'nvidia.dali', checked_version=min_version, operator=operator.ge)

    # If DALI is not installed
    if module_available is None:
        if verbose:
            logging.info(DALI_INSTALLATION_MESSAGE)

        return False

    return module_available
Beispiel #4
0
def resolve_rnnt_loss(loss_name: str,
                      blank_idx: int,
                      loss_kwargs: dict = None) -> torch.nn.Module:
    loss_function_names = list(RNNT_LOSS_RESOLVER.keys())

    if loss_name not in loss_function_names:
        raise ValueError(
            f"Provided `loss_name` {loss_name} not in list of available RNNT losses \n"
            f"{loss_function_names}")

    all_available_losses = {
        name: config
        for name, config in RNNT_LOSS_RESOLVER.items() if config.is_available
    }

    loss_config = RNNT_LOSS_RESOLVER[loss_name]  # type: RNNTLossConfig

    # Re-raise import error with installation message
    if not loss_config.is_available:
        msg = (
            f"Installed RNNT losses are : {list(all_available_losses.keys())}.\n"
            f"****************************************************************\n"
            f"To install the selected loss function, please follow the steps below:\n"
            f"{loss_config.installation_msg}")
        raise ImportError(msg)

    # Library version check
    if loss_config.min_version is not None:
        ver_matched, msg = model_utils.check_lib_version(
            loss_config.lib_name,
            checked_version=loss_config.min_version,
            operator=operator.ge)

        if ver_matched is False:
            msg = (
                f"{msg}\n"
                f"****************************************************************\n"
                f"To update the selected loss function, please follow the steps below:\n"
                f"{loss_config.installation_msg}")
            raise RuntimeError(msg)

    # Resolve loss functions sequentially
    loss_kwargs = {} if loss_kwargs is None else loss_kwargs

    if isinstance(loss_kwargs, DictConfig):
        loss_kwargs = OmegaConf.to_container(loss_kwargs, resolve=True)

    # Get actual loss name for `default`
    if loss_name == 'default':
        loss_name = loss_config.loss_name
    """
    Resolve RNNT loss functions
    """
    if loss_name == 'warprnnt':
        loss_func = warprnnt.RNNTLoss(blank=blank_idx, reduction='none')
        _warn_unused_additional_kwargs(loss_name, loss_kwargs)

    elif loss_name == 'warprnnt_numba':
        fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0)
        clamp = loss_kwargs.pop('clamp', -1.0)
        loss_func = RNNTLossNumba(blank=blank_idx,
                                  reduction='none',
                                  fastemit_lambda=fastemit_lambda,
                                  clamp=clamp)
        _warn_unused_additional_kwargs(loss_name, loss_kwargs)

    else:
        raise ValueError(
            f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :"
            f"{loss_function_names}")

    return loss_func