Ejemplo n.º 1
0
    def compute(self) -> Tensor:
        """First concat state ``indexes``, ``preds`` and ``target`` since they were stored as lists.

        After that, compute list of groups that will help in keeping together predictions about the same query. Finally,
        for each group compute the ``_metric`` if the number of positive targets is at least 1, otherwise behave as
        specified by ``self.empty_target_action``.
        """
        indexes = torch.cat(self.indexes, dim=0)
        preds = torch.cat(self.preds, dim=0)
        target = torch.cat(self.target, dim=0)

        res = []
        groups = get_group_indexes(indexes)

        for group in groups:
            mini_preds = preds[group]
            mini_target = target[group]

            if not mini_target.sum():
                if self.empty_target_action == "error":
                    raise ValueError(
                        "`compute` method was provided with a query with no positive target."
                    )
                if self.empty_target_action == "pos":
                    res.append(tensor(1.0))
                elif self.empty_target_action == "neg":
                    res.append(tensor(0.0))
            else:
                # ensure list contains only float tensors
                res.append(self._metric(mini_preds, mini_target))

        return torch.stack([x.to(preds) for x in res
                            ]).mean() if res else tensor(0.0).to(preds)
Ejemplo n.º 2
0
def _compute_sklearn_metric(
    preds: Union[Tensor, array],
    target: Union[Tensor, array],
    indexes: np.ndarray = None,
    metric: Callable = None,
    empty_target_action: str = "skip",
    **kwargs
) -> Tensor:
    """ Compute metric with multiple iterations over every query predictions set. """

    if indexes is None:
        indexes = np.full_like(preds, fill_value=0, dtype=np.int64)
    if isinstance(indexes, Tensor):
        indexes = indexes.cpu().numpy()
    if isinstance(preds, Tensor):
        preds = preds.cpu().numpy()
    if isinstance(target, Tensor):
        target = target.cpu().numpy()

    assert isinstance(indexes, np.ndarray)
    assert isinstance(preds, np.ndarray)
    assert isinstance(target, np.ndarray)

    indexes = indexes.flatten()
    preds = preds.flatten()
    target = target.flatten()
    groups = get_group_indexes(indexes)

    sk_results = []
    for group in groups:
        trg, pds = target[group], preds[group]

        if trg.sum() == 0:
            if empty_target_action == 'skip':
                pass
            elif empty_target_action == 'pos':
                sk_results.append(1.0)
            else:
                sk_results.append(0.0)
        else:
            res = metric(trg, pds, **kwargs)
            sk_results.append(res)

    if len(sk_results) > 0:
        return np.mean(sk_results)
    return np.array(0.0)
Ejemplo n.º 3
0
    def compute(self) -> Tensor:
        r"""
        First concat state `idx`, `preds` and `target` since they were stored as lists. After that,
        compute list of groups that will help in keeping together predictions about the same query.
        Finally, for each group compute the `_metric` if the number of positive targets is at least
        1, otherwise behave as specified by `self.query_without_relevant_docs`.
        """

        idx = torch.cat(self.idx, dim=0)
        preds = torch.cat(self.preds, dim=0)
        target = torch.cat(self.target, dim=0)

        res = []
        kwargs = {'device': idx.device, 'dtype': torch.float32}

        groups = get_group_indexes(idx)
        for group in groups:

            mini_preds = preds[group]
            mini_target = target[group]

            if not mini_target.sum():
                if self.query_without_relevant_docs == 'error':
                    raise ValueError(
                        f"`{self.__class__.__name__}.compute()` was provided with "
                        f"a query without positive targets, indexes: {group}"
                    )
                if self.query_without_relevant_docs == 'pos':
                    res.append(torch.tensor(1.0, **kwargs))
                elif self.query_without_relevant_docs == 'neg':
                    res.append(torch.tensor(0.0, **kwargs))
            else:
                res.append(self._metric(mini_preds, mini_target))

        if len(res) > 0:
            return torch.stack(res).mean()
        return torch.tensor(0.0, **kwargs)