def numba_cuda_is_supported(min_version: str) -> bool: """ Tests if an appropriate version of numba is installed, and if it is, if cuda is supported properly within it. Args: min_version: The minimum version of numba that is required. Returns: bool, whether cuda is supported with this current installation or not. """ module_available, msg = model_utils.check_lib_version( 'numba', checked_version=min_version, operator=operator.ge) # If numba is not installed if module_available is None: return False # If numba version is installed and available if module_available is True: from numba import cuda # this method first arrived in 0.53, and that's the minimum version required if hasattr(cuda, 'is_supported_version'): return cuda.is_supported_version() else: # assume cuda is supported, but it may fail due to CUDA incompatibility return False else: return False
def numba_cpu_is_supported(min_version: str) -> bool: """ Tests if an appropriate version of numba is installed. Args: min_version: The minimum version of numba that is required. Returns: bool, whether numba CPU supported with this current installation or not. """ module_available, msg = model_utils.check_lib_version( 'numba', checked_version=min_version, operator=operator.ge) # If numba is not installed if module_available is None: return False else: return True
def is_dali_supported(min_version: str, verbose: bool = False) -> bool: """ Checks if DALI in installed, and version is >= min_verion. Args: min_version: A semver str that is the minimum requirement. verbose: Whether to log the installation instructions if DALI is not found. Returns: bool - whether DALI could be imported or not. """ module_available, _ = model_utils.check_lib_version( 'nvidia.dali', checked_version=min_version, operator=operator.ge) # If DALI is not installed if module_available is None: if verbose: logging.info(DALI_INSTALLATION_MESSAGE) return False return module_available
def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) -> torch.nn.Module: loss_function_names = list(RNNT_LOSS_RESOLVER.keys()) if loss_name not in loss_function_names: raise ValueError( f"Provided `loss_name` {loss_name} not in list of available RNNT losses \n" f"{loss_function_names}") all_available_losses = { name: config for name, config in RNNT_LOSS_RESOLVER.items() if config.is_available } loss_config = RNNT_LOSS_RESOLVER[loss_name] # type: RNNTLossConfig # Re-raise import error with installation message if not loss_config.is_available: msg = ( f"Installed RNNT losses are : {list(all_available_losses.keys())}.\n" f"****************************************************************\n" f"To install the selected loss function, please follow the steps below:\n" f"{loss_config.installation_msg}") raise ImportError(msg) # Library version check if loss_config.min_version is not None: ver_matched, msg = model_utils.check_lib_version( loss_config.lib_name, checked_version=loss_config.min_version, operator=operator.ge) if ver_matched is False: msg = ( f"{msg}\n" f"****************************************************************\n" f"To update the selected loss function, please follow the steps below:\n" f"{loss_config.installation_msg}") raise RuntimeError(msg) # Resolve loss functions sequentially loss_kwargs = {} if loss_kwargs is None else loss_kwargs if isinstance(loss_kwargs, DictConfig): loss_kwargs = OmegaConf.to_container(loss_kwargs, resolve=True) # Get actual loss name for `default` if loss_name == 'default': loss_name = loss_config.loss_name """ Resolve RNNT loss functions """ if loss_name == 'warprnnt': loss_func = warprnnt.RNNTLoss(blank=blank_idx, reduction='none') _warn_unused_additional_kwargs(loss_name, loss_kwargs) elif loss_name == 'warprnnt_numba': fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0) clamp = loss_kwargs.pop('clamp', -1.0) loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda, clamp=clamp) _warn_unused_additional_kwargs(loss_name, loss_kwargs) else: raise ValueError( f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}") return loss_func