示例#1
0
    def test_verify_data(self):
        parser = setup_args()
        opt = parser.parse_args(print_args=False)
        changed_files = testing_utils.git_changed_files()
        changed_task_files = []
        for file in changed_files:
            if ('parlai/tasks' in file and 'README' not in file
                    and 'task_list.py' not in file):
                changed_task_files.append(file)

        if not changed_task_files:
            return

        for file in changed_task_files:
            task = file.split('/')[-2]
            module_name = "%s.tasks.%s.agents" % ('parlai', task)
            task_module = importlib.import_module(module_name)
            subtasks = [
                ':'.join([task, x]) for x in dir(task_module)
                if ('teacher' in x.lower() and x not in BASE_TEACHERS)
            ]

            for subt in subtasks:
                opt['task'] = subt
                with testing_utils.capture_output() as _:
                    text, log = verify(opt, print_parser=False)
                for key in KEYS:
                    self.assertEqual(
                        log[key], 0,
                        'There are {} {} in this task.'.format(log[key], log))
示例#2
0
    def test_download_multiprocess_chunks(self):
        # Tests that the three finish downloading but may finish in any order
        urls = [
            'https://parl.ai/downloads/mnist/mnist.tar.gz',
            'https://parl.ai/downloads/mnist/mnist.tar.gz.BAD',
            'https://parl.ai/downloads/mnist/mnist.tar.gz.BAD',
        ]

        with testing_utils.capture_output() as stdout:
            download_results = build_data.download_multiprocess(
                urls,
                self.datapath,
                dest_filenames=self.dest_filenames,
                chunk_size=1)
        stdout = stdout.getvalue()

        output_filenames, output_statuses, output_errors = zip(
            *download_results)

        self.assertIn('mnist0.tar.gz', output_filenames,
                      f'missing file:\n{stdout}')
        self.assertIn('mnist1.tar.gz', output_filenames,
                      f'missing file:\n{stdout}')
        self.assertIn('mnist2.tar.gz', output_filenames,
                      f'missing file:\n{stdout}')
        self.assertIn(200, output_statuses,
                      f'unexpected error code:\n{stdout}')
        self.assertIn(403, output_statuses,
                      f'unexpected error code:\n{stdout}')
示例#3
0
 def verify_batch_lengths(defaults):
     with testing_utils.capture_output() as _, testing_utils.tempdir() as tmpdir:
         # Get processed act from agent
         parser = train_setup_args()
         defaults['model_file'] = os.path.join(tmpdir, 'model')
         defaults['dict_file'] = os.path.join(tmpdir, 'model.dict')
         parser.set_defaults(**defaults)
         opt = parser.parse_args()
         build_dict(opt)
         agent = create_agent(opt)
         world_data = create_task(opt, agent)
         batch_sort_acts = []
         # first epoch
         while len(batch_sort_acts) < 900/50:
             world_data.parley()
             batch_sort_acts.append(world_data.acts[0])
         teacher = world_data.world.get_agents()[0]
         teacher.reset_data()
         # second epoch
         while len(batch_sort_acts) < 1800/50:
             world_data.parley()
             batch_sort_acts.append(world_data.acts[0])
         world_data.shutdown()
     field = defaults['batch_sort_field']
     lengths = [[ep_length(b[field]) for b in bb if field in b]
                for bb in batch_sort_acts[:-2]]  # exclude last batch
     # verify batch lengths
     for batch_lens in lengths:
         self.assertLessEqual(max(batch_lens) - min(batch_lens), max_range,
                              'PytorchDataTeacher batching does not give '
                              'batches with similar sized examples, when '
                              'sorting by `{}` field.'.format(
                                 defaults['batch_sort_field']))
示例#4
0
 def test_shuffle(self):
     """Simple test to ensure that dataloader is initialized with correct
         data sampler
     """
     dts = ['train', 'valid', 'test']
     exts = ['', ':stream', ':ordered', ':stream:ordered']
     shuffle_opts = [False, True]
     task = 'babi:task1k:1'
     for dt in dts:
         for ext in exts:
             datatype = dt + ext
             for shuffle in shuffle_opts:
                 opt_defaults = {
                     'pytorch_teacher_task': task,
                     'datatype': datatype,
                     'shuffle': shuffle
                 }
                 with testing_utils.capture_output() as _:
                     parser = display_setup_args()
                     parser.set_defaults(**opt_defaults)
                     opt = parser.parse_args()
                     teacher = create_task_agent_from_taskname(opt)[0]
                     if ('ordered' in datatype or
                         ('stream' in datatype and not opt.get('shuffle'))
                             or 'train' not in datatype):
                         self.assertIsInstance(
                             teacher.pytorch_dataloader.sampler, Sequential,
                             'PytorchDataTeacher failed with args: {}'.
                             format(opt))
                     else:
                         self.assertIsInstance(
                             teacher.pytorch_dataloader.sampler,
                             RandomSampler,
                             'PytorchDataTeacher failed with args: {}'.
                             format(opt))
示例#5
0
 def setUpClass(cls):
     """Set up the test by downloading the model/data."""
     with testing_utils.capture_output():
         parser = display_data.setup_args()
         parser.set_defaults(**MODEL_OPTIONS)
         opt = parser.parse_args(print_args=False)
         opt['num_examples'] = 1
         display_data.display_data(opt)
示例#6
0
 def setUpClass(cls):
     # go ahead and download things here
     with testing_utils.capture_output():
         parser = display_data.setup_args()
         parser.set_defaults(**END2END_OPTIONS)
         opt = parser.parse_args(print_args=False)
         opt['num_examples'] = 1
         display_data.display_data(opt)
示例#7
0
    def test_kvmemnn_f1(self):
        import projects.convai2.baselines.kvmemnn.eval_f1 as eval_f1

        with testing_utils.capture_output() as stdout:
            report = eval_f1.main()
        self.assertAlmostEqual(report['f1'],
                               0.1173,
                               delta=0.0002,
                               msg=str(stdout))
示例#8
0
    def test_train_model_with_no_dict_file(self):
        """Ensure training a model requires a dict_file or model_file."""
        import parlai.scripts.train_model as tms

        with testing_utils.capture_output():
            parser = tms.setup_args()
            parser.set_params(task='babi:task1k:1', model='seq2seq')
            popt = parser.parse_args(print_args=False)
            with self.assertRaises(RuntimeError):
                tms.TrainLoop(popt)
示例#9
0
 def test_train_model_with_no_dict_file(self):
     """Check that attempting to train a model without specifying a dict_file
     or model_file fails
     """
     import parlai.scripts.train_model as tms
     with testing_utils.capture_output():
         parser = tms.setup_args()
         parser.set_params(task='babi:task1k:1', model='seq2seq')
         popt = parser.parse_args(print_args=False)
         with self.assertRaises(RuntimeError):
             tms.TrainLoop(popt)
示例#10
0
    def test_verify_data(self):
        parser = setup_args()
        opt = parser.parse_args(print_args=False)
        changed_files = testing_utils.git_changed_files()
        changed_task_files = []
        for file in changed_files:
            if ('parlai/tasks' in file and 'README' not in file
                    and 'task_list.py' not in file):
                changed_task_files.append(file)

        if not changed_task_files:
            return

        found_errors = False
        for file in changed_task_files:
            task = file.split('/')[-2]
            module_name = "%s.tasks.%s.agents" % ('parlai', task)
            task_module = importlib.import_module(module_name)
            subtasks = [
                ':'.join([task, x]) for x in dir(task_module)
                if ('teacher' in x.lower() and x not in BASE_TEACHERS)
            ]

            if testing_utils.is_this_circleci():
                if len(subtasks) == 0:
                    continue

                self.fail(
                    'test_verify_data plays poorly with CircleCI. Please run '
                    '`python tests/datatests/test_new_tasks.py` locally and '
                    'paste the output in your pull request.')

            for subt in subtasks:
                parser = setup_args()
                opt = parser.parse_args(args=['--task', subt],
                                        print_args=False)
                opt['task'] = subt
                try:
                    with testing_utils.capture_output():
                        text, log = verify(opt, print_parser=False)
                except Exception:
                    found_errors = True
                    traceback.print_exc()
                    print("Got above exception in {}".format(subt))
                for key in KEYS:
                    if log[key] != 0:
                        print('There are {} {} in {}.'.format(
                            log[key],
                            key,
                            subt,
                        ))
                        found_errors = True

        self.assertFalse(found_errors, "Errors were found.")
示例#11
0
 def get_teacher_act(defaults, teacher_processed=False, agent_to=None):
     parser = train_setup_args()
     parser.set_defaults(**defaults)
     opt = parser.parse_args()
     build_dict(opt)
     with testing_utils.capture_output() as _:
         teacher = create_task_agent_from_taskname(opt)[0]
     agent = create_agent(opt)
     act = teacher.act()
     if teacher_processed:
         return act, agent
     return agent.observe(act), agent
示例#12
0
    def test_resume_checkpoint(self):
        """Make sure when resuming training that model uses appropriate mf.

        Copy train_model from testing_utils to directly access agent.
        """
        import parlai.scripts.train_model as tms

        def get_popt_and_tl(opt):
            parser = tms.setup_args()
            parser.set_params(**opt)
            popt = parser.parse_args(print_args=False)
            for k, v in opt.items():
                popt[k] = v
            return popt, tms.TrainLoop(popt)

        def get_opt(init_mf, mf):
            return {
                'task': 'integration_tests',
                'init_model': init_mf,
                'model':
                'parlai.agents.test_agents.dummy_torch_agent:MockTorchAgent',
                'model_file': mf,
                'num_epochs': 3,
                'validation_every_n_epochs': 1,
                'save_after_valid': True,
                'log_every_n_secs': 10,
            }

        with capture_output():
            with tempdir() as tmpdir:
                # First train model with init_model path set
                mf = os.path.join(tmpdir, 'model')
                init_mf = os.path.join(tmpdir, 'init_model')
                with open(init_mf, 'w') as f:
                    f.write(' ')
                opt = get_opt(init_mf, mf)
                popt, tl = get_popt_and_tl(opt)
                agent = tl.agent
                # init model file should be set appropriately
                init_model_file, is_finetune = agent._get_init_model(
                    popt, None)
                self.assertEqual(init_model_file, init_mf)
                self.assertTrue(is_finetune)
                valid, test = tl.train()
                # now, train the model for another epoch
                opt = get_opt('{}.checkpoint'.format(mf), mf)
                opt['load_from_checkpoint'] = True
                popt, tl = get_popt_and_tl(opt)
                agent = tl.agent
                init_model_file, is_finetune = agent._get_init_model(
                    popt, None)
                self.assertEqual(init_model_file, '{}.checkpoint'.format(mf))
                self.assertFalse(is_finetune)
 def run_display_test(defaults, ep_and_ex_counts):
     with testing_utils.capture_output() as f:
         parser = display_setup_args()
         parser.set_defaults(**defaults)
         opt = parser.parse_args()
         display_data(opt)
     str_output = f.getvalue()
     self.assertTrue(
         '[ loaded {} episodes with a total of {} examples ]'.format(
             ep_and_ex_counts[0], ep_and_ex_counts[1]) in str_output,
         'PytorchDataTeacher multitasking failed with '
         'following args: {}'.format(opt))
    def test_output(self):
        """Does display_data reach the end of the loop?"""
        with testing_utils.capture_output() as stdout:
            parser = ParlaiParser()
            opt = parser.parse_args(['--task', 'babi:task1k:1'],
                                    print_args=False)
            opt['num_examples'] = 1
            display_data(opt)

        str_output = stdout.getvalue()
        self.assertGreater(len(str_output), 0, "Output is empty")
        self.assertIn("[babi:task1k:1]:", str_output,
                      "Babi task did not print")
        self.assertIn("~~", str_output, "Example output did not complete")
    def test_pyt_preprocess(self):
        """
        Test that the preprocess functionality works with the PytorchDataTeacher
        with a sample TorchAgent (here, the Seq2seq model).

        This tests whether the action provided by the preprocessed teacher
        is equivalent to the agent's observation after the agent processes it.
        """
        def get_teacher_act(defaults, teacher_processed=False, agent_to=None):
            parser = train_setup_args()
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            build_dict(opt)
            with testing_utils.capture_output() as _:
                teacher = create_task_agent_from_taskname(opt)[0]
            agent = create_agent(opt)
            act = teacher.act()
            if teacher_processed:
                return act, agent
            return agent.observe(act), agent

        with testing_utils.capture_output() as _, testing_utils.tempdir(
        ) as tmpdir:
            defaults = unit_test_parser_defaults.copy()
            defaults['batch_size'] = 1
            defaults['datatype'] = 'train:stream:ordered'

            # Get processed act from agent
            defaults['model_file'] = os.path.join(tmpdir, 'model')
            defaults['dict_file'] = os.path.join(tmpdir, 'model.dict')
            agent_processed_observation, agent1 = get_teacher_act(defaults)

            # Get preprocessed act from teacher
            defaults['model_file'] = os.path.join(tmpdir, 'model')
            defaults['dict_file'] = os.path.join(tmpdir, 'model.dict')
            defaults['pytorch_preprocess'] = True
            teacher_processed_act, agent2 = get_teacher_act(
                defaults, teacher_processed=True)  # noqa: E501

        for key in agent_processed_observation:
            val1 = agent_processed_observation[key]
            val2 = teacher_processed_act[key]
            if isinstance(val1, torch.Tensor):
                self.assertTrue(
                    bool(torch.all(torch.eq(val1, val2))),
                    '{} is not equal to {}'.format(val1, val2),
                )
            else:
                self.assertEqual(val1, val2)
示例#16
0
def get_agent(**kwargs):
    r"""
    Return opt-initialized agent.

    :param kwargs: any kwargs you want to set using parser.set_params(\*\*kwargs)
    """
    if 'no_cuda' not in kwargs:
        kwargs['no_cuda'] = True
    from parlai.core.params import ParlaiParser

    parser = ParlaiParser()
    MockTorchAgent.add_cmdline_args(parser)
    parser.set_params(**kwargs)
    opt = parser.parse_args(print_args=False)
    with testing_utils.capture_output():
        return MockTorchAgent(opt)
 def test_train_model(self):
     """
     Check the training script doesn't crash.
     """
     import projects.controllable_dialogue.train_controllable_seq2seq as tcs2s
     parser = tcs2s.setup_args()
     # make it much smaller just for testing
     parser.set_params(
         max_train_time=120,
         validation_max_exs=128,
         batchsize=16,
         truncate=32,
         short_final_eval=True,
     )
     with testing_utils.capture_output():
         opt = parser.parse_args()
         tcs2s.TrainLoop(opt).train()
示例#18
0
    def test_upgrade_opt(self):
        """Test whether upgrade_opt works."""
        with testing_utils.tempdir() as tmp:
            with testing_utils.capture_output() as _:
                modfn = os.path.join(tmp, 'model')
                with open(modfn, 'w') as f:
                    f.write('Test.')
                optfn = modfn + '.opt'
                base_opt = {
                    'model': 'tests.test_params:_ExampleUpgradeOptAgent',
                    'dict_file': modfn + '.dict',
                    'model_file': modfn,
                }
                with open(optfn, 'w') as f:
                    json.dump(base_opt, f)

                pp = ParlaiParser(True, True)
                opt = pp.parse_args(['--model-file', modfn])
                agents.create_agent(opt)
示例#19
0
    def test_verify_data(self):
        parser = setup_args()
        opt = parser.parse_args(print_args=False)
        changed_files = testing_utils.git_changed_files()
        changed_task_files = []
        for file in changed_files:
            if ('parlai/tasks' in file and 'README' not in file
                    and 'task_list.py' not in file):
                changed_task_files.append(file)

        if not changed_task_files:
            return

        for file in changed_task_files:
            task = file.split('/')[-2]
            module_name = "%s.tasks.%s.agents" % ('parlai', task)
            task_module = importlib.import_module(module_name)
            subtasks = [
                ':'.join([task, x]) for x in dir(task_module)
                if ('teacher' in x.lower() and x not in BASE_TEACHERS)
            ]

            if testing_utils.is_this_circleci():
                if len(subtasks) == 0:
                    continue

                self.fail(
                    'test_verify_data plays poorly with CircleCI. Please run '
                    '`python tests/data/test_new_tasks.py` locally and paste the '
                    'output in your pull request.')

            for subt in subtasks:
                parser = setup_args()
                opt = parser.parse_args(args=['--task', subt],
                                        print_args=False)
                opt['task'] = subt
                with testing_utils.capture_output():
                    text, log = verify(opt, print_parser=False)
                for key in KEYS:
                    self.assertEqual(
                        log[key], 0,
                        'There are {} {} in this task.'.format(log[key], log))
    def test_pytd_teacher(self):
        """
        Test that the pytorch teacher works with given Pytorch Datasets
        as well

        I'll be using the Flickr30k dataset to ensure that the observations
        are the same.
        """
        defaults = parser_defaults.copy()
        defaults['datatype'] = 'train:stream'
        defaults['image_mode'] = 'ascii'

        with testing_utils.capture_output():
            # Get processed act from agent
            parser = display_setup_args()
            defaults['pytorch_teacher_dataset'] = 'flickr30k'
            del defaults['pytorch_teacher_task']
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            teacher = create_task_agent_from_taskname(opt)[0]
            pytorch_teacher_act = teacher.act()

            parser = display_setup_args()
            defaults['task'] = 'flickr30k'
            del defaults['pytorch_teacher_dataset']
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            teacher = create_task_agent_from_taskname(opt)[0]
            regular_teacher_act = teacher.act()

        keys = set(pytorch_teacher_act.keys()).intersection(
            set(regular_teacher_act.keys()))
        self.assertTrue(len(keys) != 0)
        for key in keys:
            self.assertTrue(
                pytorch_teacher_act[key] == regular_teacher_act[key],
                'PytorchDataTeacher does not have the same value '
                'as regular teacher for act key: {}'.format(key))
    def _distributed_train_model(self, opt):
        # we have to delay our import to here, because the set_spawn_method call
        # inside multiprocessing_train will break the multithreading tests, even
        # when we skip the test.
        import parlai.scripts.multiprocessing_train as mp_train

        with testing_utils.capture_output() as output:
            with testing_utils.tempdir() as tmpdir:
                if 'model_file' not in opt:
                    opt['model_file'] = os.path.join(tmpdir, 'model')
                if 'dict_file' not in opt:
                    opt['dict_file'] = os.path.join(tmpdir, 'model.dict')

                parser = mp_train.setup_args()
                popt = _forced_parse(parser, opt)

                # we need a prebuilt dictionary
                parser = build_dict.setup_args()
                build_dict.build_dict(popt)

                valid, test = mp_train.launch_and_train(popt, 31337)

        return (output.getvalue(), valid, test)
示例#22
0
    def test_download_multiprocess(self):
        urls = [
            'https://parl.ai/downloads/mnist/mnist.tar.gz',
            'https://parl.ai/downloads/mnist/mnist.tar.gz.BAD',
            'https://parl.ai/downloads/mnist/mnist.tar.gz.BAD',
        ]

        with testing_utils.capture_output() as stdout:
            download_results = build_data.download_multiprocess(
                urls, self.datapath, dest_filenames=self.dest_filenames)
        stdout = stdout.getvalue()

        output_filenames, output_statuses, output_errors = zip(
            *download_results)
        self.assertEqual(
            output_filenames,
            self.dest_filenames,
            f'output filenames not correct\n{stdout}',
        )
        self.assertEqual(
            output_statuses,
            (200, 403, 403),
            f'output http statuses not correct\n{stdout}',
        )
示例#23
0
    def test_valid_pyt_batchsort(self):
        """
        Tests that batchsort *works* for two epochs; that is, that
        every example is seen both epochs
        """
        parser = train_setup_args()

        def get_acts_epochs_1_and_2(defaults):
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            build_dict(opt)
            agent = create_agent(opt)
            world_data = create_task(opt, agent)
            acts_epoch_1 = []
            acts_epoch_2 = []
            while not world_data.epoch_done():
                world_data.parley()
                acts_epoch_1.append(world_data.acts[0])
            world_data.reset()
            while not world_data.epoch_done():
                world_data.parley()
                acts_epoch_2.append(world_data.acts[0])
            acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b]
            acts_epoch_1 = sorted([b for b in acts_epoch_1 if 'text' in b],
                                  key=lambda x: x.get('text'))
            acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b]
            acts_epoch_2 = sorted([b for b in acts_epoch_2 if 'text' in b],
                                  key=lambda x: x.get('text'))
            world_data.shutdown()
            return acts_epoch_1, acts_epoch_2

        def check_equal_act_lists(acts1, acts2):
            for idx in range(len(acts1)):
                act1 = acts1[idx]
                act2 = acts2[idx]
                for key in act1:
                    val1 = act1[key]
                    val2 = act2[key]
                    if type(val1) is torch.Tensor:
                        self.assertTrue(bool(torch.all(torch.eq(val1, val2))))
                    else:
                        self.assertEqual(val1, val2)
        # First, check that batchsort itself works
        defaults = unit_test_parser_defaults.copy()
        defaults['datatype'] = 'train:stream:ordered'
        defaults['pytorch_teacher_batch_sort'] = True

        with testing_utils.capture_output() as _, testing_utils.tempdir() as tmpdir:
            # Get processed act from agent
            defaults['pytorch_teacher_task'] = 'babi:task1k:1'
            defaults['batch_sort_cache_type'] = 'index'
            defaults['batchsize'] = 50
            defaults['model_file'] = os.path.join(tmpdir, 'model')
            defaults['dict_file'] = os.path.join(tmpdir, 'model.dict')
            bsrt_acts_ep1, bsrt_acts_ep2 = get_acts_epochs_1_and_2(defaults)

            defaults['pytorch_teacher_batch_sort'] = False
            defaults['model_file'] = os.path.join(tmpdir, 'model')
            defaults['dict_file'] = os.path.join(tmpdir, 'model.dict')
            no_bsrt_acts_ep1, no_bsrt_acts_ep2 = get_acts_epochs_1_and_2(defaults)

        check_equal_act_lists(bsrt_acts_ep1, no_bsrt_acts_ep1)
        check_equal_act_lists(bsrt_acts_ep2, no_bsrt_acts_ep2)
示例#24
0
    def test_languagemodel_f1(self):
        import projects.convai2.baselines.language_model.eval_f1 as eval_f1

        with testing_utils.capture_output() as stdout:
            report = eval_f1.main()
        self.assertEqual(report['f1'], 0.1531, str(stdout))
示例#25
0
    def test_kvmemnn_hits1(self):
        import projects.convai2.baselines.kvmemnn.eval_hits as eval_hits

        with testing_utils.capture_output() as stdout:
            report = eval_hits.main()
        self.assertEqual(report['hits@1'], 0.5510, str(stdout))
示例#26
0
    def test_seq2seq_f1(self):
        import projects.convai2.baselines.seq2seq.eval_f1 as eval_f1

        with testing_utils.capture_output() as stdout:
            report = eval_f1.main()
        self.assertEqual(report['f1'], 0.1682, str(stdout))
    def test_seq2seq_hits1(self):
        import projects.convai2.baselines.seq2seq.eval_hits as eval_hits

        with testing_utils.capture_output() as stdout:
            report = eval_hits.main()
        self.assertEqual(report['hits@1'], .1250, str(stdout))