Exemplo n.º 1
0
    def message_mutation(self, message: Message) -> Message:
        if CONST.RETRIEVED_DOCS not in message:
            return message
        new_message = message.copy()
        docs = message.get(CONST.RETRIEVED_DOCS)
        new_docs = []
        max_chunks = self.opt.get('woi_doc_max_chunks', 100)

        keep = torch.randperm(len(docs))[0:max_chunks]
        remove = torch.ones(len(docs))
        remove[keep] = 0

        for i in range(len(docs)):
            if remove[i] == 0:
                new_docs.append(docs[i])
            else:
                # We may still keep the doc if it contains the gold checked sentence(s).
                checked_sentences = message.get(
                    CONST.SELECTED_SENTENCES,
                    message.get('labels', [CONST.NO_SELECTED_SENTENCES_TOKEN]),
                )
                d = docs[i]
                found = False
                if ' '.join(checked_sentences) != CONST.NO_SELECTED_SENTENCES_TOKEN:
                    for sent in checked_sentences:
                        s = sent.lstrip(' ').rstrip(' ')
                        if s in d:
                            found = True
                if found:
                    new_docs.append(docs[i])

        new_message.force_set(CONST.RETRIEVED_DOCS, new_docs)
        return new_message
Exemplo n.º 2
0
    def act(self):
        """
        Send new dialog message.
        """
        if not hasattr(self, 'epochDone'):
            # reset if haven't yet
            self.reset()

        # get next example, action is episode_done dict if already out of exs
        action, self.epochDone = self.next_example()
        # TODO: all teachers should eventually create messages
        # while setting up the data, so this won't be necessary
        action = Message(action)
        action.force_set('id', self.getID())

        # remember correct answer if available
        self.lastY = action.get('labels_1', action.get('eval_labels_1', None))
        if (not self.datatype.startswith('train')
                or 'evalmode' in self.datatype) and 'labels' in action:
            # move labels to eval field so not used for training
            # but this way the model can use the labels for perplexity or loss
            action = action.copy()
            labels = action.pop('labels')
            if not self.opt.get('hide_labels', False):
                action['eval_labels'] = labels

        return action
Exemplo n.º 3
0
 def message_mutation(self, message: Message) -> Message:
     new_message = message.copy()
     docs = message['text'].split('\n')
     context = docs[-1]
     docs = docs[0:-1]
     new_message[CONST.RETRIEVED_DOCS] = docs
     new_message.force_set('text', context)
     return new_message
Exemplo n.º 4
0
 def message_mutation(self, message: Message) -> Message:
     new_message = message.copy()
     new_message[CONST.RETRIEVED_DOCS] = [message['checked_sentence']]
     new_message[CONST.RETRIEVED_DOCS_TITLES] = [message['title']]
     new_message[CONST.RETRIEVED_DOCS_URLS] = ['']
     new_message.pop('checked_sentence')
     if message.get('history', '') != '':
         text = message.get('history', '') + '\n' + message.get('text', '')
         new_message.force_set('text', text)
     return new_message
Exemplo n.º 5
0
 def message_mutation(self, message: Message) -> Message:
     new_message = message.copy()
     if not isinstance(new_message[CONST.RETRIEVED_DOCS], list):
         new_message.force_set(CONST.RETRIEVED_DOCS,
                               [new_message[CONST.RETRIEVED_DOCS]])
     new_message[CONST.RETRIEVED_DOCS_TITLES] = [''] * len(
         new_message[CONST.RETRIEVED_DOCS])
     new_message[CONST.RETRIEVED_DOCS_URLS] = [''] * len(
         new_message[CONST.RETRIEVED_DOCS])
     return new_message
Exemplo n.º 6
0
 def message_mutation(self, message: Message) -> Message:
     new_message = message.copy()
     new_message.pop('knowledge')
     new_docs = [' '.join(message['knowledge'].split('\n'))]
     new_titles = ['']
     new_urls = ['']
     new_message.force_set(CONST.RETRIEVED_DOCS, new_docs)
     new_message.force_set(CONST.RETRIEVED_DOCS_TITLES, new_titles)
     new_message.force_set(CONST.RETRIEVED_DOCS_URLS, new_urls)
     return new_message
Exemplo n.º 7
0
 def test_message(self):
     message = Message()
     message['text'] = 'lol'
     err = None
     try:
         message['text'] = 'rofl'
     except RuntimeError as e:
         err = e
     assert err is not None, 'Message allowed override'
     message_copy = message.copy()
     assert type(message_copy) == Message, 'Message did not copy properly'
Exemplo n.º 8
0
    def message_mutation(self, message: Message) -> Message:
        new_message = message.copy()
        if 'text' not in message or 'labels' not in message or not message[
                'labels']:
            return message
        labels = new_message.pop('labels')
        checked_sentence = new_message.get(self.checked_sentence_kword, '')
        if isinstance(checked_sentence, list):
            checked_sentence = ' '.join(checked_sentence)

        new_message['dialogue_response'] = labels
        new_message['labels'] = [checked_sentence]
        return new_message
Exemplo n.º 9
0
    def message_mutation(self, message: Message) -> Message:
        new_message = message.copy()
        if 'text' not in message:
            return message
        text = new_message.pop('text')
        checked_sentence = new_message.get(self.checked_sentence_kword, '')
        if isinstance(checked_sentence, list):
            checked_sentence = ' '.join(checked_sentence)

        text += (
            f'\n{CONST.KNOWLEDGE_TOKEN} {checked_sentence} {CONST.END_KNOWLEDGE_TOKEN}'
        )
        new_message['text'] = text

        return new_message
Exemplo n.º 10
0
    def message_mutation(self, message: Message) -> Message:
        new_message = message.copy()
        if 'text' not in message or 'labels' not in message or not message[
                'labels']:
            return message
        if 'dialogue_response' in new_message:
            # checked_sentence_as_label was applied before
            labels = new_message['dialogue_response']
        else:
            labels = new_message['labels']
        dialogue_response = labels[0]
        text = new_message.pop('text')

        text += f'\n{CONST.TOKEN_LABEL} {dialogue_response} {CONST.TOKEN_END_LABEL}'
        new_message['text'] = text

        return new_message
Exemplo n.º 11
0
    def message_mutation(self, message: Message) -> Message:
        new_message = message.copy()
        if 'text' not in message or 'labels' not in message or not message[
                'labels']:
            return message
        if 'dialogue_response' in new_message:
            # checked_sentence_as_label was applied before
            labels = new_message['dialogue_response']
        else:
            labels = new_message['labels']
        dialogue_response = labels[0]
        text = new_message.pop('text')

        ls = dialogue_response.split()
        ind = random.randint(0, len(ls) - 1)
        label1 = ' '.join(ls[0:ind])
        label2 = ' '.join(ls[ind:len(ls)])

        text += f'\n{label1}\n{CONST.TOKEN_LABEL} {label2} {CONST.TOKEN_END_LABEL}'
        new_message['text'] = text

        return new_message
Exemplo n.º 12
0
    def observe(self, observation: Message) -> Dict[str, Message]:
        """
        Observe in 3 out of the 4 modules.

        :param observation:
            incoming message

        :return self.observation:
            returned observation is actually a dictionary mapping
            agent module name to the corresponding observation
        """
        if not isinstance(observation, Message):
            observation = Message(observation)
        for key in ['label_candidates', 'knowledge']:
            # Delete unnecessarily large keys
            observation.pop(key, '')
        observation['knowledge_response'] = observation.get('checked_sentence', '')

        raw_observation = copy.deepcopy(observation)
        # This part is *specifically* for document chunking.
        if self.krm_mutators:
            observation = observation.copy()
            for mutator in self.krm_mutators:
                assert isinstance(mutator, MessageMutator), "not message mutator"
                observation = next(mutator([observation]))

        knowledge_observation = self.knowledge_agent.observe(observation)
        knowledge_observation['prior_knowledge_responses'] = ' '.join(
            self.knowledge_responses
        )
        if observation.get('episode_done'):
            self.knowledge_responses = ['__SILENCE__']
        search_query_observation = None
        if self.search_query_agent:
            sqm_obs = copy.deepcopy(observation)
            if self.opt['search_query_control_token']:
                sqm_obs.force_set(
                    'temp_history', f" {self.opt['search_query_control_token']}"
                )
            sqm_obs.force_set('skip_retrieval', True)
            search_query_observation = self.search_query_agent.observe(sqm_obs)

        search_decision_observation = None
        if (
            self.search_decision_agent
            and self.search_decision is SearchDecision.COMPUTE
        ):
            assert (
                self.search_decision_agent.history.size == 1
            ), "wrong history size! set --sdm-history-size 1"
            sdm_obs = copy.deepcopy(observation)
            if self.opt['search_decision_control_token']:
                sdm_obs.force_set(
                    'temp_history', f" {self.opt['search_decision_control_token']}"
                )
            sdm_obs.force_set('skip_retrieval', True)
            search_decision_observation = self.search_decision_agent.observe(sdm_obs)

        observations = {
            'raw': raw_observation,
            'knowledge_agent': knowledge_observation,
            'search_query_agent': search_query_observation,
            'search_decision_agent': search_decision_observation,
        }
        self.observations = observations
        return observations