Exemplo n.º 1
0
    def self_play_simulation_with_fixed_start_corpus(self,
                                                     simulation_cnt,
                                                     print_details=True):
        success_cnt_list, turns_cnt_list, simulation_result_list = [], [], []
        for start_utterance in self.random_selected_sub_start_corpus:
            success_cnt, turns_cnt, simulation_result_str = \
                self.self_play_simulation(simulation_cnt, start_utterance, print_details)
            success_cnt_list.append(success_cnt)
            turns_cnt_list.append(turns_cnt)
            simulation_result_list.append(simulation_result_str)

        total_simulation_cnt = simulation_cnt * len(
            self.random_selected_sub_start_corpus)
        total_success_cnt = sum(success_cnt_list)
        total_turns_cnt = sum(turns_cnt_list)
        total_success_rate = (total_success_cnt / total_simulation_cnt) * 100
        total_average_turns = total_turns_cnt / total_success_cnt

        for start_utterance, simulation_result in \
            zip(self.random_selected_sub_start_corpus, simulation_result_list):
            add_log(self.simulation_save_path,
                    'For start utterance: {}, {}'.format(
                        start_utterance, simulation_result),
                    print_details=True)
        add_log(
            self.simulation_save_path,
            'total success times / total sessions: {}/{}, total success rate: {:.1f}%, total average turns: {:.2f}'
            .format(total_success_cnt, total_simulation_cnt,
                    total_success_rate, total_average_turns),
            print_details=True)
Exemplo n.º 2
0
    def train(self):
        batch = self.iterator.get_next()
        kw_loss, kw_acc, _ = self.predict_keywords(batch)
        kw_saver = tf.train.Saver()
        loss, acc, rank = self.forward_response_retrieval(batch)
        op_step = tf.Variable(0, name='retrieval_step')
        train_op = tx.core.get_train_op(
            loss,
            global_step=op_step,
            hparams=self.model_config._retrieval_opt_hparams)
        max_val_acc = 0.
        with tf.Session(config=self.gpu_config) as sess:
            sess.run(tf.tables_initializer())
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            kw_saver.restore(sess, self.model_config._kp_save_path)
            saver = tf.train.Saver()
            for epoch_id in range(self.model_config._max_epoch):
                self.iterator.switch_to_train_data(sess)
                cur_step = 0
                cnt_acc = []
                while True:
                    try:
                        cur_step += 1
                        feed = {tx.global_mode(): tf.estimator.ModeKeys.TRAIN}
                        loss, acc_ = sess.run([train_op, acc], feed_dict=feed)
                        cnt_acc.append(acc_)
                        if cur_step % 200 == 0:
                            logs_loss_acc = 'batch {}, loss={}, acc1={}'.format(
                                cur_step, loss, np.mean(cnt_acc[-200:]))
                            add_log(self.logs_save_path, logs_loss_acc)
                    except tf.errors.OutOfRangeError:
                        break

                self.iterator.switch_to_val_data(sess)
                cnt_acc, cnt_kwacc = [], []
                while True:
                    try:
                        feed = {tx.global_mode(): tf.estimator.ModeKeys.EVAL}
                        acc_, kw_acc_ = sess.run([acc, kw_acc], feed_dict=feed)
                        cnt_acc.append(acc_)
                        cnt_kwacc.append(kw_acc_)
                    except tf.errors.OutOfRangeError:
                        mean_acc = np.mean(cnt_acc)
                        logs_loss_acc = 'epoch_id {}, valid acc1={}, kw_acc1={}'.format(
                            epoch_id + 1, mean_acc, np.mean(cnt_kwacc))
                        add_log(self.logs_save_path, logs_loss_acc)
                        if mean_acc > max_val_acc:
                            max_val_acc = mean_acc
                            saver.save(sess,
                                       self.model_config._retrieval_save_path)
                        break
Exemplo n.º 3
0
 def self_play_simulation(self,
                          simulation_cnt,
                          start_utterance=None,
                          print_details=True):
     if start_utterance is None:
         start_utterance = random.sample(self.start_corpus, 1)[0]
     simulation_start_str = 'start self-play simulation with start utterance: {} (total {} sessions)'.format(
         start_utterance, simulation_cnt)
     add_log(self.simulation_save_path, simulation_start_str, print_details)
     success_cnt, turns_cnt = 0, 0
     for i in tqdm(range(simulation_cnt)):
         add_log(self.simulation_save_path,
                 '-------- Session {} --------'.format(i), print_details)
         success, turns = self.simulate(start_utterance=start_utterance,
                                        target_keyword=self.target_set[i],
                                        print_details=print_details)
         success_cnt += success
         turns_cnt += turns
     # the average number of turns used to reach a target
     success_rate = (success_cnt / simulation_cnt) * 100
     average_turns = turns_cnt / success_cnt
     simulation_result_str = '#success / #sessions: {}/{}, success rate: {:.1f}%, average turns: {:.2f}'.format(
         success_cnt, simulation_cnt, success_rate, average_turns)
     add_log(self.simulation_save_path,
             simulation_result_str,
             print_details=True)
     return success_cnt, turns_cnt, simulation_result_str
Exemplo n.º 4
0
    def chat(self, user_input=None):
        responses = []
        # if is the beginning of a conversation
        if user_input is None:
            self._reset()
            reply = self.start_utterance
            add_log(self.conversation_save_path, '-------- Session {} --------'.format(self.current_sessions))
            add_log(self.conversation_save_path, 'START: {}'.format(reply))
        else:
            self.history.append(user_input)
            source = utter_preprocess(self.history, self.agent.data_config._max_seq_len)
            reply = self.agent.retrieve(source, self.sess)
            add_log(self.conversation_save_path, 'HUMAN: {}'.format(user_input), print_details=False)
            add_log(self.conversation_save_path, 'AGENT: {}'.format(reply))
        self.history.append(reply)
        responses.append(reply)
        self.current_turns += 1

        # if the last two utterances contain target keyword
        if is_reach_goal(' '.join(self.history[-2:]), self.target_keyword):
            end_message = '[SUCCESS] target: \'{}\'.'.format(self.target_keyword)
            add_log(self.conversation_save_path, end_message)
            responses.append(end_message)
        # if is out of the max dialogue turn
        elif self.current_turns > self.max_turns:
            end_message = '[FAIL] out of the max dialogue turns, target: \'{}\'.'.format(self.target_keyword)
            add_log(self.conversation_save_path, end_message)
            responses.append(end_message)

        return responses