Beispiel #1
0
def get_accuracy(targets, outputs, k=1, ignore_index=None):
    """ Get the accuracy top-k accuracy between two tensors.

    Args:
      targets (1 - 2D :class:`torch.Tensor`): Target or true vector against which to measure
          saccuracy
      outputs (1 - 3D :class:`torch.Tensor`): Prediction or output vector
      ignore_index (int, optional): Specifies a target index that is ignored

    Returns:
      :class:`tuple` consisting of accuracy (:class:`float`), number correct (:class:`int`) and
      total (:class:`int`)

    Example:

        >>> import torch
        >>> from torchnlp.metrics import get_accuracy
        >>> targets = torch.LongTensor([1, 2, 3, 4, 5])
        >>> outputs = torch.LongTensor([1, 2, 2, 3, 5])
        >>> accuracy, n_correct, n_total = get_accuracy(targets, outputs, ignore_index=3)
        >>> accuracy
        0.8
        >>> n_correct
        4
        >>> n_total
        5
    """
    n_correct = 0.0
    for target, output in zip(targets, outputs):
        if not torch.is_tensor(target):
            target = torch.LongTensor([target])

        if not torch.is_tensor(output):
            output = torch.LongTensor([[output]])

        predictions = output.topk(k=min(k, len(output)), dim=0)[0]
        for prediction in predictions:
            if not torch.is_tensor(prediction):
                prediction = torch.LongTensor([prediction])

            if torch_equals_ignore_index(target,
                                         prediction,
                                         ignore_index=ignore_index):
                n_correct += 1
                break

    return n_correct / len(targets), n_correct, len(targets)
Beispiel #2
0
def test_torch_equals_ignore_index():
    source = torch.LongTensor([1, 2, 3])
    target = torch.LongTensor([1, 2, 4])
    assert torch_equals_ignore_index(source, target, ignore_index=3)
    assert not torch_equals_ignore_index(source, target)