def verify_batch_lengths(defaults): with testing_utils.capture_output() as _, testing_utils.tempdir() as tmpdir: # Get processed act from agent parser = train_setup_args() defaults['model_file'] = os.path.join(tmpdir, 'model') defaults['dict_file'] = os.path.join(tmpdir, 'model.dict') parser.set_defaults(**defaults) opt = parser.parse_args() build_dict(opt) agent = create_agent(opt) world_data = create_task(opt, agent) batch_sort_acts = [] # first epoch while len(batch_sort_acts) < 900/50: world_data.parley() batch_sort_acts.append(world_data.acts[0]) teacher = world_data.world.get_agents()[0] teacher.reset_data() # second epoch while len(batch_sort_acts) < 1800/50: world_data.parley() batch_sort_acts.append(world_data.acts[0]) world_data.shutdown() field = defaults['batch_sort_field'] lengths = [[ep_length(b[field]) for b in bb if field in b] for bb in batch_sort_acts[:-2]] # exclude last batch # verify batch lengths for batch_lens in lengths: self.assertLessEqual(max(batch_lens) - min(batch_lens), max_range, 'PytorchDataTeacher batching does not give ' 'batches with similar sized examples, when ' 'sorting by `{}` field.'.format( defaults['batch_sort_field']))
def test_pyt_preprocess_train(self): """ Test that the preprocess functionality works with the PytorchDataTeacher with a sample TorchAgent (here, the Seq2seq model). This tests whether an agent can train to completion with these preprocessed examples """ # Second, check that the model will train print('Testing test_pyt_preprocess training') f = io.StringIO() with redirect_stdout(f): parser = train_setup_args() defaults = parser_defaults.copy() set_model_file(defaults) defaults['datatype'] = 'train' defaults['pytorch_preprocess'] = True parser.set_defaults(**defaults) TrainLoop(parser.parse_args()).train() str_output = f.getvalue() self.assertTrue( solved_task(str_output), 'Teacher could not teach seq2seq with preprocessed obs, ' 'output: {}'.format(str_output)) print('\n------Passed `test_pyt_preprocess_train`------\n')
def test_pyt_train(self): """ Integration test: ensure that pytorch data teacher can successfully teach Seq2Seq model to fully solve the babi:task10k:1 task. The Seq2Seq model can solve the babi:task10k:1 task with the normal ParlAI setup, and thus should be able to with a PytorchDataTeacher This tests the following setups: 1. -dt train 2. -dt train:stream 3. -dt train:stream:ordered """ dts = ['train', 'train:stream', 'train:stream:ordered'] for dt in dts: print('Testing test_pyt_train with dt: {}'.format(dt)) f = io.StringIO() with redirect_stdout(f): parser = train_setup_args() defaults = parser_defaults.copy() set_model_file(defaults) defaults['datatype'] = dt defaults['shuffle'] = True # for train:stream parser.set_defaults(**defaults) TrainLoop(parser.parse_args()).train() str_output = f.getvalue() self.assertTrue( solved_task(str_output), 'Teacher could not teach seq2seq with args: ' '{}; here is str_output: {}'.format(defaults, str_output)) print('\n------Passed `test_pyt_train`------\n')
def verify_batch_lengths(defaults): f = io.StringIO() with redirect_stdout(f): # Get processed act from agent parser = train_setup_args() set_model_file(defaults) parser.set_defaults(**defaults) opt = parser.parse_args() build_dict(opt) agent = create_agent(opt) world_data = create_task(opt, agent) batch_sort_acts = [] # first epoch while len(batch_sort_acts) < 900 / 50: world_data.parley() batch_sort_acts.append(world_data.acts[0]) teacher = world_data.world.get_agents()[0] teacher.reset_data() # second epoch while len(batch_sort_acts) < 1800 / 50: world_data.parley() batch_sort_acts.append(world_data.acts[0]) field = defaults['batch_sort_field'] lengths = [[ep_length(b[field]) for b in bb if field in b] for bb in batch_sort_acts[:-2]] # exclude last batch # verify batch lengths for batch_lens in lengths: self.assertLessEqual( max(batch_lens) - min(batch_lens), max_range, 'PytorchDataTeacher batching does not give ' 'batches with similar sized examples, when ' 'sorting by `{}` field.'.format( defaults['batch_sort_field']))
def get_teacher_act(defaults, teacher_processed=False, agent_to=None): parser = train_setup_args() parser.set_defaults(**defaults) opt = parser.parse_args([]) build_dict(opt) teacher = create_task_agent_from_taskname(opt)[0] agent = create_agent(opt) act = teacher.act() if teacher_processed: return act, agent return agent.observe(act), agent
def test_pyt_batchsort_train(self): """ Tests the functionality of training with batchsort under the following conditions: 1. -dt train --pytorch_preprocess False 2. -dt train:stream --pytorch_preprocess False 3. -dt train --pytorch_preprocess True --batch_sort_field text_vec """ # Next, check that training works dt_and_preprocess = [('train', False), ('train:stream', False), ('train', True)] for dt, preprocess in dt_and_preprocess: print('Testing test_pyt_batchsort with -dt {} and --preprocess {}'. format(dt, preprocess)) f = io.StringIO() with redirect_stdout(f): parser = train_setup_args() defaults = parser_defaults.copy() set_model_file(defaults) defaults['datatype'] = dt defaults['pytorch_preprocess'] = preprocess defaults['pytorch_teacher_batch_sort'] = True defaults['batchsize'] = 50 if preprocess: defaults['batch_sort_field'] = 'text_vec' parser.set_defaults(**defaults) TrainLoop(parser.parse_args()).train() str_output = f.getvalue() self.assertTrue( solved_task(str_output), 'Teacher could not teach seq2seq with batch sort ' 'and args {} and output {}'.format((dt, preprocess), str_output)) print('\n------Passed `test_pyt_batchsort_train`------\n')
def test_valid_pyt_batchsort(self): """ Tests that batchsort *works* for two epochs; that is, that every example is seen both epochs """ parser = train_setup_args() def get_acts_epochs_1_and_2(defaults): parser.set_defaults(**defaults) opt = parser.parse_args() build_dict(opt) agent = create_agent(opt) world_data = create_task(opt, agent) acts_epoch_1 = [] acts_epoch_2 = [] while not world_data.epoch_done(): world_data.parley() acts_epoch_1.append(world_data.acts[0]) world_data.reset() while not world_data.epoch_done(): world_data.parley() acts_epoch_2.append(world_data.acts[0]) acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b] acts_epoch_1 = sorted([b for b in acts_epoch_1 if 'text' in b], key=lambda x: x.get('text')) acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b] acts_epoch_2 = sorted([b for b in acts_epoch_2 if 'text' in b], key=lambda x: x.get('text')) world_data.shutdown() return acts_epoch_1, acts_epoch_2 def check_equal_act_lists(acts1, acts2): for idx in range(len(acts1)): act1 = acts1[idx] act2 = acts2[idx] for key in act1: val1 = act1[key] val2 = act2[key] if type(val1) is torch.Tensor: self.assertTrue(bool(torch.all(torch.eq(val1, val2)))) else: self.assertEqual(val1, val2) # First, check that batchsort itself works defaults = unit_test_parser_defaults.copy() defaults['datatype'] = 'train:stream:ordered' defaults['pytorch_teacher_batch_sort'] = True with testing_utils.capture_output() as _, testing_utils.tempdir() as tmpdir: # Get processed act from agent defaults['pytorch_teacher_task'] = 'babi:task1k:1' defaults['batch_sort_cache_type'] = 'index' defaults['batchsize'] = 50 defaults['model_file'] = os.path.join(tmpdir, 'model') defaults['dict_file'] = os.path.join(tmpdir, 'model.dict') bsrt_acts_ep1, bsrt_acts_ep2 = get_acts_epochs_1_and_2(defaults) defaults['pytorch_teacher_batch_sort'] = False defaults['model_file'] = os.path.join(tmpdir, 'model') defaults['dict_file'] = os.path.join(tmpdir, 'model.dict') no_bsrt_acts_ep1, no_bsrt_acts_ep2 = get_acts_epochs_1_and_2(defaults) check_equal_act_lists(bsrt_acts_ep1, no_bsrt_acts_ep1) check_equal_act_lists(bsrt_acts_ep2, no_bsrt_acts_ep2)
def test_valid_pyt_batchsort(self): """ Tests that batchsort *works* for two epochs; that is, that every example is seen both epochs """ parser = train_setup_args() def get_acts_epochs_1_and_2(defaults): parser.set_defaults(**defaults) opt = parser.parse_args() build_dict(opt) agent = create_agent(opt) world_data = create_task(opt, agent) acts_epoch_1 = [] acts_epoch_2 = [] while not world_data.epoch_done(): world_data.parley() acts_epoch_1.append(world_data.acts[0]) world_data.reset() while not world_data.epoch_done(): world_data.parley() acts_epoch_2.append(world_data.acts[0]) acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b] acts_epoch_1 = sorted([b for b in acts_epoch_1 if 'text' in b], key=lambda x: x.get('text')) acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b] acts_epoch_2 = sorted([b for b in acts_epoch_2 if 'text' in b], key=lambda x: x.get('text')) return acts_epoch_1, acts_epoch_2 def check_equal_act_lists(acts1, acts2): for idx in range(len(acts1)): act1 = acts1[idx] act2 = acts2[idx] for key in act1: val1 = act1[key] val2 = act2[key] if type(val1) is torch.Tensor: self.assertTrue(bool(torch.all(torch.eq(val1, val2)))) else: self.assertTrue( val1 == val2, '{}\n\n --not equal to-- \n\n{}'.format( val1, val2)) # First, check that batchsort itself works defaults = parser_defaults.copy() defaults['datatype'] = 'train:stream:ordered' defaults['pytorch_teacher_batch_sort'] = True f = io.StringIO() with redirect_stdout(f): # Get processed act from agent defaults['pytorch_teacher_task'] = 'babi:task1k:1' defaults['batch_sort_cache_type'] = 'index' defaults['batchsize'] = 50 set_model_file(defaults) bsrt_acts_ep1, bsrt_acts_ep2 = get_acts_epochs_1_and_2(defaults) defaults['pytorch_teacher_batch_sort'] = False set_model_file(defaults) no_bsrt_acts_ep1, no_bsrt_acts_ep2 = get_acts_epochs_1_and_2( defaults) check_equal_act_lists(bsrt_acts_ep1, no_bsrt_acts_ep1) check_equal_act_lists(bsrt_acts_ep2, no_bsrt_acts_ep2) print('\n------Passed `test_pyt_batchsort`------\n')