Esempio n. 1
0
def test_whole_word_target_sampling_count_negative():
    test = torch.arange(0, 10).long()
    whole_word_masker = {idx: 1 for idx in range(10)}
    # Set some words as partial tokens
    whole_word_masker[0] = 0
    whole_word_masker[5] = 0
    whole_word_masker[9] = 0
    result = whole_word_sampling(test,
                                 whole_word_masker,
                                 seq_sample_ratio=1.0,
                                 word_count_to_sample=-10,
                                 contains_eos=False)
    assert len(result) == 0, "There should be no constraints"
Esempio n. 2
0
def test_whole_word_target_sampling_with_eos():
    test = torch.arange(
        0, 10).long()  # eos is the last element and is considered a whole word
    whole_word_masker = {idx: 1 for idx in range(10)}
    # Set some words as partial tokens
    whole_word_masker[5] = 0
    whole_word_masker[9] = 0
    result = whole_word_sampling(test,
                                 whole_word_masker,
                                 seq_sample_ratio=1.0,
                                 word_count_to_sample=2,
                                 contains_eos=True)
    assert len(result) == 2, "There should be two constraints"
Esempio n. 3
0
def test_whole_word_target_sampling_all_partial():
    test = torch.arange(0, 10).long()
    whole_word_masker = {idx: 1 for idx in range(10)}
    # make sure that every other word is a partial token
    whole_word_masker[1] = 0
    whole_word_masker[3] = 0
    whole_word_masker[5] = 0
    whole_word_masker[7] = 0
    whole_word_masker[9] = 0
    result = whole_word_sampling(test,
                                 whole_word_masker,
                                 seq_sample_ratio=1.0,
                                 word_count_to_sample=2,
                                 contains_eos=False)
    assert len(result) == 2, "There should be two constraints"
    assert all(list(
        len(constraint) == 2
        for constraint in result)), "All constraints should have two elements"