def retrieval_r_precision(preds: Tensor, target: Tensor) -> Tensor: """Computes the r-precision metric (for information retrieval). R-Precision is the fraction of relevant documents among all the top ``k`` retrieved documents where ``k`` is equal to the total number of relevant documents. ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, otherwise an error is raised. If you want to measure Precision@K, ``k`` must be a positive integer. Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. Returns: a single-value tensor with the r-precision of the predictions ``preds`` w.r.t. the labels ``target``. Example: >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_r_precision(preds, target) tensor(0.5000) """ preds, target = _check_retrieval_functional_inputs(preds, target) relevant_number = target.sum() if not relevant_number: return tensor(0.0, device=preds.device) relevant = target[torch.argsort(preds, dim=-1, descending=True)][:relevant_number].sum().float() return relevant / relevant_number
def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor: """Computes reciprocal rank (for information retrieval). See `Mean Reciprocal Rank`_ ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, 0 is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be ``float``, otherwise an error is raised. Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. Return: a single-value tensor with the reciprocal rank (RR) of the predictions ``preds`` wrt the labels ``target``. Example: >>> from torchmetrics.functional import retrieval_reciprocal_rank >>> preds = torch.tensor([0.2, 0.3, 0.5]) >>> target = torch.tensor([False, True, False]) >>> retrieval_reciprocal_rank(preds, target) tensor(0.5000) """ preds, target = _check_retrieval_functional_inputs(preds, target) if not target.sum(): return tensor(0.0, device=preds.device) target = target[torch.argsort(preds, dim=-1, descending=True)] position = torch.nonzero(target).view(-1) res = 1.0 / (position[0] + 1.0) return res
def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor: """Computes average precision (for information retrieval), as explained in `IR Average precision`_. ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, otherwise an error is raised. Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. Return: a single-value tensor with the average precision (AP) of the predictions ``preds`` w.r.t. the labels ``target``. Example: >>> from torchmetrics.functional import retrieval_average_precision >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_average_precision(preds, target) tensor(0.8333) """ preds, target = _check_retrieval_functional_inputs(preds, target) if not target.sum(): return tensor(0.0, device=preds.device) target = target[torch.argsort(preds, dim=-1, descending=True)] positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0] res = torch.div((torch.arange( len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean() return res
def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None, adaptive_k: bool = False) -> Tensor: """Computes the precision metric (for information retrieval). Precision is the fraction of relevant documents among all the retrieved documents. ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be ``float``, otherwise an error is raised. If you want to measure Precision@K, ``k`` must be a positive integer. Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. k: consider only the top k elements (default: ``None``, which considers them all) adaptive_k: adjust `k` to `min(k, number of documents)` for each query Returns: a single-value tensor with the precision (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``. Raises: ValueError: If ``k`` is not `None` or an integer larger than 0. ValueError: If ``adaptive_k`` is not boolean. Example: >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_precision(preds, target, k=2) tensor(0.5000) """ preds, target = _check_retrieval_functional_inputs(preds, target) if not isinstance(adaptive_k, bool): raise ValueError("`adaptive_k` has to be a boolean") if k is None or (adaptive_k and k > preds.shape[-1]): k = preds.shape[-1] if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") if not target.sum(): return tensor(0.0, device=preds.device) relevant = target[preds.topk(min(k, preds.shape[-1]), dim=-1)[1]].sum().float() return relevant / k
def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor: """Computes `Normalized Discounted Cumulative Gain`_ (for information retrieval). ``preds`` and ``target`` should be of the same shape and live on the same device. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, otherwise an error is raised. Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document relevance. k: consider only the top k elements (default: `None`, which considers them all) Return: a single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target``. Raises: ValueError: If ``k`` parameter is not `None` or an integer larger than 0 Example: >>> from torchmetrics.functional import retrieval_normalized_dcg >>> preds = torch.tensor([.1, .2, .3, 4, 70]) >>> target = torch.tensor([10, 0, 0, 1, 5]) >>> retrieval_normalized_dcg(preds, target) tensor(0.6957) """ preds, target = _check_retrieval_functional_inputs(preds, target, allow_non_binary_target=True) k = preds.shape[-1] if k is None else k if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") sorted_target = target[torch.argsort(preds, dim=-1, descending=True)][:k] ideal_target = torch.sort(target, descending=True)[0][:k] ideal_dcg = _dcg(ideal_target) target_dcg = _dcg(sorted_target) # filter undefined scores all_irrelevant = ideal_dcg == 0 target_dcg[all_irrelevant] = 0 target_dcg[~all_irrelevant] /= ideal_dcg[~all_irrelevant] return target_dcg.mean()
def retrieval_fall_out(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor: """Computes the Fall-out (for information retrieval), as explained in `IR Fall-out`_ Fall-out is the fraction of non-relevant documents retrieved among all the non-relevant documents. ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be ``float``, otherwise an error is raised. If you want to measure Fall-out@K, ``k`` must be a positive integer. Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. k: consider only the top k elements (default: ``None``, which considers them all) Returns: a single-value tensor with the fall-out (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``. Raises: ValueError: If ``k`` parameter is not `None` or an integer larger than 0 Example: >>> from torchmetrics.functional import retrieval_fall_out >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_fall_out(preds, target, k=2) tensor(1.) """ preds, target = _check_retrieval_functional_inputs(preds, target) k = preds.shape[-1] if k is None else k if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") target = 1 - target # we want to compute the probability of getting a non-relevant doc among all non-relevant docs if not target.sum(): return tensor(0.0, device=preds.device) relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum().float() return relevant / target.sum()
def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor: """ Computes Normalized Discounted Cumulative Gain (for information retrieval), as explained `here <https://en.wikipedia.org/wiki/Discounted_cumulative_gain>`__. ``preds`` and ``target`` should be of the same shape and live on the same device. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, otherwise an error is raised. Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document relevance. k: consider only the top k elements (default: None) Return: a single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target``. Example: >>> from torchmetrics.functional import retrieval_normalized_dcg >>> preds = torch.tensor([.1, .2, .3, 4, 70]) >>> target = torch.tensor([10, 0, 0, 1, 5]) >>> retrieval_normalized_dcg(preds, target) tensor(0.6957) """ preds, target = _check_retrieval_functional_inputs( preds, target, allow_non_binary_target=True) k = preds.shape[-1] if k is None else k if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") if not target.sum(): return tensor(0.0, device=preds.device) sorted_target = target[torch.argsort(preds, dim=-1, descending=True)][:k] ideal_target = torch.sort(target, descending=True)[0][:k] return _dcg(sorted_target) / _dcg(ideal_target)
def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor: """ Computes the precision metric (for information retrieval), as explained `here <https://en.wikipedia.org/wiki/Precision_and_recall#Precision>`__. Precision is the fraction of relevant documents among all the retrieved documents. ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, otherwise an error is raised. If you want to measure Precision@K, ``k`` must be a positive integer. Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. k: consider only the top k elements (default: None) Returns: a single-value tensor with the precision (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``. Example: >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_precision(preds, target, k=2) tensor(0.5000) """ preds, target = _check_retrieval_functional_inputs(preds, target) if k is None: k = preds.shape[-1] if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") if not target.sum(): return tensor(0.0, device=preds.device) relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum().float() return relevant / k
def retrieval_hit_rate(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor: """Computes the hit rate (for information retrieval). The hit rate is 1.0 if there is at least one relevant document among all the top `k` retrieved documents. ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be ``float``, otherwise an error is raised. If you want to measure HitRate@K, ``k`` must be a positive integer. Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. k: consider only the top k elements (default: `None`, which considers them all) Returns: a single-value tensor with the hit rate (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``. Raises: ValueError: If ``k`` parameter is not `None` or an integer larger than 0 Example: >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_hit_rate(preds, target, k=2) tensor(1.) """ preds, target = _check_retrieval_functional_inputs(preds, target) if k is None: k = preds.shape[-1] if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum() return (relevant > 0).float()