def compute(self) -> Tensor: """First concat state ``indexes``, ``preds`` and ``target`` since they were stored as lists. After that, compute list of groups that will help in keeping together predictions about the same query. Finally, for each group compute the ``_metric`` if the number of positive targets is at least 1, otherwise behave as specified by ``self.empty_target_action``. """ indexes = torch.cat(self.indexes, dim=0) preds = torch.cat(self.preds, dim=0) target = torch.cat(self.target, dim=0) res = [] groups = get_group_indexes(indexes) for group in groups: mini_preds = preds[group] mini_target = target[group] if not mini_target.sum(): if self.empty_target_action == "error": raise ValueError( "`compute` method was provided with a query with no positive target." ) if self.empty_target_action == "pos": res.append(tensor(1.0)) elif self.empty_target_action == "neg": res.append(tensor(0.0)) else: # ensure list contains only float tensors res.append(self._metric(mini_preds, mini_target)) return torch.stack([x.to(preds) for x in res ]).mean() if res else tensor(0.0).to(preds)
def _compute_sklearn_metric( preds: Union[Tensor, array], target: Union[Tensor, array], indexes: np.ndarray = None, metric: Callable = None, empty_target_action: str = "skip", **kwargs ) -> Tensor: """ Compute metric with multiple iterations over every query predictions set. """ if indexes is None: indexes = np.full_like(preds, fill_value=0, dtype=np.int64) if isinstance(indexes, Tensor): indexes = indexes.cpu().numpy() if isinstance(preds, Tensor): preds = preds.cpu().numpy() if isinstance(target, Tensor): target = target.cpu().numpy() assert isinstance(indexes, np.ndarray) assert isinstance(preds, np.ndarray) assert isinstance(target, np.ndarray) indexes = indexes.flatten() preds = preds.flatten() target = target.flatten() groups = get_group_indexes(indexes) sk_results = [] for group in groups: trg, pds = target[group], preds[group] if trg.sum() == 0: if empty_target_action == 'skip': pass elif empty_target_action == 'pos': sk_results.append(1.0) else: sk_results.append(0.0) else: res = metric(trg, pds, **kwargs) sk_results.append(res) if len(sk_results) > 0: return np.mean(sk_results) return np.array(0.0)
def compute(self) -> Tensor: r""" First concat state `idx`, `preds` and `target` since they were stored as lists. After that, compute list of groups that will help in keeping together predictions about the same query. Finally, for each group compute the `_metric` if the number of positive targets is at least 1, otherwise behave as specified by `self.query_without_relevant_docs`. """ idx = torch.cat(self.idx, dim=0) preds = torch.cat(self.preds, dim=0) target = torch.cat(self.target, dim=0) res = [] kwargs = {'device': idx.device, 'dtype': torch.float32} groups = get_group_indexes(idx) for group in groups: mini_preds = preds[group] mini_target = target[group] if not mini_target.sum(): if self.query_without_relevant_docs == 'error': raise ValueError( f"`{self.__class__.__name__}.compute()` was provided with " f"a query without positive targets, indexes: {group}" ) if self.query_without_relevant_docs == 'pos': res.append(torch.tensor(1.0, **kwargs)) elif self.query_without_relevant_docs == 'neg': res.append(torch.tensor(0.0, **kwargs)) else: res.append(self._metric(mini_preds, mini_target)) if len(res) > 0: return torch.stack(res).mean() return torch.tensor(0.0, **kwargs)