def test_token_insert(self): bart_noise = BARTNoising( vocab=[self.FAKE_VOCAB], mask_tok=self.MASK_TOK, insert_ratio=0.5, random_ratio=0.3, replace_length=0, # not raise Error # Defalt: full_stop_token=[".", "?", "!"] ) tokens = ["This", "looks", "really", "good", "!"] inserted = bart_noise.apply(tokens) n_insert = math.ceil(len(tokens) * bart_noise.insert_ratio) inserted_len = n_insert + len(tokens) self.assertEqual(len(inserted), inserted_len) # random_ratio of inserted tokens are chosen in vocab n_random = math.ceil(n_insert * bart_noise.random_ratio) self.assertEqual( sum(1 if tok == self.FAKE_VOCAB else 0 for tok in inserted), n_random, ) # others are MASK_TOK self.assertEqual( sum(1 if tok == self.MASK_TOK else 0 for tok in inserted), n_insert - n_random, )
def test_rotate(self): bart_noise = BARTNoising( vocab=[self.FAKE_VOCAB], rotate_ratio=1.0, replace_length=0, # not raise Error ) tokens = ["This", "looks", "really", "good", "!"] rotated = bart_noise.apply(tokens) self.assertNotEqual(tokens, rotated) not_rotate = bart_noise.rolling_noise(tokens, p=0.0) self.assertEqual(tokens, not_rotate)
def test_sentence_permute(self): sent1 = ["Hello", "world", "."] sent2 = ["Sentence", "1", "!"] sent3 = ["Sentence", "2", "!"] sent4 = ["Sentence", "3", "!"] bart_noise = BARTNoising( vocab=[self.FAKE_VOCAB], permute_sent_ratio=0.5, replace_length=0, # not raise Error # Defalt: full_stop_token=[".", "?", "!"] ) tokens = sent1 + sent2 + sent3 + sent4 ends = bart_noise._get_sentence_borders(tokens).tolist() self.assertEqual(ends, [3, 6, 9, 12]) tokens_perm = bart_noise.apply(tokens) expected_tokens = sent2 + sent1 + sent3 + sent4 self.assertEqual(expected_tokens, tokens_perm)
def test_span_infilling(self): bart_noise = BARTNoising( vocab=[self.FAKE_VOCAB], mask_tok=self.MASK_TOK, mask_ratio=0.5, mask_length="span-poisson", poisson_lambda=3.0, is_joiner=True, replace_length=1, # insert_ratio=0.5, # random_ratio=0.3, # Defalt: full_stop_token=[".", "?", "!"] ) self.assertIsNotNone(bart_noise.mask_span_distribution) tokens = ["H■", "ell■", "o", "world", ".", "An■", "other", "!"] # start token of word are identified using subword marker token_starts = [True, False, False, True, True, True, False, True] self.assertEqual(bart_noise._is_word_start(tokens), token_starts) bart_noise.apply(copy.copy(tokens))
def test_whole_word_mask(self): """Mask will be done on whole word that may across multiply token. Condition: * `mask_length` == word; * specify subword marker in order to find word boundary. """ bart_noise = BARTNoising( vocab=[self.FAKE_VOCAB], mask_tok=self.MASK_TOK, mask_ratio=0.5, mask_length="word", is_joiner=True, replace_length=0, # 0 to drop them, 1 to replace them with MASK # insert_ratio=0.0, # random_ratio=0.0, # Defalt: full_stop_token=[".", "?", "!"] ) tokens = ["H■", "ell■", "o", "wor■", "ld", "."] # start token of word are identified using subword marker token_starts = [True, False, False, True, False, True] self.assertEqual(bart_noise._is_word_start(tokens), token_starts) # 1. replace_length 0: "words" are dropped masked = bart_noise.apply(copy.copy(tokens)) n_words = sum(token_starts) n_masked = math.ceil(n_words * bart_noise.mask_ratio) # print(f"word delete: {masked} / {tokens}") # self.assertEqual(len(masked), n_words - n_masked) # 2. replace_length 1: "words" are replaced with a single MASK bart_noise.replace_length = 1 masked = bart_noise.apply(copy.copy(tokens)) # print(f"whole word single mask: {masked} / {tokens}") # len(masked) depend on number of tokens in select word n_words = sum(token_starts) n_masked = math.ceil(n_words * bart_noise.mask_ratio) self.assertEqual( sum(1 if tok == self.MASK_TOK else 0 for tok in masked), n_masked) # 3. replace_length -1: all tokens in "words" are replaced with MASK bart_noise.replace_length = -1 masked = bart_noise.apply(copy.copy(tokens)) # print(f"whole word multi mask: {masked} / {tokens}") self.assertEqual(len(masked), len(tokens)) # length won't change n_words = sum(token_starts) n_masked = math.ceil(n_words * bart_noise.mask_ratio) # number of mask_tok depend on number of tokens in selected word # number of MASK_TOK can be greater than n_masked self.assertTrue( sum(1 if tok == self.MASK_TOK else 0 for tok in masked) > n_masked)
def test_token_mask(self): """Mask will be done on token level. Condition: * `mask_length` == subword; * or not specify subword marker (joiner/spacer) by `is_joiner`. """ bart_noise = BARTNoising( vocab=[self.FAKE_VOCAB], mask_tok=self.MASK_TOK, mask_ratio=0.5, mask_length="subword", replace_length=0, # 0 to drop them, 1 to replace them with MASK # insert_ratio=0.0, # random_ratio=0.0, # Defalt: full_stop_token=[".", "?", "!"] ) tokens = ["H■", "ell■", "o", "world", "."] # all token are considered as an individual word self.assertTrue(all(bart_noise._is_word_start(tokens))) n_tokens = len(tokens) # 1. tokens are dropped when replace_length is 0 masked = bart_noise.apply(tokens) n_masked = math.ceil(n_tokens * bart_noise.mask_ratio) # print(f"token delete: {masked} / {tokens}") self.assertEqual(len(masked), n_tokens - n_masked) # 2. tokens are replaced by MASK when replace_length is 1 bart_noise.replace_length = 1 masked = bart_noise.apply(tokens) n_masked = math.ceil(n_tokens * bart_noise.mask_ratio) # print(f"token mask: {masked} / {tokens}") self.assertEqual(len(masked), n_tokens) self.assertEqual( sum([1 if tok == self.MASK_TOK else 0 for tok in masked]), n_masked)
def setUp(self): BARTNoising.set_random_seed(1234) self.MASK_TOK = "[MASK]" self.FAKE_VOCAB = "[TESTING]"