def self_play_simulation_with_fixed_start_corpus(self, simulation_cnt, print_details=True): success_cnt_list, turns_cnt_list, simulation_result_list = [], [], [] for start_utterance in self.random_selected_sub_start_corpus: success_cnt, turns_cnt, simulation_result_str = \ self.self_play_simulation(simulation_cnt, start_utterance, print_details) success_cnt_list.append(success_cnt) turns_cnt_list.append(turns_cnt) simulation_result_list.append(simulation_result_str) total_simulation_cnt = simulation_cnt * len( self.random_selected_sub_start_corpus) total_success_cnt = sum(success_cnt_list) total_turns_cnt = sum(turns_cnt_list) total_success_rate = (total_success_cnt / total_simulation_cnt) * 100 total_average_turns = total_turns_cnt / total_success_cnt for start_utterance, simulation_result in \ zip(self.random_selected_sub_start_corpus, simulation_result_list): add_log(self.simulation_save_path, 'For start utterance: {}, {}'.format( start_utterance, simulation_result), print_details=True) add_log( self.simulation_save_path, 'total success times / total sessions: {}/{}, total success rate: {:.1f}%, total average turns: {:.2f}' .format(total_success_cnt, total_simulation_cnt, total_success_rate, total_average_turns), print_details=True)
def train(self): batch = self.iterator.get_next() kw_loss, kw_acc, _ = self.predict_keywords(batch) kw_saver = tf.train.Saver() loss, acc, rank = self.forward_response_retrieval(batch) op_step = tf.Variable(0, name='retrieval_step') train_op = tx.core.get_train_op( loss, global_step=op_step, hparams=self.model_config._retrieval_opt_hparams) max_val_acc = 0. with tf.Session(config=self.gpu_config) as sess: sess.run(tf.tables_initializer()) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) kw_saver.restore(sess, self.model_config._kp_save_path) saver = tf.train.Saver() for epoch_id in range(self.model_config._max_epoch): self.iterator.switch_to_train_data(sess) cur_step = 0 cnt_acc = [] while True: try: cur_step += 1 feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN} loss, acc_ = sess.run([train_op, acc], feed_dict=feed) cnt_acc.append(acc_) if cur_step % 200 == 0: logs_loss_acc = 'batch {}, loss={}, acc1={}'.format( cur_step, loss, np.mean(cnt_acc[-200:])) add_log(self.logs_save_path, logs_loss_acc) except tf.errors.OutOfRangeError: break self.iterator.switch_to_val_data(sess) cnt_acc, cnt_kwacc = [], [] while True: try: feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL} acc_, kw_acc_ = sess.run([acc, kw_acc], feed_dict=feed) cnt_acc.append(acc_) cnt_kwacc.append(kw_acc_) except tf.errors.OutOfRangeError: mean_acc = np.mean(cnt_acc) logs_loss_acc = 'epoch_id {}, valid acc1={}, kw_acc1={}'.format( epoch_id + 1, mean_acc, np.mean(cnt_kwacc)) add_log(self.logs_save_path, logs_loss_acc) if mean_acc > max_val_acc: max_val_acc = mean_acc saver.save(sess, self.model_config._retrieval_save_path) break
def self_play_simulation(self, simulation_cnt, start_utterance=None, print_details=True): if start_utterance is None: start_utterance = random.sample(self.start_corpus, 1)[0] simulation_start_str = 'start self-play simulation with start utterance: {} (total {} sessions)'.format( start_utterance, simulation_cnt) add_log(self.simulation_save_path, simulation_start_str, print_details) success_cnt, turns_cnt = 0, 0 for i in tqdm(range(simulation_cnt)): add_log(self.simulation_save_path, '-------- Session {} --------'.format(i), print_details) success, turns = self.simulate(start_utterance=start_utterance, target_keyword=self.target_set[i], print_details=print_details) success_cnt += success turns_cnt += turns # the average number of turns used to reach a target success_rate = (success_cnt / simulation_cnt) * 100 average_turns = turns_cnt / success_cnt simulation_result_str = '#success / #sessions: {}/{}, success rate: {:.1f}%, average turns: {:.2f}'.format( success_cnt, simulation_cnt, success_rate, average_turns) add_log(self.simulation_save_path, simulation_result_str, print_details=True) return success_cnt, turns_cnt, simulation_result_str
def chat(self, user_input=None): responses = [] # if is the beginning of a conversation if user_input is None: self._reset() reply = self.start_utterance add_log(self.conversation_save_path, '-------- Session {} --------'.format(self.current_sessions)) add_log(self.conversation_save_path, 'START: {}'.format(reply)) else: self.history.append(user_input) source = utter_preprocess(self.history, self.agent.data_config._max_seq_len) reply = self.agent.retrieve(source, self.sess) add_log(self.conversation_save_path, 'HUMAN: {}'.format(user_input), print_details=False) add_log(self.conversation_save_path, 'AGENT: {}'.format(reply)) self.history.append(reply) responses.append(reply) self.current_turns += 1 # if the last two utterances contain target keyword if is_reach_goal(' '.join(self.history[-2:]), self.target_keyword): end_message = '[SUCCESS] target: \'{}\'.'.format(self.target_keyword) add_log(self.conversation_save_path, end_message) responses.append(end_message) # if is out of the max dialogue turn elif self.current_turns > self.max_turns: end_message = '[FAIL] out of the max dialogue turns, target: \'{}\'.'.format(self.target_keyword) add_log(self.conversation_save_path, end_message) responses.append(end_message) return responses