def normalize_replies(self, x): xs = x.split('\n') your_personas = [] partner_personas = [] non_personas = [] for x in xs: if x.startswith('your persona: '): # Normalize the sentence appearing after 'your persona:' x = x[len('your persona: ') :] x = normalize_reply(x) your_personas.append(x) elif x.startswith("partner's persona: "): x = x[len("partner's persona: ") :] x = normalize_reply(x) partner_personas.append(x) else: x = normalize_reply(x) non_personas.append(x) xs2 = [] if not self.is_convai2_session_level: your_personas = ['your persona: ' + yx for yx in your_personas] partner_personas = ["partner's persona: " + px for px in partner_personas] else: if your_personas: your_personas = ['your persona: ' + " ".join(your_personas)] if partner_personas: partner_personas = ["partner's persona: " + " ".join(partner_personas)] if self.your_persona_first: xs2.extend(your_personas) xs2.extend(partner_personas) else: xs2.extend(partner_personas) xs2.extend(your_personas) xs2.extend(non_personas) return '\n'.join(xs2)
def test_normalize_reply_version1(self): assert string_utils.normalize_reply("I ' ve a cat .") == "I've a cat." assert ( string_utils.normalize_reply("do you think i can dance?") == "Do you think I can dance?" ) assert string_utils.normalize_reply("I ' m silly '") == "I'm silly'"
def normalize_replies(self, x): xs = x.split('\n') your_personas = [] partner_personas = [] non_personas = [] for x in xs: if x.startswith('your persona: '): # Normalize the sentence appearing after 'your persona:' x = x[len('your persona: '):] x = normalize_reply(x) x = 'your persona: ' + x your_personas.append(x) elif x.startswith("partner's persona: "): x = x[len("partner's persona: "):] x = normalize_reply(x) x = "partner's persona: " + x partner_personas.append(x) else: x = normalize_reply(x) non_personas.append(x) xs2 = [] if self.your_persona_first: xs2.extend(your_personas) xs2.extend(partner_personas) else: xs2.extend(partner_personas) xs2.extend(your_personas) xs2.extend(non_personas) return '\n'.join(xs2)
def test_normalize_reply_version2(self): assert string_utils.normalize_reply("Add a period", 2) == "Add a period." assert string_utils.normalize_reply("Add a period?", 2) == "Add a period?" assert string_utils.normalize_reply("Add a period!", 2) == "Add a period!" assert string_utils.normalize_reply('"Add a period"', 2) == '"add a period"'
def normalize_replies(self, x): xs = x.split('\n') xs2 = [] for x in xs: if 'your persona:' in x: # Normalize the sentence appearing after 'your persona:' x = x[len('your persona: '):] x = normalize_reply(x) x = 'your persona: ' + x else: x = normalize_reply(x) xs2.append(x) return '\n'.join(xs2)
def normalize_replies(self, x): xs = x.split('\n') xs2 = [] for x in xs: if x.startswith('your persona: '): continue elif x.startswith("partner's persona: "): x = x[len("partner's persona: "):] x = normalize_reply(x) x = "partner's persona: " + x else: x = normalize_reply(x) xs2.append(x) return '\n'.join(xs2)
def _acutify_convo( self, conversation: Conversation, config_id: str ) -> Dict[str, List]: """ Format world-logged conversation to be ACUTE format. :param conversation: dictionary containing the dialogue for a config_id :param config_id: config_id string :return conversation: An ACUTE-Readable conversation """ config = CONFIG[config_id] is_selfchat = 'model' in config acute_conversation: Dict[str, List] = { 'context': [], 'dialogue': [], 'speakers': [], } for i, ex in enumerate(conversation): if ex['id'] == 'context': acute_conversation['context'].append(ex) continue speaker_id = ex['id'] if is_selfchat: speaker_id = config_id if i % 2 == 0 else f'other_{config_id}' if speaker_id not in acute_conversation['speakers']: acute_conversation['speakers'].append(speaker_id) acute_conversation['dialogue'].append( {'id': speaker_id, 'text': normalize_reply(ex['text'])} ) return acute_conversation
def check_negative_sentiment(sent_eval, text): norm_text = normalize_reply(text) sent_scores = sent_eval.polarity_scores(norm_text) if sent_scores["compound"] >= 0: return False return True
def act(self, timeout=None): """ Same as model chat's bot_agent.py except the self_observe function is removed, a custom observe is instead written in worlds.py. This is so that the two bots can read each others' messages using observe so that the conversation history stays the same. """ _ = timeout # The model doesn't care about the timeout if self.semaphore: with self.semaphore: act_out = self.model_agent.batch_act([self.model_agent.observation])[0] else: act_out = self.model_agent.batch_act([self.model_agent.observation])[0] act_out = Message(act_out).json_safe_payload() if 'dict_lower' in self.opt and not self.opt['dict_lower']: # model is cased so we don't want to normalize the reply like below final_message_text = act_out['text'] else: final_message_text = normalize_reply(act_out['text']) act_out['text'] = final_message_text assert ('episode_done' not in act_out) or (not act_out['episode_done']) self.turn_idx += 1 return {**act_out, 'episode_done': False}
def _acutify_convo( self, dialogue_dict: Dict[str, Any], model: str ) -> Dict[str, List]: """ Format world-logged conversation to be ACUTE format. :param dialogue_dict: dictionary containing the dialogue for a model :param model: model string :return conversation: An ACUTE-Readable conversation """ conversation = { 'context': [], 'dialogue': [], 'speakers': [model, f'other_{model}'], } dialog = dialogue_dict['dialog'] for act_pair in dialog: for i, ex in enumerate(act_pair): if ex['id'] == 'context': conversation['context'].append(ex) continue conversation['dialogue'].append( { 'id': model if i == 0 else f'other_{model}', 'text': normalize_reply(ex['text']), } ) return conversation
def check_negation(spacy_nlp, text): norm_text = normalize_reply(text) doc = spacy_nlp(norm_text) for token in doc: if token.dep_ == "neg": return True return False
def normalize_replies(self, x): xs = [xt.strip() for xt in x.split('\n')] xs2 = [] for x in xs: if 'your persona:' in x: # Normalize the sentence appearing after 'your persona:' x = x[len('your persona: '):] x = normalize_reply(x) x = 'your persona: ' + x elif "partner's persona: " in x: x = x[len("partner's persona: "):] x = normalize_reply(x) x = "partner's persona: " + x elif x != DUMMY_TEXT: x = normalize_reply(x) xs2.append(x) return "\n".join(xs2)
def _acutify_convo(self, dialogue_dict: Dict[str, Any], model: str) -> Dict[str, List]: """ Format world-logged conversation to be ACUTE format. :param dialogue_dict: dictionary containing the dialogue for a model :param model: model string :return conversation: An ACUTE-Readable conversation """ conversation = { 'context': [], 'dialogue': [], 'speakers': ['human_evaluator', model], } if ('is_selfchat' in self.model_config[model] and self.model_config[model]['is_selfchat']): if 'flip' in self.model_config[model]: conversation['speakers'] = ['other_speaker', model] else: conversation['speakers'] = [model, 'other_speaker'] dialog = dialogue_dict['dialog'] for act_pair in dialog: for i, ex in enumerate(act_pair): if ex['id'] == 'context': conversation['context'].append(ex) continue else: # agent 1 is the model, agent 0 is human convo = { 'id': ex['id'], 'text': normalize_reply(ex['text']) } if ('is_selfchat' in self.model_config[model] and self.model_config[model]['is_selfchat']): # if_selfchat override agent_id if 'flip' in self.model_config[model]: if i % 2 == 1: convo['id'] = model else: convo['id'] = 'other_speaker' else: if i % 2 == 0: convo['id'] = model else: convo['id'] = 'other_speaker' conversation['dialogue'].append(convo) return conversation
def _acutify_convo(self, dialogue_dict: Dict[str, Any], model: str) -> Dict[str, List]: """ Format world-logged conversation to be ACUTE format. :param dialogue_dict: dictionary containing the dialogue for a model :param model: model string :return conversation: An ACUTE-Readable conversation """ is_selfchat = 'model' in self.model_config[model] or self.model_config[ model].get('is_selfchat', False) # It's a self-chat if one of the following are true: # (1) a model is specified in the config, meaning that we're collecting # self-chats with that model # (2) we manually set 'is_selfchat' to True in the config if is_selfchat: # Set which speaker we will evaluate the conversation turns of speaker_idx = self.model_config[model].get('speaker_idx', 0) assert speaker_idx in [0, 1] conversation = {'context': [], 'dialogue': [], 'speakers': []} dialog = dialogue_dict['dialog'] for act_pair in dialog: for i, ex in enumerate(act_pair): if ex['id'] == 'context': conversation['context'].append(ex) continue if is_selfchat: speaker_id = model if i == speaker_idx else f'other_speaker' else: speaker_id = ex['id'] if speaker_id not in conversation['speakers']: conversation['speakers'].append(speaker_id) conversation['dialogue'].append({ 'id': speaker_id, 'text': normalize_reply(ex['text']) }) return conversation
def act(self, timeout=None): _ = timeout # The model doesn't care about the timeout if self.semaphore: with self.semaphore: act_out = self.model_agent.act() else: act_out = self.model_agent.act() act_out = Message(act_out).json_safe_payload() if 'dict_lower' in self.opt and not self.opt['dict_lower']: # model is cased so we don't want to normalize the reply like below final_message_text = act_out['text'] else: final_message_text = normalize_reply(act_out['text']) act_out['text'] = final_message_text assert ('episode_done' not in act_out) or (not act_out['episode_done']) self.turn_idx += 1 return {**act_out, 'episode_done': False}
def act(self, timeout=None): _ = timeout # The model doesn't care about the timeout if self.semaphore: with self.semaphore: act_out = self.model_agent.act() else: act_out = self.model_agent.act() act_out = Message(act_out) # Wrap as a Message for compatibility with older ParlAI models if 'dict_lower' in self.opt and not self.opt['dict_lower']: # model is cased so we don't want to normalize the reply like below final_message_text = act_out['text'] else: final_message_text = normalize_reply(act_out['text']) act_out.force_set('text', final_message_text) assert ('episode_done' not in act_out) or (not act_out['episode_done']) self.turn_idx += 1 return {**act_out, 'episode_done': False}
def act(self, timeout=None): _ = timeout # The model doesn't care about the timeout if self.semaphore: with self.semaphore: act_out = self.model_agent.act() else: act_out = self.model_agent.act() annotations_html = construct_annotations_html( annotations_intro=self.opt['annotations_intro'], annotations_config=self.opt['annotations_config'], turn_idx=self.turn_idx, ) if 'dict_lower' in self.opt and not self.opt['dict_lower']: # model is cased so we don't want to normalize the reply like below final_message_text = act_out['text'] else: normalized_act_text = normalize_reply(act_out['text']) final_message_text = normalized_act_text + annotations_html if self.turn_idx >= self.num_turns * 2: radio_css_style = 'margin-left:5px;margin-right:15px;' radio_buttons_html = '' for i in range(1, 6): radio_buttons_html += f"""<input type="radio" id="radio_rating_{i}" name="radio_final_rating_group" value="{i}" /><span style={radio_css_style}>{i}</span>""" final_scoring_question = self.opt['final_rating_question'] exceeds_min_turns = f"""<br><br><div>{self.num_turns} chat turns finished! {final_scoring_question}</div> {radio_buttons_html} <br>Then, please click the "Done" button to end the chat.""" final_message_text += exceeds_min_turns act_out = Compatibility.backward_compatible_force_set( act_out, 'exceed_min_turns', True) act_out = Compatibility.backward_compatible_force_set( act_out, 'text', final_message_text) assert ('episode_done' not in act_out) or (not act_out['episode_done']) self.turn_idx += 1 return {**act_out, 'episode_done': False, 'checked_radio_name_id': ''}
def act(self, timeout=None): _ = timeout # The model doesn't care about the timeout if self.semaphore: with self.semaphore: act_out = self.model_agent.act() else: act_out = self.model_agent.act() annotations_html = TurkLikeAgent.construct_annotations_html(self.turn_idx) if 'dict_lower' in self.opt and not self.opt['dict_lower']: # model is cased so we don't want to normalize the reply like below final_message_text = act_out['text'] else: normalized_act_text = normalize_reply(act_out['text']) final_message_text = normalized_act_text + annotations_html if self.turn_idx >= self.num_turns * 2: radio_css_style = 'margin-left:5px;margin-right:15px;' radio_buttons_html = '' for i in range(1, 6): radio_buttons_html += f"""<input type="radio" id="radio_rating_{i}" name="radio_final_rating_group" value="{i}" /><span style={radio_css_style}>{i}</span>""" exceeds_min_turns = f"""<br><br><div>{self.num_turns} chat turns finished! Please rate your partner on a scale of 1-5, how much would you enjoy talking to this partner over the course of a long conversation? (1: not at all, 5: a lot)</div> {radio_buttons_html} <br>Then, please click the "Done" button to end the chat.""" final_message_text += exceeds_min_turns act_out = Compatibility.backward_compatible_force_set( act_out, 'exceed_min_turns', True ) act_out = Compatibility.backward_compatible_force_set( act_out, 'text', final_message_text ) assert ('episode_done' not in act_out) or (not act_out['episode_done']) self.turn_idx += 1 return {**act_out, 'episode_done': False, 'checked_radio_name_id': ''}
def rerank( self, observation: Message, response_cands: List[str], response_cand_scores: torch.Tensor, ) -> Tuple[List[str], List[int]]: """ Re-rank candidates according to predictor score. :param observation: Message object that includes the dialogue history :param response_cands: ranked list of model response candidates :param response_cand_scores: list of model response candidates' scores :return (candidates, indices): candidates: a re-ranked list of candidates indices: list of indices into response_cands corresponding to re-rank order """ full_context = observation['full_text'] # 0) Normalize the replies if the opt is passed in if self.normalize_candidates: response_cands = [normalize_reply(c) for c in response_cands] # 1) Augment context with response candidates if not self.include_label_cand_only: contexts = [ self.augment_context(full_context, cand, include_context=self.include_context) for cand in response_cands ] contexts = [self.delimiter.join(utts) for utts in contexts] else: # This variant only passes in the label candidates (with no dialogue history whatsoever # into the ranker. Can be useful for things like e.g. simple utterance-based safety classifiers. contexts = response_cands # 2) Predict with augmented context label_candidates = self.get_predictor_label_candidates( observation, full_context) reranker_outputs = self.batch_predict(contexts, label_candidates) # 3) Rerank rerank_for_class = self.get_class_to_rerank_for( observation, full_context) reranked_candidates, indices = self._rerank_candidates( reranker_outputs, response_cands, response_cand_scores, rerank_for_class=rerank_for_class, ) if self.show_debug_logging(observation): debug_str = self._construct_debug_logging( observation, response_cands, response_cand_scores, reranker_outputs, reranked_candidates, rerank_for_class, ) # Need print because even logging.WARN swallowed during eval print(debug_str) return reranked_candidates, indices
def user_sent_message_generate(message): """ Called when the participant sends a message to the generative model. """ if len( message["data"].strip()) > 0: session_info[request.sid]['num_written'] += 1 # message['count'] = session_info[request.sid]['num_written'] # session_info[request.sid]['condition'] = 'generation' agent_choice = message["agent"] pid = session_info[request.sid]['pid'] condition = session_info[request.sid]['condition']#+'-'+agent_choice # condition = agent_choice exchange_num = session_info[request.sid]['num_written'] last_flow = session_info[request.sid]['last_flow'] allow_continue = CONDITION_MESSAGES_NO_ACK['allow_continue'][last_flow] == 'continue' # Extract a string of the user's message raw_user_input_text = message["data"] raw_user_input_len = len(raw_user_input_text.split(' ')) if message['agent'] == 'therapybot_other': flow_has_more = last_flow + 1 < len(CONDITION_MESSAGES_NO_ACK['flow_fixed']) else: flow_has_more = last_flow + 1 < len(CONDITION_MESSAGES_NO_ACK['flow']) input_not_empty = len(raw_user_input_text.strip()) > 0 # message['is_done'] = 'false' if flow_has_more else 'true' emit('render_usr_message', message, room=request.sid) if flow_has_more and input_not_empty: input_text = raw_user_input_text if message['agent'] == 'therapybot_other': # use the generated acknowledgement and add some flow text flow_num = last_flow + 1 flow_text = CONDITION_MESSAGES_NO_ACK['flow_fixed'][flow_num] strategy = CONDITION_MESSAGES_NO_ACK['strategy_fixed'][flow_num] bot_response = flow_text # updating tracking of flow text usage session_info[request.sid]['last_flow'] = flow_num session_info[request.sid]['continue_cnt'] = 0 # reset because a new flow message was used else: torch.cuda.set_device(agents[message['agent']].opt['gpu']) # make sure model history is reset agents[agent_choice].reset() # set model history as user's conversation history agents[agent_choice].history = copy.deepcopy(session_info[request.sid]['parlai_history']) # print('PRE-GEN: ') # print(agents[agent_choice].history.history_strings) # # observe user input agents[agent_choice].observe(package_text(input_text)) # generate bot response agent_output = agents[agent_choice].act() ack = normalize_reply(agent_output['text']) # decide if let the user chat only with the generative model w/o any flow generate_only = allow_continue and (session_info[request.sid]['continue_cnt'] < MAX_GEN_TURNS) and (raw_user_input_len >= GEN_MIN_LEN) if message["agent"] == 'ethics_base': generate_only = True if generate_only: print('Generating the full message without any flow text') flow_num = -1 strategy = 'generation' bot_response = ack # note that message was only generated session_info[request.sid]['continue_cnt'] += 1 # increment the number of turns since the last flow message else: # use the generated acknowledgement and add some flow text flow_num = last_flow + 1 flow_text = CONDITION_MESSAGES_NO_ACK['flow'][flow_num] strategy = CONDITION_MESSAGES_NO_ACK['strategy'][flow_num] bot_response = ack + " " + flow_text # updating tracking of flow text usage session_info[request.sid]['last_flow'] = flow_num session_info[request.sid]['continue_cnt'] = 0 # reset because a new flow message was used replace_last_reply(agents[agent_choice].history, bot_response) # print(message['agent']) # print(agent_output) # print('POST-GEN: ') # print(agents[agent_choice].history.history_strings) # make sure to store updated user's history session_info[request.sid]['parlai_history'] = copy.deepcopy(agents[agent_choice].history) # make sure model history is reset agents[agent_choice].reset() # print('POST-RESET: ') # print(agents[agent_choice].history.history_strings) # input_text = raw_user_input_text.lower() # debug: review the preprocessing of the raw input text. # output_text, response_info = MODEL_DICT[selected_model].chat(input_text, # compound_sid, # '', # assignmentid # # bot_response = CONDITION_MESSAGES[condition][exchange_num] output_text = bot_response with sqlite3.connect('data/session_info.db') as conn: cur = conn.cursor() # sid text, pid text, condition text, message text, response text, exchange_num int db_input = (request.sid, pid, condition, input_text, output_text, exchange_num, strategy, flow_num) cur.executemany("INSERT INTO message_pairs VALUES (?, ?, ?, ?, ?, ?, ?, ?)", [db_input]) conn.commit() print('Into message_pairs: ', db_input) # output_text = output_text.replace('f*****g', '<EXPLETIVE>').replace('f**k', '<EXPLETIVE>') # Pause if only a flow message to be more natural if message['agent'] == 'therapybot_other': time.sleep(2) # Render our response emit('render_sys_message', {"data": output_text}, room=request.sid) # response_info_str = json.dumps(response_info) # debug: uncomment storing to sql below. # with sqlite3.connect('data/session_info.db') as conn: # cur = conn.cursor() # cur.executemany("INSERT INTO response_info VALUES (?, ?)", [(compound_sid, response_info_str,)]) # conn.commit() session_info[request.sid]['convo'].append((raw_user_input_text, output_text)) # elif not flow_has_more: # emit('render_sys_message', {"data": '[Chat completed. Please continue to survey.]'}, room=request.sid) elif not input_not_empty: emit('render_sys_message', {"data": '[oops, please enter message text]'}, room=request.sid) else: emit('render_sys_message', {"data": '[Chat completed. Please click the “->” button to move on to the next step.]'}, room=request.sid)
def parley(self): self.turn_idx += 1 print(self.world_tag + ' is at turn {}...'.format(self.turn_idx)) """If at first turn, we need to give each agent their persona""" if self.turn_idx == 1: for idx, agent in enumerate(self.agents): persona_text = '' for s in self.personas[idx]: persona_text += ('<b><span style="color:blue">' '{}\n</span></b>'.format(s.strip())) control_msg = self.get_control_msg() control_msg['persona_text'] = persona_text control_msg['text'] = self.get_instruction(tag='start', agent_id=agent.id) # TODO: check that get instruction actually exists? agent.observe(validate(control_msg)) if idx == 0: time.sleep(3) """If we get to the min turns, inform turker that they can end if they want. """ if self.turn_idx == self.n_turn + 1: for idx, agent in enumerate(self.agents): control_msg = self.get_control_msg() control_msg['text'] = self.get_instruction( idx, tag='exceed_min_turns') control_msg['exceed_min_turns'] = True agent.observe(validate(control_msg)) """Otherwise, we proceed accordingly.""" # Other agent first if self.other_first and self.turn_idx == 1: if self.model_agent is not None: # Model must observe its persona persona_act = { 'text': '\n'.join([self.model_persona_text, '__SILENCE__']), 'episode_done': False, } self.model_agent.observe(persona_act) self.bot_seen_persona = True model_act = copy.deepcopy(self.model_agent.act()) model_act.force_set('text', normalize_reply(model_act['text'])) model_act.force_set('id', 'PERSON_2') self.dialog.append((1, model_act.get('text'))) _random_delay() self.eval_agent.observe(_strip_tensors(model_act)) else: act = self.get_human_agent_act(self.other_agent) timeout = self.check_timeout(act) if timeout: # eval agent early disconnect control_msg = self.get_control_msg() control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG self.eval_agent.observe(validate(control_msg)) return else: self.dialog.append((1, act.get('text'))) act = copy.deepcopy(act) act.force_set('text', normalize_reply(act['text'])) self.eval_agent.observe(act) # Eval agent turn act = Message(self.get_human_agent_act(self.eval_agent)) timeout = self.check_timeout(act) if timeout: if self.model_agent is None: control_msg = self.get_control_msg() control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG self.other_agent.observe(validate(control_msg)) return if act['episode_done']: if self.turn_idx >= self.n_turn: if not self.other_first: self.dialog_list = [ '\n'.join([self.dialog[i][1], self.dialog[i + 1][1]]) for i in range(0, len(self.dialog), 2) ] else: self.dialog_list = [' \n' + self.dialog[0][1]] + [ '\n'.join([self.dialog[i][1], self.dialog[i + 1][1]]) for i in range(1, len(self.dialog) - 1, 2) ] self.parallel_eval_mode() self.chat_done = True for ag in self.agents: control_msg = self.get_control_msg() control_msg['text'] = CHAT_ENDED_MSG ag.observe(validate(control_msg)) return self.dialog.append((0, act['text'])) if not self.bot_seen_persona and self.model_agent is not None: # Add persona for model to observe act.force_set('text', '\n'.join([self.model_persona_text, act['text']])) self.bot_seen_persona = True if self.model_agent is not None: self.model_agent.observe(act) else: act = copy.deepcopy(act) act.force_set('text', normalize_reply(act['text'])) self.other_agent.observe(act) # Model_agent turn if not self.other_first or self.turn_idx < self.n_turn: if self.model_agent is not None: _random_delay() act = _strip_tensors(copy.deepcopy(self.model_agent.act())) act.force_set('text', normalize_reply(act['text'])) act.force_set('id', 'PERSON_2') # NOTE: your model may or may not need to observe itself here # If it does, call model_observes_itself or some other specialized # function else: act = self.get_human_agent_act(self.other_agent) timeout = self.check_timeout(act) if timeout: # eval agent early disconnect control_msg = self.get_control_msg() control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG self.eval_agent.observe(validate(control_msg)) return self.dialog.append((1, act.get('text'))) act = copy.deepcopy(act) act.force_set('text', normalize_reply(act['text'])) self.eval_agent.observe(act)
def normalize_replies(self, x): xs = x.split('\n') xs2 = [] for x in xs: xs2.append(normalize_reply(x)) return '\n'.join(xs2)
def _run_conversation(conversation_id, conversation, tgt_agent, ref_agent): tgt_agent.reset() ref_agent.reset() model_persona = conversation['model_persona'] model_persona = PERSONA_PREFIX + model_persona.replace( '\n', '\n' + PERSONA_PREFIX) new_dialog = [] for turn_id, turn in enumerate(conversation['dialog']): speaker = turn['speaker'] reference_text = turn['text'] if turn_id == 0 and speaker == 'model': silenced_text = model_persona + '\n' + SILENCE observed = ref_agent.observe({ 'id': 'SPEAKER_2', 'text': silenced_text, 'episode_done': False }) observed = tgt_agent.observe({ 'id': 'SPEAKER_2', 'text': silenced_text, 'episode_done': False }) new_dialog.append({'speaker': 'human_evaluator', 'text': SILENCE}) elif turn_id == 0 and speaker == 'human_evaluator': reference_text = model_persona + '\n' + reference_text if speaker == 'human_evaluator': observed = ref_agent.observe({ 'id': 'SPEAKER_2', 'text': reference_text, 'episode_done': False }) observed = tgt_agent.observe({ 'id': 'SPEAKER_2', 'text': reference_text, 'episode_done': False }) new_dialog.append({ 'speaker': 'human_evaluator', 'text': turn['text'] }) if speaker == 'model': ref_response = ref_agent.batch_act([ref_agent.observation])[0] ref_agent.self_observe(ref_response) tgt_response = tgt_agent.batch_act([tgt_agent.observation])[0] tgt_agent.self_observe(deepcopy(ref_response)) assert tgt_response['id'] == ref_response['id'] response_normalized = normalize_reply(ref_response['text']) if response_normalized != reference_text: logging.error( f'{conversation_id}:{turn_id}: ref {repr(reference_text)} ' f'!= resp {repr(response_normalized)}. Context:\n{repr(observed)}' ) return False response_normalized = normalize_reply(tgt_response['text']) new_dialog.append({'speaker': 'model', 'text': turn['text']}) new_dialog.append({ 'speaker': 'tgt_model', 'text': response_normalized }) #else: # logging.info(f'{conversation_id}:{turn_id} OK') conversation['dialog'] = new_dialog return True