def test_resuming(self): with testing_utils.tempdir() as tmpdir: model_file = os.path.join(tmpdir, 'model') stdout1, valid1, test1 = testing_utils.train_model( dict( model_file=model_file, task='integration_tests:candidate', model='transformer/ranker', optimizer='adamax', learningrate=7e-3, batchsize=32, num_epochs=1, n_layers=1, n_heads=1, ffn_size=32, embedding_size=32, warmup_updates=1, lr_scheduler='invsqrt', )) stdout2, valid2, test2 = testing_utils.train_model( dict( model_file=model_file, task='integration_tests:candidate', model='transformer/ranker', num_epochs=1, )) # make sure the number of updates is being tracked correctly self.assertGreater(valid2['num_updates'], valid1['num_updates'], 'Number of updates is not increasing') # make sure the learning rate is decreasing self.assertLess(valid2['lr'], valid1['lr'], 'Learning rate is not decreasing')
def verify_batch_lengths(defaults): with testing_utils.capture_output() as _, testing_utils.tempdir() as tmpdir: # Get processed act from agent parser = train_setup_args() defaults['model_file'] = os.path.join(tmpdir, 'model') defaults['dict_file'] = os.path.join(tmpdir, 'model.dict') parser.set_defaults(**defaults) opt = parser.parse_args() build_dict(opt) agent = create_agent(opt) world_data = create_task(opt, agent) batch_sort_acts = [] # first epoch while len(batch_sort_acts) < 900/50: world_data.parley() batch_sort_acts.append(world_data.acts[0]) teacher = world_data.world.get_agents()[0] teacher.reset_data() # second epoch while len(batch_sort_acts) < 1800/50: world_data.parley() batch_sort_acts.append(world_data.acts[0]) world_data.shutdown() field = defaults['batch_sort_field'] lengths = [[ep_length(b[field]) for b in bb if field in b] for bb in batch_sort_acts[:-2]] # exclude last batch # verify batch lengths for batch_lens in lengths: self.assertLessEqual(max(batch_lens) - min(batch_lens), max_range, 'PytorchDataTeacher batching does not give ' 'batches with similar sized examples, when ' 'sorting by `{}` field.'.format( defaults['batch_sort_field']))
def test_learning_rate_resuming(self, args): mdl = args['model'] with testing_utils.tempdir() as tmpdir: model_file = os.path.join(tmpdir, 'model') stdout1, valid1, test1 = testing_utils.train_model( dict(model_file=model_file, lr_scheduler='invsqrt', **args)) stdout2, valid2, test2 = testing_utils.train_model( dict(model_file=model_file, lr_scheduler='invsqrt', **args)) # make sure the number of updates is being tracked correctly self.assertGreater( valid2['num_updates'], valid1['num_updates'], '({}) Number of updates is not increasing'.format(mdl), ) # make sure the learning rate is decreasing self.assertLess( valid2['lr'], valid1['lr'], '({}) Learning rate is not decreasing'.format(mdl), ) # but make sure we're not loading the scheduler if we're fine # tuning stdout3, valid3, test3 = testing_utils.train_model( dict( init_model=os.path.join(tmpdir, 'model'), model_file=os.path.join(tmpdir, 'newmodel'), lr_scheduler='invsqrt', **args, )) self.assertEqual( valid3['num_updates'], valid1['num_updates'], '({}) Finetuning LR scheduler reset failed ' '(num_updates).'.format(mdl), ) self.assertEqual( valid3['lr'], valid1['lr'], '({}) Finetuning LR scheduler reset failed ' '(lr).'.format(mdl), ) # and make sure we're not loading the scheduler if it changes stdout4, valid4, test4 = testing_utils.train_model( dict( init_model=os.path.join(tmpdir, 'model'), model_file=os.path.join(tmpdir, 'newmodel2'), lr_scheduler='reduceonplateau', **args, )) self.assertEqual( valid4['num_updates'], valid1['num_updates'], '({}) LR scheduler change reset failed (num_updates).' '\n{}'.format(mdl, stdout4), ) self.assertEqual( valid4['lr'], 1e-3, '({}) LR is not correct in final resume.\n{}'.format(mdl, stdout4), )
def test_resume_checkpoint(self): """Make sure when resuming training that model uses appropriate mf. Copy train_model from testing_utils to directly access agent. """ import parlai.scripts.train_model as tms def get_popt_and_tl(opt): parser = tms.setup_args() parser.set_params(**opt) popt = parser.parse_args(print_args=False) for k, v in opt.items(): popt[k] = v return popt, tms.TrainLoop(popt) def get_opt(init_mf, mf): return { 'task': 'integration_tests', 'init_model': init_mf, 'model': 'parlai.agents.test_agents.dummy_torch_agent:MockTorchAgent', 'model_file': mf, 'num_epochs': 3, 'validation_every_n_epochs': 1, 'save_after_valid': True, 'log_every_n_secs': 10, } with capture_output(): with tempdir() as tmpdir: # First train model with init_model path set mf = os.path.join(tmpdir, 'model') init_mf = os.path.join(tmpdir, 'init_model') with open(init_mf, 'w') as f: f.write(' ') opt = get_opt(init_mf, mf) popt, tl = get_popt_and_tl(opt) agent = tl.agent # init model file should be set appropriately init_model_file, is_finetune = agent._get_init_model( popt, None) self.assertEqual(init_model_file, init_mf) self.assertTrue(is_finetune) valid, test = tl.train() # now, train the model for another epoch opt = get_opt('{}.checkpoint'.format(mf), mf) opt['load_from_checkpoint'] = True popt, tl = get_popt_and_tl(opt) agent = tl.agent init_model_file, is_finetune = agent._get_init_model( popt, None) self.assertEqual(init_model_file, '{}.checkpoint'.format(mf)) self.assertFalse(is_finetune)
def test_pyt_preprocess(self): """ Test that the preprocess functionality works with the PytorchDataTeacher with a sample TorchAgent (here, the Seq2seq model). This tests whether the action provided by the preprocessed teacher is equivalent to the agent's observation after the agent processes it. """ def get_teacher_act(defaults, teacher_processed=False, agent_to=None): parser = train_setup_args() parser.set_defaults(**defaults) opt = parser.parse_args() build_dict(opt) with testing_utils.capture_output() as _: teacher = create_task_agent_from_taskname(opt)[0] agent = create_agent(opt) act = teacher.act() if teacher_processed: return act, agent return agent.observe(act), agent with testing_utils.capture_output() as _, testing_utils.tempdir( ) as tmpdir: defaults = unit_test_parser_defaults.copy() defaults['batch_size'] = 1 defaults['datatype'] = 'train:stream:ordered' # Get processed act from agent defaults['model_file'] = os.path.join(tmpdir, 'model') defaults['dict_file'] = os.path.join(tmpdir, 'model.dict') agent_processed_observation, agent1 = get_teacher_act(defaults) # Get preprocessed act from teacher defaults['model_file'] = os.path.join(tmpdir, 'model') defaults['dict_file'] = os.path.join(tmpdir, 'model.dict') defaults['pytorch_preprocess'] = True teacher_processed_act, agent2 = get_teacher_act( defaults, teacher_processed=True) # noqa: E501 for key in agent_processed_observation: val1 = agent_processed_observation[key] val2 = teacher_processed_act[key] if isinstance(val1, torch.Tensor): self.assertTrue( bool(torch.all(torch.eq(val1, val2))), '{} is not equal to {}'.format(val1, val2), ) else: self.assertEqual(val1, val2)
def test_upgrade_opt(self): """Test whether upgrade_opt works.""" with testing_utils.tempdir() as tmp: with testing_utils.capture_output() as _: modfn = os.path.join(tmp, 'model') with open(modfn, 'w') as f: f.write('Test.') optfn = modfn + '.opt' base_opt = { 'model': 'tests.test_params:_ExampleUpgradeOptAgent', 'dict_file': modfn + '.dict', 'model_file': modfn, } with open(optfn, 'w') as f: json.dump(base_opt, f) pp = ParlaiParser(True, True) opt = pp.parse_args(['--model-file', modfn]) agents.create_agent(opt)
def test_resuming_reduce_on_plateau(self): """ Reduce on Plateau can be tricky when combined with warmup. See: https://github.com/facebookresearch/ParlAI/pull/1812 """ with testing_utils.tempdir() as tmpdir: model_file = os.path.join(tmpdir, 'model') stdout1, valid1, test1 = testing_utils.train_model( dict( model_file=model_file, task='integration_tests:candidate', model='transformer/ranker', optimizer='adamax', learningrate=7e-3, batchsize=32, num_epochs=1, n_layers=1, n_heads=1, ffn_size=32, embedding_size=32, warmup_updates=1, lr_scheduler='reduceonplateau', ) ) stdout2, valid2, test2 = testing_utils.train_model( dict( model_file=model_file, task='integration_tests:candidate', model='transformer/ranker', num_epochs=1, lr_scheduler='reduceonplateau', ) ) # make sure the learning rate is decreasing self.assertGreater( valid2['lr'], 1e-5, 'Learning rate should not be that low when resuming' )
def _distributed_train_model(self, opt): # we have to delay our import to here, because the set_spawn_method call # inside multiprocessing_train will break the multithreading tests, even # when we skip the test. import parlai.scripts.multiprocessing_train as mp_train with testing_utils.capture_output() as output: with testing_utils.tempdir() as tmpdir: if 'model_file' not in opt: opt['model_file'] = os.path.join(tmpdir, 'model') if 'dict_file' not in opt: opt['dict_file'] = os.path.join(tmpdir, 'model.dict') parser = mp_train.setup_args() popt = _forced_parse(parser, opt) # we need a prebuilt dictionary parser = build_dict.setup_args() build_dict.build_dict(popt) valid, test = mp_train.launch_and_train(popt, 31337) return (output.getvalue(), valid, test)
def test_eval_fixed(self): args = self._get_args() args['eval_candidates'] = 'fixed' args['encode_candidate_vecs'] = True args['ignore_bad_candidates'] = True stdout, valid, test = testing_utils.train_model(args) # none of the train candidates appear in evaluation, so should have # zero accuracy: this tests whether the fixed candidates were built # properly (i.e., only using candidates from the train set) self.assertEqual( valid['hits@1'], 0, "valid hits@1 = {}\nLOG:\n{}".format(valid['hits@1'], stdout), ) # now try again with a fixed candidate file that includes all possible # candidates teacher = CandidateTeacher({'datatype': 'train'}) all_cands = teacher.train + teacher.val + teacher.test all_cands_str = '\n'.join([' '.join(x) for x in all_cands]) with testing_utils.tempdir() as tmpdir: tmp_cands_file = os.path.join(tmpdir, 'all_cands.text') with open(tmp_cands_file, 'w') as f: f.write(all_cands_str) args['fixed_candidates_path'] = tmp_cands_file args['encode_candidate_vecs'] = False # don't encode before training args['ignore_bad_candidates'] = False args['num_epochs'] = 20 stdout, valid, test = testing_utils.train_model(args) self.assertGreaterEqual( valid['hits@100'], 0.1, "valid hits@1 = {}\nLOG:\n{}".format(valid['hits@1'], stdout), )
def test_valid_pyt_batchsort(self): """ Tests that batchsort *works* for two epochs; that is, that every example is seen both epochs """ parser = train_setup_args() def get_acts_epochs_1_and_2(defaults): parser.set_defaults(**defaults) opt = parser.parse_args() build_dict(opt) agent = create_agent(opt) world_data = create_task(opt, agent) acts_epoch_1 = [] acts_epoch_2 = [] while not world_data.epoch_done(): world_data.parley() acts_epoch_1.append(world_data.acts[0]) world_data.reset() while not world_data.epoch_done(): world_data.parley() acts_epoch_2.append(world_data.acts[0]) acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b] acts_epoch_1 = sorted([b for b in acts_epoch_1 if 'text' in b], key=lambda x: x.get('text')) acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b] acts_epoch_2 = sorted([b for b in acts_epoch_2 if 'text' in b], key=lambda x: x.get('text')) world_data.shutdown() return acts_epoch_1, acts_epoch_2 def check_equal_act_lists(acts1, acts2): for idx in range(len(acts1)): act1 = acts1[idx] act2 = acts2[idx] for key in act1: val1 = act1[key] val2 = act2[key] if type(val1) is torch.Tensor: self.assertTrue(bool(torch.all(torch.eq(val1, val2)))) else: self.assertEqual(val1, val2) # First, check that batchsort itself works defaults = unit_test_parser_defaults.copy() defaults['datatype'] = 'train:stream:ordered' defaults['pytorch_teacher_batch_sort'] = True with testing_utils.capture_output() as _, testing_utils.tempdir() as tmpdir: # Get processed act from agent defaults['pytorch_teacher_task'] = 'babi:task1k:1' defaults['batch_sort_cache_type'] = 'index' defaults['batchsize'] = 50 defaults['model_file'] = os.path.join(tmpdir, 'model') defaults['dict_file'] = os.path.join(tmpdir, 'model.dict') bsrt_acts_ep1, bsrt_acts_ep2 = get_acts_epochs_1_and_2(defaults) defaults['pytorch_teacher_batch_sort'] = False defaults['model_file'] = os.path.join(tmpdir, 'model') defaults['dict_file'] = os.path.join(tmpdir, 'model.dict') no_bsrt_acts_ep1, no_bsrt_acts_ep2 = get_acts_epochs_1_and_2(defaults) check_equal_act_lists(bsrt_acts_ep1, no_bsrt_acts_ep1) check_equal_act_lists(bsrt_acts_ep2, no_bsrt_acts_ep2)
def test_resuming(self): BASE_ARGS = dict( task='integration_tests:nocandidate', model='transformer/generator', optimizer='adamax', learningrate=1e-3, batchsize=32, num_epochs=1, n_layers=1, n_heads=1, ffn_size=32, embedding_size=32, skip_generation=True, warmup_updates=1, ) with testing_utils.tempdir() as tmpdir: model_file = os.path.join(tmpdir, 'model') stdout1, valid1, test1 = testing_utils.train_model( dict( model_file=model_file, lr_scheduler='invsqrt', **BASE_ARGS, )) stdout2, valid2, test2 = testing_utils.train_model( dict( model_file=model_file, lr_scheduler='invsqrt', **BASE_ARGS, )) # make sure the number of updates is being tracked correctly self.assertGreater(valid2['num_updates'], valid1['num_updates'], 'Number of updates is not increasing') # make sure the learning rate is decreasing self.assertLess(valid2['lr'], valid1['lr'], 'Learning rate is not decreasing') # but make sure we're not loading the scheduler if we're fine tuning stdout3, valid3, test3 = testing_utils.train_model( dict( init_model=os.path.join(tmpdir, 'model'), model_file=os.path.join(tmpdir, 'newmodel'), lr_scheduler='invsqrt', **BASE_ARGS, )) self.assertEqual( valid3['num_updates'], valid1['num_updates'], 'Finetuning LR scheduler reset failed (num_updates).') self.assertEqual(valid3['lr'], valid1['lr'], 'Finetuning LR scheduler reset failed (lr).') # and make sure we're not loading the scheduler if it changes stdout4, valid4, test4 = testing_utils.train_model( dict(init_model=os.path.join(tmpdir, 'model'), model_file=os.path.join(tmpdir, 'newmodel2'), lr_scheduler='reduceonplateau', **BASE_ARGS)) self.assertEqual( valid4['num_updates'], valid1['num_updates'], 'LR scheduler change reset failed (num_updates).\n' + stdout4) self.assertEqual(valid4['lr'], 1e-3, 'LR is not correct in final resume.\n' + stdout4)