Exemple #1
0
    def get_internal_state(self, cand_kb, last_turn):
        agent_act, user_value = last_turn

        # slot cls and inform cls
        agent_act_id = AgentActs.slot_to_id(agent_act)
        turn_slot_cls = torch.zeros(AgentActs.slot_size())
        turn_slot_cls[agent_act_id] = 1

        # user have no answer
        if user_value is None:
            turn_inform_cls = 1
            turn_level_doc_dist = [1 for _ in range(len(cand_kb))]
        else:
            turn_inform_cls = 0
            turn_level_doc_dist = []
            for kb in cand_kb:
                if not self.enable_full_answer:
                    if agent_act not in kb or user_value in kb[agent_act]:
                        turn_level_doc_dist.append(1)
                        continue
                else:
                    if ', ' in user_value:
                        user_value = user_value.split(', ')

                    if agent_act not in kb or kb[
                            agent_act] == user_value or kb[agent_act] == [
                                user_value
                            ]:
                        turn_level_doc_dist.append(1)
                        continue

                turn_level_doc_dist.append(float('-inf'))

        turn_level_doc_dist = torch.softmax(torch.tensor(turn_level_doc_dist,
                                                         dtype=torch.float),
                                            dim=-1)
        return turn_level_doc_dist, turn_slot_cls, turn_inform_cls
Exemple #2
0
    def __getitem__(self, index):
        turn = self.turn_data[index]

        tar_name = turn['tar_name']
        cand_docs_names = turn['cand_names']

        cand_docs_idx = [self.all_docs_names.index(name) for name in cand_docs_names]
        cand_doc_tensor = self.docs_rep[cand_docs_idx, :]

        # get the dialog history tensor
        dia_his_tensor = self.qa_to_tensor(turn['dialog_history']).long()

        # get the ground truth of agent act and value
        agent_act = AgentActs.slot_to_id(turn['agent_act'])

        # target documents and documents distribution
        tar_idx = cand_docs_names.index(tar_name)
        docs_dist = torch.tensor(turn['docs_dist'], dtype=torch.float).gt(0).long().squeeze(0)

        return cand_doc_tensor, dia_his_tensor, agent_act, tar_idx, docs_dist
Exemple #3
0
    def run_dialog(self, cand_docs, cand_docs_diff, cand_names, tar_kb,
                   tar_name):
        """
        Simulate a dialog with model
        :param cand_docs:
        :param cand_names:
        :param tar_kb:
        :param tar_name:
        :return:
        """
        self.agent.init_dialog(cand_docs, cand_docs_diff, cand_names)
        self.simulator.init_dialog(tar_kb, tar_name)

        tar_idx = cand_names.index(tar_name)

        # dialog status
        top_1_success = False
        top_3_success = False
        mrr = 0
        all_rewards = []
        saved_log_act_probs = [
        ]  # saved action log probability for REINFORCE algorithm
        saved_log_doc_probs = [
        ]  # saved docs log probability for REINFORCE algorithm
        num_turns = self.max_turns

        # dialog start with user
        # user_act, user_value, user_nl = self.simulator.respond_act(agent_act=None, agent_value=None)
        # dialog_json = [{'turn_id': 0,
        #                 'agent_act': '',
        #                 'agent_value': '',
        #                 'agent_nl': '',
        #                 'user_act': user_act,
        #                 'user_value': user_value,
        #                 'user_nl': user_nl,
        #                 'tar_rank': 0,
        #                 'docs_entropy': 0}]

        # if 'rule-' in self.agent_type:
        #     last_turn = (user_act, user_value)
        # else:
        #     last_turn = self.dataset.turn_to_tensor(agent_nl='', user_nl=user_nl).unsqueeze(0)
        dialog_his = '_BOS_'
        dialog_json = []

        last_turn = None

        assert self.max_turns > 1
        for turn_i in range(1, self.max_turns + 1):
            agent_act, agent_act_prob, agent_value, agent_nl = self.agent.turn_act(
                last_turn)

            # feedback from environments
            if len(dialog_json) == 0:
                is_end = agent_act == AgentActs.GUESS
            else:
                reward, top_1_success, top_3_success, is_end, mrr, tar_r, docs_entropy, tar_prob = \
                    self.env_feedback(agent_act,
                                      self.agent.dialog_level_doc_dist,
                                      tar_idx)
                all_rewards.append(reward)
                dialog_json[-1]['tar_rank'] = tar_r
                dialog_json[-1]['tar_prob'] = float('{:.2f}'.format(tar_prob))
                dialog_json[-1]['docs_entropy'] = float(
                    '{:.2f}'.format(docs_entropy))

            # user response
            user_act, user_value, user_nl = self.simulator.respond_act(
                agent_act, agent_value)

            # record current turn
            if 'rule-' in self.agent_type or 'mrc' in self.agent_type:
                last_turn = (user_act, user_value)
            else:
                last_turn = self.dataset.turn_to_tensor(
                    agent_nl=agent_nl, user_nl=user_nl).unsqueeze(0)

            # record data
            turn_json = {
                'turn_id': turn_i,
                'agent_act': agent_act,
                'agent_value': agent_value,
                'agent_nl': agent_nl,
                'user_act': user_act,
                'user_value': user_value,
                'user_nl': user_nl,
                'tar_rank': 0,
                'tar_prob': 0,
                'docs_entropy': 0,
                'docs_dist': self.agent.dialog_level_doc_dist.tolist()
            }
            dialog_json.append(turn_json)

            # early stop
            if is_end:
                num_turns = turn_i
                break

            # saved log probability for rl-training
            if self.training:
                agent_act_id = AgentActs.slot_to_id(agent_act)
                saved_log_act_probs.append(
                    torch.log(agent_act_prob[agent_act_id]))

                pred_doc_id = cand_names.index(agent_value)
                saved_log_doc_probs.append(
                    torch.log(
                        self.agent.dialog_level_doc_dist[0][pred_doc_id]))

        # steps on episode
        self.num_steps += 1

        # when training
        if self.training and len(all_rewards):
            loss = self.agent.update(all_rewards, saved_log_act_probs,
                                     saved_log_doc_probs)  # update the model
            self.on_training(loss)

        cur_rewards = sum(all_rewards)
        data_json = {
            'tar_name': tar_name,
            'cand_names': cand_names,
            'dialog': dialog_json,
            'top_1_success': top_1_success,
            'top_3_success': top_3_success,
            'mrr': mrr,
            'rewards': '{:.2f}'.format(cur_rewards)
        }

        return data_json, top_1_success, top_3_success, mrr, num_turns, cur_rewards