Example #1
0
def test_masked_correct_predictions():
    preds = torch.tensor([[1, 5, 1, 5, 1, 5, 12, 12, 12],
                          [10, 1, 5, 1, 5, 12, 12, 12, 12]])
    targets = torch.tensor([[1, 9, 5, 7, 5, 9, 13, 6, 0],
                            [1, 9, 7, 13, 4, 7, 7, 7, 0]])
    targets_sequence_length = torch.tensor([8, 8])

    result = metric_utils.masked_correct_predictions(targets, preds,
                                                     targets_sequence_length)

    assert torch.equal(
        result,
        torch.Tensor([
            1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
            0.0, 0.0, 0.0
        ]))
Example #2
0
 def get_current_value(self, preds: Tensor, target: Tensor) -> Tensor:
     target = target.type(preds.dtype)
     target_sequence_length = sequence_length_2D(target)
     masked_correct_preds = masked_correct_predictions(
         target, preds, target_sequence_length)
     return torch.mean(masked_correct_preds)