def test_sequence_length_2D(input_sequence: List[List[int]], expected_output: List[int]): output_seq_length = sequence_length_2D(torch.tensor(input_sequence)) assert torch.equal(torch.tensor(expected_output), output_seq_length)
def get_current_value(self, preds: Tensor, target: Tensor) -> Tensor: target = target.type(preds.dtype) target_sequence_length = sequence_length_2D(target) masked_correct_preds = masked_correct_predictions( target, preds, target_sequence_length) return torch.mean(masked_correct_preds)