def __init__(self, opt, agent, bot): super().__init__(opt, agent) # num_turns turns for a single side, and really it appears to be # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance num_turns = opt['num_turns'] max_resp_time = opt['max_resp_time'] self.opt = opt self.bot = bot self.task_turn_idx = 0 self.num_turns = num_turns self.agent.agent_id = 'Speaker 1' self.bot.agent_id = 'Speaker 2' self.dialog = [] self.tag = f'conversation_id {agent.mephisto_agent.db_id}' self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.check_acceptability = opt['check_acceptability'] self.acceptability_checker = AcceptabilityChecker() self.block_qualification = opt['block_qualification'] self.final_chat_data = None # TODO: remove this attribute once chat data is only stored in the Mephisto # TaskRun for this HIT (see .get_custom_task_data() docstring for more # information) # below are timeout protocols self.max_resp_time = max_resp_time # in secs print( f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.' )
def __init__(self, opt: Dict[str, Any]): AbstractTurnAnnotationResultsCompiler.__init__(self, opt) # Input args self.model_nickname = opt['model_nickname'] assert len(self.results_folders) > 0 for folder in self.results_folders: assert os.path.isdir(folder), f'{folder} is not a valid folder!' os.makedirs(self.output_folder, exist_ok=True) self.start_date = opt['start_date'] self.max_convos_per_worker = opt['max_convos_per_worker'] self.min_word_count = opt['min_word_count'] self.hit_block_list = opt['hit_block_list'].split(',') self.worker_block_list = opt['worker_block_list'].split(',') # Setting up problem buckets if self.use_problem_buckets: self.regular_buckets = [ bucket for bucket in self.problem_buckets if bucket not in ['other', 'none_all_good'] ] # Remove the buckets that are special cases self.acceptability_checker = AcceptabilityChecker() self.completed_run_stats_path = opt['completed_run_stats_path']
def __init__(self, opt: Dict[str, Any]): super().__init__(opt) # Validate problem buckets if self.use_problem_buckets and 'none_all_good' not in self.problem_buckets: # The code relies on a catchall "none" category if the user selects no other # annotation bucket raise ValueError( 'There must be a "none_all_good" category in self.problem_buckets!' ) # Input args assert len(self.results_folders) > 0 for folder in self.results_folders: assert os.path.isdir(folder), f'{folder} is not a valid folder!' os.makedirs(self.output_folder, exist_ok=True) self.start_date = opt['start_date'] self.max_convos_per_worker = opt['max_convos_per_worker'] self.min_word_count = opt['min_word_count'] self.hit_block_list = opt['hit_block_list'].split(',') self.worker_block_list = opt['worker_block_list'].split(',') # Setting up problem buckets if self.use_problem_buckets: self.regular_buckets = [ bucket for bucket in self.problem_buckets if bucket not in ['other', 'none_all_good'] ] # Remove the buckets that are special cases self.acceptability_checker = AcceptabilityChecker()
def __init__(self, opt, agent, bot): super().__init__(opt, agent) # num_turns turns for a single side, and really it appears to be # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance num_turns = opt['num_turns'] max_resp_time = opt['max_resp_time'] self.opt = opt self.bot = bot self.task_turn_idx = 0 self.num_turns = num_turns self.dialog = [] self.tag = f'conversation_id {agent.mephisto_agent.db_id}' self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.check_acceptability = opt['check_acceptability'] self.acceptability_checker = AcceptabilityChecker() self.block_qualification = opt['block_qualification'] # below are timeout protocols self.max_resp_time = max_resp_time # in secs print( f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.' )
def __init__(self, opt, agent, bots, context_info: Optional[dict] = None): # num_turns turns for a single side, and really it appears to be # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance # TODO: this logic is very close to that of BaseModelChatWorld.__init__(). Can # any of this be deduplicated? super(BaseModelChatWorld, self).__init__(opt, agent) num_turns = opt['num_turns'] max_resp_time = opt['max_resp_time'] self.opt = opt self.bots = bots self.task_turn_idx = 0 self.num_turns = num_turns self.dialog = [] self.tag = f'conversation_id {agent.mephisto_agent.db_id}' self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.check_acceptability = opt['check_acceptability'] self.acceptability_checker = AcceptabilityChecker() self.block_qualification = opt['block_qualification'] self.final_chat_data = None # TODO: remove this attribute once chat data is only stored in the Mephisto # TaskRun for this HIT (see .get_custom_task_data() docstring for more # information) # below are timeout protocols self.max_resp_time = max_resp_time # in secs logging.info( f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.' ) if context_info is not None: self.context_info = context_info self.personas = [ self.context_info['persona_1_strings'], self.context_info['persona_2_strings'], ] else: self.context_info = {} self.personas = None
def __init__(self, opt: Dict[str, Any]): super().__init__(opt) # Input args os.makedirs(self.output_folder, exist_ok=True) # TODO: see if this can be moved to the superclass self.filter_uniform_hits = opt['filter_uniform_hits'] # Save paths self.unacceptable_worker_ids_path = os.path.join( self.output_folder, 'unacceptable_worker_ids.txt') self.annotation_selection_rate_path = os.path.join( self.output_folder, 'annotation_selection_rates.csv') self.likert_score_stat_path = os.path.join(self.output_folder, 'likert_score_stats.csv') self.acceptability_checker = AcceptabilityChecker()
def __init__(self, opt: Dict[str, Any]): # TODO: deduplicate init from ModelChatResultsCompiler super().__init__(opt) # Input args os.makedirs(self.output_folder, exist_ok=True) self.worker_block_list = opt['worker_block_list'].split(',') # Save paths self.worker_results_path = os.path.join( self.output_folder, 'worker_results.csv' ) self.unacceptable_worker_ids_path = os.path.join( self.output_folder, 'unacceptable_worker_ids.txt' ) self.win_rate_by_date_path = os.path.join( self.output_folder, 'win_rates_by_date.csv' ) self.stat_mean_length_by_date_path = os.path.join( self.output_folder, 'stat_mean_length_by_date.csv' ) self.completion_time_by_model_pair_path = os.path.join( self.output_folder, 'mean_completion_times.csv' ) self.acceptability_checker = AcceptabilityChecker() # Set fields that should be empty strings if the relevant information is not # present blank_field_columns = [ 'human_text', 'human_choice', 'human_justification', 'accepted_bot_text', 'not_accepted_bot_text', ] self.blank_fields = {field: '' for field in blank_field_columns} # Results attributes self.stat_counts = {} self.mean_completion_time = None
def __init__( self, opt, agents=None, shared=None, num_turns=6, tag=None, max_resp_time=120, agent_timeout_shutdown=120, context_info: Optional[dict] = None, ): # 6 turns for a single side (so 12 total), and really it appears to be # 14 total b/c of the "Hi!" and first bot utterance self.agents = agents self.task_turn_idx = 0 self.num_turns = num_turns self.dialog = [] self.tag = tag self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False if context_info is not None: self.context_info = context_info self.personas = [ self.context_info['persona_1_strings'], self.context_info['persona_2_strings'], ] else: self.context_info = {} self.personas = None self.check_acceptability = opt['check_acceptability'] self.acceptability_checker = AcceptabilityChecker() # below are timeout protocols self.max_resp_time = max_resp_time # in secs self.agent_timeout_shutdown = agent_timeout_shutdown print( f'Creating {self.__class__.__name__} for tag {tag} with {num_turns} turns.' ) super().__init__(opt, agents, shared)
def __init__(self, opt: Dict[str, Any]): super().__init__(opt) # Input args assert len(self.results_folders) > 0 for folder in self.results_folders: assert os.path.isdir(folder), f'{folder} is not a valid folder!' os.makedirs(self.output_folder, exist_ok=True) self.start_date = opt['start_date'] self.max_convos_per_worker = opt['max_convos_per_worker'] self.min_word_count = opt['min_word_count'] self.hit_block_list = opt['hit_block_list'].split(',') self.worker_block_list = opt['worker_block_list'].split(',') # Setting up problem buckets self.regular_buckets = [ bucket for bucket in self.problem_buckets if bucket not in ['other', 'none_all_good'] ] # Remove the buckets that are special cases self.acceptability_checker = AcceptabilityChecker()
class ModelChatResultsCompiler(AbstractResultsCompiler): """ Compile and save results of human+model chats. Results will be saved on the level of specific conversations, as well as aggregated up to the level of each worker as a whole. """ @classmethod def setup_args(cls): parser = super().setup_args() parser.add_argument( '--filter-uniform-hits', action='store_true', help= 'Filter out any HITs in which the worker\'s annotations were the exact same on each turn of the conversation', ) return parser def __init__(self, opt: Dict[str, Any]): super().__init__(opt) # Input args os.makedirs(self.output_folder, exist_ok=True) # TODO: see if this can be moved to the superclass self.filter_uniform_hits = opt['filter_uniform_hits'] # Save paths self.unacceptable_worker_ids_path = os.path.join( self.output_folder, 'unacceptable_worker_ids.txt') self.annotation_selection_rate_path = os.path.join( self.output_folder, 'annotation_selection_rates.csv') self.likert_score_stat_path = os.path.join(self.output_folder, 'likert_score_stats.csv') self.acceptability_checker = AcceptabilityChecker() def get_results_path_base(self) -> str: return os.path.join(self.output_folder, 'results') # TODO: see if this can be moved to the superclass def compile_results(self) -> pd.DataFrame: # Load task data logging.info('Retrieving task data from Mephisto.') task_units_data = self.get_task_data() logging.info( f'Data for {len(task_units_data)} units loaded successfully.') num_convos_with_no_save_data = 0 num_wrong_status_convos = 0 num_complete_convos = 0 unacceptable_task_units = [] unacceptable_worker_ids = [] conversation_idx = 0 conversation_dfs = [] for task_unit in task_units_data: worker_id = task_unit['worker_id'] assignment_id = task_unit['assignment_id'] # Skipping this conversation if save data is not found or the status is # invalid if task_unit['data']['save_data'] is None: logging.info('Found a task unit with no save data! Skipping.') num_convos_with_no_save_data += 1 continue elif task_unit['status'] not in ['completed', 'approved']: logging.info( f'Found a HIT with the status "{task_unit["status"]}"!.' f'Skipping.') num_wrong_status_convos += 1 continue else: num_complete_convos += 1 # Extract out useful conversation-level data custom_data = task_unit['data']['save_data']['custom_data'] mturk_worker_id = Worker.get(self.get_mephisto_db(), worker_id).worker_name task_start = datetime.utcfromtimestamp(task_unit['task_start']) task_end = datetime.utcfromtimestamp(task_unit['task_end']) info_dict = { ('worker_id', ''): worker_id, ('mturk_worker_id', ''): mturk_worker_id, ('unit_id', ''): task_unit['unit_id'], ('assignment_id', ''): assignment_id, ('conversation_idx', ''): conversation_idx, ('date', ''): task_start.strftime('%Y-%m-%d'), ('completion_time', ''): (task_end - task_start).total_seconds(), } # Check that the conversation consists of pairs of comments between # Speaker 1 and Speaker 2, with Speaker 1 speaking first assert 'final_rating' in task_unit['data']['messages'][-1][ 'task_data'] convo_messages = [m for m in task_unit['data']['messages'][:-1]] # The final message is just a final rating assert all([ message['id'] == 'Speaker 2' if message_idx % 2 else 'Speaker 1' for message_idx, message in enumerate(convo_messages) ]) messages_1 = [m for m in convo_messages if m['id'] == 'Speaker 1'] messages_2 = [m for m in convo_messages if m['id'] == 'Speaker 2'] assert len(messages_1) + len(messages_2) == len(convo_messages) # Determine whether the HIT contains unacceptable messages. (We do this for # every HIT, even if acceptability violation info was already saved, because # the violation criteria may have changed since the HIT was collected.) utterances_1 = [m['text'] for m in messages_1] assert utterances_1[0] == 'Hi!', ( 'This script assumes that the first human message is "Hi!", which is ' 'set by default and cannot be changed by the crowdsourcing worker.' ) acceptability_violations = self.acceptability_checker.check_messages( messages=utterances_1[1:], # Don't use the initial "Hi!" is_worker_0=True, violation_types=self.acceptability_checker.ALL_VIOLATION_TYPES, ) # Here, "worker 0" refers to Speaker 1, because we mix 0- and 1-indexing if acceptability_violations != '': logging.info( f'Conversation fails acceptability checks with a violation of ' f'"{acceptability_violations}", given the following utterances: ' f'{utterances_1[1:]}. Skipping.') unacceptable_task_units.append(task_unit) assert ( mturk_worker_id is not None ), "MTurk worker ID cannot be determined for this unacceptable conversation!" unacceptable_worker_ids.append(mturk_worker_id) continue # Ignore the conversation if ratings for all turns are the same, because # it's somewhat implausible that *all* turns in a conversation should garner # the same rating of engagingness, humanness, interestingness, or none. # (However, don't put these workers on the "unacceptable worker IDs" list, # to give them a little bit of the benefit of the doubt: i.e. maybe the # worker just didn't try hard enough to find which responses were more # engaging, etc. than others, but that doesn't mean that all of their HITs # across all evals are bad and should be removed.) if self.filter_uniform_hits: annotations = [ m['task_data']['problem_data_for_prior_message'] for m in task_unit['data']['messages'] if 'problem_data_for_prior_message' in m.get( 'task_data', {}) ] hashable_annotations = [ tuple(a[key] for key in sorted(a.keys())) for a in annotations ] unique_annotations = set(hashable_annotations) if len(unique_annotations) < 1: raise ValueError('No annotations found for this HIT!') elif len(unique_annotations) == 1: logging.info( f'All model responses in the conversation received the same ' f'annotation: {hashable_annotations[0]}. Skipping.') unacceptable_task_units.append(task_unit) continue single_turn_dicts = [] # Compile personas and previous utterances text_parts = [] if custom_data['personas'] is not None and len( custom_data['personas']) > 0: assert len(custom_data['personas']) == 2 text_parts += [ 'HUMAN PERSONA: ' + ' '.join(custom_data['personas'][0]), 'BOT PERSONA: ' + ' '.join(custom_data['personas'][1]), ] if (custom_data['additional_context'] is not None and len(custom_data['additional_context']) > 0): text_parts.append('ADDITIONAL CONTEXT: ' + custom_data['additional_context']) single_turn_dicts.append({ **info_dict, ('context', ''): ' '.join(text_parts) }) # Loop over conversation turns turns_per_speaker = defaultdict(int) for message in task_unit['data']['messages']: if 'text' in message: speaker_id = message['id'] # Add in annotation results, if they exist if 'problem_data_for_prior_message' in message.get( 'task_data', {}): bucket_data = { ('annotation_bucket', bucket): value for bucket, value in message['task_data'] ['problem_data_for_prior_message'].items() } else: bucket_data = {} # Add in results from the final rating(s), if they exist if 'final_rating' in message.get('task_data', {}): ratings = message['task_data']['final_rating'].split( '|') final_rating_data = { ('final_rating', str(idx)): value for idx, value in enumerate(ratings) } else: final_rating_data = {} turns_per_speaker[speaker_id] += 1 single_turn_dicts.append({ **info_dict, ('speaker_id', ''): speaker_id, ('speaker_turn_idx', ''): turns_per_speaker[speaker_id], ('text', ''): message['text'].replace('\n', '__newline__'), **bucket_data, **final_rating_data, }) # Adding the full conversation to the list of conversations single_turn_series = [ pd.Series(dict_).to_frame().transpose() for dict_ in single_turn_dicts ] single_convo_df = pd.concat(single_turn_series, axis=0, sort=False) conversation_dfs.append(single_convo_df) conversation_idx += 1 logging.info( f'{num_convos_with_no_save_data:d} conversations found with no save data.' ) logging.info( f'{num_wrong_status_convos:d} conversations found with the wrong status.' ) logging.info(f'{num_complete_convos:d} complete conversations found:') logging.info( f'\t{len(unacceptable_task_units):d} unacceptable conversations.') logging.info(f'\t{len(conversation_dfs):d} acceptable conversations.') # # Compile results across all conversations if len(conversation_dfs) == 0: raise ValueError('No acceptable conversations found!') unordered_conversation_df = pd.concat(conversation_dfs, axis=0) initial_ordered_columns = list(info_dict.keys()) + [ ('context', ''), ('speaker_id', ''), ('speaker_turn_idx', ''), ('text', ''), ] all_ordered_columns = initial_ordered_columns + [ col for col in unordered_conversation_df.columns if col not in initial_ordered_columns ] conversation_df = unordered_conversation_df[all_ordered_columns] # TODO: is there a less hacky way than this, which relies on the most recent # value of `info_dict`, to put the columns back into the right order? # # Calculate and save auxiliary stats logging.info( f'Saving MTurk IDs of workers with unacceptable conversations to ' f'{self.unacceptable_worker_ids_path}.') with open(self.unacceptable_worker_ids_path, 'w') as f: for worker_id in unacceptable_worker_ids: f.write(worker_id + '\n') # Calculate rates of selecting various annotation buckets annotation_bucket_df = conversation_df['annotation_bucket'].dropna( axis=0, how='any') if annotation_bucket_df.isna().sum().sum() > 0: raise ValueError( 'There is at least one row in which only partial annotation bucket data exists!' ) annotation_selection_rate_df = annotation_bucket_df.mean().to_frame( 'selection_rate') annotation_selection_rate_df.to_csv( self.annotation_selection_rate_path) logging.info( f'Annotation bucket selection rates saved to {self.annotation_selection_rate_path}.' ) output_strings = [ f'{series.name}: {100*series["selection_rate"]:0.0f}%' for _, series in annotation_selection_rate_df.iterrows() ] logging.info('Annotation bucket selection rates:\n' + '\n'.join(output_strings)) # Calculate Likert score stats final_rating_df = conversation_df['final_rating'].dropna(axis=0, how='any') if final_rating_df.isna().sum().sum() > 0: raise ValueError( 'There is at least one row in which only partial final rating data exists!' ) likert_score_stat_df = final_rating_df.astype(int).describe() likert_score_stat_df.to_csv(self.likert_score_stat_path) logging.info( f'Likert score statistics saved to {self.likert_score_stat_path}.') logging.info( f'Mean Likert scores:\n{likert_score_stat_df.loc["mean"]}') return conversation_df
class BaseModelChatWorld(CrowdTaskWorld, ABC): def __init__(self, opt, agent, bot): super().__init__(opt, agent) # num_turns turns for a single side, and really it appears to be # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance num_turns = opt['num_turns'] max_resp_time = opt['max_resp_time'] self.opt = opt self.bot = bot self.task_turn_idx = 0 self.num_turns = num_turns self.dialog = [] self.tag = f'conversation_id {agent.mephisto_agent.db_id}' self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.check_acceptability = opt['check_acceptability'] self.acceptability_checker = AcceptabilityChecker() self.block_qualification = opt['block_qualification'] # below are timeout protocols self.max_resp_time = max_resp_time # in secs print( f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.' ) def __add_problem_data_to_utterance(self, p, turn_idx: int): """ Attach problem data to the bot's prior utterance, given by turn_idx. """ print(p) assert (self.dialog[turn_idx]['agent_idx'] == 1 ), 'Problem data must be attached to a bot utterance.' assert ('problem_data' not in self.dialog[turn_idx] ), "Don't overwrite existing problem data!" self.dialog[turn_idx]['problem_data'] = p def parley(self): print( f'{self.__class__.__name__}:{self.tag}: is at turn {self.task_turn_idx}, with {self.num_turns} pairs of turns needed...' ) if self.task_turn_idx == 0: self._run_initial_turn() self.task_turn_idx += 1 return """Otherwise, we proceed accordingly""" print( f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: {self.task_turn_idx}' ) acts = [None, None] for idx, agent in enumerate([self.agent, self.bot]): if not self.chat_done: acts[idx] = agent.act(timeout=self.max_resp_time) acts[idx] = Compatibility.maybe_fix_act(acts[idx]) if 'metrics' in acts[idx]: del acts[idx]['metrics'] # Metrics can't be saved to JSON and are not needed here print( f'Got act for agent idx {idx}, act was: {acts[idx]} and self.task_turn_idx: {self.task_turn_idx}.' ) if acts[idx].get('task_data', {}).get('final_rating') is not None: self.chat_done = True # agent ends chat after exceeding minimum number of turns if self.task_turn_idx > self.num_turns: # Human has just responded. Any problem data received now will be # regarding the bot's prior utterance p = acts[idx]['task_data'].get( 'problem_data_for_prior_message') if p is not None: turn_idx = -1 # Attach the problem data to the last utterance, since the human # hasn't said anything since then self.__add_problem_data_to_utterance(p, turn_idx=turn_idx) # Save the final chat data time_string = time.strftime('%Y%m%d_%H%M%S') chat_data_folder = self.opt['chat_data_folder'] os.makedirs(chat_data_folder, exist_ok=True) chat_data_path = os.path.join( chat_data_folder, f'{time_string}_{np.random.randint(0, 1000)}_{self.task_type}.json', ) final_chat_data = self.get_final_chat_data() self.agent.mephisto_agent.state.messages.append( {'final_chat_data': final_chat_data}) # Append the chat data directly to the agent state's message list in # order to prevent the worker from seeing a new text response in the UI with open(chat_data_path, 'w+') as f_json: data_str = json.dumps(final_chat_data) f_json.write(data_str) print(f'{self.__class__.__name__}:{self.tag}: Data saved at ' f'{chat_data_path} for model: {self.bot.worker_id}.') # Soft-block the worker if there were acceptability violations acceptability_violations = final_chat_data[ 'acceptability_violations'][0] if (acceptability_violations is not None and acceptability_violations != ''): print( f'**NOTE** Acceptability violations detected: {acceptability_violations}' ) # Grant the failed qualification self.agent.mephisto_agent.get_worker().grant_qualification( self.block_qualification, 1) return else: utterance_data = { 'agent_idx': idx, # Get rid of annotations HTML if it's the bot response 'text': acts[idx]['text'].split('<br>')[0], 'id': acts[idx]['id'] if 'id' in acts[idx] else 'NULL_ID', # Person1 or Polyencoder } self.dialog.append(utterance_data) if idx == 0: # Human has just responded. Any problem data received now will be # regarding the bot's prior utterance p = acts[idx]['task_data'].get( 'problem_data_for_prior_message') if p is not None: turn_idx = -2 # Attach the problem data to the second-to-last utterance, since # the last utterance is what the human just said self.__add_problem_data_to_utterance(p, turn_idx=turn_idx) self._postprocess_acts(acts=acts, agent_idx=idx) for other_agent in [self.agent, self.bot]: if other_agent != agent: other_agent.observe(validate(acts[idx])) print( f'[agent {idx}] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: {self.dialog}' ) self.task_turn_idx += 1 @abstractmethod def _run_initial_turn(self) -> None: """ Runs logic for the first turn of the human and the bot. """ def _postprocess_acts(self, acts: List[dict], agent_idx: int): """ Optionally perform further processing of the acts. Useful for subclasses. Will be executed after saving act data to self.dialog but before showing the act to the other agent. """ def shutdown(self): if self.chat_done: self.opt['run_statistics'][self.bot.worker_id] += 1 print('Runs completed per model: ' + ', '.join( f'{model}: {count:d}' for model, count in self.opt['run_statistics'].items())) self.agent.shutdown() def episode_done(self): return self.chat_done def get_final_chat_data(self) -> Dict[str, Any]: """ Return specific info about the conversation, the context, acceptability, etc. """ if self.check_acceptability: human_messages, violation_types = self._prepare_acceptability_checking( ) violations_string = self.acceptability_checker.check_messages( messages=human_messages, is_worker_0=False, violation_types=violation_types, ) else: violations_string = None data = { 'dialog': self.dialog, 'workers': [get_mturk_id_from_mephisto_wrapper(self.agent)], 'bad_workers': [], 'acceptability_violations': (violations_string, ), 'hit_ids': [self.agent.mephisto_agent.task_run_id], 'assignment_ids': [self.agent.mephisto_agent.assignment_id], 'task_description': { 'annotations_config': self.opt['annotations_config'], 'model_nickname': self.bot.worker_id, 'model_file': self.bot.model_agent.opt.get('model_file'), 'model_opt': self.bot.model_agent.opt, }, } # 'bad_workers' is for compatibility. Before, it was only non-empty if a # worker abandoned, returned, etc. a HIT, but now we don't even save chat # data in that case if self.check_acceptability: data['acceptability_violations'] = (violations_string, ) # Make a tuple for compatibility with a human/human conversation in # which we check both sides for acceptability return data def _prepare_acceptability_checking(self) -> Tuple[List[str], List[str]]: """ Return the list of human messages and the list of acceptability types to check. """ human_messages = [ message['text'] for message in self.dialog if message['agent_idx'] == 0 ] violation_types = ['min_words', 'all_caps', 'exact_match', 'safety'] return human_messages, violation_types
def test_sample_inputs(self): """ Test sample inputs/outputs for the acceptability checker. """ # Define test cases test_cases = [ { # Should pass 'messages': [ 'Hi - how are you?', 'What? Whatever for?', 'Wow, that sounds like a lot of work.', "No, I don't expect he would be too happy about that either.", "I don't even know where you would find that many squirrels.", 'Well, let me know if you need an extra hand.', ], 'is_worker_0': False, 'expected_violations': '', }, { 'messages': ['Hi', 'What?', 'Wow', "No", "I don't even know", 'Well,'], 'is_worker_0': False, 'expected_violations': 'under_min_length', }, { # Should fail, because the first worker shouldn't start with a greeting 'messages': [ 'Hi - how are you?', 'What? Whatever for?', 'Wow, that sounds like a lot of work.', "No, I don't expect he would be too happy about that either.", "I don't even know where you would find that many squirrels.", 'Well, let me know if you need an extra hand.', ], 'is_worker_0': True, 'expected_violations': 'starts_with_greeting', }, { 'messages': [ 'HEYYYYYYY', 'What? Whatever for?', 'Wow, that sounds like a lot of work.', "No, I don't expect he would be too happy about that either.", "I don't even know where you would find that many squirrels.", 'WELLLLL LEMME KNOOOOOO', ], 'is_worker_0': False, 'expected_violations': 'too_much_all_caps', }, { 'messages': [ 'Hi - how are you?', 'What? Whatever for?', 'Wow, that sounds like a lot of work.', "No, I don't expect he would be too happy about that either.", "I don't even know where you would find that many squirrels.", 'Hi - how are you?', ], 'is_worker_0': False, 'expected_violations': 'exact_match', }, { 'messages': [ 'Hi - how are you?', 'What? Whatever for?', 'Wow, that sounds like a lot of work.', "No, I don't expect he would be too happy about that either.", "I don't even know where you would find that many squirrels.", 'Well, let me know if you need an extra hand.', "I'm gonna say something that's totally XXX!", ], 'is_worker_0': False, 'expected_violations': 'unsafe:7', }, ] test_cases_with_errors = [{ 'messages': ['Message 1', 'Message 2'], 'is_worker_0': True, 'violation_types': ['non_existent_violation_type'], 'expected_exception': ValueError, }] # Create checker acceptability_checker = AcceptabilityChecker() # Run through violation test cases for test_case in test_cases: actual_violations = acceptability_checker.check_messages( messages=test_case['messages'], is_worker_0=test_case['is_worker_0'], violation_types=acceptability_checker.possible_violation_types, ) self.assertEqual(actual_violations, test_case['expected_violations']) # Run through test cases that should raise an error for test_case in test_cases_with_errors: with self.assertRaises(test_case['expected_exception']): acceptability_checker.check_messages( messages=test_case['messages'], is_worker_0=test_case['is_worker_0'], violation_types=test_case['violation_types'], )
class PerTurnEvalWorld(ModelChatWorld): def __init__(self, opt, agent, bots, context_info: Optional[dict] = None): # num_turns turns for a single side, and really it appears to be # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance # TODO: this logic is very close to that of BaseModelChatWorld.__init__(). Can # any of this be deduplicated? super(BaseModelChatWorld, self).__init__(opt, agent) num_turns = opt['num_turns'] max_resp_time = opt['max_resp_time'] self.opt = opt self.bots = bots self.task_turn_idx = 0 self.num_turns = num_turns self.dialog = [] self.tag = f'conversation_id {agent.mephisto_agent.db_id}' self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.check_acceptability = opt['check_acceptability'] self.acceptability_checker = AcceptabilityChecker() self.block_qualification = opt['block_qualification'] self.final_chat_data = None # TODO: remove this attribute once chat data is only stored in the Mephisto # TaskRun for this HIT (see .get_custom_task_data() docstring for more # information) # below are timeout protocols self.max_resp_time = max_resp_time # in secs logging.info( f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.' ) if context_info is not None: self.context_info = context_info self.personas = [ self.context_info['persona_1_strings'], self.context_info['persona_2_strings'], ] else: self.context_info = {} self.personas = None def __add_problem_data_to_utterance(self, p, turn_idx: int): """ Attach problem data to the bot's prior utterance, given by turn_idx. """ logging.info(p) assert (self.dialog[turn_idx]['agent_idx'] == 1 ), 'Problem data must be attached to a bot utterance.' assert ('problem_data' not in self.dialog[turn_idx] ), "Don't overwrite existing problem data!" self.dialog[turn_idx]['problem_data'] = p def parley(self): """ The main function that controls the logic of the task. Uses self.task_turn_idx to control the sequence of the conversation. Specifically, when self.task_turn_idx is even, we know that the bots just gave their potential responses, and that it is the human's turn to choose one of the responses and give a justification value. When self.task_turn_idx is odd, we know that the human just chose one of the bots' responses, and now needs to respond to that response. self.task_turn_idx is initially 0, and during _run_initial_turn() the UI is redrawn to have the human select between the bots' responses. Then, self.task_turn_idx is incremented to 1. During self.agent.observe(), the UI is redrawn for the following human input, and during self.agent.act(), the code awaits the human input. """ logging.info( f'{self.__class__.__name__}:{self.tag}: is at task_turn_idx ' f'{self.task_turn_idx}, with {self.num_turns} pairs of turns needed...' ) if self.task_turn_idx == 0: self._run_initial_turn() self.task_turn_idx += 1 return logging.info( f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: ' f'{self.task_turn_idx}') # At this point, we know that the human now needs to respond to the bot's # response that the human just chose # We retrieve information regarding the human's choice and justification using # self.agent.act() human_choose_bot_response_act = self.agent.act( timeout=self.max_resp_time) human_choose_bot_response_act = Message( Compatibility.maybe_fix_act( human_choose_bot_response_act)).json_safe_payload() logging.info( f'Got act for human, act was: {human_choose_bot_response_act} and ' f'self.task_turn_idx: {self.task_turn_idx}.') accepted_bot_response = human_choose_bot_response_act['task_data'][ 'accepted_bot_response'] accepted_bot_id = human_choose_bot_response_act['task_data'][ 'accepted_bot_id'] accepted_bot_justification_value = human_choose_bot_response_act[ 'task_data']['justification_value'] not_accepted_bot_response = human_choose_bot_response_act['task_data'][ 'not_accepted_bot_response'] not_accepted_bot_id = human_choose_bot_response_act['task_data'][ 'not_accepted_bot_id'] # We have both bots observe the accepted bot's response so that the conversation # history stays the same self.bots[0].observe(accepted_bot_response) self.bots[1].observe(accepted_bot_response) task_data = {} accepted_bot_utterance_data = { 'text': accepted_bot_response['text'].split('<br>')[0], 'id': accepted_bot_id, } not_accepted_bot_utterance_data = { 'text': not_accepted_bot_response['text'].split('<br>')[0], 'id': not_accepted_bot_id, } bot_utterance_data = { 'agent_idx': 1, 'accepted_bot_data': accepted_bot_utterance_data, 'not_accepted_bot_data': not_accepted_bot_utterance_data, 'human_choice': accepted_bot_id, 'human_justification': accepted_bot_justification_value, } self.dialog.append(bot_utterance_data) self._postprocess_acts(acts=None, agent_idx=0) # All logic and processing for this step has now been done, so we do # self.agent.observe() to send the accepted response back to the frontend to # display and update task turn index, as well as await for the next action, # which is the human typing their response task_data['task_turn_idx'] = self.task_turn_idx # The UI will ask the human to respond to the chosen bot response self.agent.observe({ 'text': accepted_bot_response['text'], 'task_data': task_data }) # Make self.task_turn_idx even now self.task_turn_idx += 1 # Check for whether 6 pairs of turns has been done, since the last message of a # conversation should always be the bot's response if (human_choose_bot_response_act is not None and human_choose_bot_response_act.get( 'task_data', {}).get('finished') is not None): self.chat_done = True # agent ends chat after exceeding minimum number of turns # Bot has just responded. Any problem data received now will be # regarding this bot's utterance # Get the final chat data self.final_chat_data = self.get_final_chat_data() # Soft-block the worker if there were acceptability violations acceptability_violations = self.final_chat_data[ 'acceptability_violations'][0] if acceptability_violations is not None and acceptability_violations != '': logging.info(f'**NOTE** Acceptability violations detected: ' f'{acceptability_violations}') # Grant the failed qualification self.agent.mephisto_agent.get_worker().grant_qualification( self.block_qualification, 1) return logging.info( f'[human agent] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: ' f'{self.dialog}') logging.info( f'Got act for human, act was: {human_choose_bot_response_act} and ' f'self.task_turn_idx: {self.task_turn_idx}.') # At this point, we know that the human now needs to respond to the bot's # response that the human just chose # We retrieve information regarding the human's response using self.agent.act() human_response_act = self.agent.act(timeout=self.max_resp_time) # Again, we have both bots observe the human response so that the conversation # history stays the same self.bots[0].observe(validate(human_response_act)) self.bots[1].observe(validate(human_response_act)) # Check that the models' conversation histories are the same bot_1_history = self.bots[0].model_agent.history.history_strings bot_2_history = self.bots[1].model_agent.history.history_strings assert ( bot_1_history == bot_2_history ), f"The two bots' conversation histories are different.\nBot 1 history: {bot_1_history}\nBot 2 history: {bot_2_history}" # After the bots have observed the human response, it's time for them to produce # their response to the human using self.bots.act() bot_1_response = self.bots[0].act() bot_1_response = Compatibility.maybe_fix_act(bot_1_response) bot_2_response = self.bots[1].act() bot_2_response = Compatibility.maybe_fix_act(bot_2_response) # We display the result to the frontend randomly so there is no selection bias. # Also, we attach our result to task_data to send arbitrary data to the frontend if random.random() > 0.5: task_data = { 'top_bot_data': { 'top_bot_id': self.bots[0].worker_id, 'top_bot_response': bot_1_response, }, 'bottom_bot_data': { 'bottom_bot_id': self.bots[1].worker_id, 'bottom_bot_response': bot_2_response, }, 'task_turn_idx': self.task_turn_idx, } else: task_data = { 'top_bot_data': { 'top_bot_id': self.bots[1].worker_id, 'top_bot_response': bot_2_response, }, 'bottom_bot_data': { 'bottom_bot_id': self.bots[0].worker_id, 'bottom_bot_response': bot_1_response, }, 'task_turn_idx': self.task_turn_idx, } human_utterance_data = { 'agent_idx': 0, # Get rid of annotations HTML if it's the bot response 'text': human_response_act['text'].split('<br>')[0], 'id': human_response_act['id'] if 'id' in human_response_act else 'NULL_ID', # Person1 or Polyencoder } self.dialog.append(human_utterance_data) # Human has just responded. Any problem data received now will be regarding the # bot's prior utterance p = human_response_act['task_data'].get( 'problem_data_for_prior_message') if p is not None: turn_idx = -2 # Attach the problem data to the second-to-last utterance, since the last # utterance is what the human just said self.__add_problem_data_to_utterance(p, turn_idx=turn_idx) self._postprocess_acts(acts=None, agent_idx=0) task_data['task_turn_idx'] = self.task_turn_idx # All logic and processing for this step has now been done, so we do # self.agent.observe() to send the two bots' responses back to the frontend to # display and update task turn index, as well as await for the next action, # which is the human choosing from the two responses and providing a # justification value # The UI will ask the human to choose between two bot responses and give a # justification logging.info(f'*** self.task_turn_idx: {self.task_turn_idx} ***') self.agent.observe({'text': '', 'task_data': task_data}) # Make self.task_turn_idx odd now self.task_turn_idx += 1 logging.info( f'[bot agent] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: ' f'{self.dialog}') def get_final_chat_data(self) -> Dict[str, Any]: """ Add relevant fields to the final chat data. """ if self.check_acceptability: human_messages, violation_types = self._prepare_acceptability_checking( ) violations_string = self.acceptability_checker.check_messages( messages=human_messages, is_worker_0=False, violation_types=violation_types, ) else: violations_string = None data = { 'dialog': self.dialog, 'workers': [get_mturk_id_from_mephisto_wrapper(self.agent)], 'bad_workers': [], 'acceptability_violations': (violations_string, ), 'hit_ids': [self.agent.mephisto_agent.task_run_id], 'assignment_ids': [self.agent.mephisto_agent.assignment_id], 'task_description': { 'annotations_config': self.opt['annotations_config'], 'model_1_nickname': self.bots[0].worker_id, 'model_1_file': self.bots[0].model_agent.opt.get('model_file'), 'model_1_opt': self.bots[0].model_agent.opt, 'model_2_nickname': self.bots[1].worker_id, 'model_2_file': self.bots[1].model_agent.opt.get('model_file'), 'model_2_opt': self.bots[1].model_agent.opt, }, } # 'bad_workers' is for compatibility. Before, it was only non-empty if a # worker abandoned, returned, etc. a HIT, but now we don't even save chat # data in that case if self.check_acceptability: data['acceptability_violations'] = (violations_string, ) # Make a tuple for compatibility with a human/human conversation in # which we check both sides for acceptability context_data = { 'personas': self.personas, 'context_dataset': self.context_info.get('context_dataset'), 'person1_seed_utterance': self.context_info.get('person1_seed_utterance'), 'person2_seed_utterance': self.context_info.get('person2_seed_utterance'), 'additional_context': self.context_info.get('additional_context'), } data.update(context_data) return data def _run_initial_turn(self) -> None: """ Run the initial turn for both the human and the bot. Optionally show the bot its persona. If we are in Meena-like conversation mode show "Hi!" to the human and the bot and let the bot respond accordingly. Check parley() function for more information on the main logic. """ control_msg = {"episode_done": False} if self.opt['include_persona']: # The Bot agent # We add the personas and 1/3 of the time WoW topic as the # first utterance in the history. # Previously for BST task, we also had a big first utterance # that gave instructions. Removing that for this task. persona_strings = [s.strip() for s in self.personas[1]] persona_utterance = self._get_persona_utterance( persona_strings=persona_strings, context_dataset=self.context_info['context_dataset'], additional_context=self.context_info['additional_context'], is_bot=True, ) message = control_msg.copy() message['text'] = persona_utterance # The bot seeing its persona does not count as a "turn" self.bots[0].observe(validate(message), increment_turn=False) self.bots[1].observe(validate(message), increment_turn=False) if self.opt['conversation_start_mode'] == 'hi': logging.info('[Displaying "Hi!" only as per Meena task.]') if self.personas is not None: human_persona_strings = [s.strip() for s in self.personas[0]] else: human_persona_strings = ['', ''] human_first_msg = { 'episode_done': False, 'id': self.agent.id, 'text': 'Hi!', 'fake_start': True, 'agent_idx': 0, 'task_data': { 'human_persona_string_1': human_persona_strings[0], 'human_persona_string_2': human_persona_strings[1], 'prompt_instruction': self.opt['task_question'], }, } for k, v in control_msg.items(): human_first_msg[k] = v # The first message is always "Hi", so we have both bots observe the message self.dialog.append(human_first_msg) self.agent.observe(validate(human_first_msg)) self.bots[0].observe(validate(human_first_msg)) self.bots[1].observe(validate(human_first_msg)) bot_1_response = self.bots[0].act() bot_1_response = Compatibility.maybe_fix_act(bot_1_response) bot_2_response = self.bots[1].act() bot_2_response = Compatibility.maybe_fix_act(bot_2_response) if random.random() > 0.5: task_data = { 'top_bot_data': { 'top_bot_id': self.bots[0].worker_id, 'top_bot_response': bot_1_response, }, 'bottom_bot_data': { 'bottom_bot_id': self.bots[1].worker_id, 'bottom_bot_response': bot_2_response, }, 'task_turn_idx': self.task_turn_idx, } else: task_data = { 'top_bot_data': { 'top_bot_id': self.bots[1].worker_id, 'top_bot_response': bot_2_response, }, 'bottom_bot_data': { 'bottom_bot_id': self.bots[0].worker_id, 'bottom_bot_response': bot_1_response, }, 'task_turn_idx': self.task_turn_idx, } # Need an initial human's observe to observe the two choices from the bot self.agent.observe({'text': '', 'task_data': task_data}) else: raise ValueError( f"Conversation start mode {self.opt['conversation_start_mode']} " f"not recognized!") def shutdown(self): if self.chat_done: self.opt['run_statistics'][ f'{self.bots[0].worker_id}:{self.bots[1].worker_id}'] += 1 logging.info('Runs completed per model: ' + ', '.join( f'{model}: {count:d}' for model, count in self.opt['run_statistics'].items())) self.agent.shutdown()
class TurnAnnotationsChatWorld(CrowdTaskWorld): def __init__(self, opt, agent=None, bot=None, context_info: Optional[dict] = None): super().__init__(opt, agent) # num_turns turns for a single side, and really it appears to be # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance num_turns = opt['num_turns'] max_resp_time = opt['max_resp_time'] self.opt = opt self.bot = bot self.task_turn_idx = 0 self.num_turns = num_turns self.dialog = [] self.tag = f'conversation_id {agent.mephisto_agent.db_id}' self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False if context_info is not None: self.context_info = context_info self.personas = [ self.context_info['persona_1_strings'], self.context_info['persona_2_strings'], ] else: self.context_info = {} self.personas = None self.check_acceptability = opt['check_acceptability'] self.acceptability_checker = AcceptabilityChecker() self.block_qualification = opt['block_qualification'] # below are timeout protocols self.max_resp_time = max_resp_time # in secs print( f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.' ) def __add_problem_data_to_utterance(self, p, turn_idx: int): """ Attach problem data to the bot's prior utterance, given by turn_idx. """ print(p) assert (self.dialog[turn_idx]['agent_idx'] == 1 ), 'Problem data must be attached to a bot utterance.' assert ('problem_data' not in self.dialog[turn_idx] ), "Don't overwrite existing problem data!" self.dialog[turn_idx]['problem_data'] = p def parley(self): print( f'{self.__class__.__name__}:{self.tag}: is at turn {self.task_turn_idx}, with {self.num_turns} pairs of turns needed...' ) control_msg = {"episode_done": False} if self.task_turn_idx == 0: if self.opt['include_persona']: # The Bot agent # We add the personas and 1/3 of the time WoW topic as the # first utterance in the history. # Previously for BST task, we also had a big first utterance # that gave instructions. Removing that for this task. persona_strings = [s.strip() for s in self.personas[1]] persona_utterance = self._get_persona_utterance( persona_strings=persona_strings, context_dataset=self.context_info['context_dataset'], additional_context=self.context_info['additional_context'], is_bot=True, ) message = control_msg.copy() message['text'] = persona_utterance # The bot seeing its persona does not count as a "turn" self.bot.observe(validate(message), increment_turn=False) if self.opt['conversation_start_mode'] == 'bst': print('[Displaying first utterances as per BST task.]') # Display the previous two utterances human_first_msg = { 'episode_done': False, 'id': self.agent.id, 'text': self.context_info['person1_seed_utterance'], 'fake_start': True, 'agent_idx': 0, } for k, v in control_msg.items(): human_first_msg[k] = v bot_first_msg = { 'episode_done': False, 'id': self.bot.id, 'text': self.context_info['person2_seed_utterance'], 'fake_start': True, 'agent_idx': 1, } print( f'human_first_msg: {human_first_msg}, bot_first_msg: {bot_first_msg}' ) self.dialog.append(human_first_msg) self.dialog.append(bot_first_msg) for observer in [self.agent, self.bot]: observer.observe(validate(human_first_msg)) observer.observe(validate(bot_first_msg)) elif self.opt['conversation_start_mode'] == 'hi': print('[Displaying "Hi!" only as per Meena task.]') human_first_msg = { 'episode_done': False, 'id': self.agent.id, 'text': 'Hi!', 'fake_start': True, 'agent_idx': 0, } for k, v in control_msg.items(): human_first_msg[k] = v self.dialog.append(human_first_msg) self.agent.observe(validate(human_first_msg)) self.bot.observe(validate(human_first_msg)) first_bot_act = self.bot.act() first_bot_act = Compatibility.maybe_fix_act(first_bot_act) self.agent.observe(validate(first_bot_act)) bot_utterance_data = { 'agent_idx': 1, 'text': first_bot_act['text'], 'id': first_bot_act['id'], } self.dialog.append(bot_utterance_data) else: raise ValueError( f"Conversation start mode {self.opt['conversation_start_mode']} " f"not recognized!") self.task_turn_idx += 1 return """Otherwise, we proceed accordingly""" print( f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: {self.task_turn_idx}' ) acts = [None, None] for idx, agent in enumerate([self.agent, self.bot]): if not self.chat_done: acts[idx] = agent.act(timeout=self.max_resp_time) acts[idx] = Compatibility.maybe_fix_act(acts[idx]) print( f'Got act for agent idx {idx}, act was: {acts[idx]} and self.task_turn_idx: {self.task_turn_idx}.' ) if acts[idx].get('task_data', {}).get('final_rating') is not None: self.chat_done = True # agent ends chat after exceeding minimum number of turns if self.task_turn_idx > self.num_turns: # Human has just responded. Problem data received # now will be regarding the bot's prior utterance p = acts[idx]['task_data'][ 'problem_data_for_prior_message'] turn_idx = -1 # Attach the problem data to the last utterance, since the human # hasn't said anything since then self.__add_problem_data_to_utterance(p, turn_idx=turn_idx) # Save the final chat data time_string = time.strftime('%Y%m%d_%H%M%S') chat_data_folder = self.opt['chat_data_folder'] os.makedirs(chat_data_folder, exist_ok=True) chat_data_path = os.path.join( chat_data_folder, f'{time_string}_{np.random.randint(0, 1000)}_{self.task_type}.json', ) final_chat_data = self.get_final_chat_data() self.agent.mephisto_agent.state.messages.append( {'final_chat_data': final_chat_data}) # Append the chat data directly to the agent state's message list in # order to prevent the worker from seeing a new text response in the UI with open(chat_data_path, 'w+') as f_json: data_str = json.dumps(final_chat_data) f_json.write(data_str) print(f'{self.__class__.__name__}:{self.tag}: Data saved at ' f'{chat_data_path} for model: {self.bot.worker_id}.') # Soft-block the worker if there were acceptability violations acceptability_violations = final_chat_data[ 'acceptability_violations'][0] if (acceptability_violations is not None and acceptability_violations != ''): print( f'**NOTE** Acceptability violations detected: {acceptability_violations}' ) # Grant the failed qualification self.agent.mephisto_agent.get_worker().grant_qualification( self.block_qualification, 1) return else: utterance_data = { 'agent_idx': idx, # Get rid of annotations HTML if it's the bot response 'text': acts[idx]['text'].split('<br>')[0], 'id': acts[idx]['id'] if 'id' in acts[idx] else 'NULL_ID', # Person1 or Polyencoder } self.dialog.append(utterance_data) if idx == 0: # Human has just responded. Problem data received # now will be regarding the bot's prior utterance p = acts[idx]['task_data'][ 'problem_data_for_prior_message'] turn_idx = -2 # Attach the problem data to the second-to-last utterance, since the # last utterance is what the human just said self.__add_problem_data_to_utterance(p, turn_idx=turn_idx) for other_agent in [self.agent, self.bot]: if other_agent != agent: other_agent.observe(validate(acts[idx])) print( f'[agent {idx}] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: {self.dialog}' ) self.task_turn_idx += 1 def shutdown(self): if self.chat_done: self.opt['run_statistics'][self.bot.worker_id] += 1 print('Runs completed per model: ' + ', '.join( f'{model}: {count:d}' for model, count in self.opt['run_statistics'].items())) self.agent.shutdown() def episode_done(self): return self.chat_done def _get_persona_utterance( self, persona_strings: Optional[List[str]] = None, context_dataset: Optional[str] = None, additional_context: Optional[str] = None, is_bot: bool = False, ): if is_bot: # Pass back the original context persona_pieces = [ f"your persona: {str_}" for str_ in persona_strings ] if context_dataset == 'wizard_of_wikipedia': additional_context_pieces = [additional_context] else: additional_context_pieces = [] full_context = '\n'.join(persona_pieces + additional_context_pieces) print(f'FULL CONTEXT: {full_context}') return full_context else: if context_dataset == 'convai2': last_sentence = 'Pretend that the conversation has already begun.' elif context_dataset == 'empathetic_dialogues': last_sentence = ( f'Pretend that the conversation has already begun, and that you ' f'had been talking about the following situation: ' f'<b>"{additional_context}"</b>') elif context_dataset == 'wizard_of_wikipedia': last_sentence = ( f'Pretend that the conversation has already begun, and that you ' f'had been talking about <b>{additional_context}</b>.') else: raise ValueError('Context dataset unrecognized!') joined_personas = '\n'.join(persona_strings) return ( f'\nSuccessfully matched with another user! Now let\'s get to know ' f'each other through the chat. You need to finish at least ' f'<b>{self.num_turns} chat turns</b>, and after that you can click the ' f'"Done" button to end the chat.\n\n' f'<b>Your character description is:\n<span style="color:blue">{joined_personas}</span></b> ' '\n\n<b>Remember that you can get to know each ' 'other as your characters, talk about any topic, or talk about a ' 'situation that might have happened to your character.</b>' '\n<b>Do not trivially copy the ' 'character descriptions into the message.</b><br><br>' f'{last_sentence}') def get_final_chat_data(self) -> Dict[str, Any]: """ Return specific info about the conversation, the context, acceptability, etc. """ if self.check_acceptability: human_texts = [ message['text'] for message in self.dialog if message['agent_idx'] == 0 ] violation_types = [ 'min_words', 'all_caps', 'exact_match', 'safety' ] if self.opt['conversation_start_mode'] == 'bst': # The BST mode starts the conversation with two previous utterances, so # there should be no new greeting. Also, the first human response is one # of the previous utterances, so it shouldn't get checked. violation_types.append('penalize_greetings') human_texts = human_texts[1:] violations_string = self.acceptability_checker.check_messages( messages=human_texts, is_worker_0=False, violation_types=violation_types) else: violations_string = None data = { 'personas': self.personas, 'context_dataset': self.context_info.get('context_dataset'), 'person1_seed_utterance': self.context_info.get('person1_seed_utterance'), 'person2_seed_utterance': self.context_info.get('person2_seed_utterance'), 'additional_context': self.context_info.get('additional_context'), 'dialog': self.dialog, 'workers': [get_mturk_id_from_mephisto_wrapper(self.agent)], 'bad_workers': [], 'acceptability_violations': (violations_string, ), 'hit_ids': [self.agent.mephisto_agent.task_run_id], 'assignment_ids': [self.agent.mephisto_agent.assignment_id], 'task_description': { 'annotations_config': self.opt['annotations_config'], 'model_nickname': self.bot.worker_id, 'model_file': self.bot.model_agent.opt.get('model_file'), 'model_opt': self.bot.model_agent.opt, }, } # 'bad_workers' is for compatibility. Before, it was only non-empty if a # worker abandoned, returned, etc. a HIT, but now we don't even save chat # data in that case if self.check_acceptability: data['acceptability_violations'] = (violations_string, ) # Make a tuple for compatibility with a human/human conversation in # which we check both sides for acceptability return data
class ModelChatResultsCompiler(BaseModelChatResultsCompiler): """ Compile and save results of human+model chats. Results will be saved on the level of specific conversations, as well as aggregated up the level of each worker as a whole. """ @classmethod def setup_args(cls): parser = super().setup_args() parser.add_argument('--model-nickname', type=str, default='', help='name of the model') parser.add_argument( '--completed-run-stats-path', type=str, default='', help='path of the task run stats file', ) return parser def __init__(self, opt: Dict[str, Any]): AbstractTurnAnnotationResultsCompiler.__init__(self, opt) # Input args self.model_nickname = opt['model_nickname'] assert len(self.results_folders) > 0 for folder in self.results_folders: assert os.path.isdir(folder), f'{folder} is not a valid folder!' os.makedirs(self.output_folder, exist_ok=True) self.start_date = opt['start_date'] self.max_convos_per_worker = opt['max_convos_per_worker'] self.min_word_count = opt['min_word_count'] self.hit_block_list = opt['hit_block_list'].split(',') self.worker_block_list = opt['worker_block_list'].split(',') # Setting up problem buckets if self.use_problem_buckets: self.regular_buckets = [ bucket for bucket in self.problem_buckets if bucket not in ['other', 'none_all_good'] ] # Remove the buckets that are special cases self.acceptability_checker = AcceptabilityChecker() self.completed_run_stats_path = opt['completed_run_stats_path'] def compile_results(self) -> pd.DataFrame: # TODO modularize the shared components to dedup the code read_folders = [] date_strings = [] import ipdb ipdb.set_trace() for folder in self.results_folders: # Load paths # TODO load this data in using DataBrowser date_strings = sorted([ obj for obj in os.listdir(folder) if os.path.isdir(os.path.join(folder, obj)) and re.fullmatch(r'\d\d\d\d_\d\d_\d\d', obj) ]) if self.start_date != '': date_strings = [ str_ for str_ in date_strings if str_ >= self.start_date ] folders = [os.path.join(folder, str_) for str_ in date_strings] read_folders.extend(folders) print(f'Date folders: ' + ', '.join(date_strings)) # Read in each file num_incomplete_convos = 0 num_complete_convos = 0 complete_convos_per_model = {} bad_conversations = [] worker_stats = {} worker_conversation_counts = {} conversation_idx = 0 conversation_dfs = [] stat_counts = {} for read_folder in read_folders: read_folder_name = os.path.split(read_folder)[-1] for file_name in sorted(os.listdir(read_folder)): if file_name in self.hit_block_list or 'sandbox' in file_name: continue if 'incomplete' in file_name: num_incomplete_convos += 1 continue else: num_complete_convos += 1 # Read in file with open(os.path.join(read_folder, file_name), 'rb') as f: data = json.load(f) # Only include the first max_convos_per_worker conversations from a # worker to avoid biasing worker_id = data['workers'][0] worker_id = worker_id.split('-')[-1] assignment_id = data['assignment_ids'][0] if worker_id in worker_conversation_counts: conversations_so_far = worker_conversation_counts[ worker_id] else: conversations_so_far = 0 worker_conversation_counts[ worker_id] = conversations_so_far + 1 if (self.max_convos_per_worker != -1 and conversations_so_far >= self.max_convos_per_worker): print( f'Had {conversations_so_far} conversation(s) already from this worker {worker_id}. Skipping {assignment_id}.' ) continue # Check if need to block the turker word_counts = [ len(d['text'].split(' ')) for d in data['dialog'] if d['agent_idx'] == 0 ] utterances = [ d['text'] for d in data['dialog'] if d['agent_idx'] == 0 ] if np.average(word_counts) < self.min_word_count: bad_conversations.append(data) print( f'Bad complete conversation, words from human: {utterances}. Skipping.' ) continue if not all(bucket in data['dialog'][0]['problem_data'] for bucket in self.problem_buckets): raise ValueError( 'Bucket(s) are missing from the problem data!') model_nickname = data['model_name'] assert self.model_nickname == model_nickname initial_data_id = data['context_info']['observation_for_bot'][ 'initial_data_id'] if model_nickname not in stat_counts: stat_counts[model_nickname] = {} if model_nickname in complete_convos_per_model: complete_convos_per_model[model_nickname].append( initial_data_id) else: complete_convos_per_model[model_nickname] = [ initial_data_id ] # Extract non-message info info_dict = { 'read_folder_name': read_folder_name, 'file_name': file_name, 'worker': worker_id, 'model_nickname': model_nickname, 'bad_workers': ','.join(data['bad_workers']), 'hit_id': data['hit_ids'][0], 'assignment_id': assignment_id, 'is_incomplete': 'incomplete' in file_name, 'context_info': data['context_info'], 'bot_persona_strings': data['bot_persona_strings'], 'human_persona_strings': data['human_persona_strings'], 'initial_task_data': data['initial_task_data'], 'initial_data_id': initial_data_id, } # Check that the conversation consists of pairs of comments between # agents 0 and 1, with 1(bot) speaking first assert all([ utterance_data['agent_idx'] == (utterance_idx + 1) % 2 for utterance_idx, utterance_data in enumerate(data['dialog']) ]) # Determine whether the HIT contains unacceptable messages. # (We do this for every HIT, even if acceptability violation info # was already saved, because the violation criteria may have # changed since the HIT was collected.) messages_0 = [ utt for utt in data['dialog'] if utt['agent_idx'] == 0 ] messages_1 = [ utt for utt in data['dialog'] if utt['agent_idx'] == 1 ] assert len(messages_0) + len(messages_1) == len(data['dialog']) # Check the human utterances for safety utterances_0 = [m['text'] for m in messages_0] info_dict[ 'acceptability_violations_0'] = self.acceptability_checker.check_messages( messages=utterances_0, is_worker_0=True, violation_types=self.acceptability_checker. ALL_VIOLATION_TYPES, ) # Compile personas and previous utterances df = pd.DataFrame( [], columns=[ 'folder', 'file_name' 'worker_id', 'hit_id', 'is_incomplete', 'context_info', 'initial_data_id', 'acceptability_violations_0', 'model_nickname', 'conversation_idx', 'turn_idx', 'agent_idx', 'text', ] + self.problem_buckets, ) df = df.append( { 'folder': info_dict['read_folder_name'], 'file_name': info_dict['file_name'], 'worker_id': info_dict['worker'], 'hit_id': info_dict['hit_id'], 'is_incomplete': info_dict['is_incomplete'], 'context_info': info_dict['context_info'], 'initial_data_id': info_dict['initial_task_data'], 'acceptability_violations_0': info_dict['acceptability_violations_0'], 'model_nickname': model_nickname, 'conversation_idx': conversation_idx, 'turn_idx': -1, 'agent_idx': 0, 'text': info_dict['context_info']['observation_for_bot'] ['text'], **{bucket: '' for bucket in self.problem_buckets}, }, ignore_index=True, ) for utterance_idx, utt in enumerate(data['dialog']): d = { 'folder': info_dict['read_folder_name'], 'file_name': info_dict['file_name'], 'worker_id': info_dict['worker'], 'hit_id': info_dict['hit_id'], 'is_incomplete': info_dict['is_incomplete'], 'context_info': info_dict['context_info'], 'initial_data_id': info_dict['initial_task_data'], 'acceptability_violations_0': info_dict['acceptability_violations_0'], 'model_nickname': model_nickname, 'conversation_idx': conversation_idx, 'turn_idx': utterance_idx, 'agent_idx': utt['agent_idx'], 'text': utt['text'], **{bucket: '' for bucket in self.problem_buckets}, } if utt['agent_idx'] == 1: if 'problem_data' not in utt: for bucket in self.problem_buckets: d[bucket] = 'MALFORMED' print( f'Warning got MALFORMED utterance problem data inside complete convo: {utt}. Skipping.' ) continue else: for bucket in self.regular_buckets: d[bucket] = utt['problem_data'][bucket] d['final_rating'] = (utt['final_rating'] if 'final_rating' in utt else None) for k in self.regular_buckets: if k not in stat_counts[model_nickname]: stat_counts[model_nickname][k] = 0 stat_counts[model_nickname][k] += d[k] if 'total' not in stat_counts[model_nickname]: stat_counts[model_nickname]['total'] = 0 if d['agent_idx'] == 1: stat_counts[model_nickname]['total'] += 1 if d['final_rating'] is not None: # Only one the last utterance (agent idx == 1) if 'count_ratings' not in stat_counts[ model_nickname]: stat_counts[model_nickname][ 'count_ratings'] = 0 stat_counts[model_nickname]['count_ratings'] += 1 if 'ratings' not in stat_counts[model_nickname]: stat_counts[model_nickname]['ratings'] = [] if 'pairwise_ratings' not in stat_counts[ model_nickname]: stat_counts[model_nickname][ 'pairwise_ratings'] = {} stat_counts[model_nickname]['ratings'].append( int(d['final_rating'])) stat_counts[model_nickname]['pairwise_ratings'][ info_dict['initial_data_id']] = int( d['final_rating']) if 'bot_word_count' not in stat_counts[model_nickname]: stat_counts[model_nickname]['bot_word_count'] = 0 stat_counts[model_nickname]['bot_word_count'] += len( d['text'].strip().split(' ')) else: # Counting some aspects of the human's utterances if 'human_utterance_count' not in stat_counts[ model_nickname]: stat_counts[model_nickname][ 'human_utterance_count'] = 0 stat_counts[model_nickname][ 'human_utterance_count'] += 1 if 'human_word_count' not in stat_counts[ model_nickname]: stat_counts[model_nickname]['human_word_count'] = 0 stat_counts[model_nickname]['human_word_count'] += len( d['text'].strip().split(' ')) if 'human_question_count' not in stat_counts[ model_nickname]: stat_counts[model_nickname][ 'human_question_count'] = 0 stat_counts[model_nickname][ 'human_question_count'] += d['text'].count('?') # Only want to count bot utterances but human ones, while included, # won't be False if info_dict['worker'] not in worker_stats: worker_stats[info_dict['worker']] = {'conversations': 0} worker_stats[info_dict['worker']]['conversations'] += 1 # Logic for calculating percent of conversations that are clean if 'count_convos' not in stat_counts[model_nickname]: stat_counts[model_nickname]['count_convos'] = 0 stat_counts[model_nickname]['count_convos'] += 1 # Adding the full conversation to the list of conversations conversation_dfs.append(df) conversation_idx += 1 for m, conversations_completed in complete_convos_per_model.items(): print( f'Got {len(conversations_completed)} complete conversations for model: {m}' ) print(f"{m} completed: {conversations_completed}") print(f'{num_complete_convos:d} complete conversation(s) collected.') print(f'{len(bad_conversations):d} bad conversation(s).') num_approved_convos = num_complete_convos - len(bad_conversations) print(f'{num_approved_convos:d} approved conversation(s).') print( f'({num_incomplete_convos:d} incomplete conversation(s) collected.)' ) for model_nickname, model_stats_dict in stat_counts.items(): print(f'---{model_nickname}---') for p, v in model_stats_dict.items(): if p == 'count_ratings' or p == 'pairwise_ratings': continue if p == 'ratings': print( f'Average Engaging-ness Rating: {np.average(model_stats_dict["ratings"])} ({model_stats_dict["count_ratings"]} ratings)' ) continue if p == 'human_word_count' or p == 'human_question_count': print( f'{p}: {v} ({v/model_stats_dict["human_utterance_count"]:.3})' ) elif p == 'bot_word_count': print(f'{p}: {v} ({v/model_stats_dict["total"]:.3})') elif p == 'human_utterance_count': print(f'{p}: {v}') elif p == 'count_convos': print(f'{p}: {v}') else: print(f'{p}: {v} ({v/model_stats_dict["total"]:.2%})') print('Printing worker IDs not already in block list to add...') for b in bad_conversations: worker_id = b['workers'][0] if worker_id not in self.worker_block_list: print(f"""'{worker_id}',""") print('Done printing bad workers.') worker_df = pd.DataFrame([], columns=['worker_id', 'conversations']) for worker_id, data in worker_stats.items(): stat = { 'worker_id': worker_id, 'conversations': data['conversations'] } worker_df = worker_df.append(stat, ignore_index=True) with open(self.completed_run_stats_path, 'r') as f: completed_run_stats = json.load(f) assert completed_run_stats['bot_model_name'] == self.model_nickname completed_run_stats['context_done_statistics'][ self.model_nickname] = complete_convos_per_model[ self.model_nickname] completed_run_stats['context_done_counts'] = len( complete_convos_per_model[self.model_nickname]) with open(self.completed_run_stats_path, 'w') as fw: json.dump(completed_run_stats, fw) print(f'Wrote override opt to: {self.completed_run_stats_path}') rating_path = os.path.join(self.output_folder, f'pairwise_ratings.json') with open(rating_path, 'w') as fw: json.dump(stat_counts[self.model_nickname]['pairwise_ratings'], fw) print(f'Wrote pairwise ratings to: {rating_path}') # Save full results all_conversations_df = pd.DataFrame() for df in conversation_dfs: all_conversations_df = all_conversations_df.append(df) return all_conversations_df
class ModelChatResultsCompiler(AbstractTurnAnnotationResultsCompiler): """ Compile and save results of human+model chats. Results will be saved on the level of specific conversations, as well as aggregated up the level of each worker as a whole. """ @classmethod def setup_args(cls): parser = super().setup_args() parser.add_argument( '--start-date', type=str, default='', help='The earliest date to analyze results from', ) parser.add_argument( '--max-convos-per-worker', type=int, default=100, help= 'The most conversations to analyze from any one user. Set to -1 for no limit.', ) parser.add_argument( '--min-word-count', type=int, default=4, help= 'The minimum acceptable mean number of words per human utterance', ) parser.add_argument( '--hit-block-list', type=str, default='', help='Comma-separated list of all hits to block', ) parser.add_argument( '--worker-block-list', type=str, default='', help='Comma-separated list of all workers to block', ) return parser def __init__(self, opt: Dict[str, Any]): super().__init__(opt) # Validate problem buckets if self.use_problem_buckets and 'none_all_good' not in self.problem_buckets: # The code relies on a catchall "none" category if the user selects no other # annotation bucket raise ValueError( 'There must be a "none_all_good" category in self.problem_buckets!' ) # Input args assert len(self.results_folders) > 0 for folder in self.results_folders: assert os.path.isdir(folder), f'{folder} is not a valid folder!' os.makedirs(self.output_folder, exist_ok=True) self.start_date = opt['start_date'] self.max_convos_per_worker = opt['max_convos_per_worker'] self.min_word_count = opt['min_word_count'] self.hit_block_list = opt['hit_block_list'].split(',') self.worker_block_list = opt['worker_block_list'].split(',') # Setting up problem buckets if self.use_problem_buckets: self.regular_buckets = [ bucket for bucket in self.problem_buckets if bucket not in ['other', 'none_all_good'] ] # Remove the buckets that are special cases self.acceptability_checker = AcceptabilityChecker() def get_results_path_base(self) -> str: now = datetime.now() return os.path.join(self.output_folder, f'results_{now.strftime("%Y%m%d_%H%M%S")}') def compile_results(self) -> pd.DataFrame: read_folders = [] date_strings = [] for folder in self.results_folders: # Load paths date_strings = sorted([ obj for obj in os.listdir(folder) if os.path.isdir(os.path.join(folder, obj)) and re.fullmatch(r'\d\d\d\d_\d\d_\d\d', obj) ]) if self.start_date != '': date_strings = [ str_ for str_ in date_strings if str_ >= self.start_date ] folders = [os.path.join(folder, str_) for str_ in date_strings] read_folders.extend(folders) print(f'Date folders: ' + ', '.join(date_strings)) now = datetime.now() worker_results_file = os.path.join( self.output_folder, f'worker_results_{now.strftime("%Y%m%d_%H%M%S")}.csv') # Read in each file num_incomplete_convos = 0 num_complete_convos = 0 complete_convos_per_model = {} bad_conversations = [] stat_counts = {} worker_stats = {} worker_conversation_counts = {} total_utterances = 0 conversation_idx = 0 conversation_dfs = [] for read_folder in read_folders: read_folder_name = os.path.split(read_folder)[-1] for file_name in sorted(os.listdir(read_folder)): if file_name in self.hit_block_list: continue if 'incomplete' in file_name: num_incomplete_convos += 1 continue else: num_complete_convos += 1 # Read in file with open(os.path.join(read_folder, file_name), 'rb') as f: data = json.load(f) # Only include the first max_convos_per_worker conversations from a # worker to avoid biasing worker_id = data['workers'][0] assignment_id = data['assignment_ids'][0] if worker_id in worker_conversation_counts: conversations_so_far = worker_conversation_counts[ worker_id] else: conversations_so_far = 0 worker_conversation_counts[ worker_id] = conversations_so_far + 1 if (self.max_convos_per_worker != -1 and conversations_so_far >= self.max_convos_per_worker): print( f'Had {conversations_so_far} conversation(s) already from this worker {worker_id}. Skipping {assignment_id}.' ) continue # Check if need to block the turker word_counts = [ len(d['text'].split(' ')) for d in data['dialog'] if d['agent_idx'] == 0 ] utterances = [ d['text'] for d in data['dialog'] if d['agent_idx'] == 0 ] if np.average(word_counts) < self.min_word_count: bad_conversations.append(data) print( f'Bad complete conversation, words from human: {utterances}. Skipping.' ) continue if self.use_problem_buckets: if not all(bucket in data['dialog'][1]['problem_data'] for bucket in self.problem_buckets): raise ValueError( 'Bucket(s) are missing from the problem data!') model_nickname = data['task_description']['model_nickname'] if model_nickname not in stat_counts: stat_counts[model_nickname] = {} if model_nickname in complete_convos_per_model: complete_convos_per_model[model_nickname] += 1 else: complete_convos_per_model[model_nickname] = 1 # Extract non-message info info_dict = { 'read_folder_name': read_folder_name, 'file_name': file_name, 'worker': worker_id, 'model_nickname': model_nickname, 'bad_workers': ','.join(data['bad_workers']), 'hit_id': data['hit_ids'][0], 'assignment_id': assignment_id, 'is_incomplete': 'incomplete' in file_name, 'context_dataset': data['context_dataset'], 'additional_context': data['additional_context'], } # Check that the conversation consists of pairs of comments between # agents 0 and 1, with 0 speaking first assert all([ utterance_data['agent_idx'] == utterance_idx % 2 for utterance_idx, utterance_data in enumerate(data['dialog']) ]) # Determine whether the HIT contains unacceptable messages. # (We do this for every HIT, even if acceptability violation info # was already saved, because the violation criteria may have # changed since the HIT was collected.) messages_0 = [ utt for utt in data['dialog'] if utt['agent_idx'] == 0 ] messages_1 = [ utt for utt in data['dialog'] if utt['agent_idx'] == 1 ] assert len(messages_0) + len(messages_1) == len(data['dialog']) # Check the human utterances for safety utterances_0 = [m['text'] for m in messages_0] info_dict[ 'acceptability_violations_0'] = self.acceptability_checker.check_messages( messages=utterances_0, is_worker_0=True, violation_types=self.acceptability_checker. ALL_VIOLATION_TYPES, ) # Compile personas and previous utterances df = pd.DataFrame( [], columns=[ 'folder', 'worker_id', 'hit_id', 'model_nickname', 'conversation_idx', 'turn_idx', 'agent_idx', 'text', ] + self.problem_buckets, ) text_parts = [] if data['personas'] is not None and len(data['personas']) > 0: text_parts += [ 'your persona: ' + data['personas'][1][0], 'your persona: ' + data['personas'][1][1], ] if (data['additional_context'] is not None and len(data['additional_context']) > 0): text_parts.append(data['additional_context']) df = df.append( { 'folder': info_dict['read_folder_name'], 'worker_id': info_dict['worker'], 'hit_id': info_dict['hit_id'], 'model_nickname': model_nickname, 'conversation_idx': conversation_idx, 'turn_idx': -1, 'agent_idx': 1, 'text': '\n'.join(text_parts), **{bucket: '' for bucket in self.problem_buckets}, }, ignore_index=True, ) total_utterances += len( [d for d in data["dialog"] if d["agent_idx"] == 1]) if len(data['dialog']) > 20: print( f'Got long dialogue of {len(data["dialog"])} utterances, hit id: {info_dict["hit_id"]}, model_nickname: {model_nickname}.' ) if self.use_problem_buckets: dialog_has_problems = False for utterance_idx, utt in enumerate(data['dialog']): d = { 'folder': info_dict['read_folder_name'], 'worker_id': info_dict['worker'], 'hit_id': info_dict['hit_id'], 'model_nickname': model_nickname, 'conversation_idx': conversation_idx, 'turn_idx': utterance_idx, 'agent_idx': utt['agent_idx'], 'text': utt['text'], **{bucket: '' for bucket in self.problem_buckets}, } if utt['agent_idx'] == 1: d['final_rating'] = utt.get('final_rating') if self.use_problem_buckets: if 'problem_data' not in utt: for bucket in self.problem_buckets: d[bucket] = 'MALFORMED' print( f'Warning got MALFORMED utterance problem data inside complete convo: {utt}. Skipping.' ) continue else: for bucket in self.regular_buckets + [ 'none_all_good' ]: d[bucket] = utt['problem_data'][bucket] for k in self.regular_buckets + ['none_all_good']: if k not in stat_counts[model_nickname]: stat_counts[model_nickname][k] = 0 stat_counts[model_nickname][k] += d[k] if k != 'none_all_good' and d[k]: dialog_has_problems = True if 'total' not in stat_counts[model_nickname]: stat_counts[model_nickname]['total'] = 0 if d['agent_idx'] == 1: stat_counts[model_nickname]['total'] += 1 if d['final_rating'] is not None: # Only one the last utterance (agent idx == 1) if 'count_ratings' not in stat_counts[ model_nickname]: stat_counts[model_nickname][ 'count_ratings'] = 0 stat_counts[model_nickname]['count_ratings'] += 1 if 'ratings' not in stat_counts[model_nickname]: stat_counts[model_nickname]['ratings'] = [] stat_counts[model_nickname]['ratings'].append( int(d['final_rating'])) else: # Counting some aspects of the human's utterances if 'human_utterance_count' not in stat_counts[ model_nickname]: stat_counts[model_nickname][ 'human_utterance_count'] = 0 stat_counts[model_nickname][ 'human_utterance_count'] += 1 if 'human_word_count' not in stat_counts[ model_nickname]: stat_counts[model_nickname]['human_word_count'] = 0 stat_counts[model_nickname]['human_word_count'] += len( d['text'].strip().split(' ')) if 'human_question_count' not in stat_counts[ model_nickname]: stat_counts[model_nickname][ 'human_question_count'] = 0 stat_counts[model_nickname][ 'human_question_count'] += d['text'].count('?') d = self._add_additional_per_turn_stats(d=d, utt=utt) df = df.append(d, ignore_index=True) if info_dict['worker'] not in worker_stats: worker_stats[info_dict['worker']] = {'conversations': 0} if self.use_problem_buckets: worker_stats[info_dict['worker']]['problems_found'] = 0 worker_stats[info_dict['worker']]['conversations'] += 1 if self.use_problem_buckets: # Count the number of problems the worker got is_problem = ~df['none_all_good'].replace('', True) # Only want to count bot utterances but human ones, while included, # won't be False count = is_problem.sum() worker_stats[ info_dict['worker']]['problems_found'] += count # Logic for calculating percent of conversations that are clean if 'count_convos' not in stat_counts[model_nickname]: stat_counts[model_nickname]['count_convos'] = 0 stat_counts[model_nickname]['count_convos'] += 1 if self.use_problem_buckets and not dialog_has_problems: if 'convo_clean' not in stat_counts[model_nickname]: stat_counts[model_nickname]['convo_clean'] = 0 stat_counts[model_nickname]['convo_clean'] += 1 # Adding the full conversation to the list of conversations conversation_dfs.append(df) conversation_idx += 1 for m, conversation_count in complete_convos_per_model.items(): print( f'Got {conversation_count} complete conversation(s) for model: {m}' ) print(f'{num_complete_convos:d} complete conversation(s) collected.') print(f'{len(bad_conversations):d} bad conversation(s).') num_approved_convos = num_complete_convos - len(bad_conversations) print(f'{num_approved_convos:d} approved conversation(s).') print( f'({num_incomplete_convos:d} incomplete conversation(s) collected.)' ) for model_nickname, model_stats_dict in stat_counts.items(): print(f'---{model_nickname}---') for p, v in model_stats_dict.items(): if p == 'count_ratings': continue if p == 'ratings': print( f'Average Engaging-ness Rating: {np.average(model_stats_dict["ratings"])} ({model_stats_dict["count_ratings"]} ratings)' ) continue if p == 'human_word_count' or p == 'human_question_count': print( f'{p}: {v} ({v/model_stats_dict["human_utterance_count"]:.3})' ) elif p == 'human_utterance_count': print(f'{p}: {v}') elif p == 'count_convos': print(f'{p}: {v}') elif self.use_problem_buckets and p == 'convo_clean': print( f'{p}: {v} ({v/model_stats_dict["count_convos"]:.2%})') else: print(f'{p}: {v} ({v/model_stats_dict["total"]:.2%})') print('Printing worker IDs not already in block list to add...') for b in bad_conversations: worker_id = b['workers'][0] if worker_id not in self.worker_block_list: print(f"""'{worker_id}',""") print('Done printing bad workers.') print('Worker stats:') worker_columns = ['worker_id', 'conversations'] if self.use_problem_buckets: worker_columns += ['problems_found', 'avg_problems_per_convo'] worker_df = pd.DataFrame([], columns=worker_columns) for worker_id, data in worker_stats.items(): print(worker_id) stat = { 'worker_id': worker_id, 'conversations': data['conversations'] } if self.use_problem_buckets: avg_problems_per_convo = data['problems_found'] / data[ 'conversations'] stat.update({ 'problems_found': data['problems_found'], 'avg_problems_per_convo': avg_problems_per_convo, }) worker_df = worker_df.append(stat, ignore_index=True) if self.use_problem_buckets: worker_df = worker_df.sort_values('avg_problems_per_convo', ascending=0) worker_df.to_csv(worker_results_file, index=False) print(worker_df) print(f'Wrote worker statistical results to: {worker_results_file}') # Save full results all_conversations_df = pd.DataFrame() for df in conversation_dfs: all_conversations_df = all_conversations_df.append(df) print(f'\nWorker conversation counts: {worker_conversation_counts}') return all_conversations_df def _add_additional_per_turn_stats(self, d: dict, utt: dict) -> dict: """ Add in additional statistics on the level of each conversation turn. Useful for subclasses. """ _ = utt # utt is ignored in this passthrough method return d
class PerTurnEvalResultsCompiler(AbstractResultsCompiler): """ Compile and save results of human+model chats. Results will be saved on the level of specific conversations, as well as aggregated up the level of each worker as a whole. """ # TODO: deduplicate setup_args from ModelChatResultsCompiler @classmethod def setup_args(cls): parser = super().setup_args() parser.add_argument( '--worker-block-list', type=str, default='', help='Comma-separated list of all workers to block', ) return parser def __init__(self, opt: Dict[str, Any]): # TODO: deduplicate init from ModelChatResultsCompiler super().__init__(opt) # Input args os.makedirs(self.output_folder, exist_ok=True) self.worker_block_list = opt['worker_block_list'].split(',') # Save paths self.worker_results_path = os.path.join( self.output_folder, 'worker_results.csv' ) self.unacceptable_worker_ids_path = os.path.join( self.output_folder, 'unacceptable_worker_ids.txt' ) self.win_rate_by_date_path = os.path.join( self.output_folder, 'win_rates_by_date.csv' ) self.stat_mean_length_by_date_path = os.path.join( self.output_folder, 'stat_mean_length_by_date.csv' ) self.completion_time_by_model_pair_path = os.path.join( self.output_folder, 'mean_completion_times.csv' ) self.acceptability_checker = AcceptabilityChecker() # Set fields that should be empty strings if the relevant information is not # present blank_field_columns = [ 'human_text', 'human_choice', 'human_justification', 'accepted_bot_text', 'not_accepted_bot_text', ] self.blank_fields = {field: '' for field in blank_field_columns} # Results attributes self.stat_counts = {} self.mean_completion_time = None # Useful for subclasses, to compare with other eval techniques def get_results_path_base(self) -> str: return os.path.join(self.output_folder, 'results') def compile_results(self) -> pd.DataFrame: # Load task data logging.info('Retrieving task data from Mephisto.') task_units_data = self.get_task_data() logging.info(f'Data for {len(task_units_data)} units loaded successfully.') # Read in each file num_convos_with_no_save_data = 0 num_wrong_status_convos = 0 num_complete_convos = 0 worker_stats = {} worker_conversation_counts = {} total_utterances = 0 unacceptable_task_units = [] unacceptable_worker_ids = [] conversation_idx = 0 conversation_dfs = [] for task_unit in task_units_data: worker_id = task_unit['worker_id'] assignment_id = task_unit['assignment_id'] # Determining whether the task unit should be skipped # Extract out custom data if task_unit['data']['save_data'] is None: logging.info('Found a task unit with no save data! Skipping.') num_convos_with_no_save_data += 1 continue elif task_unit['status'] not in ['completed', 'approved']: logging.info( f'Found a HIT with the status "{task_unit["status"]}"!.' f'Skipping.' ) num_wrong_status_convos += 1 continue else: num_complete_convos += 1 # Check if the Turker is on the list of blocked Turkers if worker_id in self.worker_block_list: logging.info( f'Found a HIT with the worker {worker_id}, on the blocklist. ' f'Skipping.' ) continue if worker_id in worker_conversation_counts: conversations_so_far = worker_conversation_counts[worker_id] else: conversations_so_far = 0 worker_conversation_counts[worker_id] = conversations_so_far + 1 data = task_unit['data']['save_data']['custom_data'] # Extract out information about this conversation model_1_nickname = data['task_description']['model_1_nickname'] model_2_nickname = data['task_description']['model_2_nickname'] # Since we have two models, we use the format model_1_name:model_2_name model_pair_nickname = f"{model_1_nickname}:{model_2_nickname}" if model_pair_nickname not in self.stat_counts: self.stat_counts[model_pair_nickname] = { 'per_turn': defaultdict( lambda: {model_1_nickname: 0, model_2_nickname: 0} ) } # Extract non-message info mturk_worker_id_match = re.fullmatch( r'--NOT-MTURK-AGENT-(.*)', data['workers'][0] ) # TODO: figure out why --NOT-MTURK-AGENT appears at the beginning of this # field, and remove it; then, remove this re.fullmatch() call if mturk_worker_id_match is not None: mturk_worker_id = mturk_worker_id_match.group(1) else: mturk_worker_id = None task_start = datetime.utcfromtimestamp(task_unit['task_start']) task_end = datetime.utcfromtimestamp(task_unit['task_end']) single_convo_info_dict = { 'worker': worker_id, 'mturk_worker_id': mturk_worker_id, 'model_pair_nickname': model_pair_nickname, 'bad_workers': ','.join(data['bad_workers']), 'hit_id': data['hit_ids'][0], 'assignment_id': assignment_id, 'context_dataset': data['context_dataset'], 'additional_context': data['additional_context'], 'date': task_start.strftime('%Y-%m-%d'), 'task_start': task_start, 'task_end': task_end, 'completion_time': (task_end - task_start).total_seconds(), } # TODO: 'task_start' and 'task_end' assume that the original datetime floats # are stored in UTC. Check this! # Check that the conversation consists of pairs of comments between # agents 0 and 1, with 0 speaking first assert all( [ utterance_data['agent_idx'] == utterance_idx % 2 for utterance_idx, utterance_data in enumerate(data['dialog']) ] ) messages_0 = [utt for utt in data['dialog'] if utt['agent_idx'] == 0] messages_1 = [utt for utt in data['dialog'] if utt['agent_idx'] == 1] assert len(messages_0) + len(messages_1) == len(data['dialog']) # Determine whether the HIT contains unacceptable messages. # (We do this for every HIT, even if acceptability violation info # was already saved, because the violation criteria may have # changed since the HIT was collected.) utterances_0 = [m['text'] for m in messages_0] assert utterances_0[0] == 'Hi!', ( 'This script assumes that the first human message is "Hi!", which is ' 'set by default and cannot be changed by the crowdsourcing worker.' ) acceptability_violations = self.acceptability_checker.check_messages( messages=utterances_0[1:], # Don't use the initial "Hi!" is_worker_0=True, violation_types=self.acceptability_checker.ALL_VIOLATION_TYPES, ) if acceptability_violations != '': logging.info( f'Conversation fails acceptability checks with a violation of ' f'"{acceptability_violations}", given the following utterances: ' f'{utterances_0[1:]}. Skipping.' ) unacceptable_task_units.append(task_unit) assert ( mturk_worker_id is not None ), "MTurk worker ID cannot be determined for this unacceptable conversation!" unacceptable_worker_ids.append(mturk_worker_id) continue single_convo_info_dict[ 'acceptability_violations' ] = acceptability_violations # Identify information to put in each line of the output DataFrame info_for_each_turn = { 'worker_id': single_convo_info_dict['worker'], 'mturk_worker_id': single_convo_info_dict['mturk_worker_id'], 'hit_id': single_convo_info_dict['hit_id'], 'model_pair_nickname': model_pair_nickname, 'conversation_idx': conversation_idx, 'date': single_convo_info_dict['date'], 'completion_time': single_convo_info_dict['completion_time'], } single_turn_dicts = [] # Compile personas and previous utterances text_parts = [] if data['personas'] is not None and len(data['personas']) > 0: assert len(data['personas']) == 2 text_parts += [ 'human persona: ' + ' '.join(data['personas'][0]), 'bot persona: ' + ' '.join(data['personas'][1]), ] if ( data['additional_context'] is not None and len(data['additional_context']) > 0 ): text_parts.append(data['additional_context']) single_turn_dicts.append( { **info_for_each_turn, 'turn_idx': -1, 'agent_idx': -1, 'context': '\n'.join(text_parts), **self.blank_fields, } ) total_utterances += len([d for d in data["dialog"] if d["agent_idx"] == 1]) if len(data['dialog']) > 20: logging.info( f'Got long dialogue of {len(data["dialog"])} utterances, hit id: ' f'{single_convo_info_dict["hit_id"]}, model_pair_nickname: ' f'{model_pair_nickname}.' ) # # Loop over conversation turns for utterance_idx, utt in enumerate(data['dialog']): this_turn_dict = { **info_for_each_turn, 'turn_idx': utterance_idx, 'agent_idx': utt['agent_idx'], **self.blank_fields, } if utt['agent_idx'] == 1: # This is a turn in which the bots have responded human_turn_idx = int(utterance_idx / 2) + 1 # Turns are 1-indexed # TODO: maybe clean up some of this logic this_turn_dict['human_choice'] = utt['human_choice'] this_turn_dict['human_justification'] = utt['human_justification'] this_turn_dict['accepted_bot_text'] = ( utt['accepted_bot_data']['text'] .replace('\n', '__newline__') .replace('\r', '__CR__') ) this_turn_dict['not_accepted_bot_text'] = ( utt['not_accepted_bot_data']['text'] .replace('\n', '__newline__') .replace('\r', '__CR__') ) if 'total' not in self.stat_counts[model_pair_nickname]: self.stat_counts[model_pair_nickname]['total'] = 0 if this_turn_dict['agent_idx'] == 1: self.stat_counts[model_pair_nickname]['total'] += 1 # Calculating overall human choice statistics if model_1_nickname not in self.stat_counts[model_pair_nickname]: self.stat_counts[model_pair_nickname][model_1_nickname] = 0 if model_2_nickname not in self.stat_counts[model_pair_nickname]: self.stat_counts[model_pair_nickname][model_2_nickname] = 0 # Calculating per-turn human choice statistics if utt['human_choice'] == model_1_nickname: self.stat_counts[model_pair_nickname][model_1_nickname] += 1 self.stat_counts[model_pair_nickname]['per_turn'][ human_turn_idx ][model_1_nickname] += 1 elif utt['human_choice'] == model_2_nickname: self.stat_counts[model_pair_nickname][model_2_nickname] += 1 self.stat_counts[model_pair_nickname]['per_turn'][ human_turn_idx ][model_2_nickname] += 1 else: raise Exception( 'Something wrong has occurred: human choice is not equal ' 'to either of the two models!' ) else: # This is a turn in which the human has responded this_turn_dict['human_text'] = utt['text'] # Counting some aspects of the human's utterances if ( 'human_utterance_count' not in self.stat_counts[model_pair_nickname] ): self.stat_counts[model_pair_nickname][ 'human_utterance_count' ] = 0 self.stat_counts[model_pair_nickname]['human_utterance_count'] += 1 if 'human_word_count' not in self.stat_counts[model_pair_nickname]: self.stat_counts[model_pair_nickname]['human_word_count'] = 0 self.stat_counts[model_pair_nickname]['human_word_count'] += len( this_turn_dict['human_text'].strip().split(' ') ) if ( 'human_question_count' not in self.stat_counts[model_pair_nickname] ): self.stat_counts[model_pair_nickname][ 'human_question_count' ] = 0 self.stat_counts[model_pair_nickname][ 'human_question_count' ] += this_turn_dict['human_text'].count('?') single_turn_dicts.append(this_turn_dict) # Finish up collecting per-conversation stats if single_convo_info_dict['worker'] not in worker_stats: worker_stats[single_convo_info_dict['worker']] = {'conversations': 0} worker_stats[single_convo_info_dict['worker']]['conversations'] += 1 # Logic for calculating percent of conversations that are clean if 'acceptable_convos' not in self.stat_counts[model_pair_nickname]: self.stat_counts[model_pair_nickname]['acceptable_convos'] = 0 self.stat_counts[model_pair_nickname]['acceptable_convos'] += 1 # Adding the full conversation to the list of conversations single_convo_df = pd.DataFrame(single_turn_dicts) conversation_dfs.append(single_convo_df) conversation_idx += 1 # Print results # TODO: all of this would be cleaner if saved as CSVs, so we don't have to # re-run to get the results logging.info( f'{num_convos_with_no_save_data:d} conversations found with no save data.' ) logging.info( f'{num_wrong_status_convos:d} conversations found with the wrong status.' ) logging.info(f'{num_complete_convos:d} complete conversations found:') logging.info(f'\t{len(unacceptable_task_units):d} unacceptable conversations.') logging.info(f'\t{len(conversation_dfs):d} acceptable conversations.') for model_pair_nickname, model_stats_dict in self.stat_counts.items(): logging.info(f'---{model_pair_nickname}---') model_1_nickname = model_pair_nickname.split(":")[0] model_2_nickname = model_pair_nickname.split(":")[1] for p, v in model_stats_dict.items(): if p == 'per_turn': for human_turn_idx in model_stats_dict['per_turn']: per_turn_model_1 = model_stats_dict['per_turn'][human_turn_idx][ model_1_nickname ] per_turn_model_2 = model_stats_dict['per_turn'][human_turn_idx][ model_2_nickname ] per_turn_model_total = per_turn_model_1 + per_turn_model_2 logging.info( f"Turn {human_turn_idx}, {model_1_nickname}: {per_turn_model_1} " f"({per_turn_model_1/per_turn_model_total:.2%})" f", {model_2_nickname}: {per_turn_model_2} " f"({per_turn_model_2/per_turn_model_total:.2%})" ) continue if p == 'human_word_count' or p == 'human_question_count': logging.info( f'{p}: {v} ({v/model_stats_dict["human_utterance_count"]:.3})' ) elif p == 'human_utterance_count': logging.info(f'{p}: {v}') elif p == 'acceptable_convos': logging.info(f'{p}: {v}') else: logging.info(f'{p}: {v} ({v/model_stats_dict["total"]:.2%})') logging.info('Printing worker IDs not already in block list to add...') for b in unacceptable_task_units: worker_id = b['worker_id'] if worker_id not in self.worker_block_list: logging.info(f"""'{worker_id}',""") logging.info('Done printing bad workers.') logging.info(f'\nWorker conversation counts: {worker_conversation_counts}') # Compile full results all_conversations_df = pd.DataFrame() for single_convo_df in conversation_dfs: all_conversations_df = all_conversations_df.append(single_convo_df) for field in self.blank_fields.keys(): assert all_conversations_df[field].isna().sum() == 0, ( f'Some rows of the "{field}" column have NaNs in them, making them ' f'hard to calculate statistics on!' ) # Save analysis files logging.info( f'Saving worker statistical results to {self.worker_results_path}.' ) worker_columns = ['worker_id', 'conversations'] worker_df = pd.DataFrame([], columns=worker_columns) for worker_id, data in worker_stats.items(): stat = {'worker_id': worker_id, 'conversations': data['conversations']} worker_df = worker_df.append(stat, ignore_index=True) worker_df.to_csv(self.worker_results_path, index=False) logging.info( f'Saving MTurk IDs of workers with unacceptable conversations to ' f'{self.unacceptable_worker_ids_path}.' ) with open(self.unacceptable_worker_ids_path, 'w') as f: for worker_id in unacceptable_worker_ids: f.write(worker_id + '\n') logging.info(f'Saving win rates cut by date to {self.win_rate_by_date_path}.') pivoted_win_rate_df = ( all_conversations_df[lambda df: df['human_choice'].notna()] .assign(count=1) .groupby(['model_pair_nickname', 'date', 'human_choice']) .agg({'count': 'sum'}) .reset_index() .pivot( index=['model_pair_nickname', 'date'], columns='human_choice', values='count', ) ) model_names = pivoted_win_rate_df.columns pivoted_win_rate_df.loc[:, 'total_count'] = pivoted_win_rate_df[ model_names ].sum(axis=1) for model_name in model_names: pivoted_win_rate_df.loc[:, f'frac_{model_name}'] = ( pivoted_win_rate_df[model_name] / pivoted_win_rate_df['total_count'] ) pivoted_win_rate_df.to_csv(self.win_rate_by_date_path) logging.info( f'Saving mean word count of different stats, cut by date, to ' f'{self.stat_mean_length_by_date_path}.' ) stats_to_calculate_mean_length_of = ['human_text', 'human_justification'] assert ( len(stats_to_calculate_mean_length_of) == 2 ), 'This section of the code won\'t work with more than 2 stats!' stat_mean_length_dfs = [] for stat in stats_to_calculate_mean_length_of: stat_mean_length_dfs.append( all_conversations_df[lambda df: df[stat] != ''] .assign(word_count=lambda df: df[stat].str.split().str.len()) .groupby(['model_pair_nickname', 'date'])['word_count'] .mean() .to_frame(stat) ) joined_stat_mean_length_df = stat_mean_length_dfs[0].join( stat_mean_length_dfs[1] ) joined_stat_mean_length_df.to_csv(self.stat_mean_length_by_date_path) logging.info( f'Saving mean completion time stats to ' f'{self.completion_time_by_model_pair_path}.' ) completion_time_by_convo_df = all_conversations_df[ ['model_pair_nickname', 'conversation_idx', 'completion_time'] ].drop_duplicates() for model_pair_nickname in completion_time_by_convo_df[ 'model_pair_nickname' ].unique(): assert ( completion_time_by_convo_df[ lambda df: df['model_pair_nickname'] == model_pair_nickname ].index.size == self.stat_counts[model_pair_nickname]['acceptable_convos'] ), ( f"The count of convos for the model pair {model_pair_nickname} is " f"inconsistent!" ) completion_time_by_model_pair_df = ( completion_time_by_convo_df.groupby('model_pair_nickname')[ 'completion_time' ] .mean() .to_frame('mean_completion_time') ) completion_time_by_model_pair_df.to_csv(self.completion_time_by_model_pair_path) return all_conversations_df
class TurnAnnotationsChatWorld(MultiAgentDialogWorld): def __init__( self, opt, agents=None, shared=None, num_turns=6, tag=None, max_resp_time=120, agent_timeout_shutdown=120, context_info: Optional[dict] = None, ): # 6 turns for a single side (so 12 total), and really it appears to be # 14 total b/c of the "Hi!" and first bot utterance self.agents = agents self.task_turn_idx = 0 self.num_turns = num_turns self.dialog = [] self.tag = tag self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False if context_info is not None: self.context_info = context_info self.personas = [ self.context_info['persona_1_strings'], self.context_info['persona_2_strings'], ] else: self.context_info = {} self.personas = None self.check_acceptability = opt['check_acceptability'] self.acceptability_checker = AcceptabilityChecker() # below are timeout protocols self.max_resp_time = max_resp_time # in secs self.agent_timeout_shutdown = agent_timeout_shutdown print( f'Creating {self.__class__.__name__} for tag {tag} with {num_turns} turns.' ) super().__init__(opt, agents, shared) def __add_problem_data_to_utterance(self, p): # Human has just responded. Problem data received # now will be from bot's prior utterance (turn_idx # is also present to be safe that data matches) is_fake_utterance = ('fake_start' in self.dialog[p['turn_idx']] and self.dialog[p['turn_idx']]['fake_start']) annotations = [] for a in self.opt['annotations_config']: annotations.append(p[a['value']]) assert any(annotations) or is_fake_utterance self.dialog[p['turn_idx']]['problem_data'] = p def parley(self): control_msg = { 'episode_done': False, 'config': { 'min_num_turns': self.num_turns, 'annotations_config': self.opt['annotations_config'], }, 'left_pane_text': self.opt['left_pane_text'], } print( f'{self.__class__.__name__}:{self.tag}: is at turn {self.task_turn_idx}, with {self.num_turns} pairs of turns needed...' ) if self.task_turn_idx == 0: for agent_idx, agent in enumerate(self.agents): if agent_idx == 1 and self.opt['include_persona']: print('Including persona for the bot.') # The Bot agent # We add the personas and 1/3 of the time WoW topic as the # first utterance in the history. # Previously for BST task, we also had a big first utterance # that gave instructions. Removing that for this task. persona_strings = [ s.strip() for s in self.personas[agent_idx] ] persona_utterance = self._get_persona_utterance( persona_strings=persona_strings, context_dataset=self.context_info['context_dataset'], additional_context=self. context_info['additional_context'], is_bot=(agent_idx == 1), ) message = control_msg.copy() message['text'] = persona_utterance agent.observe(validate(message), increment_turn=False) # The bot seeing its persona does not count as a "turn" if agent_idx == 0: time.sleep(3) if self.opt['conversation_start_mode'] == 'bst': print('[Displaying first utterances as per BST task.]') # Display the previous two utterances human_first_msg = { 'left_pane_text': self.opt['left_pane_text'], 'episode_done': False, 'id': self.agents[0].id, 'text': self.context_info['person1_seed_utterance'], 'fake_start': True, 'agent_idx': 0, } for k, v in control_msg.items(): human_first_msg[k] = v bot_first_msg = { 'episode_done': False, 'id': self.agents[1].id, 'text': self.context_info['person2_seed_utterance'], 'fake_start': True, 'agent_idx': 1, } print( f'human_first_msg: {human_first_msg}, bot_first_msg: {bot_first_msg}' ) self.dialog.append(human_first_msg) self.dialog.append(bot_first_msg) for agent in self.agents: agent.observe(validate(human_first_msg)) agent.observe(validate(bot_first_msg)) elif self.opt['conversation_start_mode'] == 'hi': print('[Displaying "Hi!" only as per Meena task.]') human_first_msg = { 'left_pane_text': self.opt['left_pane_text'], 'episode_done': False, 'id': self.agents[0].id, 'text': 'Hi!', 'fake_start': True, 'agent_idx': 0, } for k, v in control_msg.items(): human_first_msg[k] = v self.dialog.append(human_first_msg) self.agents[0].observe(validate(human_first_msg)) self.agents[1].observe(validate(human_first_msg)) first_bot_act = self.agents[1].act() first_bot_act = Compatibility.maybe_fix_act(first_bot_act) self.agents[0].observe(validate(first_bot_act)) bot_utterance_data = { 'agent_idx': 1, # Get rid of annotations HTML from bot response 'text': first_bot_act['text'].split('<br>')[0], 'id': first_bot_act['id'], } self.dialog.append(bot_utterance_data) else: raise ValueError( f"Conversation start mode {self.opt['conversation_start_mode']} " f"not recognized!") self.task_turn_idx += 1 return """Otherwise, we proceed accordingly""" print( f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: {self.task_turn_idx}' ) acts = [None, None] for idx, agent in enumerate(self.agents): if not self.chat_done: acts[idx] = agent.act(timeout=self.max_resp_time) acts[idx] = Compatibility.maybe_fix_act(acts[idx]) print( f'Got act for agent idx {idx}, act was: {acts[idx]} and self.task_turn_idx: {self.task_turn_idx}.' ) if self.check_timeout(acts[idx]): return if acts[idx]['episode_done']: self.chat_done = True for ag in self.agents: # if agent disconnected if ag != agent and ag.some_agent_disconnected: if idx == 0: # Human message = control_msg.copy() message['text'] = ( 'The other worker unexpectedly diconnected. ' 'Please click "Done with this HIT" button below to finish this HIT.' ) message['episode_done'] = True ag.observe(validate(message)) return # agent ends chat after exceeding minimum number of turns if self.task_turn_idx > self.num_turns: for ag in self.agents: if idx == 0: print( 'One of you ended the chat utterance coming.') message = control_msg.copy() message['text'] = ( 'One of you ended the chat. Thanks for your ' 'time! Please click "Done with this HIT"' 'button below to finish this HIT.') message['episode_done'] = True ag.observe(validate(message)) # Human has just responded. Problem data received # now will be from bot's prior utterance (turn_idx # is a also present to be safe that data matches) p = acts[idx]['problem_data_for_prior_message'] self.__add_problem_data_to_utterance(p) return else: utterance_data = { 'agent_idx': idx, # Get rid of annotations HTML if it's the bot response 'text': acts[idx]['text'].split('<br>')[0], 'id': acts[idx]['id'] if 'id' in acts[idx] else 'NULL_ID', # Person1 or Polyencoder } self.dialog.append(utterance_data) if idx == 0: # Human has just responded. Problem data received # now will be from bot's prior utterance (turn_idx # is a also present to be safe that data matches) p = acts[idx]['problem_data_for_prior_message'] self.__add_problem_data_to_utterance(p) for other_agent in self.agents: if other_agent != agent: other_agent.observe(validate(acts[idx])) print( f'[agent {idx}] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: {self.dialog}' ) self.task_turn_idx += 1 def shutdown(self): global shutdown_agent def shutdown_agent(mturk_agent): mturk_agent.shutdown() Parallel(n_jobs=len(self.agents), backend='threading')(delayed(shutdown_agent)(agent) for agent in self.agents) def episode_done(self): return self.chat_done def _get_persona_utterance( self, persona_strings: Optional[List[str]] = None, context_dataset: Optional[str] = None, additional_context: Optional[str] = None, is_bot: bool = False, ): if is_bot: # Pass back the original context persona_pieces = [ f"your persona: {str_}" for str_ in persona_strings ] if context_dataset == 'wizard_of_wikipedia': additional_context_pieces = [additional_context] else: additional_context_pieces = [] full_context = '\n'.join(persona_pieces + additional_context_pieces) print(f'FULL CONTEXT: {full_context}') return full_context else: if context_dataset == 'convai2': last_sentence = 'Pretend that the conversation has already begun.' elif context_dataset == 'empathetic_dialogues': last_sentence = ( f'Pretend that the conversation has already begun, and that you ' f'had been talking about the following situation: ' f'<b>"{additional_context}"</b>') elif context_dataset == 'wizard_of_wikipedia': last_sentence = ( f'Pretend that the conversation has already begun, and that you ' f'had been talking about <b>{additional_context}</b>.') else: raise ValueError('Context dataset unrecognized!') joined_personas = '\n'.join(persona_strings) return ( f'\nSuccessfully matched with another user! Now let\'s get to know ' f'each other through the chat. You need to finish at least ' f'<b>{self.num_turns} chat turns</b>, and after that you can click the ' f'"Done" button to end the chat.\n\n' f'<b>Your character description is:\n<span style="color:blue">{joined_personas}</span></b> ' '\n\n<b>Remember that you can get to know each ' 'other as your characters, talk about any topic, or talk about a ' 'situation that might have happened to your character.</b>' '\n<b>Do not trivially copy the ' 'character descriptions into the message.</b><br><br>' f'{last_sentence}') def save_data(self): convo_finished = True bad_workers = [] for ag in self.agents: if (ag.hit_is_abandoned or ag.hit_is_returned or ag.disconnected or ag.hit_is_expired): bad_workers.append(ag.worker_id) convo_finished = False ag.not_approve = True if self.check_acceptability: human_texts = [ message['text'] for message in self.dialog if message['agent_idx'] == 0 ] violation_types = [ 'min_words', 'all_caps', 'exact_match', 'safety' ] if self.opt['conversation_start_mode'] == 'bst': # The BST mode starts the conversation with two previous utterances, so # there should be no new greeting violation_types.append('penalize_greetings') violations_agent_0 = self.acceptability_checker.check_messages( messages=human_texts, is_worker_0=False, violation_types=violation_types) else: violations_agent_0 = None time_string = time.strftime('%Y%m%d_%H%M%S') data_path = self.opt['save_folder'] if convo_finished: filename = os.path.join( data_path, '{}_{}_{}.json'.format(time_string, np.random.randint(0, 1000), self.task_type), ) else: filename = os.path.join( data_path, '{}_{}_{}_incomplete.json'.format(time_string, np.random.randint(0, 1000), self.task_type), ) with open(os.path.join(filename), 'w+') as f_json: data = { 'personas': self.personas, 'context_dataset': self.context_info.get('context_dataset'), 'person1_seed_utterance': self.context_info.get('person1_seed_utterance'), 'person2_seed_utterance': self.context_info.get('person2_seed_utterance'), 'additional_context': self.context_info.get('additional_context'), 'dialog': self.dialog, 'workers': [ag.worker_id for ag in self.agents], 'bad_workers': bad_workers, 'acceptability_violations': (violations_agent_0, ), 'hit_ids': [ag.hit_id for ag in self.agents], 'assignment_ids': [ag.assignment_id for ag in self.agents], 'task_description': { 'annotations_config': self.opt['annotations_config'], 'model_nickname': self.agents[1].worker_id, 'model_file': self.agents[1].model_agent.opt.get('model_file'), 'model_opt': self.agents[1].model_agent.opt, }, } if self.check_acceptability: data['acceptability_violations'] = (violations_agent_0, ) # Make a tuple for compatibility with a human/human conversation in # which we check both sides for acceptability data_str = json.dumps(data) f_json.write(data_str) print( f'{self.__class__.__name__}:{self.tag}: Data successfully saved at ' f'{filename} for model: {self.agents[1].worker_id}.') if self.check_acceptability: print(f'Acceptability violations for agent 0: ' f'{violations_agent_0}') return self.agents[ 1].worker_id, violations_agent_0 != '', convo_finished else: return self.agents[1].worker_id, False, convo_finished def check_timeout(self, act): if act['text'] == '[TIMEOUT]' and act['episode_done']: control_msg = {'episode_done': True} control_msg['id'] = 'SYSTEM' control_msg[ 'text'] = 'HIT has timed out. Please click the "Done with this HIT" button below to exit this HIT. No rejections.' for ag in self.agents: if ag.id != act['id']: if ag.id != AGENT_1: ag.observe(validate(control_msg)) self.chat_done = True return True else: return False def review_work(self): pass