def test_minimal_reference_with_simple_attn(): tokens = [ '[CLS]', 'for', 'the', '_MATH_', '-', 'th', 'disc', 're', 'pan', 'cy', 'have', '_MATHDISP_', ',', 'where', '_MATH_', ',', '_MATH_', ',', '_MATH_' ] words = [ '[CLS]', 'For', 'the', '_MATH_-th', 'discrepancy', 'have', '_MATHDISP_', ',', 'where', '_MATH_', ',', '_MATH_', ',', '_MATH_' ] word_ends = [ '[CLS]', 'for', 'the', 'th', 'cy', 'have', '_MATHDISP_', ',', 'where', '_MATH_', ',', '_MATH_', ',', '_MATH_' ] saliency = np.ones((len(tokens)), dtype=np.float32) attention = [saliency for _ in range(len(tokens))] attention = np.array(attention, dtype=np.float32) merged = reference.merge(attention, tokens, words, word_ends) expected = np.array([[1, 1, 1, 3, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1] for _ in range(len(words))], dtype=np.float32) np.testing.assert_allclose(merged, expected) for row1, row2 in zip(merged, merged[1:]): np.testing.assert_allclose(row1, row2)
def test_simple(): tokens = ["AB"] words = ["A", "B"] word_ends = ["AB", "AB"] attention = np.array([[1]], dtype=np.float32) merged = merge(attention, tokens, words, word_ends) expected = np.array([[1, 0], [0, 0]], dtype=np.float32) np.testing.assert_allclose(merged, expected)
def test_3x3(): tokens = ["A", "B", "C"] words = ["A", "B", "C"] word_ends = ["A", "B", "C"] attention = np.ones((3, 3), dtype=np.float32) merged = reference.merge(attention, tokens, words, word_ends) np.testing.assert_allclose(merged, attention) merged = merge(attention, tokens, words, word_ends) np.testing.assert_allclose(merged, attention)
def test_minimal_example_with_repeat(): tokens = ['a', 'b', 'a'] words = ['ab', 'a'] word_ends = ['b', 'a'] attention = np.ones((len(tokens), len(tokens)), dtype=np.float32) merged = reference.merge(attention, tokens, words, word_ends, verbosity=2) expected = np.array([[2, 1] for _ in range(len(words))], dtype=np.float32) np.testing.assert_allclose(merged, expected)
def test_unbalanced(): tokens = ["straw", "##berries"] words = ["strawberries"] word_ends = ["##berries"] attention = np.array([[0.2, 0.8], [0.2, 0.8]], dtype=np.float32) merged = merge(attention, tokens, words, word_ends) expected = np.array([[1.0]]) np.testing.assert_allclose(merged, expected) merged = reference.merge(attention, tokens, words, word_ends) expected = np.array([[1.0]]) np.testing.assert_allclose(merged, expected)
def test_near_zero(): tokens = ["A", "B", "C"] words = ["A", "B", "C"] word_ends = ["A", "B", "C"] attention = np.array( [[1e-16, 1e-16, 1e-16], [1e-16, 1e-16, 1e-16], [1e-16, 1e-16, 1e-16]], dtype=np.float32, ) merged = reference.merge(attention, tokens, words, word_ends) np.testing.assert_allclose(merged, attention) merged = merge(attention, tokens, words, word_ends) np.testing.assert_allclose(merged, attention)
def test_the_smallest_so_far_reference(): tokens = ['_MATH_', 'th', '_MATH_'] words = ['_MATH_th', '_MATH_'] word_ends = ['th', '_MATH_'] attention = np.ones((len(tokens), len(tokens)), dtype=np.float32) merged = reference.merge(attention, tokens, words, word_ends) expected = np.array([[2, 1] for _ in range(len(words))], dtype=np.float32) np.testing.assert_allclose(merged, expected) for row1, row2 in zip(merged, merged[1:]): np.testing.assert_allclose(row1, row2)
def test_against_reference(args): attn, tokens, word_ends = args np.testing.assert_allclose(merge(attn, tokens, word_ends, word_ends), reference.merge(attn, tokens, word_ends, word_ends), atol=1e-16)