def test_real_with_simple_attn(): tokens = [ '[CLS]', 'for', 'the', '_MATH_', '-', 'th', 'disc', 're', 'pan', 'cy', 'we', 'have', '_MATHDISP_', ',', 'where', '_MATH_', ',', '_MATH_', ',', '_MATH_', 'and', '_MATH_', '.', '[SEP]' ] words = [ '[CLS]', 'For', 'the', '_MATH_-th', 'discrepancy', 'we', 'have', '_MATHDISP_', ',', 'where', '_MATH_', ',', '_MATH_', ',', '_MATH_', 'and', '_MATH_', '.', '[SEP]' ] word_ends = [ '[CLS]', 'for', 'the', 'th', 'cy', 'we', 'have', '_MATHDISP_', ',', 'where', '_MATH_', ',', '_MATH_', ',', '_MATH_', 'and', '_MATH_', '.', '[SEP]' ] saliency = np.ones((len(tokens)), dtype=np.float32) attention = [saliency for _ in range(len(tokens))] attention = np.array(attention, dtype=np.float32) merged = merge(attention, tokens, words, word_ends) expected = np.array( [[1, 1, 1, 3, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] for _ in range(len(words))], dtype=np.float32) np.testing.assert_allclose(merged, expected) for row1, row2 in zip(merged, merged[1:]): np.testing.assert_allclose(row1, row2)
def test_paren_comma() -> None: tokens = ["),"] words = [")", ","] word_ends = ["),", "),"] attention = np.array([[1.0]], dtype=np.float32) merged = merge(attention, tokens, words, word_ends) expected = np.array([[1, 0], [0, 0]], dtype=np.float32) np.testing.assert_allclose(merged, expected, atol=1e-10)
def test_3x3(): tokens = ["A", "B", "C"] words = ["A", "B", "C"] word_ends = ["A", "B", "C"] attention = np.ones((3, 3), dtype=np.float32) merged = reference.merge(attention, tokens, words, word_ends) np.testing.assert_allclose(merged, attention) merged = merge(attention, tokens, words, word_ends) np.testing.assert_allclose(merged, attention)
def test_paren_comma_with_many_words() -> None: tokens = ["at", "),", "E", "q"] words = ["at", ")", ",", "Eq"] word_ends = ["at", "),", "),", "q"] attention = np.ones((len(tokens), len(tokens)), dtype=np.float32) merged = merge(attention, tokens, words, word_ends) assert merged.shape == (4, 4) expected = np.array( [[1, 1, 0, 2], [1, 1, 0, 2], [0, 0, 0, 0], [1, 1, 0, 2]], dtype=np.float32, ) np.testing.assert_allclose(merged, expected, atol=1e-10)
def test_unbalanced(): tokens = ["straw", "##berries"] words = ["strawberries"] word_ends = ["##berries"] attention = np.array([[0.2, 0.8], [0.2, 0.8]], dtype=np.float32) merged = merge(attention, tokens, words, word_ends) expected = np.array([[1.0]]) np.testing.assert_allclose(merged, expected) merged = reference.merge(attention, tokens, words, word_ends) expected = np.array([[1.0]]) np.testing.assert_allclose(merged, expected)
def test_simple(): tokens = ["A", "B"] words = ["AB"] word_ends = ["B"] attention = np.array([[1, 0], [0, 1]], dtype=np.float32) merged = merge(attention, tokens, words, word_ends) expected = np.array([[1.0]]) np.testing.assert_allclose(merged, expected) merged = reference.merge(attention, tokens, words, word_ends) expected = np.array([[1.0]]) np.testing.assert_allclose(merged, expected)
def test_simple(): tokens = ['A', 'B', 'C'] words = ['AB', 'C'] word_ends = ['B', 'C'] saliency = np.array([0.1, 0.2, 0.3], dtype=np.float32) attention = [saliency for _ in range(len(tokens))] attention = np.array(attention, dtype=np.float32) merged = merge(attention, tokens, words, word_ends) expected = np.array([[0.3, 0.3], [0.3, 0.3]], dtype=np.float32) np.testing.assert_allclose(merged, expected)
def test_near_zero(): tokens = ["A", "B", "C"] words = ["A", "B", "C"] word_ends = ["A", "B", "C"] attention = np.array( [[1e-16, 1e-16, 1e-16], [1e-16, 1e-16, 1e-16], [1e-16, 1e-16, 1e-16]], dtype=np.float32, ) merged = reference.merge(attention, tokens, words, word_ends) np.testing.assert_allclose(merged, attention) merged = merge(attention, tokens, words, word_ends) np.testing.assert_allclose(merged, attention)
def test_against_reference(args): attn, tokens, word_ends = args np.testing.assert_allclose(merge(attn, tokens, word_ends, word_ends), reference.merge(attn, tokens, word_ends, word_ends), atol=1e-16)
def test_saliency_properties(args): saliency, tokens, word_ends = args merged = merge(saliency, tokens, word_ends, word_ends), for row1, row2 in zip(merged, merged[1:]): np.testing.assert_allclose(row1, row2)