예제 #1
0
    def act(self):
        reply = Message()
        reply['id'] = self.getID()
        print("act start!!!!!!!!!!!!!!1\n")
        try:
            print("before recv\n")
            text = connectionSock.recv(BUF_SIZE)  #서버로부터 데이터 받기
            print("after recv\n")
            reply_text = text.decode('utf-8')
            print(reply_text)
        except EOFError:
            print("@@@@@")
            print(EOFError)
            self.finished = True
            return {'episode_done': True}

        reply_text = reply_text.replace('\\n', '\n')
        reply['episode_done'] = False
        if self.opt.get('single_turn', False):
            reply.force_set('episode_done', True)
        reply['label_candidates'] = self.fixedCands_txt
        if '[DONE]' in reply_text:
            # let interactive know we're resetting
            raise StopIteration
        reply['text'] = reply_text
        if '[EXIT]' in reply_text:
            self.finished = True
            raise StopIteration
        print(reply)
        return reply
예제 #2
0
    def act(self):
        reply = Message()
        reply['id'] = self.getID()
        try:
            reply_text = input(colorize("Enter Your Message:", 'text') + ' ')
            # f = open("my_log.txt","a+")
            # f.write(reply_text+"\n")
            # f.close()
        except EOFError:
            self.finished = True
            return {'episode_done': True}

        reply_text = reply_text.replace('\\n', '\n')
        reply['episode_done'] = False
        if self.opt.get('single_turn', False):
            reply.force_set('episode_done', True)
        reply['label_candidates'] = self.fixedCands_txt
        if '[DONE]' in reply_text:
            # let interactive know we're resetting
            raise StopIteration
        reply['text'] = reply_text
        if '[EXIT]' in reply_text:
            self.finished = True
            raise StopIteration
        return reply
    def act(self):
        reply = Message()
        reply['id'] = self.getID()
        try:
            #            reply_text = input(colorize("Enter Your Message:", 'text') + ' ')
            reply_text = self.query_txt[self.line_no]
            self.line_no += 1
#            if reply_text == '[EXIT]' :
#                self.finished = True
#                raise StopIteration
        except EOFError:
            self.finished = True
            return {'episode_done': True}

        reply_text = reply_text.replace('\\n', '\n')
        reply['episode_done'] = False
        if self.opt.get('single_turn', False):
            reply.force_set('episode_done', True)
        reply['label_candidates'] = self.fixedCands_txt
        if '[DONE]' in reply_text:
            # let interactive know we're resetting
            self.line_no = 0
            raise StopIteration
        reply['text'] = reply_text
        if '[EXIT]' in reply_text:
            self.line_no = 0
            self.finished = True
            raise StopIteration
        if self.line_no >= len(self.query_txt):
            self.line_no = 0
            self.finished = True
            raise StopIteration

        return reply
예제 #4
0
    def act(self):
        """
        Send new dialog message.
        """
        if not hasattr(self, 'epochDone'):
            # reset if haven't yet
            self.reset()

        # get next example, action is episode_done dict if already out of exs
        action, self.epochDone = self.next_example()
        # TODO: all teachers should eventually create messages
        # while setting up the data, so this won't be necessary
        action = Message(action)
        action.force_set('id', self.getID())

        # remember correct answer if available
        self.lastY = action.get('labels_1', action.get('eval_labels_1', None))
        if (not self.datatype.startswith('train')
                or 'evalmode' in self.datatype) and 'labels' in action:
            # move labels to eval field so not used for training
            # but this way the model can use the labels for perplexity or loss
            action = action.copy()
            labels = action.pop('labels')
            if not self.opt.get('hide_labels', False):
                action['eval_labels'] = labels

        return action
예제 #5
0
    def set_input_turn_cnt_vec(self, observation: Message, model: RagModel,
                               query_str: str) -> Message:
        """
        Compute the number of turns of input, and set the vec accordingly.

        :param observation:
            observation in which to set the vec
        :param model:
            model provided for access to retriever tokenizer
        :param query_str:
            the query string for computation of the input turns.

        :return observation:
            return the observation with the input turn vec set appropriately.
        """
        delimiter = model.get_retriever_delimiter()
        split_text_raw = query_str.split(delimiter)
        split_text: List[str] = []
        if self.n_turns > 1 and len(split_text_raw) > self.n_turns:
            end_off = self.n_turns - 1
            split_text = [delimiter.join(split_text_raw[:-end_off])
                          ] + split_text_raw[-end_off:]
        else:
            split_text = split_text_raw

        input_turns_cnt = torch.LongTensor([len(split_text)])
        query_vecs = [model.tokenize_query(q) for q in split_text]
        # Override query vec
        observation.force_set('query_vec', query_vecs)
        observation['input_turn_cnt_vec'] = input_turns_cnt
        return observation
예제 #6
0
파일: agents.py 프로젝트: xujiameng/ParlAI
 def _edit_action(self, act: Message) -> Message:
     """
     # SQuAD returns passage and question both, only passage required for task.
     """
     passage = act['text'].split('\n')[0]
     act.force_set('text', passage)
     return act
예제 #7
0
 def message_mutation(self, message: Message) -> Message:
     assert self.get_label(message)
     if not message['text'].endswith(self.PROMPT):
         last_context = message['text'].split('\n')[-1]
         message.force_set('text', f"{last_context} {self.PROMPT}")
     if message['labels'] != [self.get_label(message)]:
         message.force_set('labels', [self.get_label(message)])
     return message
예제 #8
0
 def act(self):
     reply = Message()
     reply['id'] = self.getID()
     reply_text = input("Enter Your Message: ")
     reply_text = reply_text.replace('\\n', '\n')
     if self.opt.get('single_turn', False):
         reply_text += '[DONE]'
     reply['episode_done'] = False
     reply['label_candidates'] = self.fixedCands_txt
     if '[DONE]' in reply_text:
         reply.force_set('episode_done', True)
         self.episodeDone = True
         reply_text = reply_text.replace('[DONE]', '')
     reply['text'] = reply_text
     return reply
예제 #9
0
    def _opening_message_text(self, parlai_message: Message, action: Dict):
        """
        Handles the first message if this agent is has the opening message.
        """
        if not self.include_persona:
            return

        persona = action[CONST.PERSONA]
        curr_text = parlai_message[CONST.MESSAGE_TEXT]
        if curr_text:
            new_text = f'{persona}{self.text_flatten_delimeter}{curr_text}'
        else:
            new_text = persona

        parlai_message.force_set(CONST.MESSAGE_TEXT, new_text)
예제 #10
0
    def observe(self, observation):
        """Save observation for act.
        If multiple observations are from the same episode, concatenate them.
        """
        # shallow copy observation (deep copy can be expensive)
        obs = Message(observation.copy())  # TODO: all teachers should return
        # messages, so this should be eventually unecessary
        seq_len = self.opt['seq_len']
        is_training = True
        if 'labels' not in obs:
            is_training = False

        if is_training:
            if 'text' in obs:
                if self.use_person_tokens:
                    obs.force_set('text', 'PERSON1 ' + obs['text'])
                vec = self.parse(obs['text'])
                vec.append(self.END_IDX)
                self.next_observe += vec
            if 'labels' in obs:
                if self.use_person_tokens:
                    labels = [
                        'PERSON2 ' + label for label in obs['labels'] if label != ''
                    ]
                    obs.force_set('labels', tuple(labels))
                vec = self.parse(obs['labels'][0])
                vec.append(self.END_IDX)
                self.next_observe += vec
            if len(self.next_observe) < (seq_len + 1):
                # not enough to return to make a batch
                # we handle this case in vectorize
                # labels indicates that we are training
                self.observation = {'labels': ''}
                return self.observation
            else:
                vecs_to_return = []
                total = len(self.next_observe) // (seq_len + 1)
                for _ in range(total):
                    observe = self.next_observe[: (seq_len + 1)]
                    self.next_observe = self.next_observe[(seq_len + 1):]
                    vecs_to_return.append(observe)
                dict_to_return = {'text': '', 'labels': '', 'text2vec': vecs_to_return}
                self.observation = Message(dict_to_return)
                return dict_to_return
        else:
            if 'text' in obs:
                if self.use_person_tokens:
                    obs.force_set('text', 'PERSON1 ' + obs['text'])
            if 'eval_labels' in obs:
                if self.use_person_tokens:
                    eval_labels = [
                        'PERSON2 ' + label
                        for label in obs['eval_labels']
                        if label != ''
                    ]
                    obs.force_set('eval_labels', tuple(eval_labels))
            self.observation = obs
            return obs
예제 #11
0
 def _set_text_vec(self, obs: Message, history: History,
                   truncate: Optional[int]) -> Message:
     """
     Override to prepend start token and append end token.
     """
     obs = super()._set_text_vec(obs, history, truncate)
     if 'text' not in obs or 'text_vec' not in obs:
         return obs
     vec = obs['text_vec']
     if truncate is not None:
         vec = torch.LongTensor(  # type: ignore
             self._check_truncate(obs['text_vec'], truncate - 2, True))
     obs.force_set(
         'text_vec',
         self._add_start_end_tokens(vec, add_start=True, add_end=True))
     return obs
예제 #12
0
    def act(self, msg):
        reply = Message()
        reply['id'] = self.getID()
        reply_text = msg

        reply_text = reply_text.replace('\\n', '\n')
        reply['episode_done'] = False
        if self.opt.get('single_turn', False):
            reply.force_set('episode_done', True)
        reply['label_candidates'] = self.fixedCands_txt
        if '[DONE]' in reply_text:
            # let interactive know we're resetting
            raise StopIteration
        reply['text'] = reply_text
        if '[EXIT]' in reply_text:
            self.finished = True
            raise StopIteration
        return reply
예제 #13
0
파일: retnref.py 프로젝트: convobox/ParlAI
 def observe(self, observation: Message) -> Message:
     """
     Before general observe, if use_knowledge and add knowledge to history,
     knowledge will be added to agent's along with text.
     """
     use_knowledge = self.opt.get('use_knowledge', False)
     add_knowledge_to_history = self.opt.get('add_knowledge_to_history',
                                             False)
     if use_knowledge and add_knowledge_to_history:
         if self.opt.get('chosen_sentence', True):
             add_text = observation.get('checked_sentence', None)
         else:
             add_text = observation.get('knowledge', None)
         if isinstance(add_text, str) and add_text != '':
             observation.force_set(
                 'text',
                 observation['text'] + self.history.delimiter + add_text)
     return super().observe(observation)
예제 #14
0
    def act(self, timeout=None):
        _ = timeout  # The model doesn't care about the timeout
        if self.semaphore:
            with self.semaphore:
                act_out = self.model_agent.act()
        else:
            act_out = self.model_agent.act()
        act_out = Message(act_out)
        # Wrap as a Message for compatibility with older ParlAI models

        if 'dict_lower' in self.opt and not self.opt['dict_lower']:
            # model is cased so we don't want to normalize the reply like below
            final_message_text = act_out['text']
        else:
            final_message_text = normalize_reply(act_out['text'])

        act_out.force_set('text', final_message_text)
        assert ('episode_done' not in act_out) or (not act_out['episode_done'])
        self.turn_idx += 1
        return {**act_out, 'episode_done': False}
예제 #15
0
    def act(self):
        reply = Message()
        reply['id'] = self.getID()
        try:
            reply_text = input("Enter Your Message(set by me):")
        except EOFError:
            self.finished = True
            return {'episode_done': True}

        reply_text = reply_text.replace('\\n', '\n')
        reply['episode_done'] = False
        if self.opt.get('single_turn', False):
            reply.force_set('episode_done', True)
        reply['label_candidates'] = self.fixedCands_txt
        if '[DONE]' in reply_text:
            # let interactive know we're resetting
            raise StopIteration
        reply['text'] = reply_text
        if '[EXIT]' in reply_text: 
            self.finished = True
            raise StopIteration
        return reply
 def act(self):
     reply = Message()
     reply['id'] = self.getID()
     reply_text = input(colorize("Enter Your Message:", 'field') + ' ')
     reply_text = reply_text.replace('\\n', '\n')
     if self.offensive(reply_text):
         print("[ Sorry, could not process that message. ]")
         self.self_offensive = True
     else:
         self.self_offensive = False
     if self.opt.get('single_turn', False):
         reply_text += '[DONE]'
     reply['episode_done'] = False
     reply['label_candidates'] = self.fixedCands_txt
     if '[DONE]' in reply_text:
         reply.force_set('episode_done', True)
         self.episodeDone = True
         reply_text = reply_text.replace('[DONE]', '')
     reply['text'] = reply_text
     if '[EXIT]' in reply_text:
         self.finished = True
     return reply
예제 #17
0
def observe_samp_expanded_observation(observation, multi_turn=False):
    """
    Process incoming message in preparation for producing a response.

    This includes remembering the past history of the conversation.
    """
    # TODO: Migration plan: TorchAgent currently supports being passed
    # observations as vanilla dicts for legacy interop; eventually we
    # want to remove this behavior and demand that teachers return Messages
    observation = Message(observation)

    if 'text' in observation:
        # ---> refactor the observation
        orig_text: str = observation['text']
        items = orig_text.split('__SAMP__')
        real_text = items[0].strip()

        samp_cs, samp_rs, c_vs_samp_r_scores, samp_c_vs_r_scores = None, None, None, None
        if len(items) > 1:
            samples = [d.strip() for d in items[1].split('__EOD__')]
            samp_cs = [d.split('__EOC__')[0].strip() for d in samples]
            samp_rs = [d.split('__EOC__')[1].strip() for d in samples]
            if multi_turn:
                samp_cs = [[utt.strip() for utt in samp_c.split('__EOT__')]
                           for samp_c in samp_cs]

        if len(items) > 2:
            c_vs_samp_r_scores = [float(score) for score in items[2].split()]
        if len(items) > 3:
            samp_c_vs_r_scores = [float(score) for score in items[3].split()]

        observation.force_set('text', real_text)
        observation['samp_cs'] = samp_cs
        observation['samp_rs'] = samp_rs
        observation['c_vs_samp_r_scores'] = c_vs_samp_r_scores
        observation['samp_c_vs_r_scores'] = samp_c_vs_r_scores
        # <--- refactor the observation

    return observation
예제 #18
0
파일: agents.py 프로젝트: simplecoka/cortx
    def get(self, episode_idx: int, entry_idx: int = 0) -> Message:
        """
        Return a flattened example.

        If using a fixed control, put that in instead of what was originally in the text.

        :param episode_idx:
            index of ep in data
        :param entry_idx:
            index of ex in ep

        :return ex:
            return an example
        """
        ex = Message(self.data[episode_idx])

        if self.opt['fixed_control'] != '':
            old_text = ' '.join(ex['text'].split(' ')[:-1])
            text = f"{old_text} {self.opt['fixed_control']}"
            ex.force_set('text', text)

        return ex
예제 #19
0
    def act(self):
        reply = Message()
        reply['id'] = self.getID()
        try:
            reply_text = input(colorize("Enter Your Message:", 'text') + ' ')
        except EOFError:
            self.finished = True
            return {'episode_done': True}

        reply_text = reply_text.replace('\\n', '\n')
        if self.opt.get('single_turn', False):
            reply_text += '[DONE]'
        reply['episode_done'] = False
        reply['label_candidates'] = self.fixedCands_txt
        if '[DONE]' in reply_text:
            reply.force_set('episode_done', True)
            self.episodeDone = True
            reply_text = reply_text.replace('[DONE]', '')
        reply['text'] = reply_text
        if '[EXIT]' in reply_text:
            self.finished = True
            return {'episode_done': True}
        return reply
예제 #20
0
파일: agents.py 프로젝트: J-Douglas/PodBot
 def _edit_action(self, act: Message) -> Message:
     """
     Edit the fields of the action manually.
     """
     if 'labels' in act:
         labels = act['labels']
         if len(labels) != 1:
             raise ValueError(
                 f'{type(self).__name__} can only be used with one label!')
         act.force_set('text', labels[0])
         act.force_set('labels', [''])
     else:
         assert 'text' not in act and act['episode_done'] is True
     act.force_set('episode_done', True)  # Clear the dialogue history
     return act
예제 #21
0
파일: agents.py 프로젝트: sagar-spkt/ParlAI
 def _edit_action(self, act: Message) -> Message:
     """
     Edit the fields of the action manually.
     """
     if 'labels' in act:
         labels = act['labels']
         if len(labels) != 1:
             raise ValueError(
                 f'{type(self).__name__} can only be used with one label!')
         assert '\n' not in labels[0]
         # Classifier will not expect more than 1 newline in context
         act.force_set('text',
                       act['text'].split('\n')[-1] + '\n' + labels[0])
         act.force_set('labels', [act['personality']])
     else:
         assert 'text' not in act and act['episode_done'] is True
     act.force_set('episode_done', True)  # Clear the dialogue history
     return act
예제 #22
0
 def message_mutation(self, message: Message) -> Message:
     message.force_set(CONST.RETRIEVED_DOCS, [''])
     message.force_set(CONST.RETRIEVED_SENTENCES, [''])
     message.force_set(CONST.RETRIEVED_DOCS_TITLES, [''])
     message.force_set(CONST.RETRIEVED_DOCS_URLS, [''])
     message.force_set(CONST.SELECTED_DOCS, [''])
     message.force_set(CONST.SELECTED_DOCS_TITLES, [''])
     message.force_set(CONST.SELECTED_DOCS_URLS, [''])
     return message
예제 #23
0
파일: worlds.py 프로젝트: lpschaub/ParlAI
                    control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                    self.eval_agent.observe(validate(control_msg))
                    return
                else:
                    self.dialog.append((1, act.get('text')))
                    act = copy.deepcopy(act)
<<<<<<< HEAD
                    act.force_set('text', self.format_model_reply(act['text']))
                    self.eval_agent.observe(act)

        # Eval agent turn
        act = Message(self.get_human_agent_act(self.eval_agent))
=======
<<<<<<< HEAD
<<<<<<< HEAD
                    act.force_set('text', self.format_model_reply(act['text']))
                    self.eval_agent.observe(act)

        # Eval agent turn
        act = Message(self.get_human_agent_act(self.eval_agent))
=======
=======
>>>>>>> ef574cebef2a8d5aa38b73176b1e71a919d6670f
                    act['text'] = self.format_model_reply(act['text'])
                    self.eval_agent.observe(act)

        # Eval agent turn
        act = self.get_human_agent_act(self.eval_agent)
>>>>>>> 4f6b99642d60aff1a41b9eae8bd2ccd9e40ebba4
<<<<<<< HEAD
>>>>>>> origin/master
예제 #24
0
    def setup_data(self, path):

        logging.debug('loading: ' + path)
        with PathManager.open(path) as f:
            df = f.readlines()

        turn_idx = 1
        responder_text_dialogue = []
        experiencer_text_dialogue = []
        data = []
        for i in range(1, len(df)):

            cparts = df[i - 1].strip().split(",")
            sparts = df[i].strip().split(",")

            if cparts[0] == sparts[0]:

                # Check that the turn number has incremented correctly
                turn_idx += 1
                assert (int(cparts[1]) + 1 == int(sparts[1])
                        and int(sparts[1]) == turn_idx)

                contextt = cparts[5].replace("_comma_", ",")
                label = sparts[5].replace("_comma_", ",").strip()
                prompt = sparts[2]
                sit = sparts[3].replace("_comma_", ",")
                if len(sparts) == 9:
                    if sparts[8] != '':
                        inline_label_candidates = [
                            cand.replace("_comma_",
                                         ",").replace("_pipe_", "|")
                            for cand in sparts[8].split('|')
                        ]
                    else:
                        inline_label_candidates = None
                elif len(sparts) == 8:
                    inline_label_candidates = None
                else:
                    raise ValueError(
                        f'Line {i:d} has the wrong number of fields!')

                dialogue_parts = Message({
                    'text': contextt,
                    'labels': [label],
                    'emotion': prompt,
                    'situation': sit,
                })
                if inline_label_candidates is not None:
                    inline_label_candidates = [
                        lc.strip() for lc in inline_label_candidates
                    ]
                    dialogue_parts.force_set('label_candidates',
                                             inline_label_candidates)

                if int(sparts[1]) % 2 == 0:
                    # experiencer is the "text" and responder is the "label"
                    experiencer_text_dialogue.append(dialogue_parts)
                else:
                    # responder is the "text" and experiencer is the "label"
                    responder_text_dialogue.append(dialogue_parts)

            else:

                # We've finished the previous episode, so add it to the data
                turn_idx = 1
                data += self._select_dialogues_to_add(
                    experiencer_text_dialogue, responder_text_dialogue)
                experiencer_text_dialogue = []
                responder_text_dialogue = []

        # Add in the final episode
        data += self._select_dialogues_to_add(experiencer_text_dialogue,
                                              responder_text_dialogue)

        for episode in data:
            for entry_idx, entry in enumerate(episode):
                new_episode = entry_idx == 0
                yield entry, new_episode
예제 #25
0
파일: agents.py 프로젝트: sagar-spkt/ParlAI
def remove_selected_docs_from_message(message: Message):
    message.force_set(CONST.SELECTED_DOCS, [CONST.NO_SELECTED_DOCS_TOKEN])
    message.force_set(CONST.SELECTED_SENTENCES,
                      [CONST.NO_SELECTED_SENTENCES_TOKEN])
    message.force_set(CONST.SELECTED_DOCS_URLS, [CONST.NO_URLS])
    message.force_set(CONST.SELECTED_DOCS_TITLES, [CONST.NO_TITLE])
예제 #26
0
파일: agents.py 프로젝트: sagar-spkt/ParlAI
def remove_retrieved_docs_from_message(message: Message):
    message.force_set(CONST.RETRIEVED_DOCS, [CONST.NO_RETRIEVED_DOCS_TOKEN])
    message.force_set(CONST.RETRIEVED_DOCS_URLS, [CONST.NO_URLS])
    message.force_set(CONST.RETRIEVED_DOCS_TITLES, [CONST.NO_TITLE])
예제 #27
0
    def parley(self):
        self.turn_idx += 1
        print(self.world_tag + ' is at turn {}...'.format(self.turn_idx))
        """If at first turn, we need to give each agent their persona"""
        if self.turn_idx == 1:
            for idx, agent in enumerate(self.agents):
                persona_text = ''
                for s in self.personas[idx]:
                    persona_text += ('<b><span style="color:blue">'
                                     '{}\n</span></b>'.format(s.strip()))
                control_msg = self.get_control_msg()
                control_msg['persona_text'] = persona_text
                control_msg['text'] = self.get_instruction(tag='start',
                                                           agent_id=agent.id)
                # TODO: check that get instruction actually exists?
                agent.observe(validate(control_msg))
                if idx == 0:
                    time.sleep(3)
        """If we get to the min turns, inform turker that they can end if they
        want.
        """
        if self.turn_idx == self.n_turn + 1:
            for idx, agent in enumerate(self.agents):
                control_msg = self.get_control_msg()
                control_msg['text'] = self.get_instruction(
                    idx, tag='exceed_min_turns')
                control_msg['exceed_min_turns'] = True
                agent.observe(validate(control_msg))
        """Otherwise, we proceed accordingly."""
        # Other agent first
        if self.other_first and self.turn_idx == 1:
            if self.model_agent is not None:
                # Model must observe its persona
                persona_act = {
                    'text': '\n'.join([self.model_persona_text,
                                       '__SILENCE__']),
                    'episode_done': False,
                }
                self.model_agent.observe(persona_act)
                self.bot_seen_persona = True
                model_act = copy.deepcopy(self.model_agent.act())
                model_act.force_set('text', normalize_reply(model_act['text']))
                model_act.force_set('id', 'PERSON_2')
                self.dialog.append((1, model_act.get('text')))
                _random_delay()
                self.eval_agent.observe(_strip_tensors(model_act))
            else:
                act = self.get_human_agent_act(self.other_agent)
                timeout = self.check_timeout(act)
                if timeout:
                    # eval agent early disconnect
                    control_msg = self.get_control_msg()
                    control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                    self.eval_agent.observe(validate(control_msg))
                    return
                else:
                    self.dialog.append((1, act.get('text')))
                    act = copy.deepcopy(act)
                    act.force_set('text', normalize_reply(act['text']))
                    self.eval_agent.observe(act)

        # Eval agent turn
        act = Message(self.get_human_agent_act(self.eval_agent))
        timeout = self.check_timeout(act)
        if timeout:
            if self.model_agent is None:
                control_msg = self.get_control_msg()
                control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                self.other_agent.observe(validate(control_msg))
            return

        if act['episode_done']:
            if self.turn_idx >= self.n_turn:
                if not self.other_first:
                    self.dialog_list = [
                        '\n'.join([self.dialog[i][1], self.dialog[i + 1][1]])
                        for i in range(0, len(self.dialog), 2)
                    ]
                else:
                    self.dialog_list = [' \n' + self.dialog[0][1]] + [
                        '\n'.join([self.dialog[i][1], self.dialog[i + 1][1]])
                        for i in range(1,
                                       len(self.dialog) - 1, 2)
                    ]
                self.parallel_eval_mode()

                self.chat_done = True
                for ag in self.agents:
                    control_msg = self.get_control_msg()
                    control_msg['text'] = CHAT_ENDED_MSG
                    ag.observe(validate(control_msg))
            return

        self.dialog.append((0, act['text']))

        if not self.bot_seen_persona and self.model_agent is not None:
            # Add persona for model to observe
            act.force_set('text',
                          '\n'.join([self.model_persona_text, act['text']]))
            self.bot_seen_persona = True
        if self.model_agent is not None:
            self.model_agent.observe(act)
        else:
            act = copy.deepcopy(act)
            act.force_set('text', normalize_reply(act['text']))
            self.other_agent.observe(act)

        # Model_agent turn
        if not self.other_first or self.turn_idx < self.n_turn:
            if self.model_agent is not None:
                _random_delay()
                act = _strip_tensors(copy.deepcopy(self.model_agent.act()))
                act.force_set('text', normalize_reply(act['text']))
                act.force_set('id', 'PERSON_2')
                # NOTE: your model may or may not need to observe itself here
                # If it does, call model_observes_itself or some other specialized
                # function
            else:
                act = self.get_human_agent_act(self.other_agent)
                timeout = self.check_timeout(act)
                if timeout:
                    # eval agent early disconnect
                    control_msg = self.get_control_msg()
                    control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                    self.eval_agent.observe(validate(control_msg))
                    return

            self.dialog.append((1, act.get('text')))
            act = copy.deepcopy(act)
            act.force_set('text', normalize_reply(act['text']))
            self.eval_agent.observe(act)
예제 #28
0
 def message_mutation(self, message: Message) -> Message:
     message.force_set('text', message['text'].split('\n')[-1])
     return message
예제 #29
0
 def message_mutation(self, message: Message) -> Message:
     message.force_set('skip_retrieval', True)
     return message
예제 #30
0
 def message_mutation(self, message: Message) -> Message:
     if not message['text'].endswith(self.PROMPT):
         message.force_set('text', f"{message['text']} {self.PROMPT}")
     return message