def test_add_special_tokens(self): """ Add a list of special tokens to the dictionary. """ special_toks_lst = ['MY', 'NAME', 'IS', 'EMILY'] # create Dictionary Agent parser = ParlaiParser() parser.set_params( dict_tokenizer='bytelevelbpe', bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB, bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE, hf_skip_special_tokens=False, ) opt = parser.parse_args([]) agent = DictionaryAgent(opt) agent.add_additional_special_tokens(special_toks_lst) self.assertEqual(agent.additional_special_tokens, special_toks_lst) phrases = [ 'Hi what is up EMILY', 'What IS your NAME', 'That is MY dog' ] for phrase in phrases: vec = agent.txt2vec(phrase) text = agent.vec2txt(vec) self.assertEqual(phrase, text)
def _run_specialtok_test(self, **kwargs): for special_token in ['SPECIAL TOKENS', '[SPECIAL; TOKENS]']: with testing_utils.tempdir() as tmpdir: if 'dict_file' not in kwargs: kwargs['dict_file'] = os.path.join(tmpdir, 'dict') string = f"This is a test of {special_token}" parser = ParlaiParser(False, False) DictionaryAgent.add_cmdline_args(parser, partial_opt=None) opt = parser.parse_kwargs(**kwargs) da = DictionaryAgent(opt) before = da.tokenize(string) da.add_additional_special_tokens([special_token]) after = da.tokenize(string) assert before != after assert len(before) > len(after) assert after[-1] == special_token assert before[:5] == after[:5] if opt['dict_tokenizer'] in ( 'bytelevelbpe', 'gpt2', 'slow_bytelevel_bpe', ): # we need to let the dictionary handle the tokenid mappings assert da.vec2txt(da.txt2vec(string)) == string
def test_special_tokenization(self): from parlai.core.dict import DictionaryAgent from parlai.core.params import ParlaiParser from parlai.torchscript.modules import ScriptableDictionaryAgent SPECIAL = ['Q00', 'Q01'] text = "Don't have a Q00, man! Have a Q01 instead." parser = ParlaiParser(False, False) DictionaryAgent.add_cmdline_args(parser) with testing_utils.tempdir() as tmp: opt = parser.parse_kwargs( dict_tokenizer='gpt2', dict_file=os.path.join(tmp, 'dict') ) orig_dict = DictionaryAgent(opt) orig_bpe = orig_dict.bpe fused_key_bpe_ranks = { '\n'.join(key): float(val) for key, val in orig_bpe.bpe_ranks.items() } sda = ScriptableDictionaryAgent( null_token=orig_dict.null_token, end_token=orig_dict.end_token, unk_token=orig_dict.unk_token, start_token=orig_dict.start_token, freq=orig_dict.freq, tok2ind=orig_dict.tok2ind, ind2tok=orig_dict.ind2tok, bpe_add_prefix_space=False, bpe_encoder=orig_bpe.encoder, bpe_byte_encoder=orig_bpe.byte_encoder, fused_key_bpe_ranks=fused_key_bpe_ranks, special_tokens=[], ) tokenized = sda.txt2vec(text) assert len(tokenized) == 15 assert sda.vec2txt(tokenized) == text nice_tok = [sda.ind2tok[i] for i in tokenized] orig_dict = DictionaryAgent(opt) orig_dict.add_additional_special_tokens(SPECIAL) orig_bpe = orig_dict.bpe sda = ScriptableDictionaryAgent( null_token=orig_dict.null_token, end_token=orig_dict.end_token, unk_token=orig_dict.unk_token, start_token=orig_dict.start_token, freq=orig_dict.freq, tok2ind=orig_dict.tok2ind, ind2tok=orig_dict.ind2tok, bpe_add_prefix_space=False, bpe_encoder=orig_bpe.encoder, bpe_byte_encoder=orig_bpe.byte_encoder, fused_key_bpe_ranks=fused_key_bpe_ranks, special_tokens=SPECIAL, ) special_tokenized = sda.txt2vec(text) assert len(special_tokenized) == 15 assert sda.vec2txt(special_tokenized) == text assert special_tokenized != tokenized nice_specialtok = [sda.ind2tok[i] for i in special_tokenized]