def test_token_aligner_project_to_empty_target_token_sequence():
    source_tokens = ["abc", "def", "ghi", "jkl"]
    target_tokens = []
    ta = TokenAligner(source_tokens, target_tokens)
    m = ta.project_tokens([1, 3])
    m_expected = np.array([])
    assert (m == m_expected).all()
def test_token_aligner_project_to_mismatched_token_sequence():
    source_tokens = ["abc", "def", "ghi", "jkl"]
    target_tokens = ["qrs", "tuv", "wxy", "z"]
    ta = TokenAligner(source_tokens, target_tokens)
    m = ta.project_tokens([1])
    m_expected = np.array([])
    assert (m == m_expected).all()
def test_moses_tok_idx_proj_3():
    src_tokens = [
        "Mr.",
        "Immelt",
        "chose",
        "to",
        "focus",
        "on",
        "the",
        "incomprehensibility",
        "of",
        "accounting",
        "rules.",
    ]
    tgt_tokens = [
        "Mr.",
        "Immelt",
        "chose",
        "to",
        "focus",
        "on",
        "the",
        "incomprehensibility",
        "of",
        "accounting",
        "rules",
        ".",
    ]
    tgt_token_index = [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9],
                       [10, 11]]
    ta = TokenAligner(src_tokens, tgt_tokens)
    for src_token_idx in range(len(src_tokens)):
        projected_tgt_tok_idx = ta.project_tokens(src_token_idx)
        assert (tgt_token_index[src_token_idx] == projected_tgt_tok_idx).all()
def test_token_aligner_project_multiple_token_indices():
    source_tokens = ["abc", "def", "ghi", "jkl"]
    target_tokens = ["abc", "d", "ef", "ghi", "jkl"]
    ta = TokenAligner(source_tokens, target_tokens)
    m = ta.project_tokens([1, 3])
    m_expected = np.array([1, 2, 4])
    assert (m == m_expected).all()
def test_token_aligner_project_single_token_index():
    source_tokens = ["abc", "def", "ghi", "jkl"]
    target_tokens = ["abc", "d", "ef", "ghi", "jkl"]
    ta = TokenAligner(source_tokens, target_tokens)
    m = ta.project_tokens(1)
    m_expected = np.array([1, 2])
    assert (m == m_expected).all()
def test_wpm_tok_idx_proj_1():
    src_tokens = ["Members", "of", "the", "House", "clapped", "their", "hands"]
    tgt_tokens = ["Members", "of", "the", "House", "clapped", "their", "hands"]
    tgt_token_index = [[0], [1], [2], [3], [4], [5], [6]]
    ta = TokenAligner(src_tokens, tgt_tokens)
    for src_token_idx in range(len(src_tokens)):
        projected_tgt_tok_idx = ta.project_tokens(src_token_idx)
        assert (tgt_token_index[src_token_idx] == projected_tgt_tok_idx).all()
def test_bytebpe_tok_idx_proj_4():
    src_tokens = ["What?"]
    tgt_tokens = ["What", "?"]
    tgt_token_index = [[0, 1]]
    ta = TokenAligner(src_tokens, tgt_tokens)
    for src_token_idx in range(len(src_tokens)):
        projected_tgt_tok_idx = ta.project_tokens(src_token_idx)
        assert (tgt_token_index[src_token_idx] == projected_tgt_tok_idx).all()
def test_sentencepiece_tok_idx_proj_3():
    src_tokens = [
        "Mr.",
        "Immelt",
        "chose",
        "to",
        "focus",
        "on",
        "the",
        "incomprehensibility",
        "of",
        "accounting",
        "rules.",
    ]
    tgt_tokens = [
        "▁Mr",
        ".",
        "▁I",
        "m",
        "mel",
        "t",
        "▁chose",
        "▁to",
        "▁focus",
        "▁on",
        "▁the",
        "▁in",
        "comp",
        "re",
        "hen",
        "s",
        "ibility",
        "▁of",
        "▁accounting",
        "▁rules",
        ".",
    ]
    tgt_token_index = [
        [0, 1],
        [2, 3, 4, 5],
        [6],
        [7],
        [8],
        [9],
        [10],
        [11, 12, 13, 14, 15, 16],
        [17],
        [18],
        [19, 20],
    ]
    ta = TokenAligner(src_tokens, tgt_tokens)
    for src_token_idx in range(len(src_tokens)):
        projected_tgt_tok_idx = ta.project_tokens(src_token_idx)
        assert (tgt_token_index[src_token_idx] == projected_tgt_tok_idx).all()
def test_moses_tok_idx_proj_2():
    src_tokens = ["I", "look", "at", "Sarah's", "dog.", "It", "was", "cute.!"]
    tgt_tokens = [
        "I", "look", "at", "Sarah", "'s", "dog", ".", "It", "was", "cute",
        ".", "!"
    ]
    tgt_token_index = [[0], [1], [2], [3, 4], [5, 6], [7], [8], [9, 10, 11]]
    ta = TokenAligner(src_tokens, tgt_tokens)
    for src_token_idx in range(len(src_tokens)):
        projected_tgt_tok_idx = ta.project_tokens(src_token_idx)
        assert (tgt_token_index[src_token_idx] == projected_tgt_tok_idx).all()
def test_wpm_tok_idx_proj_3():
    src_tokens = [
        "Mr.",
        "Immelt",
        "chose",
        "to",
        "focus",
        "on",
        "the",
        "incomprehensibility",
        "of",
        "accounting",
        "rules.",
    ]
    tgt_tokens = [
        "Mr",
        ".",
        "I",
        "##mme",
        "##lt",
        "chose",
        "to",
        "focus",
        "on",
        "the",
        "in",
        "##com",
        "##p",
        "##re",
        "##hen",
        "##si",
        "##bility",
        "of",
        "accounting",
        "rules",
        ".",
    ]
    tgt_token_index = [
        [0, 1],
        [2, 3, 4],
        [5],
        [6],
        [7],
        [8],
        [9],
        [10, 11, 12, 13, 14, 15, 16],
        [17],
        [18],
        [19, 20],
    ]
    ta = TokenAligner(src_tokens, tgt_tokens)
    for src_token_idx in range(len(src_tokens)):
        projected_tgt_tok_idx = ta.project_tokens(src_token_idx)
        assert (tgt_token_index[src_token_idx] == projected_tgt_tok_idx).all()
def test_wpm_tok_idx_proj_2():
    src_tokens = ["I", "look", "at", "Sarah's", "dog.", "It", "was", "cute.!"]
    tgt_tokens = [
        "I", "look", "at", "Sarah", "'", "s", "dog", ".", "It", "was", "cute",
        ".", "!"
    ]
    tgt_token_index = [[0], [1], [2], [3, 4, 5], [6, 7], [8], [9],
                       [10, 11, 12]]
    ta = TokenAligner(src_tokens, tgt_tokens)
    for src_token_idx in range(len(src_tokens)):
        projected_tgt_tok_idx = ta.project_tokens(src_token_idx)
        assert (tgt_token_index[src_token_idx] == projected_tgt_tok_idx).all()
def test_bpe_tok_idx_proj_2():
    src_tokens = ["I", "look", "at", "Sarah's", "dog.", "It", "was", "cute.!"]
    tgt_tokens = [
        "i</w>",
        "look</w>",
        "at</w>",
        "sarah</w>",
        "'s</w>",
        "dog</w>",
        ".</w>",
        "it</w>",
        "was</w>",
        "cute</w>",
        ".</w>",
        "!</w>",
    ]
    tgt_token_index = [[0], [1], [2], [3, 4], [5, 6], [7], [8], [9, 10, 11]]
    ta = TokenAligner(src_tokens, tgt_tokens)
    for src_token_idx in range(len(src_tokens)):
        projected_tgt_tok_idx = ta.project_tokens(src_token_idx)
        assert (tgt_token_index[src_token_idx] == projected_tgt_tok_idx).all()
def test_bytebpe_tok_idx_proj_3():
    src_tokens = [
        "Mr.",
        "Immelt",
        "chose",
        "to",
        "focus",
        "on",
        "the",
        "incomprehensibility",
        "of",
        "accounting",
        "rules.",
    ]
    tgt_tokens = [
        "Mr",
        ".",
        "ĠImm",
        "elt",
        "Ġchose",
        "Ġto",
        "Ġfocus",
        "Ġon",
        "Ġthe",
        "Ġincomp",
        "rehens",
        "ibility",
        "Ġof",
        "Ġaccounting",
        "Ġrules",
        ".",
    ]
    tgt_token_index = [[0, 1], [2, 3], [4], [5], [6], [7], [8], [9, 10, 11],
                       [12], [13], [14, 15]]
    ta = TokenAligner(src_tokens, tgt_tokens)
    for src_token_idx in range(len(src_tokens)):
        projected_tgt_tok_idx = ta.project_tokens(src_token_idx)
        assert (tgt_token_index[src_token_idx] == projected_tgt_tok_idx).all()
def test_bpe_tok_idx_proj_3():
    src_tokens = [
        "Mr.",
        "Immelt",
        "chose",
        "to",
        "focus",
        "on",
        "the",
        "incomprehensibility",
        "of",
        "accounting",
        "rules.",
    ]
    tgt_tokens = [
        "mr.</w>",
        "im",
        "melt</w>",
        "chose</w>",
        "to</w>",
        "focus</w>",
        "on</w>",
        "the</w>",
        "in",
        "comprehen",
        "si",
        "bility</w>",
        "of</w>",
        "accounting</w>",
        "rules</w>",
        ".</w>",
    ]
    tgt_token_index = [[0], [1, 2], [3], [4], [5], [6], [7], [8, 9, 10, 11],
                       [12], [13], [14, 15]]
    ta = TokenAligner(src_tokens, tgt_tokens)
    for src_token_idx in range(len(src_tokens)):
        projected_tgt_tok_idx = ta.project_tokens(src_token_idx)
        assert (tgt_token_index[src_token_idx] == projected_tgt_tok_idx).all()