def __init__(self, agent, corpus, sv_config, sys_model, rl_config, generate_func): self.agent = agent self.corpus = corpus self.sv_config = sv_config self.sys_model = sys_model self.rl_config = rl_config # training func for supervised learning self.train_func = task_train_single_batch self.record_func = record_task self.validate_func = validate # prepare data loader train_dial, val_dial, test_dial = self.corpus.get_corpus() self.train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) self.sl_train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) self.val_data = BeliefDbDataLoaders('Val', val_dial, self.sv_config) self.test_data = BeliefDbDataLoaders('Test', test_dial, self.sv_config) # create log files if self.rl_config.record_freq > 0: self.learning_exp_file = open( os.path.join(self.rl_config.record_path, 'offline-learning.tsv'), 'w') self.ppl_val_file = open( os.path.join(self.rl_config.record_path, 'val-ppl.tsv'), 'w') self.rl_val_file = open( os.path.join(self.rl_config.record_path, 'val-rl.tsv'), 'w') self.ppl_test_file = open( os.path.join(self.rl_config.record_path, 'test-ppl.tsv'), 'w') self.rl_test_file = open( os.path.join(self.rl_config.record_path, 'test-rl.tsv'), 'w') # evaluation self.evaluator = evaluators.MultiWozEvaluator('SYS_WOZ') self.generate_func = generate_func
os.makedirs(saved_path) config.saved_path = saved_path prepare_dirs_loggers(config) logger = logging.getLogger() logger.info('[START]\n{}\n{}'.format(start_time, '=' * 30)) config.saved_path = saved_path # save configuration with open(os.path.join(saved_path, 'config.json'), 'w') as f: json.dump(config, f, indent=4) # sort_keys=True corpus = corpora.NormMultiWozCorpus(config) train_dial, val_dial, test_dial = corpus.get_corpus() train_data = BeliefDbDataLoaders('Train', train_dial, config) val_data = BeliefDbDataLoaders('Val', val_dial, config) test_data = BeliefDbDataLoaders('Test', test_dial, config) evaluator = MultiWozEvaluator('SysWoz') model = SysPerfectBD2Cat(corpus, config) if config.use_gpu: model.cuda() best_epoch = None if not config.forward_only: try: best_epoch = train(model, train_data,
class OfflineTaskReinforce(object): def __init__(self, agent, corpus, sv_config, sys_model, rl_config, generate_func): self.agent = agent self.corpus = corpus self.sv_config = sv_config self.sys_model = sys_model self.rl_config = rl_config # training func for supervised learning self.train_func = task_train_single_batch self.record_func = record_task self.validate_func = validate # prepare data loader train_dial, val_dial, test_dial = self.corpus.get_corpus() self.train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) self.sl_train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) self.val_data = BeliefDbDataLoaders('Val', val_dial, self.sv_config) self.test_data = BeliefDbDataLoaders('Test', test_dial, self.sv_config) # create log files if self.rl_config.record_freq > 0: self.learning_exp_file = open( os.path.join(self.rl_config.record_path, 'offline-learning.tsv'), 'w') self.ppl_val_file = open( os.path.join(self.rl_config.record_path, 'val-ppl.tsv'), 'w') self.rl_val_file = open( os.path.join(self.rl_config.record_path, 'val-rl.tsv'), 'w') self.ppl_test_file = open( os.path.join(self.rl_config.record_path, 'test-ppl.tsv'), 'w') self.rl_test_file = open( os.path.join(self.rl_config.record_path, 'test-rl.tsv'), 'w') # evaluation self.evaluator = evaluators.MultiWozEvaluator('SYS_WOZ') self.generate_func = generate_func def run(self): n = 0 best_valid_loss = np.inf best_rewards = -1 * np.inf # BEFORE RUN, RECORD INITIAL PERFORMANCE test_loss = self.validate_func(self.sys_model, self.test_data, self.sv_config, use_py=True) t_success, t_match, t_bleu, t_f1 = self.generate_func(self.sys_model, self.test_data, self.sv_config, self.evaluator, None, verbose=False) self.ppl_test_file.write('{}\t{}\t{}\t{}\n'.format( n, np.exp(test_loss), t_bleu, t_f1)) self.ppl_test_file.flush() self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format( n, (t_success + t_match), t_success, t_match)) self.rl_test_file.flush() self.sys_model.train() try: for epoch_id in range(self.rl_config.nepoch): self.train_data.epoch_init(self.sv_config, shuffle=True, verbose=epoch_id == 0, fix_batch=True) while True: if n % self.rl_config.episode_repeat == 0: batch = self.train_data.next_batch() if batch is None: break n += 1 if n % 50 == 0: print("Reinforcement Learning {}/{} eposide".format( n, self.train_data.num_batch * self.rl_config.nepoch)) self.learning_exp_file.write('{}\t{}\n'.format( n, np.mean(self.agent.all_rewards[-50:]))) self.learning_exp_file.flush() # reinforcement learning # make sure it's the same dialo assert len(set(batch['keys'])) == 1 task_report, success, match = self.agent.run( batch, self.evaluator, max_words=self.rl_config.max_words, temp=self.rl_config.temperature) reward = float(success) # + float(match) stats = {'Match': match, 'Success': success} self.agent.update(reward, stats) # supervised learning if self.rl_config.sv_train_freq > 0 and n % self.rl_config.sv_train_freq == 0: self.train_func(self.sys_model, self.sl_train_data, self.sv_config) # record model performance in terms of several evaluation metrics if self.rl_config.record_freq > 0 and n % self.rl_config.record_freq == 0: self.agent.print_dialog(self.agent.dlg_history, reward, stats) print('-' * 15, 'Recording start', '-' * 15) # save train reward self.learning_exp_file.write('{}\t{}\n'.format( n, np.mean( self.agent. all_rewards[-self.rl_config.record_freq:]))) self.learning_exp_file.flush() # PPL & reward on validation valid_loss = self.validate_func(self.sys_model, self.val_data, self.sv_config, use_py=True) v_success, v_match, v_bleu, v_f1 = self.generate_func( self.sys_model, self.val_data, self.sv_config, self.evaluator, None, verbose=False) self.ppl_val_file.write('{}\t{}\t{}\t{}\n'.format( n, np.exp(valid_loss), v_bleu, v_f1)) self.ppl_val_file.flush() self.rl_val_file.write('{}\t{}\t{}\t{}\n'.format( n, (v_success + v_match), v_success, v_match)) self.rl_val_file.flush() test_loss = self.validate_func(self.sys_model, self.test_data, self.sv_config, use_py=True) t_success, t_match, t_bleu, t_f1 = self.generate_func( self.sys_model, self.test_data, self.sv_config, self.evaluator, None, verbose=False) self.ppl_test_file.write('{}\t{}\t{}\t{}\n'.format( n, np.exp(test_loss), t_bleu, t_f1)) self.ppl_test_file.flush() self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format( n, (t_success + t_match), t_success, t_match)) self.rl_test_file.flush() # save model is needed if v_success + v_match > best_rewards: print( "Model saved with success {} match {}".format( v_success, v_match)) th.save(self.sys_model.state_dict(), self.rl_config.reward_best_model_path) best_rewards = v_success + v_match self.sys_model.train() print('-' * 15, 'Recording end', '-' * 15) except KeyboardInterrupt: print("RL training stopped from keyboard") print("$$$ Load {}-model".format( self.rl_config.reward_best_model_path)) self.sv_config.batch_size = 32 self.sys_model.load_state_dict( th.load(self.rl_config.reward_best_model_path)) validate(self.sys_model, self.val_data, self.sv_config, use_py=True) validate(self.sys_model, self.test_data, self.sv_config, use_py=True) with open(os.path.join(self.rl_config.record_path, 'valid_file.txt'), 'w') as f: self.generate_func(self.sys_model, self.val_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f) with open(os.path.join(self.rl_config.record_path, 'test_file.txt'), 'w') as f: self.generate_func(self.sys_model, self.test_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f)