Esempio n. 1
0
 def verify_batch_lengths(defaults):
     with testing_utils.capture_output() as _, testing_utils.tempdir() as tmpdir:
         # Get processed act from agent
         parser = train_setup_args()
         defaults['model_file'] = os.path.join(tmpdir, 'model')
         defaults['dict_file'] = os.path.join(tmpdir, 'model.dict')
         parser.set_defaults(**defaults)
         opt = parser.parse_args()
         build_dict(opt)
         agent = create_agent(opt)
         world_data = create_task(opt, agent)
         batch_sort_acts = []
         # first epoch
         while len(batch_sort_acts) < 900/50:
             world_data.parley()
             batch_sort_acts.append(world_data.acts[0])
         teacher = world_data.world.get_agents()[0]
         teacher.reset_data()
         # second epoch
         while len(batch_sort_acts) < 1800/50:
             world_data.parley()
             batch_sort_acts.append(world_data.acts[0])
         world_data.shutdown()
     field = defaults['batch_sort_field']
     lengths = [[ep_length(b[field]) for b in bb if field in b]
                for bb in batch_sort_acts[:-2]]  # exclude last batch
     # verify batch lengths
     for batch_lens in lengths:
         self.assertLessEqual(max(batch_lens) - min(batch_lens), max_range,
                              'PytorchDataTeacher batching does not give '
                              'batches with similar sized examples, when '
                              'sorting by `{}` field.'.format(
                                 defaults['batch_sort_field']))
Esempio n. 2
0
    def test_pyt_preprocess_train(self):
        """
            Test that the preprocess functionality works with the PytorchDataTeacher
            with a sample TorchAgent (here, the Seq2seq model).

            This tests whether an agent can train to completion with
            these preprocessed examples
        """

        # Second, check that the model will train
        print('Testing test_pyt_preprocess training')
        f = io.StringIO()
        with redirect_stdout(f):
            parser = train_setup_args()
            defaults = parser_defaults.copy()
            set_model_file(defaults)
            defaults['datatype'] = 'train'
            defaults['pytorch_preprocess'] = True
            parser.set_defaults(**defaults)
            TrainLoop(parser.parse_args()).train()

        str_output = f.getvalue()
        self.assertTrue(
            solved_task(str_output),
            'Teacher could not teach seq2seq with preprocessed obs, '
            'output: {}'.format(str_output))
        print('\n------Passed `test_pyt_preprocess_train`------\n')
Esempio n. 3
0
    def test_pyt_train(self):
        """
            Integration test: ensure that pytorch data teacher can successfully
            teach Seq2Seq model to fully solve the babi:task10k:1 task.

            The Seq2Seq model can solve the babi:task10k:1 task with the normal
            ParlAI setup, and thus should be able to with a PytorchDataTeacher

            This tests the following setups:
                1. -dt train
                2. -dt train:stream
                3. -dt train:stream:ordered
        """
        dts = ['train', 'train:stream', 'train:stream:ordered']
        for dt in dts:
            print('Testing test_pyt_train with dt: {}'.format(dt))
            f = io.StringIO()
            with redirect_stdout(f):
                parser = train_setup_args()
                defaults = parser_defaults.copy()
                set_model_file(defaults)
                defaults['datatype'] = dt
                defaults['shuffle'] = True  # for train:stream
                parser.set_defaults(**defaults)
                TrainLoop(parser.parse_args()).train()
            str_output = f.getvalue()
            self.assertTrue(
                solved_task(str_output),
                'Teacher could not teach seq2seq with args: '
                '{}; here is str_output: {}'.format(defaults, str_output))
        print('\n------Passed `test_pyt_train`------\n')
        def verify_batch_lengths(defaults):
            f = io.StringIO()

            with redirect_stdout(f):
                # Get processed act from agent
                parser = train_setup_args()
                set_model_file(defaults)
                parser.set_defaults(**defaults)
                opt = parser.parse_args()
                build_dict(opt)
                agent = create_agent(opt)
                world_data = create_task(opt, agent)
                batch_sort_acts = []
                # first epoch
                while len(batch_sort_acts) < 900 / 50:
                    world_data.parley()
                    batch_sort_acts.append(world_data.acts[0])
                teacher = world_data.world.get_agents()[0]
                teacher.reset_data()
                # second epoch
                while len(batch_sort_acts) < 1800 / 50:
                    world_data.parley()
                    batch_sort_acts.append(world_data.acts[0])
            field = defaults['batch_sort_field']
            lengths = [[ep_length(b[field]) for b in bb if field in b]
                       for bb in batch_sort_acts[:-2]]  # exclude last batch
            # verify batch lengths
            for batch_lens in lengths:
                self.assertLessEqual(
                    max(batch_lens) - min(batch_lens), max_range,
                    'PytorchDataTeacher batching does not give '
                    'batches with similar sized examples, when '
                    'sorting by `{}` field.'.format(
                        defaults['batch_sort_field']))
 def get_teacher_act(defaults, teacher_processed=False, agent_to=None):
     parser = train_setup_args()
     parser.set_defaults(**defaults)
     opt = parser.parse_args([])
     build_dict(opt)
     teacher = create_task_agent_from_taskname(opt)[0]
     agent = create_agent(opt)
     act = teacher.act()
     if teacher_processed:
         return act, agent
     return agent.observe(act), agent
Esempio n. 6
0
    def test_pyt_batchsort_train(self):
        """
            Tests the functionality of training with batchsort
            under the following conditions:

            1. -dt train --pytorch_preprocess False
            2. -dt train:stream --pytorch_preprocess False
            3. -dt train --pytorch_preprocess True --batch_sort_field text_vec
        """
        # Next, check that training works
        dt_and_preprocess = [('train', False), ('train:stream', False),
                             ('train', True)]
        for dt, preprocess in dt_and_preprocess:
            print('Testing test_pyt_batchsort with -dt {} and --preprocess {}'.
                  format(dt, preprocess))
            f = io.StringIO()
            with redirect_stdout(f):
                parser = train_setup_args()
                defaults = parser_defaults.copy()
                set_model_file(defaults)
                defaults['datatype'] = dt
                defaults['pytorch_preprocess'] = preprocess
                defaults['pytorch_teacher_batch_sort'] = True
                defaults['batchsize'] = 50
                if preprocess:
                    defaults['batch_sort_field'] = 'text_vec'
                parser.set_defaults(**defaults)
                TrainLoop(parser.parse_args()).train()

            str_output = f.getvalue()
            self.assertTrue(
                solved_task(str_output),
                'Teacher could not teach seq2seq with batch sort '
                'and args {} and output {}'.format((dt, preprocess),
                                                   str_output))
        print('\n------Passed `test_pyt_batchsort_train`------\n')
Esempio n. 7
0
    def test_valid_pyt_batchsort(self):
        """
        Tests that batchsort *works* for two epochs; that is, that
        every example is seen both epochs
        """
        parser = train_setup_args()

        def get_acts_epochs_1_and_2(defaults):
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            build_dict(opt)
            agent = create_agent(opt)
            world_data = create_task(opt, agent)
            acts_epoch_1 = []
            acts_epoch_2 = []
            while not world_data.epoch_done():
                world_data.parley()
                acts_epoch_1.append(world_data.acts[0])
            world_data.reset()
            while not world_data.epoch_done():
                world_data.parley()
                acts_epoch_2.append(world_data.acts[0])
            acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b]
            acts_epoch_1 = sorted([b for b in acts_epoch_1 if 'text' in b],
                                  key=lambda x: x.get('text'))
            acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b]
            acts_epoch_2 = sorted([b for b in acts_epoch_2 if 'text' in b],
                                  key=lambda x: x.get('text'))
            world_data.shutdown()
            return acts_epoch_1, acts_epoch_2

        def check_equal_act_lists(acts1, acts2):
            for idx in range(len(acts1)):
                act1 = acts1[idx]
                act2 = acts2[idx]
                for key in act1:
                    val1 = act1[key]
                    val2 = act2[key]
                    if type(val1) is torch.Tensor:
                        self.assertTrue(bool(torch.all(torch.eq(val1, val2))))
                    else:
                        self.assertEqual(val1, val2)
        # First, check that batchsort itself works
        defaults = unit_test_parser_defaults.copy()
        defaults['datatype'] = 'train:stream:ordered'
        defaults['pytorch_teacher_batch_sort'] = True

        with testing_utils.capture_output() as _, testing_utils.tempdir() as tmpdir:
            # Get processed act from agent
            defaults['pytorch_teacher_task'] = 'babi:task1k:1'
            defaults['batch_sort_cache_type'] = 'index'
            defaults['batchsize'] = 50
            defaults['model_file'] = os.path.join(tmpdir, 'model')
            defaults['dict_file'] = os.path.join(tmpdir, 'model.dict')
            bsrt_acts_ep1, bsrt_acts_ep2 = get_acts_epochs_1_and_2(defaults)

            defaults['pytorch_teacher_batch_sort'] = False
            defaults['model_file'] = os.path.join(tmpdir, 'model')
            defaults['dict_file'] = os.path.join(tmpdir, 'model.dict')
            no_bsrt_acts_ep1, no_bsrt_acts_ep2 = get_acts_epochs_1_and_2(defaults)

        check_equal_act_lists(bsrt_acts_ep1, no_bsrt_acts_ep1)
        check_equal_act_lists(bsrt_acts_ep2, no_bsrt_acts_ep2)
    def test_valid_pyt_batchsort(self):
        """
            Tests that batchsort *works* for two epochs; that is, that
            every example is seen both epochs
        """
        parser = train_setup_args()

        def get_acts_epochs_1_and_2(defaults):
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            build_dict(opt)
            agent = create_agent(opt)
            world_data = create_task(opt, agent)
            acts_epoch_1 = []
            acts_epoch_2 = []
            while not world_data.epoch_done():
                world_data.parley()
                acts_epoch_1.append(world_data.acts[0])
            world_data.reset()
            while not world_data.epoch_done():
                world_data.parley()
                acts_epoch_2.append(world_data.acts[0])
            acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b]
            acts_epoch_1 = sorted([b for b in acts_epoch_1 if 'text' in b],
                                  key=lambda x: x.get('text'))
            acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b]
            acts_epoch_2 = sorted([b for b in acts_epoch_2 if 'text' in b],
                                  key=lambda x: x.get('text'))
            return acts_epoch_1, acts_epoch_2

        def check_equal_act_lists(acts1, acts2):
            for idx in range(len(acts1)):
                act1 = acts1[idx]
                act2 = acts2[idx]
                for key in act1:
                    val1 = act1[key]
                    val2 = act2[key]
                    if type(val1) is torch.Tensor:
                        self.assertTrue(bool(torch.all(torch.eq(val1, val2))))
                    else:
                        self.assertTrue(
                            val1 == val2,
                            '{}\n\n --not equal to-- \n\n{}'.format(
                                val1, val2))

        # First, check that batchsort itself works
        defaults = parser_defaults.copy()
        defaults['datatype'] = 'train:stream:ordered'
        defaults['pytorch_teacher_batch_sort'] = True

        f = io.StringIO()

        with redirect_stdout(f):
            # Get processed act from agent
            defaults['pytorch_teacher_task'] = 'babi:task1k:1'
            defaults['batch_sort_cache_type'] = 'index'
            defaults['batchsize'] = 50
            set_model_file(defaults)
            bsrt_acts_ep1, bsrt_acts_ep2 = get_acts_epochs_1_and_2(defaults)

            defaults['pytorch_teacher_batch_sort'] = False
            set_model_file(defaults)
            no_bsrt_acts_ep1, no_bsrt_acts_ep2 = get_acts_epochs_1_and_2(
                defaults)

        check_equal_act_lists(bsrt_acts_ep1, no_bsrt_acts_ep1)
        check_equal_act_lists(bsrt_acts_ep2, no_bsrt_acts_ep2)
        print('\n------Passed `test_pyt_batchsort`------\n')