コード例 #1
0
 def normalize_replies(self, x):
     xs = x.split('\n')
     your_personas = []
     partner_personas = []
     non_personas = []
     for x in xs:
         if x.startswith('your persona: '):
             # Normalize the sentence appearing after 'your persona:'
             x = x[len('your persona: ') :]
             x = normalize_reply(x)
             your_personas.append(x)
         elif x.startswith("partner's persona: "):
             x = x[len("partner's persona: ") :]
             x = normalize_reply(x)
             partner_personas.append(x)
         else:
             x = normalize_reply(x)
             non_personas.append(x)
     xs2 = []
     if not self.is_convai2_session_level:
         your_personas = ['your persona: ' + yx for yx in your_personas]
         partner_personas = ["partner's persona: " + px for px in partner_personas]
     else:
         if your_personas:
             your_personas = ['your persona: ' + " ".join(your_personas)]
         if partner_personas:
             partner_personas = ["partner's persona: " + " ".join(partner_personas)]
     if self.your_persona_first:
         xs2.extend(your_personas)
         xs2.extend(partner_personas)
     else:
         xs2.extend(partner_personas)
         xs2.extend(your_personas)
     xs2.extend(non_personas)
     return '\n'.join(xs2)
コード例 #2
0
 def test_normalize_reply_version1(self):
     assert string_utils.normalize_reply("I ' ve a cat .") == "I've a cat."
     assert (
         string_utils.normalize_reply("do you think i can dance?")
         == "Do you think I can dance?"
     )
     assert string_utils.normalize_reply("I ' m silly '") == "I'm silly'"
コード例 #3
0
 def normalize_replies(self, x):
     xs = x.split('\n')
     your_personas = []
     partner_personas = []
     non_personas = []
     for x in xs:
         if x.startswith('your persona: '):
             # Normalize the sentence appearing after 'your persona:'
             x = x[len('your persona: '):]
             x = normalize_reply(x)
             x = 'your persona: ' + x
             your_personas.append(x)
         elif x.startswith("partner's persona: "):
             x = x[len("partner's persona: "):]
             x = normalize_reply(x)
             x = "partner's persona: " + x
             partner_personas.append(x)
         else:
             x = normalize_reply(x)
             non_personas.append(x)
     xs2 = []
     if self.your_persona_first:
         xs2.extend(your_personas)
         xs2.extend(partner_personas)
     else:
         xs2.extend(partner_personas)
         xs2.extend(your_personas)
     xs2.extend(non_personas)
     return '\n'.join(xs2)
コード例 #4
0
 def test_normalize_reply_version2(self):
     assert string_utils.normalize_reply("Add a period",
                                         2) == "Add a period."
     assert string_utils.normalize_reply("Add a period?",
                                         2) == "Add a period?"
     assert string_utils.normalize_reply("Add a period!",
                                         2) == "Add a period!"
     assert string_utils.normalize_reply('"Add a period"',
                                         2) == '"add a period"'
コード例 #5
0
    def normalize_replies(self, x):
        xs = x.split('\n')
        xs2 = []
        for x in xs:
            if 'your persona:' in x:
                # Normalize the sentence appearing after 'your persona:'
                x = x[len('your persona: '):]
                x = normalize_reply(x)
                x = 'your persona: ' + x
            else:
                x = normalize_reply(x)

            xs2.append(x)
        return '\n'.join(xs2)
コード例 #6
0
 def normalize_replies(self, x):
     xs = x.split('\n')
     xs2 = []
     for x in xs:
         if x.startswith('your persona: '):
             continue
         elif x.startswith("partner's persona: "):
             x = x[len("partner's persona: "):]
             x = normalize_reply(x)
             x = "partner's persona: " + x
         else:
             x = normalize_reply(x)
         xs2.append(x)
     return '\n'.join(xs2)
コード例 #7
0
ファイル: fast_eval.py プロジェクト: liesenf/ParlAI
    def _acutify_convo(
        self, conversation: Conversation, config_id: str
    ) -> Dict[str, List]:
        """
        Format world-logged conversation to be ACUTE format.

        :param conversation:
            dictionary containing the dialogue for a config_id
        :param config_id:
            config_id string

        :return conversation:
            An ACUTE-Readable conversation
        """
        config = CONFIG[config_id]
        is_selfchat = 'model' in config
        acute_conversation: Dict[str, List] = {
            'context': [],
            'dialogue': [],
            'speakers': [],
        }
        for i, ex in enumerate(conversation):
            if ex['id'] == 'context':
                acute_conversation['context'].append(ex)
                continue
            speaker_id = ex['id']
            if is_selfchat:
                speaker_id = config_id if i % 2 == 0 else f'other_{config_id}'
            if speaker_id not in acute_conversation['speakers']:
                acute_conversation['speakers'].append(speaker_id)
            acute_conversation['dialogue'].append(
                {'id': speaker_id, 'text': normalize_reply(ex['text'])}
            )
        return acute_conversation
コード例 #8
0
def check_negative_sentiment(sent_eval, text):
    norm_text = normalize_reply(text)
    sent_scores = sent_eval.polarity_scores(norm_text)
    if sent_scores["compound"] >= 0:
        return False

    return True
コード例 #9
0
    def act(self, timeout=None):
        """
        Same as model chat's bot_agent.py except the self_observe function is removed, a
        custom observe is instead written in worlds.py.

        This is so that the two bots can read each others' messages using observe so
        that the conversation history stays the same.
        """

        _ = timeout  # The model doesn't care about the timeout

        if self.semaphore:
            with self.semaphore:
                act_out = self.model_agent.batch_act([self.model_agent.observation])[0]
        else:
            act_out = self.model_agent.batch_act([self.model_agent.observation])[0]
        act_out = Message(act_out).json_safe_payload()

        if 'dict_lower' in self.opt and not self.opt['dict_lower']:
            # model is cased so we don't want to normalize the reply like below
            final_message_text = act_out['text']
        else:
            final_message_text = normalize_reply(act_out['text'])

        act_out['text'] = final_message_text
        assert ('episode_done' not in act_out) or (not act_out['episode_done'])
        self.turn_idx += 1
        return {**act_out, 'episode_done': False}
コード例 #10
0
    def _acutify_convo(
        self, dialogue_dict: Dict[str, Any], model: str
    ) -> Dict[str, List]:
        """
        Format world-logged conversation to be ACUTE format.

        :param dialogue_dict:
            dictionary containing the dialogue for a model
        :param model:
            model string

        :return conversation:
            An ACUTE-Readable conversation
        """
        conversation = {
            'context': [],
            'dialogue': [],
            'speakers': [model, f'other_{model}'],
        }
        dialog = dialogue_dict['dialog']
        for act_pair in dialog:
            for i, ex in enumerate(act_pair):
                if ex['id'] == 'context':
                    conversation['context'].append(ex)
                    continue
                conversation['dialogue'].append(
                    {
                        'id': model if i == 0 else f'other_{model}',
                        'text': normalize_reply(ex['text']),
                    }
                )
        return conversation
コード例 #11
0
def check_negation(spacy_nlp, text):
    norm_text = normalize_reply(text)
    doc = spacy_nlp(norm_text)
    for token in doc:
        if token.dep_ == "neg":
            return True

    return False
コード例 #12
0
 def normalize_replies(self, x):
     xs = [xt.strip() for xt in x.split('\n')]
     xs2 = []
     for x in xs:
         if 'your persona:' in x:
             # Normalize the sentence appearing after 'your persona:'
             x = x[len('your persona: '):]
             x = normalize_reply(x)
             x = 'your persona: ' + x
         elif "partner's persona: " in x:
             x = x[len("partner's persona: "):]
             x = normalize_reply(x)
             x = "partner's persona: " + x
         elif x != DUMMY_TEXT:
             x = normalize_reply(x)
         xs2.append(x)
     return "\n".join(xs2)
コード例 #13
0
ファイル: run_no_self_chat.py プロジェクト: tonirubass/ParlAI
    def _acutify_convo(self, dialogue_dict: Dict[str, Any],
                       model: str) -> Dict[str, List]:
        """
        Format world-logged conversation to be ACUTE format.

        :param dialogue_dict:
            dictionary containing the dialogue for a model
        :param model:
            model string

        :return conversation:
            An ACUTE-Readable conversation
        """
        conversation = {
            'context': [],
            'dialogue': [],
            'speakers': ['human_evaluator', model],
        }
        if ('is_selfchat' in self.model_config[model]
                and self.model_config[model]['is_selfchat']):
            if 'flip' in self.model_config[model]:
                conversation['speakers'] = ['other_speaker', model]
            else:
                conversation['speakers'] = [model, 'other_speaker']
        dialog = dialogue_dict['dialog']
        for act_pair in dialog:
            for i, ex in enumerate(act_pair):
                if ex['id'] == 'context':
                    conversation['context'].append(ex)
                    continue
                else:
                    # agent 1 is the model, agent 0 is human
                    convo = {
                        'id': ex['id'],
                        'text': normalize_reply(ex['text'])
                    }
                    if ('is_selfchat' in self.model_config[model]
                            and self.model_config[model]['is_selfchat']):
                        # if_selfchat override agent_id
                        if 'flip' in self.model_config[model]:
                            if i % 2 == 1:
                                convo['id'] = model
                            else:
                                convo['id'] = 'other_speaker'
                        else:
                            if i % 2 == 0:
                                convo['id'] = model
                            else:
                                convo['id'] = 'other_speaker'
                    conversation['dialogue'].append(convo)

        return conversation
コード例 #14
0
ファイル: fast_eval.py プロジェクト: khanhgithead/ParlAI
    def _acutify_convo(self, dialogue_dict: Dict[str, Any],
                       model: str) -> Dict[str, List]:
        """
        Format world-logged conversation to be ACUTE format.

        :param dialogue_dict:
            dictionary containing the dialogue for a model
        :param model:
            model string

        :return conversation:
            An ACUTE-Readable conversation
        """
        is_selfchat = 'model' in self.model_config[model] or self.model_config[
            model].get('is_selfchat', False)
        # It's a self-chat if one of the following are true:
        # (1) a model is specified in the config, meaning that we're collecting
        #   self-chats with that model
        # (2) we manually set 'is_selfchat' to True in the config
        if is_selfchat:
            # Set which speaker we will evaluate the conversation turns of
            speaker_idx = self.model_config[model].get('speaker_idx', 0)
            assert speaker_idx in [0, 1]
        conversation = {'context': [], 'dialogue': [], 'speakers': []}
        dialog = dialogue_dict['dialog']
        for act_pair in dialog:
            for i, ex in enumerate(act_pair):
                if ex['id'] == 'context':
                    conversation['context'].append(ex)
                    continue
                if is_selfchat:
                    speaker_id = model if i == speaker_idx else f'other_speaker'
                else:
                    speaker_id = ex['id']
                if speaker_id not in conversation['speakers']:
                    conversation['speakers'].append(speaker_id)
                conversation['dialogue'].append({
                    'id':
                    speaker_id,
                    'text':
                    normalize_reply(ex['text'])
                })
        return conversation
コード例 #15
0
    def act(self, timeout=None):
        _ = timeout  # The model doesn't care about the timeout
        if self.semaphore:
            with self.semaphore:
                act_out = self.model_agent.act()
        else:
            act_out = self.model_agent.act()
        act_out = Message(act_out).json_safe_payload()

        if 'dict_lower' in self.opt and not self.opt['dict_lower']:
            # model is cased so we don't want to normalize the reply like below
            final_message_text = act_out['text']
        else:
            final_message_text = normalize_reply(act_out['text'])

        act_out['text'] = final_message_text
        assert ('episode_done' not in act_out) or (not act_out['episode_done'])
        self.turn_idx += 1
        return {**act_out, 'episode_done': False}
コード例 #16
0
    def act(self, timeout=None):
        _ = timeout  # The model doesn't care about the timeout
        if self.semaphore:
            with self.semaphore:
                act_out = self.model_agent.act()
        else:
            act_out = self.model_agent.act()
        act_out = Message(act_out)
        # Wrap as a Message for compatibility with older ParlAI models

        if 'dict_lower' in self.opt and not self.opt['dict_lower']:
            # model is cased so we don't want to normalize the reply like below
            final_message_text = act_out['text']
        else:
            final_message_text = normalize_reply(act_out['text'])

        act_out.force_set('text', final_message_text)
        assert ('episode_done' not in act_out) or (not act_out['episode_done'])
        self.turn_idx += 1
        return {**act_out, 'episode_done': False}
コード例 #17
0
ファイル: bot_agent.py プロジェクト: advi1012/ParlAITest
    def act(self, timeout=None):
        _ = timeout  # The model doesn't care about the timeout
        if self.semaphore:
            with self.semaphore:
                act_out = self.model_agent.act()
        else:
            act_out = self.model_agent.act()

        annotations_html = construct_annotations_html(
            annotations_intro=self.opt['annotations_intro'],
            annotations_config=self.opt['annotations_config'],
            turn_idx=self.turn_idx,
        )

        if 'dict_lower' in self.opt and not self.opt['dict_lower']:
            # model is cased so we don't want to normalize the reply like below
            final_message_text = act_out['text']
        else:
            normalized_act_text = normalize_reply(act_out['text'])
            final_message_text = normalized_act_text + annotations_html

        if self.turn_idx >= self.num_turns * 2:
            radio_css_style = 'margin-left:5px;margin-right:15px;'
            radio_buttons_html = ''
            for i in range(1, 6):
                radio_buttons_html += f"""<input type="radio" id="radio_rating_{i}" name="radio_final_rating_group" value="{i}" /><span style={radio_css_style}>{i}</span>"""
            final_scoring_question = self.opt['final_rating_question']
            exceeds_min_turns = f"""<br><br><div>{self.num_turns} chat turns finished! {final_scoring_question}</div>
            {radio_buttons_html}
            <br>Then, please click the "Done" button to end the chat."""
            final_message_text += exceeds_min_turns
            act_out = Compatibility.backward_compatible_force_set(
                act_out, 'exceed_min_turns', True)

        act_out = Compatibility.backward_compatible_force_set(
            act_out, 'text', final_message_text)
        assert ('episode_done' not in act_out) or (not act_out['episode_done'])
        self.turn_idx += 1
        return {**act_out, 'episode_done': False, 'checked_radio_name_id': ''}
コード例 #18
0
ファイル: bot_agent.py プロジェクト: rhamnett/ParlAI
    def act(self, timeout=None):
        _ = timeout  # The model doesn't care about the timeout
        if self.semaphore:
            with self.semaphore:
                act_out = self.model_agent.act()
        else:
            act_out = self.model_agent.act()

        annotations_html = TurkLikeAgent.construct_annotations_html(self.turn_idx)

        if 'dict_lower' in self.opt and not self.opt['dict_lower']:
            # model is cased so we don't want to normalize the reply like below
            final_message_text = act_out['text']
        else:
            normalized_act_text = normalize_reply(act_out['text'])
            final_message_text = normalized_act_text + annotations_html

        if self.turn_idx >= self.num_turns * 2:
            radio_css_style = 'margin-left:5px;margin-right:15px;'
            radio_buttons_html = ''
            for i in range(1, 6):
                radio_buttons_html += f"""<input type="radio" id="radio_rating_{i}" name="radio_final_rating_group" value="{i}" /><span style={radio_css_style}>{i}</span>"""
            exceeds_min_turns = f"""<br><br><div>{self.num_turns} chat turns finished! Please rate your partner on a scale of 1-5, how much would you enjoy talking to this partner over the course of a long conversation? (1: not at all, 5: a lot)</div>
            {radio_buttons_html}
            <br>Then, please click the "Done" button to end the chat."""
            final_message_text += exceeds_min_turns
            act_out = Compatibility.backward_compatible_force_set(
                act_out, 'exceed_min_turns', True
            )

        act_out = Compatibility.backward_compatible_force_set(
            act_out, 'text', final_message_text
        )
        assert ('episode_done' not in act_out) or (not act_out['episode_done'])
        self.turn_idx += 1
        return {**act_out, 'episode_done': False, 'checked_radio_name_id': ''}
コード例 #19
0
    def rerank(
        self,
        observation: Message,
        response_cands: List[str],
        response_cand_scores: torch.Tensor,
    ) -> Tuple[List[str], List[int]]:
        """
        Re-rank candidates according to predictor score.

        :param observation:
            Message object that includes the dialogue history
        :param response_cands:
            ranked list of model response candidates
        :param response_cand_scores:
            list of model response candidates' scores

        :return (candidates, indices):
            candidates: a re-ranked list of candidates
            indices: list of indices into response_cands corresponding to re-rank order
        """
        full_context = observation['full_text']

        # 0) Normalize the replies if the opt is passed in
        if self.normalize_candidates:
            response_cands = [normalize_reply(c) for c in response_cands]

        # 1) Augment context with response candidates
        if not self.include_label_cand_only:
            contexts = [
                self.augment_context(full_context,
                                     cand,
                                     include_context=self.include_context)
                for cand in response_cands
            ]
            contexts = [self.delimiter.join(utts) for utts in contexts]
        else:
            # This variant only passes in the label candidates (with no dialogue history whatsoever
            # into the ranker. Can be useful for things like e.g. simple utterance-based safety classifiers.
            contexts = response_cands

        # 2) Predict with augmented context
        label_candidates = self.get_predictor_label_candidates(
            observation, full_context)
        reranker_outputs = self.batch_predict(contexts, label_candidates)

        # 3) Rerank
        rerank_for_class = self.get_class_to_rerank_for(
            observation, full_context)
        reranked_candidates, indices = self._rerank_candidates(
            reranker_outputs,
            response_cands,
            response_cand_scores,
            rerank_for_class=rerank_for_class,
        )

        if self.show_debug_logging(observation):
            debug_str = self._construct_debug_logging(
                observation,
                response_cands,
                response_cand_scores,
                reranker_outputs,
                reranked_candidates,
                rerank_for_class,
            )
            # Need print because even logging.WARN swallowed during eval
            print(debug_str)
        return reranked_candidates, indices
コード例 #20
0
def user_sent_message_generate(message):
    
    """
    Called when the participant sends a message to the generative model.
    """
    
    if len( message["data"].strip()) > 0:
        session_info[request.sid]['num_written'] += 1
    # message['count'] = session_info[request.sid]['num_written']
    
#     session_info[request.sid]['condition'] = 'generation'
    agent_choice = message["agent"]
    
    pid = session_info[request.sid]['pid']
    condition = session_info[request.sid]['condition']#+'-'+agent_choice
#     condition = agent_choice
    exchange_num = session_info[request.sid]['num_written']
    last_flow = session_info[request.sid]['last_flow']
    allow_continue = CONDITION_MESSAGES_NO_ACK['allow_continue'][last_flow] == 'continue'
    
    
    # Extract a string of the user's message
    raw_user_input_text = message["data"]
    raw_user_input_len = len(raw_user_input_text.split(' '))
    if message['agent'] == 'therapybot_other':
        flow_has_more = last_flow + 1 < len(CONDITION_MESSAGES_NO_ACK['flow_fixed'])
    else:
        flow_has_more = last_flow + 1 < len(CONDITION_MESSAGES_NO_ACK['flow'])
    input_not_empty = len(raw_user_input_text.strip()) > 0
    
#     message['is_done'] = 'false' if flow_has_more else 'true'
    emit('render_usr_message', message, room=request.sid)
    
        
    if flow_has_more and input_not_empty:
        
        input_text = raw_user_input_text
        
        if message['agent'] == 'therapybot_other':

            # use the generated acknowledgement and add some flow text
            flow_num = last_flow + 1
            flow_text = CONDITION_MESSAGES_NO_ACK['flow_fixed'][flow_num]
            strategy = CONDITION_MESSAGES_NO_ACK['strategy_fixed'][flow_num]
        
            bot_response = flow_text
        
            # updating tracking of flow text usage
            session_info[request.sid]['last_flow'] = flow_num 
            session_info[request.sid]['continue_cnt'] = 0 # reset because a new flow message was used
            
            
        else:
            torch.cuda.set_device(agents[message['agent']].opt['gpu'])
        
            # make sure model history is reset
            agents[agent_choice].reset()
        
            # set model history as user's conversation history
            agents[agent_choice].history = copy.deepcopy(session_info[request.sid]['parlai_history'])
        
            # print('PRE-GEN: ')
    #         print(agents[agent_choice].history.history_strings)
    #         
        
            # observe user input        
            agents[agent_choice].observe(package_text(input_text))
        
            # generate bot response
            agent_output = agents[agent_choice].act()
        
            ack = normalize_reply(agent_output['text'])
        
            # decide if let the user chat only with the generative model w/o any flow
            generate_only = allow_continue and (session_info[request.sid]['continue_cnt'] < MAX_GEN_TURNS) and (raw_user_input_len >= GEN_MIN_LEN)
            if message["agent"] == 'ethics_base':
                generate_only = True
        
            if generate_only: 
                print('Generating the full message without any flow text')
            
                flow_num = -1
                strategy = 'generation'
            
                bot_response = ack
            
                # note that message was only generated
                session_info[request.sid]['continue_cnt'] += 1 # increment the number of turns since the last flow message
            
            else: 
                # use the generated acknowledgement and add some flow text
                flow_num = last_flow + 1
                flow_text = CONDITION_MESSAGES_NO_ACK['flow'][flow_num]
                strategy = CONDITION_MESSAGES_NO_ACK['strategy'][flow_num]
            
                bot_response = ack + " " + flow_text
            
                # updating tracking of flow text usage
                session_info[request.sid]['last_flow'] = flow_num 
                session_info[request.sid]['continue_cnt'] = 0 # reset because a new flow message was used
        
        
            replace_last_reply(agents[agent_choice].history, bot_response)
        
    #         print(message['agent'])
    #         print(agent_output)
    #         print('POST-GEN: ')
    #         print(agents[agent_choice].history.history_strings)
        
        
            # make sure to store updated user's history
            session_info[request.sid]['parlai_history'] = copy.deepcopy(agents[agent_choice].history)
        
            # make sure model history is reset
            agents[agent_choice].reset()
        
    #         print('POST-RESET: ')
    #         print(agents[agent_choice].history.history_strings)
        
    #         input_text = raw_user_input_text.lower() # debug: review the preprocessing of the raw input text.
        
    #         output_text, response_info = MODEL_DICT[selected_model].chat(input_text, 
    #                                                                     compound_sid, 
    #                                                                     '', # assignmentid
    #                   
    #         bot_response = CONDITION_MESSAGES[condition][exchange_num]


        output_text = bot_response                                       
   
        with sqlite3.connect('data/session_info.db') as conn:
            cur = conn.cursor()
#             sid text, pid text, condition text, message text, response text, exchange_num int
            db_input = (request.sid, pid, condition, 
                        input_text, output_text, 
                        exchange_num,
                        strategy, 
                        flow_num)
            cur.executemany("INSERT INTO message_pairs VALUES (?, ?, ?, ?, ?, ?, ?, ?)", [db_input])
            conn.commit()
        
        print('Into message_pairs: ', db_input)
        
        
#         output_text = output_text.replace('f*****g', '<EXPLETIVE>').replace('f**k', '<EXPLETIVE>')
    
        # Pause if only a flow message to be more natural
        if message['agent'] == 'therapybot_other':
            time.sleep(2)
        
        # Render our response
        emit('render_sys_message', {"data": output_text}, room=request.sid)
#         response_info_str = json.dumps(response_info) # debug: uncomment storing to sql below.
#         with sqlite3.connect('data/session_info.db') as conn:
#             cur = conn.cursor()
#             cur.executemany("INSERT INTO response_info VALUES (?, ?)", [(compound_sid, response_info_str,)])
#             conn.commit()
        
        session_info[request.sid]['convo'].append((raw_user_input_text, output_text))
        
#     elif not flow_has_more:
#         emit('render_sys_message', {"data": '[Chat completed. Please continue to survey.]'}, room=request.sid)
            
    elif not input_not_empty:
        emit('render_sys_message', {"data": '[oops, please enter message text]'}, room=request.sid)

    else:
        emit('render_sys_message', {"data": '[Chat completed. Please click the “->” button to move on to the next step.]'}, room=request.sid)
コード例 #21
0
    def parley(self):
        self.turn_idx += 1
        print(self.world_tag + ' is at turn {}...'.format(self.turn_idx))
        """If at first turn, we need to give each agent their persona"""
        if self.turn_idx == 1:
            for idx, agent in enumerate(self.agents):
                persona_text = ''
                for s in self.personas[idx]:
                    persona_text += ('<b><span style="color:blue">'
                                     '{}\n</span></b>'.format(s.strip()))
                control_msg = self.get_control_msg()
                control_msg['persona_text'] = persona_text
                control_msg['text'] = self.get_instruction(tag='start',
                                                           agent_id=agent.id)
                # TODO: check that get instruction actually exists?
                agent.observe(validate(control_msg))
                if idx == 0:
                    time.sleep(3)
        """If we get to the min turns, inform turker that they can end if they
        want.
        """
        if self.turn_idx == self.n_turn + 1:
            for idx, agent in enumerate(self.agents):
                control_msg = self.get_control_msg()
                control_msg['text'] = self.get_instruction(
                    idx, tag='exceed_min_turns')
                control_msg['exceed_min_turns'] = True
                agent.observe(validate(control_msg))
        """Otherwise, we proceed accordingly."""
        # Other agent first
        if self.other_first and self.turn_idx == 1:
            if self.model_agent is not None:
                # Model must observe its persona
                persona_act = {
                    'text': '\n'.join([self.model_persona_text,
                                       '__SILENCE__']),
                    'episode_done': False,
                }
                self.model_agent.observe(persona_act)
                self.bot_seen_persona = True
                model_act = copy.deepcopy(self.model_agent.act())
                model_act.force_set('text', normalize_reply(model_act['text']))
                model_act.force_set('id', 'PERSON_2')
                self.dialog.append((1, model_act.get('text')))
                _random_delay()
                self.eval_agent.observe(_strip_tensors(model_act))
            else:
                act = self.get_human_agent_act(self.other_agent)
                timeout = self.check_timeout(act)
                if timeout:
                    # eval agent early disconnect
                    control_msg = self.get_control_msg()
                    control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                    self.eval_agent.observe(validate(control_msg))
                    return
                else:
                    self.dialog.append((1, act.get('text')))
                    act = copy.deepcopy(act)
                    act.force_set('text', normalize_reply(act['text']))
                    self.eval_agent.observe(act)

        # Eval agent turn
        act = Message(self.get_human_agent_act(self.eval_agent))
        timeout = self.check_timeout(act)
        if timeout:
            if self.model_agent is None:
                control_msg = self.get_control_msg()
                control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                self.other_agent.observe(validate(control_msg))
            return

        if act['episode_done']:
            if self.turn_idx >= self.n_turn:
                if not self.other_first:
                    self.dialog_list = [
                        '\n'.join([self.dialog[i][1], self.dialog[i + 1][1]])
                        for i in range(0, len(self.dialog), 2)
                    ]
                else:
                    self.dialog_list = [' \n' + self.dialog[0][1]] + [
                        '\n'.join([self.dialog[i][1], self.dialog[i + 1][1]])
                        for i in range(1,
                                       len(self.dialog) - 1, 2)
                    ]
                self.parallel_eval_mode()

                self.chat_done = True
                for ag in self.agents:
                    control_msg = self.get_control_msg()
                    control_msg['text'] = CHAT_ENDED_MSG
                    ag.observe(validate(control_msg))
            return

        self.dialog.append((0, act['text']))

        if not self.bot_seen_persona and self.model_agent is not None:
            # Add persona for model to observe
            act.force_set('text',
                          '\n'.join([self.model_persona_text, act['text']]))
            self.bot_seen_persona = True
        if self.model_agent is not None:
            self.model_agent.observe(act)
        else:
            act = copy.deepcopy(act)
            act.force_set('text', normalize_reply(act['text']))
            self.other_agent.observe(act)

        # Model_agent turn
        if not self.other_first or self.turn_idx < self.n_turn:
            if self.model_agent is not None:
                _random_delay()
                act = _strip_tensors(copy.deepcopy(self.model_agent.act()))
                act.force_set('text', normalize_reply(act['text']))
                act.force_set('id', 'PERSON_2')
                # NOTE: your model may or may not need to observe itself here
                # If it does, call model_observes_itself or some other specialized
                # function
            else:
                act = self.get_human_agent_act(self.other_agent)
                timeout = self.check_timeout(act)
                if timeout:
                    # eval agent early disconnect
                    control_msg = self.get_control_msg()
                    control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                    self.eval_agent.observe(validate(control_msg))
                    return

            self.dialog.append((1, act.get('text')))
            act = copy.deepcopy(act)
            act.force_set('text', normalize_reply(act['text']))
            self.eval_agent.observe(act)
コード例 #22
0
 def normalize_replies(self, x):
     xs = x.split('\n')
     xs2 = []
     for x in xs:
         xs2.append(normalize_reply(x))
     return '\n'.join(xs2)
コード例 #23
0
ファイル: eval.py プロジェクト: zqhfpjlswsqy/google-research
def _run_conversation(conversation_id, conversation, tgt_agent, ref_agent):
    tgt_agent.reset()
    ref_agent.reset()
    model_persona = conversation['model_persona']
    model_persona = PERSONA_PREFIX + model_persona.replace(
        '\n', '\n' + PERSONA_PREFIX)

    new_dialog = []
    for turn_id, turn in enumerate(conversation['dialog']):
        speaker = turn['speaker']
        reference_text = turn['text']

        if turn_id == 0 and speaker == 'model':
            silenced_text = model_persona + '\n' + SILENCE
            observed = ref_agent.observe({
                'id': 'SPEAKER_2',
                'text': silenced_text,
                'episode_done': False
            })
            observed = tgt_agent.observe({
                'id': 'SPEAKER_2',
                'text': silenced_text,
                'episode_done': False
            })
            new_dialog.append({'speaker': 'human_evaluator', 'text': SILENCE})
        elif turn_id == 0 and speaker == 'human_evaluator':
            reference_text = model_persona + '\n' + reference_text

        if speaker == 'human_evaluator':
            observed = ref_agent.observe({
                'id': 'SPEAKER_2',
                'text': reference_text,
                'episode_done': False
            })
            observed = tgt_agent.observe({
                'id': 'SPEAKER_2',
                'text': reference_text,
                'episode_done': False
            })
            new_dialog.append({
                'speaker': 'human_evaluator',
                'text': turn['text']
            })
        if speaker == 'model':
            ref_response = ref_agent.batch_act([ref_agent.observation])[0]
            ref_agent.self_observe(ref_response)
            tgt_response = tgt_agent.batch_act([tgt_agent.observation])[0]
            tgt_agent.self_observe(deepcopy(ref_response))
            assert tgt_response['id'] == ref_response['id']

            response_normalized = normalize_reply(ref_response['text'])
            if response_normalized != reference_text:
                logging.error(
                    f'{conversation_id}:{turn_id}: ref {repr(reference_text)} '
                    f'!= resp {repr(response_normalized)}. Context:\n{repr(observed)}'
                )
                return False
            response_normalized = normalize_reply(tgt_response['text'])
            new_dialog.append({'speaker': 'model', 'text': turn['text']})
            new_dialog.append({
                'speaker': 'tgt_model',
                'text': response_normalized
            })

            #else:
            #    logging.info(f'{conversation_id}:{turn_id} OK')
    conversation['dialog'] = new_dialog
    return True