コード例 #1
0
ファイル: test_dict.py プロジェクト: Taekyung2/MichinAI
    def test_add_special_tokens(self):
        """
        Add a list of special tokens to the dictionary.
        """
        special_toks_lst = ['MY', 'NAME', 'IS', 'EMILY']
        # create Dictionary Agent
        parser = ParlaiParser()
        parser.set_params(
            dict_tokenizer='bytelevelbpe',
            bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
            bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
            hf_skip_special_tokens=False,
        )
        opt = parser.parse_args([])

        agent = DictionaryAgent(opt)
        agent.add_additional_special_tokens(special_toks_lst)

        self.assertEqual(agent.additional_special_tokens, special_toks_lst)
        phrases = [
            'Hi what is up EMILY', 'What IS your NAME', 'That is MY dog'
        ]
        for phrase in phrases:
            vec = agent.txt2vec(phrase)
            text = agent.vec2txt(vec)
            self.assertEqual(phrase, text)
コード例 #2
0
ファイル: test_dict.py プロジェクト: simplecoka/cortx
 def _run_specialtok_test(self, **kwargs):
     for special_token in ['SPECIAL TOKENS', '[SPECIAL; TOKENS]']:
         with testing_utils.tempdir() as tmpdir:
             if 'dict_file' not in kwargs:
                 kwargs['dict_file'] = os.path.join(tmpdir, 'dict')
             string = f"This is a test of {special_token}"
             parser = ParlaiParser(False, False)
             DictionaryAgent.add_cmdline_args(parser, partial_opt=None)
             opt = parser.parse_kwargs(**kwargs)
             da = DictionaryAgent(opt)
             before = da.tokenize(string)
             da.add_additional_special_tokens([special_token])
             after = da.tokenize(string)
             assert before != after
             assert len(before) > len(after)
             assert after[-1] == special_token
             assert before[:5] == after[:5]
             if opt['dict_tokenizer'] in (
                     'bytelevelbpe',
                     'gpt2',
                     'slow_bytelevel_bpe',
             ):
                 # we need to let the dictionary handle the tokenid mappings
                 assert da.vec2txt(da.txt2vec(string)) == string
コード例 #3
0
    def test_special_tokenization(self):
        from parlai.core.dict import DictionaryAgent
        from parlai.core.params import ParlaiParser
        from parlai.torchscript.modules import ScriptableDictionaryAgent

        SPECIAL = ['Q00', 'Q01']
        text = "Don't have a Q00, man! Have a Q01 instead."

        parser = ParlaiParser(False, False)
        DictionaryAgent.add_cmdline_args(parser)
        with testing_utils.tempdir() as tmp:
            opt = parser.parse_kwargs(
                dict_tokenizer='gpt2', dict_file=os.path.join(tmp, 'dict')
            )

            orig_dict = DictionaryAgent(opt)

            orig_bpe = orig_dict.bpe
            fused_key_bpe_ranks = {
                '\n'.join(key): float(val) for key, val in orig_bpe.bpe_ranks.items()
            }

            sda = ScriptableDictionaryAgent(
                null_token=orig_dict.null_token,
                end_token=orig_dict.end_token,
                unk_token=orig_dict.unk_token,
                start_token=orig_dict.start_token,
                freq=orig_dict.freq,
                tok2ind=orig_dict.tok2ind,
                ind2tok=orig_dict.ind2tok,
                bpe_add_prefix_space=False,
                bpe_encoder=orig_bpe.encoder,
                bpe_byte_encoder=orig_bpe.byte_encoder,
                fused_key_bpe_ranks=fused_key_bpe_ranks,
                special_tokens=[],
            )

            tokenized = sda.txt2vec(text)
            assert len(tokenized) == 15
            assert sda.vec2txt(tokenized) == text
            nice_tok = [sda.ind2tok[i] for i in tokenized]

            orig_dict = DictionaryAgent(opt)
            orig_dict.add_additional_special_tokens(SPECIAL)
            orig_bpe = orig_dict.bpe
            sda = ScriptableDictionaryAgent(
                null_token=orig_dict.null_token,
                end_token=orig_dict.end_token,
                unk_token=orig_dict.unk_token,
                start_token=orig_dict.start_token,
                freq=orig_dict.freq,
                tok2ind=orig_dict.tok2ind,
                ind2tok=orig_dict.ind2tok,
                bpe_add_prefix_space=False,
                bpe_encoder=orig_bpe.encoder,
                bpe_byte_encoder=orig_bpe.byte_encoder,
                fused_key_bpe_ranks=fused_key_bpe_ranks,
                special_tokens=SPECIAL,
            )

            special_tokenized = sda.txt2vec(text)
            assert len(special_tokenized) == 15
            assert sda.vec2txt(special_tokenized) == text
            assert special_tokenized != tokenized
            nice_specialtok = [sda.ind2tok[i] for i in special_tokenized]