Esempio n. 1
0
 def _run_through(self, task, mutators):
     pp = ParlaiParser(True, False)
     opt = pp.parse_kwargs(task=task, mutators=mutators)
     teacher = create_task_agent_from_taskname(opt)[0]
     outputs = []
     for _ in range(5):
         outputs.append(teacher.act())
     return outputs
Esempio n. 2
0
    def test_parse_kwargs(self):
        parser = ParlaiParser(True, True)

        # implied args from the model
        opt = parser.parse_kwargs(model='transformer/generator',
                                  relu_dropout=0.3)
        assert opt['relu_dropout'] == 0.3
        assert opt['model'] == 'transformer/generator'
        assert 'n_heads' in opt

        # bad types
        with self.assertRaises(ValueError):
            parser = ParlaiParser(True, True)
            parser.parse_kwargs(model='transformer/generator',
                                relu_dropout='foo')

        # nonexistant args without model
        with self.assertRaises(KeyError):
            parser = ParlaiParser(True, True)
            parser.parse_kwargs(fake_arg='foo')

        # nonexistant args with model
        with self.assertRaises(KeyError):
            parser = ParlaiParser(True, True)
            parser.parse_kwargs(model='transformer/generator', fake_arg='foo')
Esempio n. 3
0
    def test_parse_kwargs_nargsplus(self):
        """
        Test parse_kwargs when provided an argument with >1 item
        """
        parser = ParlaiParser(False, False)
        parser.add_argument('--example', nargs='+', choices=['a', 'b', 'c'])
        opt = parser.parse_args(['--example', 'a', 'b'])
        assert opt['example'] == ['a', 'b']

        parser = ParlaiParser(False, False)
        parser.add_argument('--example', nargs='+', choices=['a', 'b', 'c'])
        opt = parser.parse_kwargs(example=['a', 'b'])
        assert opt['example'] == ['a', 'b']

        parser = ParlaiParser(False, False)
        parser.add_argument('--example', nargs='+')
        opt = parser.parse_kwargs(example=['x', 'y'])
        assert opt['example'] == ['x', 'y']
Esempio n. 4
0
    def test_not_sticky(self):
        pp = ParlaiParser(True, False)
        opt = pp.parse_kwargs(
            task='integration_tests:multiturn',
            mutators='flatten',
            datatype='train:ordered',
        )
        teacher = create_task_agent_from_taskname(opt)[0]
        first_epoch = []
        second_epoch = []
        for _ in range(teacher.num_examples()):
            first_epoch.append(teacher.act())
        teacher.reset()
        for _ in range(teacher.num_examples()):
            second_epoch.append(teacher.act())

        assert all(f == s for f, s in zip(first_epoch, second_epoch))
Esempio n. 5
0
    def test_custom_special_tokens(self):
        from parlai.agents.hugging_face.dict import Gpt2DictionaryAgent
        from parlai.core.params import ParlaiParser

        parser = ParlaiParser(False, False)
        parser.set_defaults(gpt2_size="small", add_special_tokens=True)
        Gpt2DictionaryAgent.add_cmdline_args(parser, partial_opt=None)
        with testing_utils.tempdir() as tmpdir:
            opt = parser.parse_kwargs(dict_file=os.path.join(tmpdir, 'dict'))
            dict_agent = Gpt2DictionaryAgent(opt)
            oldtokens = dict_agent.txt2vec("Hi VOLDEMORT")
            prevlen = len(dict_agent)
            dict_agent.add_additional_special_tokens(["VOLDEMORT"])
            newlen = len(dict_agent)
            assert newlen == prevlen + 1
            tokens = dict_agent.txt2vec("Hi VOLDEMORT")
            assert tokens != oldtokens
            assert len(tokens) < len(oldtokens)
Esempio n. 6
0
    def test_dataset_integrity(self):
        """
        Check the controllable dialogue data loads.
        """
        pp = ParlaiParser(True, False)
        opt = pp.parse_kwargs(task='self_feeding:all', datatype='train:ordered')
        teacher = AllTeacher(opt)
        assert teacher.num_examples() == 193777
        assert teacher.num_episodes() == 193777
        assert 'some cheetah chasing to stay in shape' in teacher.act()['text']

        opt['datatype'] = 'valid'
        teacher = AllTeacher(opt)
        assert teacher.num_examples() == 3500

        opt['datatype'] = 'test'
        teacher = AllTeacher(opt)
        assert teacher.num_examples() == 7801
Esempio n. 7
0
 def _test_bpe_dropout(self, **dict_args):
     pp = ParlaiParser(False, False)
     DictionaryAgent.add_cmdline_args(pp, partial_opt=None)
     opt = pp.parse_kwargs(bpe_dropout=0.5, **dict_args)
     da = DictionaryAgent(opt)
     da.set_tokenization_mode(TokenizationMode.TEST_TIME_TEXT)
     s = (
         "Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
         "Donec vitae metus sollicitudin, ullamcorper tortor ut, rhoncus lacus. "
         "Praesent sollicitudin commodo turpis, ut pharetra tortor gravida nec."
     )
     no_dropout = da.txt2vec(s)
     da.set_tokenization_mode(TokenizationMode.TRAIN_TIME_TEXT)
     not_the_same = 0
     for _ in range(30):
         r = da.txt2vec(s)
         assert da.vec2txt(r) == s
         if r != no_dropout:
             not_the_same += 1
     assert not_the_same > 0
Esempio n. 8
0
 def _run_specialtok_test(self, **kwargs):
     for special_token in ['SPECIAL TOKENS', '[SPECIAL; TOKENS]']:
         with testing_utils.tempdir() as tmpdir:
             if 'dict_file' not in kwargs:
                 kwargs['dict_file'] = os.path.join(tmpdir, 'dict')
             string = f"This is a test of {special_token}"
             parser = ParlaiParser(False, False)
             DictionaryAgent.add_cmdline_args(parser, partial_opt=None)
             opt = parser.parse_kwargs(**kwargs)
             da = DictionaryAgent(opt)
             before = da.tokenize(string)
             da.add_additional_special_tokens([special_token])
             after = da.tokenize(string)
             assert before != after
             assert len(before) > len(after)
             assert after[-1] == special_token
             assert before[:5] == after[:5]
             if opt['dict_tokenizer'] in (
                     'bytelevelbpe',
                     'gpt2',
                     'slow_bytelevel_bpe',
             ):
                 # we need to let the dictionary handle the tokenid mappings
                 assert da.vec2txt(da.txt2vec(string)) == string
Esempio n. 9
0
    def test_parse_kwargs_multirounds(self):
        """Test parse_kwargs when we have options that depend on options."""
        parser = ParlaiParser(True, False)
        opt = parser.parse_kwargs(task='integration_tests',
                                  mutators='episode_shuffle',
                                  preserve_context=True)
        assert opt['preserve_context'] is True
        opt = parser.parse_kwargs(task='integration_tests',
                                  mutators='episode_shuffle',
                                  preserve_context=False)
        assert opt['preserve_context'] is False

        with self.assertRaises(KeyError):
            parser.parse_kwargs(task='integration_tests',
                                mutators='episode_shuffle',
                                fake_option=False)

        with self.assertRaises(KeyError):
            parser.parse_kwargs(task='integration_tests', fake_option=False)
Esempio n. 10
0
    def test_special_tokenization(self):
        from parlai.core.dict import DictionaryAgent
        from parlai.core.params import ParlaiParser
        from parlai.torchscript.modules import ScriptableDictionaryAgent

        SPECIAL = ['Q00', 'Q01']
        text = "Don't have a Q00, man! Have a Q01 instead."

        parser = ParlaiParser(False, False)
        DictionaryAgent.add_cmdline_args(parser)
        with testing_utils.tempdir() as tmp:
            opt = parser.parse_kwargs(
                dict_tokenizer='gpt2', dict_file=os.path.join(tmp, 'dict')
            )

            orig_dict = DictionaryAgent(opt)

            orig_bpe = orig_dict.bpe
            fused_key_bpe_ranks = {
                '\n'.join(key): float(val) for key, val in orig_bpe.bpe_ranks.items()
            }

            sda = ScriptableDictionaryAgent(
                null_token=orig_dict.null_token,
                end_token=orig_dict.end_token,
                unk_token=orig_dict.unk_token,
                start_token=orig_dict.start_token,
                freq=orig_dict.freq,
                tok2ind=orig_dict.tok2ind,
                ind2tok=orig_dict.ind2tok,
                bpe_add_prefix_space=False,
                bpe_encoder=orig_bpe.encoder,
                bpe_byte_encoder=orig_bpe.byte_encoder,
                fused_key_bpe_ranks=fused_key_bpe_ranks,
                special_tokens=[],
            )

            tokenized = sda.txt2vec(text)
            assert len(tokenized) == 15
            assert sda.vec2txt(tokenized) == text
            nice_tok = [sda.ind2tok[i] for i in tokenized]

            orig_dict = DictionaryAgent(opt)
            orig_dict.add_additional_special_tokens(SPECIAL)
            orig_bpe = orig_dict.bpe
            sda = ScriptableDictionaryAgent(
                null_token=orig_dict.null_token,
                end_token=orig_dict.end_token,
                unk_token=orig_dict.unk_token,
                start_token=orig_dict.start_token,
                freq=orig_dict.freq,
                tok2ind=orig_dict.tok2ind,
                ind2tok=orig_dict.ind2tok,
                bpe_add_prefix_space=False,
                bpe_encoder=orig_bpe.encoder,
                bpe_byte_encoder=orig_bpe.byte_encoder,
                fused_key_bpe_ranks=fused_key_bpe_ranks,
                special_tokens=SPECIAL,
            )

            special_tokenized = sda.txt2vec(text)
            assert len(special_tokenized) == 15
            assert sda.vec2txt(special_tokenized) == text
            assert special_tokenized != tokenized
            nice_specialtok = [sda.ind2tok[i] for i in special_tokenized]