def talk(self, max_diag_length, batch_input_data, batch_input_kb, agent1, agent2, worker_step, batch_size, speaker=None): """The main procedure to generate a single self play conversation.""" # parse data bs_intent, bs_truth_action, bs_kb = self.parse_input( batch_input_data, batch_input_kb) # remember the roles of agents self.agents = [agent1, agent2] # In selfplay training the speaker will be non and we randomly chose an # initial speaker and initialize utterance. # In selfplay evaluation the speaker will be specified so we use as is if not speaker: speaker = int(np.random.random() < 0.5) # generate the conversation instance for this conversation. # print ('self.batch_size', self.batch_size) conv = Conversation(max_diag_length, self.turn1_token, self.turn2_token, batch_size, speaker) # generate conversation by turn in batch mode until all conversations # terminated (finished = True) or the number of turns reached the maximum. turn = 0 finished = False while (not finished) and turn < self.max_dialogue_turns: finished = self.generate_utterance(bs_intent, conv, bs_kb, speaker, turn, batch_size) # Change the speaker as we move to the next turn. speaker = (speaker + 1) % 2 turn += 1 all_rewards = dialogue_utils.compute_reward_batch( conv.utt_arr, conv.action_arr, bs_truth_action, bs_kb, self.hparams) metrics = dialogue_utils.calculate_reward_metrics(all_rewards) metrics['num_turns'] = turn # print out step stats only in debug mode if self.summary_writer and self.hparams.debug: for key in metrics: utils.add_summary(self.summary_writer, worker_step, self.dialogue_mode + '_' + key + '_ws', metrics[key]) utt_arr, bs_pred_action = conv.get_train_data() if self.hparams.debug: print('self_play debug: ' + bs_intent[0]) print('self_play debug: all_rewards', all_rewards[0]) print('self_play debug: ' + ' '.join(utt_arr[0])) print('self_play debug: ' + ' '.join(bs_pred_action[0])) sys.stdout.flush() return (bs_intent, bs_pred_action, bs_truth_action, utt_arr, bs_kb), turn, metrics
def maybe_train(self, sample, speaker, global_step, force=False): self.train_samples.append(sample) if force or len(self.train_samples) >= self.train_threadhold: # first generate training examples data_arr = [] kb_arr = [] for sample in self.train_samples: # each sample is a batch of data intent, pred_action, truth_action, utterance, kb = sample # batch version all_rewards = dialogue_utils.compute_reward_batch( utterance, pred_action, truth_action, kb, self.hparams) # batch version train_reward, _, _, _, _, _, _, _, _ = all_rewards final_reward = train_reward reward_diag, reward_action = self.scale_reward_batch( final_reward, self.gamma, utterance) # in batches flat_pred_action = [] for k in range(len(pred_action)): flat_pred_action.append(' '.join(pred_action[k])) new_data_arr = self.format_samples_batch( batch_intent=intent, batch_pred_action=flat_pred_action, batch_truth_action=truth_action, batch_utterance=utterance, batch_reward_diag=reward_diag, batch_reward_action=reward_action, batch_size=self.update_batch_size) data_arr.extend(new_data_arr) kb_arr.extend(kb) data_output, kb_output = data_arr, kb_arr new_global_step = None self.train_samples = [] # clean up self_play_hangle = self.mutable_handles[self.iterator_mode] if self.hparams.rl_training: new_global_step = self.do_rl_training(data_output, kb_output, self.update_batch_size, self.mutable_model, self.mutable_sess, speaker, global_step, self_play_hangle) print('self.hparams.self_play_sl_multiplier=', self.hparams.self_play_sl_multiplier) if self.hparams.self_play_sl_multiplier >= 0: # train multiple or don't train at all print('do', self.hparams.self_play_sl_multiplier, 'supervised training') for _ in range(self.hparams.self_play_sl_multiplier): new_global_step = self.do_SL_training( self.mutable_model, self.mutable_sess, global_step, self.mutable_handles[0]) else: print('do one supervised traiing') if self.train_counter >= abs( self.hparams.self_play_sl_multiplier): new_global_step = self.do_SL_training( self.mutable_model, self.mutable_sess, global_step, self.mutable_handles[0]) self.train_counter = 0 else: self.train_counter += 1 if self.summary_writer: utils.add_summary( self.summary_writer, new_global_step, self.dialogue_mode + '_' + 'sl_rl', self.num_sl_updates * 1.0 / (self.num_rl_updates + 0.0001)) return new_global_step return None