Ejemplo n.º 1
0
    def __init__(self, opt, agent, bot):
        super().__init__(opt, agent)

        # num_turns turns for a single side, and really it appears to be
        # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance

        num_turns = opt['num_turns']
        max_resp_time = opt['max_resp_time']

        self.opt = opt
        self.bot = bot
        self.task_turn_idx = 0
        self.num_turns = num_turns

        self.agent.agent_id = 'Speaker 1'
        self.bot.agent_id = 'Speaker 2'

        self.dialog = []
        self.tag = f'conversation_id {agent.mephisto_agent.db_id}'
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.check_acceptability = opt['check_acceptability']
        self.acceptability_checker = AcceptabilityChecker()
        self.block_qualification = opt['block_qualification']

        self.final_chat_data = None
        # TODO: remove this attribute once chat data is only stored in the Mephisto
        #  TaskRun for this HIT (see .get_custom_task_data() docstring for more
        #  information)

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        print(
            f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.'
        )
Ejemplo n.º 2
0
    def __init__(self, opt: Dict[str, Any]):

        AbstractTurnAnnotationResultsCompiler.__init__(self, opt)

        # Input args
        self.model_nickname = opt['model_nickname']
        assert len(self.results_folders) > 0
        for folder in self.results_folders:
            assert os.path.isdir(folder), f'{folder} is not a valid folder!'
        os.makedirs(self.output_folder, exist_ok=True)
        self.start_date = opt['start_date']
        self.max_convos_per_worker = opt['max_convos_per_worker']
        self.min_word_count = opt['min_word_count']
        self.hit_block_list = opt['hit_block_list'].split(',')
        self.worker_block_list = opt['worker_block_list'].split(',')

        # Setting up problem buckets
        if self.use_problem_buckets:
            self.regular_buckets = [
                bucket for bucket in self.problem_buckets
                if bucket not in ['other', 'none_all_good']
            ]
            # Remove the buckets that are special cases

        self.acceptability_checker = AcceptabilityChecker()
        self.completed_run_stats_path = opt['completed_run_stats_path']
Ejemplo n.º 3
0
    def __init__(self, opt: Dict[str, Any]):

        super().__init__(opt)
        # Validate problem buckets
        if self.use_problem_buckets and 'none_all_good' not in self.problem_buckets:
            # The code relies on a catchall "none" category if the user selects no other
            # annotation bucket
            raise ValueError(
                'There must be a "none_all_good" category in self.problem_buckets!'
            )

        # Input args
        assert len(self.results_folders) > 0
        for folder in self.results_folders:
            assert os.path.isdir(folder), f'{folder} is not a valid folder!'
        os.makedirs(self.output_folder, exist_ok=True)
        self.start_date = opt['start_date']
        self.max_convos_per_worker = opt['max_convos_per_worker']
        self.min_word_count = opt['min_word_count']
        self.hit_block_list = opt['hit_block_list'].split(',')
        self.worker_block_list = opt['worker_block_list'].split(',')

        # Setting up problem buckets
        if self.use_problem_buckets:
            self.regular_buckets = [
                bucket for bucket in self.problem_buckets
                if bucket not in ['other', 'none_all_good']
            ]
            # Remove the buckets that are special cases

        self.acceptability_checker = AcceptabilityChecker()
Ejemplo n.º 4
0
    def __init__(self, opt, agent, bot):
        super().__init__(opt, agent)

        # num_turns turns for a single side, and really it appears to be
        # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance

        num_turns = opt['num_turns']
        max_resp_time = opt['max_resp_time']

        self.opt = opt
        self.bot = bot
        self.task_turn_idx = 0
        self.num_turns = num_turns

        self.dialog = []
        self.tag = f'conversation_id {agent.mephisto_agent.db_id}'
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.check_acceptability = opt['check_acceptability']
        self.acceptability_checker = AcceptabilityChecker()
        self.block_qualification = opt['block_qualification']

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        print(
            f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.'
        )
Ejemplo n.º 5
0
    def __init__(self, opt, agent, bots, context_info: Optional[dict] = None):

        # num_turns turns for a single side, and really it appears to be
        # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance

        # TODO: this logic is very close to that of BaseModelChatWorld.__init__(). Can
        #  any of this be deduplicated?

        super(BaseModelChatWorld, self).__init__(opt, agent)

        num_turns = opt['num_turns']
        max_resp_time = opt['max_resp_time']

        self.opt = opt
        self.bots = bots

        self.task_turn_idx = 0
        self.num_turns = num_turns

        self.dialog = []
        self.tag = f'conversation_id {agent.mephisto_agent.db_id}'
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.check_acceptability = opt['check_acceptability']
        self.acceptability_checker = AcceptabilityChecker()
        self.block_qualification = opt['block_qualification']

        self.final_chat_data = None
        # TODO: remove this attribute once chat data is only stored in the Mephisto
        #  TaskRun for this HIT (see .get_custom_task_data() docstring for more
        #  information)

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        logging.info(
            f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.'
        )

        if context_info is not None:
            self.context_info = context_info
            self.personas = [
                self.context_info['persona_1_strings'],
                self.context_info['persona_2_strings'],
            ]
        else:
            self.context_info = {}
            self.personas = None
Ejemplo n.º 6
0
    def __init__(self, opt: Dict[str, Any]):

        super().__init__(opt)

        # Input args
        os.makedirs(self.output_folder, exist_ok=True)
        # TODO: see if this can be moved to the superclass
        self.filter_uniform_hits = opt['filter_uniform_hits']

        # Save paths
        self.unacceptable_worker_ids_path = os.path.join(
            self.output_folder, 'unacceptable_worker_ids.txt')
        self.annotation_selection_rate_path = os.path.join(
            self.output_folder, 'annotation_selection_rates.csv')
        self.likert_score_stat_path = os.path.join(self.output_folder,
                                                   'likert_score_stats.csv')

        self.acceptability_checker = AcceptabilityChecker()
Ejemplo n.º 7
0
    def __init__(self, opt: Dict[str, Any]):
        # TODO: deduplicate init from ModelChatResultsCompiler

        super().__init__(opt)

        # Input args
        os.makedirs(self.output_folder, exist_ok=True)
        self.worker_block_list = opt['worker_block_list'].split(',')

        # Save paths
        self.worker_results_path = os.path.join(
            self.output_folder, 'worker_results.csv'
        )
        self.unacceptable_worker_ids_path = os.path.join(
            self.output_folder, 'unacceptable_worker_ids.txt'
        )
        self.win_rate_by_date_path = os.path.join(
            self.output_folder, 'win_rates_by_date.csv'
        )
        self.stat_mean_length_by_date_path = os.path.join(
            self.output_folder, 'stat_mean_length_by_date.csv'
        )
        self.completion_time_by_model_pair_path = os.path.join(
            self.output_folder, 'mean_completion_times.csv'
        )

        self.acceptability_checker = AcceptabilityChecker()

        # Set fields that should be empty strings if the relevant information is not
        # present
        blank_field_columns = [
            'human_text',
            'human_choice',
            'human_justification',
            'accepted_bot_text',
            'not_accepted_bot_text',
        ]
        self.blank_fields = {field: '' for field in blank_field_columns}

        # Results attributes
        self.stat_counts = {}
        self.mean_completion_time = None
Ejemplo n.º 8
0
    def __init__(
        self,
        opt,
        agents=None,
        shared=None,
        num_turns=6,
        tag=None,
        max_resp_time=120,
        agent_timeout_shutdown=120,
        context_info: Optional[dict] = None,
    ):
        # 6 turns for a single side (so 12 total), and really it appears to be
        # 14 total b/c of the "Hi!" and first bot utterance

        self.agents = agents
        self.task_turn_idx = 0
        self.num_turns = num_turns

        self.dialog = []
        self.tag = tag
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        if context_info is not None:
            self.context_info = context_info
            self.personas = [
                self.context_info['persona_1_strings'],
                self.context_info['persona_2_strings'],
            ]
        else:
            self.context_info = {}
            self.personas = None
        self.check_acceptability = opt['check_acceptability']
        self.acceptability_checker = AcceptabilityChecker()

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        self.agent_timeout_shutdown = agent_timeout_shutdown
        print(
            f'Creating {self.__class__.__name__} for tag {tag} with {num_turns} turns.'
        )
        super().__init__(opt, agents, shared)
Ejemplo n.º 9
0
    def __init__(self, opt: Dict[str, Any]):

        super().__init__(opt)

        # Input args
        assert len(self.results_folders) > 0
        for folder in self.results_folders:
            assert os.path.isdir(folder), f'{folder} is not a valid folder!'
        os.makedirs(self.output_folder, exist_ok=True)
        self.start_date = opt['start_date']
        self.max_convos_per_worker = opt['max_convos_per_worker']
        self.min_word_count = opt['min_word_count']
        self.hit_block_list = opt['hit_block_list'].split(',')
        self.worker_block_list = opt['worker_block_list'].split(',')

        # Setting up problem buckets
        self.regular_buckets = [
            bucket for bucket in self.problem_buckets
            if bucket not in ['other', 'none_all_good']
        ]
        # Remove the buckets that are special cases

        self.acceptability_checker = AcceptabilityChecker()
Ejemplo n.º 10
0
class ModelChatResultsCompiler(AbstractResultsCompiler):
    """
    Compile and save results of human+model chats.

    Results will be saved on the level of specific conversations, as well as aggregated
    up to the level of each worker as a whole.
    """
    @classmethod
    def setup_args(cls):
        parser = super().setup_args()
        parser.add_argument(
            '--filter-uniform-hits',
            action='store_true',
            help=
            'Filter out any HITs in which the worker\'s annotations were the exact same on each turn of the conversation',
        )
        return parser

    def __init__(self, opt: Dict[str, Any]):

        super().__init__(opt)

        # Input args
        os.makedirs(self.output_folder, exist_ok=True)
        # TODO: see if this can be moved to the superclass
        self.filter_uniform_hits = opt['filter_uniform_hits']

        # Save paths
        self.unacceptable_worker_ids_path = os.path.join(
            self.output_folder, 'unacceptable_worker_ids.txt')
        self.annotation_selection_rate_path = os.path.join(
            self.output_folder, 'annotation_selection_rates.csv')
        self.likert_score_stat_path = os.path.join(self.output_folder,
                                                   'likert_score_stats.csv')

        self.acceptability_checker = AcceptabilityChecker()

    def get_results_path_base(self) -> str:
        return os.path.join(self.output_folder, 'results')
        # TODO: see if this can be moved to the superclass

    def compile_results(self) -> pd.DataFrame:

        # Load task data
        logging.info('Retrieving task data from Mephisto.')
        task_units_data = self.get_task_data()
        logging.info(
            f'Data for {len(task_units_data)} units loaded successfully.')

        num_convos_with_no_save_data = 0
        num_wrong_status_convos = 0
        num_complete_convos = 0

        unacceptable_task_units = []
        unacceptable_worker_ids = []
        conversation_idx = 0
        conversation_dfs = []

        for task_unit in task_units_data:

            worker_id = task_unit['worker_id']
            assignment_id = task_unit['assignment_id']

            # Skipping this conversation if save data is not found or the status is
            # invalid
            if task_unit['data']['save_data'] is None:
                logging.info('Found a task unit with no save data! Skipping.')
                num_convos_with_no_save_data += 1
                continue
            elif task_unit['status'] not in ['completed', 'approved']:
                logging.info(
                    f'Found a HIT with the status "{task_unit["status"]}"!.'
                    f'Skipping.')
                num_wrong_status_convos += 1
                continue
            else:
                num_complete_convos += 1

            # Extract out useful conversation-level data
            custom_data = task_unit['data']['save_data']['custom_data']
            mturk_worker_id = Worker.get(self.get_mephisto_db(),
                                         worker_id).worker_name
            task_start = datetime.utcfromtimestamp(task_unit['task_start'])
            task_end = datetime.utcfromtimestamp(task_unit['task_end'])
            info_dict = {
                ('worker_id', ''): worker_id,
                ('mturk_worker_id', ''): mturk_worker_id,
                ('unit_id', ''): task_unit['unit_id'],
                ('assignment_id', ''): assignment_id,
                ('conversation_idx', ''): conversation_idx,
                ('date', ''): task_start.strftime('%Y-%m-%d'),
                ('completion_time', ''):
                (task_end - task_start).total_seconds(),
            }

            # Check that the conversation consists of pairs of comments between
            # Speaker 1 and Speaker 2, with Speaker 1 speaking first
            assert 'final_rating' in task_unit['data']['messages'][-1][
                'task_data']
            convo_messages = [m for m in task_unit['data']['messages'][:-1]]
            # The final message is just a final rating
            assert all([
                message['id'] == 'Speaker 2' if message_idx %
                2 else 'Speaker 1'
                for message_idx, message in enumerate(convo_messages)
            ])
            messages_1 = [m for m in convo_messages if m['id'] == 'Speaker 1']
            messages_2 = [m for m in convo_messages if m['id'] == 'Speaker 2']
            assert len(messages_1) + len(messages_2) == len(convo_messages)

            # Determine whether the HIT contains unacceptable messages. (We do this for
            # every HIT, even if acceptability violation info was already saved, because
            # the violation criteria may have changed since the HIT was collected.)
            utterances_1 = [m['text'] for m in messages_1]
            assert utterances_1[0] == 'Hi!', (
                'This script assumes that the first human message is "Hi!", which is '
                'set by default and cannot be changed by the crowdsourcing worker.'
            )
            acceptability_violations = self.acceptability_checker.check_messages(
                messages=utterances_1[1:],  # Don't use the initial "Hi!"
                is_worker_0=True,
                violation_types=self.acceptability_checker.ALL_VIOLATION_TYPES,
            )
            # Here, "worker 0" refers to Speaker 1, because we mix 0- and 1-indexing
            if acceptability_violations != '':
                logging.info(
                    f'Conversation fails acceptability checks with a violation of '
                    f'"{acceptability_violations}", given the following utterances: '
                    f'{utterances_1[1:]}. Skipping.')
                unacceptable_task_units.append(task_unit)
                assert (
                    mturk_worker_id is not None
                ), "MTurk worker ID cannot be determined for this unacceptable conversation!"
                unacceptable_worker_ids.append(mturk_worker_id)
                continue

            # Ignore the conversation if ratings for all turns are the same, because
            # it's somewhat implausible that *all* turns in a conversation should garner
            # the same rating of engagingness, humanness, interestingness, or none.
            # (However, don't put these workers on the "unacceptable worker IDs" list,
            # to give them a little bit of the benefit of the doubt: i.e. maybe the
            # worker just didn't try hard enough to find which responses were more
            # engaging, etc. than others, but that doesn't mean that all of their HITs
            # across all evals are bad and should be removed.)
            if self.filter_uniform_hits:
                annotations = [
                    m['task_data']['problem_data_for_prior_message']
                    for m in task_unit['data']['messages']
                    if 'problem_data_for_prior_message' in m.get(
                        'task_data', {})
                ]
                hashable_annotations = [
                    tuple(a[key] for key in sorted(a.keys()))
                    for a in annotations
                ]
                unique_annotations = set(hashable_annotations)
                if len(unique_annotations) < 1:
                    raise ValueError('No annotations found for this HIT!')
                elif len(unique_annotations) == 1:
                    logging.info(
                        f'All model responses in the conversation received the same '
                        f'annotation: {hashable_annotations[0]}. Skipping.')
                    unacceptable_task_units.append(task_unit)
                    continue

            single_turn_dicts = []

            # Compile personas and previous utterances
            text_parts = []
            if custom_data['personas'] is not None and len(
                    custom_data['personas']) > 0:
                assert len(custom_data['personas']) == 2
                text_parts += [
                    'HUMAN PERSONA: ' + ' '.join(custom_data['personas'][0]),
                    'BOT PERSONA: ' + ' '.join(custom_data['personas'][1]),
                ]
            if (custom_data['additional_context'] is not None
                    and len(custom_data['additional_context']) > 0):
                text_parts.append('ADDITIONAL CONTEXT: ' +
                                  custom_data['additional_context'])
            single_turn_dicts.append({
                **info_dict, ('context', ''):
                ' '.join(text_parts)
            })

            # Loop over conversation turns
            turns_per_speaker = defaultdict(int)
            for message in task_unit['data']['messages']:
                if 'text' in message:

                    speaker_id = message['id']

                    # Add in annotation results, if they exist
                    if 'problem_data_for_prior_message' in message.get(
                            'task_data', {}):
                        bucket_data = {
                            ('annotation_bucket', bucket): value
                            for bucket, value in message['task_data']
                            ['problem_data_for_prior_message'].items()
                        }
                    else:
                        bucket_data = {}

                    # Add in results from the final rating(s), if they exist
                    if 'final_rating' in message.get('task_data', {}):
                        ratings = message['task_data']['final_rating'].split(
                            '|')
                        final_rating_data = {
                            ('final_rating', str(idx)): value
                            for idx, value in enumerate(ratings)
                        }
                    else:
                        final_rating_data = {}

                    turns_per_speaker[speaker_id] += 1

                    single_turn_dicts.append({
                        **info_dict,
                        ('speaker_id', ''):
                        speaker_id,
                        ('speaker_turn_idx', ''):
                        turns_per_speaker[speaker_id],
                        ('text', ''):
                        message['text'].replace('\n', '__newline__'),
                        **bucket_data,
                        **final_rating_data,
                    })

            # Adding the full conversation to the list of conversations
            single_turn_series = [
                pd.Series(dict_).to_frame().transpose()
                for dict_ in single_turn_dicts
            ]
            single_convo_df = pd.concat(single_turn_series, axis=0, sort=False)
            conversation_dfs.append(single_convo_df)
            conversation_idx += 1

        logging.info(
            f'{num_convos_with_no_save_data:d} conversations found with no save data.'
        )
        logging.info(
            f'{num_wrong_status_convos:d} conversations found with the wrong status.'
        )
        logging.info(f'{num_complete_convos:d} complete conversations found:')
        logging.info(
            f'\t{len(unacceptable_task_units):d} unacceptable conversations.')
        logging.info(f'\t{len(conversation_dfs):d} acceptable conversations.')

        # # Compile results across all conversations

        if len(conversation_dfs) == 0:
            raise ValueError('No acceptable conversations found!')
        unordered_conversation_df = pd.concat(conversation_dfs, axis=0)
        initial_ordered_columns = list(info_dict.keys()) + [
            ('context', ''),
            ('speaker_id', ''),
            ('speaker_turn_idx', ''),
            ('text', ''),
        ]
        all_ordered_columns = initial_ordered_columns + [
            col for col in unordered_conversation_df.columns
            if col not in initial_ordered_columns
        ]
        conversation_df = unordered_conversation_df[all_ordered_columns]
        # TODO: is there a less hacky way than this, which relies on the most recent
        #  value of `info_dict`, to put the columns back into the right order?

        # # Calculate and save auxiliary stats

        logging.info(
            f'Saving MTurk IDs of workers with unacceptable conversations to '
            f'{self.unacceptable_worker_ids_path}.')
        with open(self.unacceptable_worker_ids_path, 'w') as f:
            for worker_id in unacceptable_worker_ids:
                f.write(worker_id + '\n')

        # Calculate rates of selecting various annotation buckets
        annotation_bucket_df = conversation_df['annotation_bucket'].dropna(
            axis=0, how='any')
        if annotation_bucket_df.isna().sum().sum() > 0:
            raise ValueError(
                'There is at least one row in which only partial annotation bucket data exists!'
            )
        annotation_selection_rate_df = annotation_bucket_df.mean().to_frame(
            'selection_rate')
        annotation_selection_rate_df.to_csv(
            self.annotation_selection_rate_path)
        logging.info(
            f'Annotation bucket selection rates saved to {self.annotation_selection_rate_path}.'
        )
        output_strings = [
            f'{series.name}: {100*series["selection_rate"]:0.0f}%'
            for _, series in annotation_selection_rate_df.iterrows()
        ]
        logging.info('Annotation bucket selection rates:\n' +
                     '\n'.join(output_strings))

        # Calculate Likert score stats
        final_rating_df = conversation_df['final_rating'].dropna(axis=0,
                                                                 how='any')
        if final_rating_df.isna().sum().sum() > 0:
            raise ValueError(
                'There is at least one row in which only partial final rating data exists!'
            )
        likert_score_stat_df = final_rating_df.astype(int).describe()
        likert_score_stat_df.to_csv(self.likert_score_stat_path)
        logging.info(
            f'Likert score statistics saved to {self.likert_score_stat_path}.')
        logging.info(
            f'Mean Likert scores:\n{likert_score_stat_df.loc["mean"]}')

        return conversation_df
Ejemplo n.º 11
0
class BaseModelChatWorld(CrowdTaskWorld, ABC):
    def __init__(self, opt, agent, bot):
        super().__init__(opt, agent)

        # num_turns turns for a single side, and really it appears to be
        # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance

        num_turns = opt['num_turns']
        max_resp_time = opt['max_resp_time']

        self.opt = opt
        self.bot = bot
        self.task_turn_idx = 0
        self.num_turns = num_turns

        self.dialog = []
        self.tag = f'conversation_id {agent.mephisto_agent.db_id}'
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.check_acceptability = opt['check_acceptability']
        self.acceptability_checker = AcceptabilityChecker()
        self.block_qualification = opt['block_qualification']

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        print(
            f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.'
        )

    def __add_problem_data_to_utterance(self, p, turn_idx: int):
        """
        Attach problem data to the bot's prior utterance, given by turn_idx.
        """
        print(p)
        assert (self.dialog[turn_idx]['agent_idx'] == 1
                ), 'Problem data must be attached to a bot utterance.'
        assert ('problem_data' not in self.dialog[turn_idx]
                ), "Don't overwrite existing problem data!"
        self.dialog[turn_idx]['problem_data'] = p

    def parley(self):
        print(
            f'{self.__class__.__name__}:{self.tag}: is at turn {self.task_turn_idx}, with {self.num_turns} pairs of turns needed...'
        )

        if self.task_turn_idx == 0:
            self._run_initial_turn()
            self.task_turn_idx += 1
            return
        """Otherwise, we proceed accordingly"""
        print(
            f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: {self.task_turn_idx}'
        )
        acts = [None, None]
        for idx, agent in enumerate([self.agent, self.bot]):
            if not self.chat_done:
                acts[idx] = agent.act(timeout=self.max_resp_time)
                acts[idx] = Compatibility.maybe_fix_act(acts[idx])
                if 'metrics' in acts[idx]:
                    del acts[idx]['metrics']
                    # Metrics can't be saved to JSON and are not needed here
                print(
                    f'Got act for agent idx {idx}, act was: {acts[idx]} and self.task_turn_idx: {self.task_turn_idx}.'
                )

            if acts[idx].get('task_data', {}).get('final_rating') is not None:

                self.chat_done = True
                # agent ends chat after exceeding minimum number of turns

                if self.task_turn_idx > self.num_turns:
                    # Human has just responded. Any problem data received now will be
                    # regarding the bot's prior utterance
                    p = acts[idx]['task_data'].get(
                        'problem_data_for_prior_message')
                    if p is not None:
                        turn_idx = -1
                        # Attach the problem data to the last utterance, since the human
                        # hasn't said anything since then
                        self.__add_problem_data_to_utterance(p,
                                                             turn_idx=turn_idx)

                # Save the final chat data
                time_string = time.strftime('%Y%m%d_%H%M%S')
                chat_data_folder = self.opt['chat_data_folder']
                os.makedirs(chat_data_folder, exist_ok=True)
                chat_data_path = os.path.join(
                    chat_data_folder,
                    f'{time_string}_{np.random.randint(0, 1000)}_{self.task_type}.json',
                )
                final_chat_data = self.get_final_chat_data()
                self.agent.mephisto_agent.state.messages.append(
                    {'final_chat_data': final_chat_data})
                # Append the chat data directly to the agent state's message list in
                # order to prevent the worker from seeing a new text response in the UI
                with open(chat_data_path, 'w+') as f_json:
                    data_str = json.dumps(final_chat_data)
                    f_json.write(data_str)
                print(f'{self.__class__.__name__}:{self.tag}: Data saved at '
                      f'{chat_data_path} for model: {self.bot.worker_id}.')

                # Soft-block the worker if there were acceptability violations
                acceptability_violations = final_chat_data[
                    'acceptability_violations'][0]
                if (acceptability_violations is not None
                        and acceptability_violations != ''):
                    print(
                        f'**NOTE** Acceptability violations detected: {acceptability_violations}'
                    )
                    # Grant the failed qualification
                    self.agent.mephisto_agent.get_worker().grant_qualification(
                        self.block_qualification, 1)

                return

            else:
                utterance_data = {
                    'agent_idx': idx,
                    # Get rid of annotations HTML if it's the bot response
                    'text': acts[idx]['text'].split('<br>')[0],
                    'id': acts[idx]['id'] if 'id' in acts[idx] else
                    'NULL_ID',  # Person1 or Polyencoder
                }
                self.dialog.append(utterance_data)
                if idx == 0:
                    # Human has just responded. Any problem data received now will be
                    # regarding the bot's prior utterance
                    p = acts[idx]['task_data'].get(
                        'problem_data_for_prior_message')
                    if p is not None:
                        turn_idx = -2
                        # Attach the problem data to the second-to-last utterance, since
                        # the last utterance is what the human just said
                        self.__add_problem_data_to_utterance(p,
                                                             turn_idx=turn_idx)

                self._postprocess_acts(acts=acts, agent_idx=idx)
                for other_agent in [self.agent, self.bot]:
                    if other_agent != agent:
                        other_agent.observe(validate(acts[idx]))

                print(
                    f'[agent {idx}] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: {self.dialog}'
                )
                self.task_turn_idx += 1

    @abstractmethod
    def _run_initial_turn(self) -> None:
        """
        Runs logic for the first turn of the human and the bot.
        """

    def _postprocess_acts(self, acts: List[dict], agent_idx: int):
        """
        Optionally perform further processing of the acts.

        Useful for subclasses. Will be executed after saving act data to self.dialog but
        before showing the act to the other agent.
        """

    def shutdown(self):

        if self.chat_done:
            self.opt['run_statistics'][self.bot.worker_id] += 1
            print('Runs completed per model: ' + ', '.join(
                f'{model}: {count:d}'
                for model, count in self.opt['run_statistics'].items()))

        self.agent.shutdown()

    def episode_done(self):
        return self.chat_done

    def get_final_chat_data(self) -> Dict[str, Any]:
        """
        Return specific info about the conversation, the context, acceptability, etc.
        """

        if self.check_acceptability:
            human_messages, violation_types = self._prepare_acceptability_checking(
            )
            violations_string = self.acceptability_checker.check_messages(
                messages=human_messages,
                is_worker_0=False,
                violation_types=violation_types,
            )
        else:
            violations_string = None

        data = {
            'dialog': self.dialog,
            'workers': [get_mturk_id_from_mephisto_wrapper(self.agent)],
            'bad_workers': [],
            'acceptability_violations': (violations_string, ),
            'hit_ids': [self.agent.mephisto_agent.task_run_id],
            'assignment_ids': [self.agent.mephisto_agent.assignment_id],
            'task_description': {
                'annotations_config': self.opt['annotations_config'],
                'model_nickname': self.bot.worker_id,
                'model_file': self.bot.model_agent.opt.get('model_file'),
                'model_opt': self.bot.model_agent.opt,
            },
        }
        # 'bad_workers' is for compatibility. Before, it was only non-empty if a
        # worker abandoned, returned, etc. a HIT, but now we don't even save chat
        # data in that case
        if self.check_acceptability:
            data['acceptability_violations'] = (violations_string, )
            # Make a tuple for compatibility with a human/human conversation in
            # which we check both sides for acceptability

        return data

    def _prepare_acceptability_checking(self) -> Tuple[List[str], List[str]]:
        """
        Return the list of human messages and the list of acceptability types to check.
        """
        human_messages = [
            message['text'] for message in self.dialog
            if message['agent_idx'] == 0
        ]
        violation_types = ['min_words', 'all_caps', 'exact_match', 'safety']
        return human_messages, violation_types
Ejemplo n.º 12
0
    def test_sample_inputs(self):
        """
        Test sample inputs/outputs for the acceptability checker.
        """

        # Define test cases
        test_cases = [
            {  # Should pass
                'messages': [
                    'Hi - how are you?',
                    'What? Whatever for?',
                    'Wow, that sounds like a lot of work.',
                    "No, I don't expect he would be too happy about that either.",
                    "I don't even know where you would find that many squirrels.",
                    'Well, let me know if you need an extra hand.',
                ],
                'is_worker_0': False,
                'expected_violations': '',
            },
            {
                'messages': ['Hi', 'What?', 'Wow', "No", "I don't even know", 'Well,'],
                'is_worker_0': False,
                'expected_violations': 'under_min_length',
            },
            {  # Should fail, because the first worker shouldn't start with a greeting
                'messages': [
                    'Hi - how are you?',
                    'What? Whatever for?',
                    'Wow, that sounds like a lot of work.',
                    "No, I don't expect he would be too happy about that either.",
                    "I don't even know where you would find that many squirrels.",
                    'Well, let me know if you need an extra hand.',
                ],
                'is_worker_0': True,
                'expected_violations': 'starts_with_greeting',
            },
            {
                'messages': [
                    'HEYYYYYYY',
                    'What? Whatever for?',
                    'Wow, that sounds like a lot of work.',
                    "No, I don't expect he would be too happy about that either.",
                    "I don't even know where you would find that many squirrels.",
                    'WELLLLL LEMME KNOOOOOO',
                ],
                'is_worker_0': False,
                'expected_violations': 'too_much_all_caps',
            },
            {
                'messages': [
                    'Hi - how are you?',
                    'What? Whatever for?',
                    'Wow, that sounds like a lot of work.',
                    "No, I don't expect he would be too happy about that either.",
                    "I don't even know where you would find that many squirrels.",
                    'Hi - how are you?',
                ],
                'is_worker_0': False,
                'expected_violations': 'exact_match',
            },
            {
                'messages': [
                    'Hi - how are you?',
                    'What? Whatever for?',
                    'Wow, that sounds like a lot of work.',
                    "No, I don't expect he would be too happy about that either.",
                    "I don't even know where you would find that many squirrels.",
                    'Well, let me know if you need an extra hand.',
                    "I'm gonna say something that's totally XXX!",
                ],
                'is_worker_0': False,
                'expected_violations': 'unsafe:7',
            },
        ]
        test_cases_with_errors = [{
            'messages': ['Message 1', 'Message 2'],
            'is_worker_0':
            True,
            'violation_types': ['non_existent_violation_type'],
            'expected_exception':
            ValueError,
        }]

        # Create checker
        acceptability_checker = AcceptabilityChecker()

        # Run through violation test cases
        for test_case in test_cases:
            actual_violations = acceptability_checker.check_messages(
                messages=test_case['messages'],
                is_worker_0=test_case['is_worker_0'],
                violation_types=acceptability_checker.possible_violation_types,
            )
            self.assertEqual(actual_violations,
                             test_case['expected_violations'])

        # Run through test cases that should raise an error
        for test_case in test_cases_with_errors:
            with self.assertRaises(test_case['expected_exception']):
                acceptability_checker.check_messages(
                    messages=test_case['messages'],
                    is_worker_0=test_case['is_worker_0'],
                    violation_types=test_case['violation_types'],
                )
Ejemplo n.º 13
0
class PerTurnEvalWorld(ModelChatWorld):
    def __init__(self, opt, agent, bots, context_info: Optional[dict] = None):

        # num_turns turns for a single side, and really it appears to be
        # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance

        # TODO: this logic is very close to that of BaseModelChatWorld.__init__(). Can
        #  any of this be deduplicated?

        super(BaseModelChatWorld, self).__init__(opt, agent)

        num_turns = opt['num_turns']
        max_resp_time = opt['max_resp_time']

        self.opt = opt
        self.bots = bots

        self.task_turn_idx = 0
        self.num_turns = num_turns

        self.dialog = []
        self.tag = f'conversation_id {agent.mephisto_agent.db_id}'
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.check_acceptability = opt['check_acceptability']
        self.acceptability_checker = AcceptabilityChecker()
        self.block_qualification = opt['block_qualification']

        self.final_chat_data = None
        # TODO: remove this attribute once chat data is only stored in the Mephisto
        #  TaskRun for this HIT (see .get_custom_task_data() docstring for more
        #  information)

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        logging.info(
            f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.'
        )

        if context_info is not None:
            self.context_info = context_info
            self.personas = [
                self.context_info['persona_1_strings'],
                self.context_info['persona_2_strings'],
            ]
        else:
            self.context_info = {}
            self.personas = None

    def __add_problem_data_to_utterance(self, p, turn_idx: int):
        """
        Attach problem data to the bot's prior utterance, given by turn_idx.
        """
        logging.info(p)
        assert (self.dialog[turn_idx]['agent_idx'] == 1
                ), 'Problem data must be attached to a bot utterance.'
        assert ('problem_data' not in self.dialog[turn_idx]
                ), "Don't overwrite existing problem data!"
        self.dialog[turn_idx]['problem_data'] = p

    def parley(self):
        """
        The main function that controls the logic of the task. Uses self.task_turn_idx
        to control the sequence of the conversation.

        Specifically, when self.task_turn_idx is even, we know that the bots just gave
        their potential responses, and that it is the human's turn to choose one of the
        responses and give a justification value.

        When self.task_turn_idx is odd, we know that the human just chose one of the
        bots' responses, and now needs to respond to that response.

        self.task_turn_idx is initially 0, and during _run_initial_turn() the UI is
        redrawn to have the human select between the bots' responses. Then,
        self.task_turn_idx is incremented to 1.

        During self.agent.observe(), the UI is redrawn for the following human input,
        and during self.agent.act(), the code awaits the human input.
        """

        logging.info(
            f'{self.__class__.__name__}:{self.tag}: is at task_turn_idx '
            f'{self.task_turn_idx}, with {self.num_turns} pairs of turns needed...'
        )

        if self.task_turn_idx == 0:
            self._run_initial_turn()
            self.task_turn_idx += 1
            return

        logging.info(
            f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: '
            f'{self.task_turn_idx}')

        # At this point, we know that the human now needs to respond to the bot's
        # response that the human just chose

        # We retrieve information regarding the human's choice and justification using
        # self.agent.act()

        human_choose_bot_response_act = self.agent.act(
            timeout=self.max_resp_time)
        human_choose_bot_response_act = Message(
            Compatibility.maybe_fix_act(
                human_choose_bot_response_act)).json_safe_payload()

        logging.info(
            f'Got act for human, act was: {human_choose_bot_response_act} and '
            f'self.task_turn_idx: {self.task_turn_idx}.')

        accepted_bot_response = human_choose_bot_response_act['task_data'][
            'accepted_bot_response']
        accepted_bot_id = human_choose_bot_response_act['task_data'][
            'accepted_bot_id']
        accepted_bot_justification_value = human_choose_bot_response_act[
            'task_data']['justification_value']

        not_accepted_bot_response = human_choose_bot_response_act['task_data'][
            'not_accepted_bot_response']
        not_accepted_bot_id = human_choose_bot_response_act['task_data'][
            'not_accepted_bot_id']

        # We have both bots observe the accepted bot's response so that the conversation
        # history stays the same

        self.bots[0].observe(accepted_bot_response)
        self.bots[1].observe(accepted_bot_response)

        task_data = {}

        accepted_bot_utterance_data = {
            'text': accepted_bot_response['text'].split('<br>')[0],
            'id': accepted_bot_id,
        }
        not_accepted_bot_utterance_data = {
            'text': not_accepted_bot_response['text'].split('<br>')[0],
            'id': not_accepted_bot_id,
        }
        bot_utterance_data = {
            'agent_idx': 1,
            'accepted_bot_data': accepted_bot_utterance_data,
            'not_accepted_bot_data': not_accepted_bot_utterance_data,
            'human_choice': accepted_bot_id,
            'human_justification': accepted_bot_justification_value,
        }
        self.dialog.append(bot_utterance_data)

        self._postprocess_acts(acts=None, agent_idx=0)

        # All logic and processing for this step has now been done, so we do
        # self.agent.observe() to send the accepted response back to the frontend to
        # display and update task turn index, as well as await for the next action,
        # which is the human typing their response

        task_data['task_turn_idx'] = self.task_turn_idx
        # The UI will ask the human to respond to the chosen bot response
        self.agent.observe({
            'text': accepted_bot_response['text'],
            'task_data': task_data
        })

        # Make self.task_turn_idx even now
        self.task_turn_idx += 1

        # Check for whether 6 pairs of turns has been done, since the last message of a
        # conversation should always be the bot's response

        if (human_choose_bot_response_act is not None
                and human_choose_bot_response_act.get(
                    'task_data', {}).get('finished') is not None):
            self.chat_done = True
            # agent ends chat after exceeding minimum number of turns

            # Bot has just responded. Any problem data received now will be
            # regarding this bot's utterance

            # Get the final chat data
            self.final_chat_data = self.get_final_chat_data()

            # Soft-block the worker if there were acceptability violations
            acceptability_violations = self.final_chat_data[
                'acceptability_violations'][0]
            if acceptability_violations is not None and acceptability_violations != '':
                logging.info(f'**NOTE** Acceptability violations detected: '
                             f'{acceptability_violations}')
                # Grant the failed qualification
                self.agent.mephisto_agent.get_worker().grant_qualification(
                    self.block_qualification, 1)

            return

        logging.info(
            f'[human agent] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: '
            f'{self.dialog}')

        logging.info(
            f'Got act for human, act was: {human_choose_bot_response_act} and '
            f'self.task_turn_idx: {self.task_turn_idx}.')

        # At this point, we know that the human now needs to respond to the bot's
        # response that the human just chose

        # We retrieve information regarding the human's response using self.agent.act()

        human_response_act = self.agent.act(timeout=self.max_resp_time)

        # Again, we have both bots observe the human response so that the conversation
        # history stays the same
        self.bots[0].observe(validate(human_response_act))
        self.bots[1].observe(validate(human_response_act))

        # Check that the models' conversation histories are the same
        bot_1_history = self.bots[0].model_agent.history.history_strings
        bot_2_history = self.bots[1].model_agent.history.history_strings

        assert (
            bot_1_history == bot_2_history
        ), f"The two bots' conversation histories are different.\nBot 1 history: {bot_1_history}\nBot 2 history: {bot_2_history}"

        # After the bots have observed the human response, it's time for them to produce
        # their response to the human using self.bots.act()

        bot_1_response = self.bots[0].act()
        bot_1_response = Compatibility.maybe_fix_act(bot_1_response)

        bot_2_response = self.bots[1].act()
        bot_2_response = Compatibility.maybe_fix_act(bot_2_response)

        # We display the result to the frontend randomly so there is no selection bias.
        # Also, we attach our result to task_data to send arbitrary data to the frontend

        if random.random() > 0.5:
            task_data = {
                'top_bot_data': {
                    'top_bot_id': self.bots[0].worker_id,
                    'top_bot_response': bot_1_response,
                },
                'bottom_bot_data': {
                    'bottom_bot_id': self.bots[1].worker_id,
                    'bottom_bot_response': bot_2_response,
                },
                'task_turn_idx': self.task_turn_idx,
            }
        else:
            task_data = {
                'top_bot_data': {
                    'top_bot_id': self.bots[1].worker_id,
                    'top_bot_response': bot_2_response,
                },
                'bottom_bot_data': {
                    'bottom_bot_id': self.bots[0].worker_id,
                    'bottom_bot_response': bot_1_response,
                },
                'task_turn_idx': self.task_turn_idx,
            }

        human_utterance_data = {
            'agent_idx':
            0,
            # Get rid of annotations HTML if it's the bot response
            'text':
            human_response_act['text'].split('<br>')[0],
            'id':
            human_response_act['id'] if 'id' in human_response_act else
            'NULL_ID',  # Person1 or Polyencoder
        }

        self.dialog.append(human_utterance_data)

        # Human has just responded. Any problem data received now will be regarding the
        # bot's prior utterance
        p = human_response_act['task_data'].get(
            'problem_data_for_prior_message')
        if p is not None:
            turn_idx = -2
            # Attach the problem data to the second-to-last utterance, since the last
            # utterance is what the human just said
            self.__add_problem_data_to_utterance(p, turn_idx=turn_idx)

        self._postprocess_acts(acts=None, agent_idx=0)

        task_data['task_turn_idx'] = self.task_turn_idx

        # All logic and processing for this step has now been done, so we do
        # self.agent.observe() to send the two bots' responses back to the frontend to
        # display and update task turn index, as well as await for the next action,
        # which is the human choosing from the two responses and providing a
        # justification value

        # The UI will ask the human to choose between two bot responses and give a
        # justification
        logging.info(f'*** self.task_turn_idx: {self.task_turn_idx} ***')
        self.agent.observe({'text': '', 'task_data': task_data})

        # Make self.task_turn_idx odd now
        self.task_turn_idx += 1

        logging.info(
            f'[bot agent] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: '
            f'{self.dialog}')

    def get_final_chat_data(self) -> Dict[str, Any]:
        """
        Add relevant fields to the final chat data.
        """

        if self.check_acceptability:
            human_messages, violation_types = self._prepare_acceptability_checking(
            )
            violations_string = self.acceptability_checker.check_messages(
                messages=human_messages,
                is_worker_0=False,
                violation_types=violation_types,
            )
        else:
            violations_string = None

        data = {
            'dialog': self.dialog,
            'workers': [get_mturk_id_from_mephisto_wrapper(self.agent)],
            'bad_workers': [],
            'acceptability_violations': (violations_string, ),
            'hit_ids': [self.agent.mephisto_agent.task_run_id],
            'assignment_ids': [self.agent.mephisto_agent.assignment_id],
            'task_description': {
                'annotations_config': self.opt['annotations_config'],
                'model_1_nickname': self.bots[0].worker_id,
                'model_1_file': self.bots[0].model_agent.opt.get('model_file'),
                'model_1_opt': self.bots[0].model_agent.opt,
                'model_2_nickname': self.bots[1].worker_id,
                'model_2_file': self.bots[1].model_agent.opt.get('model_file'),
                'model_2_opt': self.bots[1].model_agent.opt,
            },
        }
        # 'bad_workers' is for compatibility. Before, it was only non-empty if a
        # worker abandoned, returned, etc. a HIT, but now we don't even save chat
        # data in that case
        if self.check_acceptability:
            data['acceptability_violations'] = (violations_string, )
            # Make a tuple for compatibility with a human/human conversation in
            # which we check both sides for acceptability

        context_data = {
            'personas':
            self.personas,
            'context_dataset':
            self.context_info.get('context_dataset'),
            'person1_seed_utterance':
            self.context_info.get('person1_seed_utterance'),
            'person2_seed_utterance':
            self.context_info.get('person2_seed_utterance'),
            'additional_context':
            self.context_info.get('additional_context'),
        }
        data.update(context_data)
        return data

    def _run_initial_turn(self) -> None:
        """
        Run the initial turn for both the human and the bot.

        Optionally show the bot its persona. If we are in Meena-like conversation mode
        show "Hi!" to the human and the bot and let the bot respond accordingly.

        Check parley() function for more information on the main logic.
        """
        control_msg = {"episode_done": False}

        if self.opt['include_persona']:
            # The Bot agent
            # We add the personas and 1/3 of the time WoW topic as the
            # first utterance in the history.
            # Previously for BST task, we also had a big first utterance
            # that gave instructions. Removing that for this task.
            persona_strings = [s.strip() for s in self.personas[1]]
            persona_utterance = self._get_persona_utterance(
                persona_strings=persona_strings,
                context_dataset=self.context_info['context_dataset'],
                additional_context=self.context_info['additional_context'],
                is_bot=True,
            )
            message = control_msg.copy()
            message['text'] = persona_utterance
            # The bot seeing its persona does not count as a "turn"
            self.bots[0].observe(validate(message), increment_turn=False)
            self.bots[1].observe(validate(message), increment_turn=False)

        if self.opt['conversation_start_mode'] == 'hi':
            logging.info('[Displaying "Hi!" only as per Meena task.]')
            if self.personas is not None:
                human_persona_strings = [s.strip() for s in self.personas[0]]
            else:
                human_persona_strings = ['', '']
            human_first_msg = {
                'episode_done': False,
                'id': self.agent.id,
                'text': 'Hi!',
                'fake_start': True,
                'agent_idx': 0,
                'task_data': {
                    'human_persona_string_1': human_persona_strings[0],
                    'human_persona_string_2': human_persona_strings[1],
                    'prompt_instruction': self.opt['task_question'],
                },
            }
            for k, v in control_msg.items():
                human_first_msg[k] = v

            # The first message is always "Hi", so we have both bots observe the message

            self.dialog.append(human_first_msg)
            self.agent.observe(validate(human_first_msg))
            self.bots[0].observe(validate(human_first_msg))
            self.bots[1].observe(validate(human_first_msg))

            bot_1_response = self.bots[0].act()
            bot_1_response = Compatibility.maybe_fix_act(bot_1_response)

            bot_2_response = self.bots[1].act()
            bot_2_response = Compatibility.maybe_fix_act(bot_2_response)

            if random.random() > 0.5:
                task_data = {
                    'top_bot_data': {
                        'top_bot_id': self.bots[0].worker_id,
                        'top_bot_response': bot_1_response,
                    },
                    'bottom_bot_data': {
                        'bottom_bot_id': self.bots[1].worker_id,
                        'bottom_bot_response': bot_2_response,
                    },
                    'task_turn_idx': self.task_turn_idx,
                }
            else:
                task_data = {
                    'top_bot_data': {
                        'top_bot_id': self.bots[1].worker_id,
                        'top_bot_response': bot_2_response,
                    },
                    'bottom_bot_data': {
                        'bottom_bot_id': self.bots[0].worker_id,
                        'bottom_bot_response': bot_1_response,
                    },
                    'task_turn_idx': self.task_turn_idx,
                }

            # Need an initial human's observe to observe the two choices from the bot
            self.agent.observe({'text': '', 'task_data': task_data})

        else:
            raise ValueError(
                f"Conversation start mode {self.opt['conversation_start_mode']} "
                f"not recognized!")

    def shutdown(self):

        if self.chat_done:
            self.opt['run_statistics'][
                f'{self.bots[0].worker_id}:{self.bots[1].worker_id}'] += 1
            logging.info('Runs completed per model: ' + ', '.join(
                f'{model}: {count:d}'
                for model, count in self.opt['run_statistics'].items()))

        self.agent.shutdown()
Ejemplo n.º 14
0
class TurnAnnotationsChatWorld(CrowdTaskWorld):
    def __init__(self,
                 opt,
                 agent=None,
                 bot=None,
                 context_info: Optional[dict] = None):
        super().__init__(opt, agent)

        # num_turns turns for a single side, and really it appears to be
        # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance

        num_turns = opt['num_turns']
        max_resp_time = opt['max_resp_time']

        self.opt = opt
        self.bot = bot
        self.task_turn_idx = 0
        self.num_turns = num_turns

        self.dialog = []
        self.tag = f'conversation_id {agent.mephisto_agent.db_id}'
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        if context_info is not None:
            self.context_info = context_info
            self.personas = [
                self.context_info['persona_1_strings'],
                self.context_info['persona_2_strings'],
            ]
        else:
            self.context_info = {}
            self.personas = None
        self.check_acceptability = opt['check_acceptability']
        self.acceptability_checker = AcceptabilityChecker()
        self.block_qualification = opt['block_qualification']

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        print(
            f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.'
        )

    def __add_problem_data_to_utterance(self, p, turn_idx: int):
        """
        Attach problem data to the bot's prior utterance, given by turn_idx.
        """
        print(p)
        assert (self.dialog[turn_idx]['agent_idx'] == 1
                ), 'Problem data must be attached to a bot utterance.'
        assert ('problem_data' not in self.dialog[turn_idx]
                ), "Don't overwrite existing problem data!"
        self.dialog[turn_idx]['problem_data'] = p

    def parley(self):
        print(
            f'{self.__class__.__name__}:{self.tag}: is at turn {self.task_turn_idx}, with {self.num_turns} pairs of turns needed...'
        )

        control_msg = {"episode_done": False}

        if self.task_turn_idx == 0:
            if self.opt['include_persona']:
                # The Bot agent
                # We add the personas and 1/3 of the time WoW topic as the
                # first utterance in the history.
                # Previously for BST task, we also had a big first utterance
                # that gave instructions. Removing that for this task.
                persona_strings = [s.strip() for s in self.personas[1]]
                persona_utterance = self._get_persona_utterance(
                    persona_strings=persona_strings,
                    context_dataset=self.context_info['context_dataset'],
                    additional_context=self.context_info['additional_context'],
                    is_bot=True,
                )
                message = control_msg.copy()
                message['text'] = persona_utterance
                # The bot seeing its persona does not count as a "turn"
                self.bot.observe(validate(message), increment_turn=False)

            if self.opt['conversation_start_mode'] == 'bst':
                print('[Displaying first utterances as per BST task.]')
                # Display the previous two utterances
                human_first_msg = {
                    'episode_done': False,
                    'id': self.agent.id,
                    'text': self.context_info['person1_seed_utterance'],
                    'fake_start': True,
                    'agent_idx': 0,
                }
                for k, v in control_msg.items():
                    human_first_msg[k] = v
                bot_first_msg = {
                    'episode_done': False,
                    'id': self.bot.id,
                    'text': self.context_info['person2_seed_utterance'],
                    'fake_start': True,
                    'agent_idx': 1,
                }
                print(
                    f'human_first_msg: {human_first_msg}, bot_first_msg: {bot_first_msg}'
                )

                self.dialog.append(human_first_msg)
                self.dialog.append(bot_first_msg)

                for observer in [self.agent, self.bot]:
                    observer.observe(validate(human_first_msg))
                    observer.observe(validate(bot_first_msg))

            elif self.opt['conversation_start_mode'] == 'hi':
                print('[Displaying "Hi!" only as per Meena task.]')
                human_first_msg = {
                    'episode_done': False,
                    'id': self.agent.id,
                    'text': 'Hi!',
                    'fake_start': True,
                    'agent_idx': 0,
                }
                for k, v in control_msg.items():
                    human_first_msg[k] = v

                self.dialog.append(human_first_msg)
                self.agent.observe(validate(human_first_msg))
                self.bot.observe(validate(human_first_msg))

                first_bot_act = self.bot.act()
                first_bot_act = Compatibility.maybe_fix_act(first_bot_act)

                self.agent.observe(validate(first_bot_act))

                bot_utterance_data = {
                    'agent_idx': 1,
                    'text': first_bot_act['text'],
                    'id': first_bot_act['id'],
                }
                self.dialog.append(bot_utterance_data)

            else:
                raise ValueError(
                    f"Conversation start mode {self.opt['conversation_start_mode']} "
                    f"not recognized!")

            self.task_turn_idx += 1
            return
        """Otherwise, we proceed accordingly"""
        print(
            f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: {self.task_turn_idx}'
        )
        acts = [None, None]
        for idx, agent in enumerate([self.agent, self.bot]):
            if not self.chat_done:
                acts[idx] = agent.act(timeout=self.max_resp_time)
                acts[idx] = Compatibility.maybe_fix_act(acts[idx])
                print(
                    f'Got act for agent idx {idx}, act was: {acts[idx]} and self.task_turn_idx: {self.task_turn_idx}.'
                )

            if acts[idx].get('task_data', {}).get('final_rating') is not None:

                self.chat_done = True
                # agent ends chat after exceeding minimum number of turns

                if self.task_turn_idx > self.num_turns:
                    # Human has just responded. Problem data received
                    # now will be regarding the bot's prior utterance
                    p = acts[idx]['task_data'][
                        'problem_data_for_prior_message']
                    turn_idx = -1
                    # Attach the problem data to the last utterance, since the human
                    # hasn't said anything since then
                    self.__add_problem_data_to_utterance(p, turn_idx=turn_idx)

                # Save the final chat data
                time_string = time.strftime('%Y%m%d_%H%M%S')
                chat_data_folder = self.opt['chat_data_folder']
                os.makedirs(chat_data_folder, exist_ok=True)
                chat_data_path = os.path.join(
                    chat_data_folder,
                    f'{time_string}_{np.random.randint(0, 1000)}_{self.task_type}.json',
                )
                final_chat_data = self.get_final_chat_data()
                self.agent.mephisto_agent.state.messages.append(
                    {'final_chat_data': final_chat_data})
                # Append the chat data directly to the agent state's message list in
                # order to prevent the worker from seeing a new text response in the UI
                with open(chat_data_path, 'w+') as f_json:
                    data_str = json.dumps(final_chat_data)
                    f_json.write(data_str)
                print(f'{self.__class__.__name__}:{self.tag}: Data saved at '
                      f'{chat_data_path} for model: {self.bot.worker_id}.')

                # Soft-block the worker if there were acceptability violations
                acceptability_violations = final_chat_data[
                    'acceptability_violations'][0]
                if (acceptability_violations is not None
                        and acceptability_violations != ''):
                    print(
                        f'**NOTE** Acceptability violations detected: {acceptability_violations}'
                    )
                    # Grant the failed qualification
                    self.agent.mephisto_agent.get_worker().grant_qualification(
                        self.block_qualification, 1)

                return

            else:
                utterance_data = {
                    'agent_idx': idx,
                    # Get rid of annotations HTML if it's the bot response
                    'text': acts[idx]['text'].split('<br>')[0],
                    'id': acts[idx]['id'] if 'id' in acts[idx] else
                    'NULL_ID',  # Person1 or Polyencoder
                }
                self.dialog.append(utterance_data)
                if idx == 0:
                    # Human has just responded. Problem data received
                    # now will be regarding the bot's prior utterance
                    p = acts[idx]['task_data'][
                        'problem_data_for_prior_message']
                    turn_idx = -2
                    # Attach the problem data to the second-to-last utterance, since the
                    # last utterance is what the human just said
                    self.__add_problem_data_to_utterance(p, turn_idx=turn_idx)

                for other_agent in [self.agent, self.bot]:
                    if other_agent != agent:
                        other_agent.observe(validate(acts[idx]))

                print(
                    f'[agent {idx}] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: {self.dialog}'
                )
                self.task_turn_idx += 1

    def shutdown(self):

        if self.chat_done:
            self.opt['run_statistics'][self.bot.worker_id] += 1
            print('Runs completed per model: ' + ', '.join(
                f'{model}: {count:d}'
                for model, count in self.opt['run_statistics'].items()))

        self.agent.shutdown()

    def episode_done(self):
        return self.chat_done

    def _get_persona_utterance(
        self,
        persona_strings: Optional[List[str]] = None,
        context_dataset: Optional[str] = None,
        additional_context: Optional[str] = None,
        is_bot: bool = False,
    ):
        if is_bot:
            # Pass back the original context
            persona_pieces = [
                f"your persona: {str_}" for str_ in persona_strings
            ]
            if context_dataset == 'wizard_of_wikipedia':
                additional_context_pieces = [additional_context]
            else:
                additional_context_pieces = []
            full_context = '\n'.join(persona_pieces +
                                     additional_context_pieces)
            print(f'FULL CONTEXT: {full_context}')
            return full_context
        else:
            if context_dataset == 'convai2':
                last_sentence = 'Pretend that the conversation has already begun.'
            elif context_dataset == 'empathetic_dialogues':
                last_sentence = (
                    f'Pretend that the conversation has already begun, and that you '
                    f'had been talking about the following situation: '
                    f'<b>"{additional_context}"</b>')
            elif context_dataset == 'wizard_of_wikipedia':
                last_sentence = (
                    f'Pretend that the conversation has already begun, and that you '
                    f'had been talking about <b>{additional_context}</b>.')
            else:
                raise ValueError('Context dataset unrecognized!')
            joined_personas = '\n'.join(persona_strings)
            return (
                f'\nSuccessfully matched with another user! Now let\'s get to know '
                f'each other through the chat. You need to finish at least '
                f'<b>{self.num_turns} chat turns</b>, and after that you can click the '
                f'"Done" button to end the chat.\n\n'
                f'<b>Your character description is:\n<span style="color:blue">{joined_personas}</span></b> '
                '\n\n<b>Remember that you can get to know each '
                'other as your characters, talk about any topic, or talk about a '
                'situation that might have happened to your character.</b>'
                '\n<b>Do not trivially copy the '
                'character descriptions into the message.</b><br><br>'
                f'{last_sentence}')

    def get_final_chat_data(self) -> Dict[str, Any]:
        """
        Return specific info about the conversation, the context, acceptability, etc.
        """

        if self.check_acceptability:
            human_texts = [
                message['text'] for message in self.dialog
                if message['agent_idx'] == 0
            ]
            violation_types = [
                'min_words', 'all_caps', 'exact_match', 'safety'
            ]
            if self.opt['conversation_start_mode'] == 'bst':
                # The BST mode starts the conversation with two previous utterances, so
                # there should be no new greeting. Also, the first human response is one
                # of the previous utterances, so it shouldn't get checked.
                violation_types.append('penalize_greetings')
                human_texts = human_texts[1:]

            violations_string = self.acceptability_checker.check_messages(
                messages=human_texts,
                is_worker_0=False,
                violation_types=violation_types)
        else:
            violations_string = None

        data = {
            'personas':
            self.personas,
            'context_dataset':
            self.context_info.get('context_dataset'),
            'person1_seed_utterance':
            self.context_info.get('person1_seed_utterance'),
            'person2_seed_utterance':
            self.context_info.get('person2_seed_utterance'),
            'additional_context':
            self.context_info.get('additional_context'),
            'dialog':
            self.dialog,
            'workers': [get_mturk_id_from_mephisto_wrapper(self.agent)],
            'bad_workers': [],
            'acceptability_violations': (violations_string, ),
            'hit_ids': [self.agent.mephisto_agent.task_run_id],
            'assignment_ids': [self.agent.mephisto_agent.assignment_id],
            'task_description': {
                'annotations_config': self.opt['annotations_config'],
                'model_nickname': self.bot.worker_id,
                'model_file': self.bot.model_agent.opt.get('model_file'),
                'model_opt': self.bot.model_agent.opt,
            },
        }
        # 'bad_workers' is for compatibility. Before, it was only non-empty if a
        # worker abandoned, returned, etc. a HIT, but now we don't even save chat
        # data in that case
        if self.check_acceptability:
            data['acceptability_violations'] = (violations_string, )
            # Make a tuple for compatibility with a human/human conversation in
            # which we check both sides for acceptability

        return data
Ejemplo n.º 15
0
class ModelChatResultsCompiler(BaseModelChatResultsCompiler):
    """
    Compile and save results of human+model chats.

    Results will be saved on the level of specific conversations, as well as aggregated
    up the level of each worker as a whole.
    """
    @classmethod
    def setup_args(cls):
        parser = super().setup_args()
        parser.add_argument('--model-nickname',
                            type=str,
                            default='',
                            help='name of the model')
        parser.add_argument(
            '--completed-run-stats-path',
            type=str,
            default='',
            help='path of the task run stats file',
        )
        return parser

    def __init__(self, opt: Dict[str, Any]):

        AbstractTurnAnnotationResultsCompiler.__init__(self, opt)

        # Input args
        self.model_nickname = opt['model_nickname']
        assert len(self.results_folders) > 0
        for folder in self.results_folders:
            assert os.path.isdir(folder), f'{folder} is not a valid folder!'
        os.makedirs(self.output_folder, exist_ok=True)
        self.start_date = opt['start_date']
        self.max_convos_per_worker = opt['max_convos_per_worker']
        self.min_word_count = opt['min_word_count']
        self.hit_block_list = opt['hit_block_list'].split(',')
        self.worker_block_list = opt['worker_block_list'].split(',')

        # Setting up problem buckets
        if self.use_problem_buckets:
            self.regular_buckets = [
                bucket for bucket in self.problem_buckets
                if bucket not in ['other', 'none_all_good']
            ]
            # Remove the buckets that are special cases

        self.acceptability_checker = AcceptabilityChecker()
        self.completed_run_stats_path = opt['completed_run_stats_path']

    def compile_results(self) -> pd.DataFrame:
        # TODO modularize the shared components to dedup the code
        read_folders = []
        date_strings = []
        import ipdb

        ipdb.set_trace()
        for folder in self.results_folders:
            # Load paths
            # TODO load this data in using DataBrowser
            date_strings = sorted([
                obj for obj in os.listdir(folder)
                if os.path.isdir(os.path.join(folder, obj))
                and re.fullmatch(r'\d\d\d\d_\d\d_\d\d', obj)
            ])
            if self.start_date != '':
                date_strings = [
                    str_ for str_ in date_strings if str_ >= self.start_date
                ]
            folders = [os.path.join(folder, str_) for str_ in date_strings]
            read_folders.extend(folders)
        print(f'Date folders: ' + ', '.join(date_strings))

        # Read in each file
        num_incomplete_convos = 0
        num_complete_convos = 0
        complete_convos_per_model = {}
        bad_conversations = []
        worker_stats = {}
        worker_conversation_counts = {}

        conversation_idx = 0
        conversation_dfs = []
        stat_counts = {}
        for read_folder in read_folders:
            read_folder_name = os.path.split(read_folder)[-1]
            for file_name in sorted(os.listdir(read_folder)):
                if file_name in self.hit_block_list or 'sandbox' in file_name:
                    continue

                if 'incomplete' in file_name:
                    num_incomplete_convos += 1
                    continue
                else:
                    num_complete_convos += 1

                # Read in file
                with open(os.path.join(read_folder, file_name), 'rb') as f:
                    data = json.load(f)

                # Only include the first max_convos_per_worker conversations from a
                # worker to avoid biasing
                worker_id = data['workers'][0]
                worker_id = worker_id.split('-')[-1]
                assignment_id = data['assignment_ids'][0]
                if worker_id in worker_conversation_counts:
                    conversations_so_far = worker_conversation_counts[
                        worker_id]
                else:
                    conversations_so_far = 0
                worker_conversation_counts[
                    worker_id] = conversations_so_far + 1
                if (self.max_convos_per_worker != -1 and
                        conversations_so_far >= self.max_convos_per_worker):
                    print(
                        f'Had {conversations_so_far} conversation(s) already from this worker {worker_id}. Skipping {assignment_id}.'
                    )
                    continue

                # Check if need to block the turker
                word_counts = [
                    len(d['text'].split(' ')) for d in data['dialog']
                    if d['agent_idx'] == 0
                ]
                utterances = [
                    d['text'] for d in data['dialog'] if d['agent_idx'] == 0
                ]
                if np.average(word_counts) < self.min_word_count:
                    bad_conversations.append(data)
                    print(
                        f'Bad complete conversation, words from human: {utterances}. Skipping.'
                    )
                    continue

                if not all(bucket in data['dialog'][0]['problem_data']
                           for bucket in self.problem_buckets):
                    raise ValueError(
                        'Bucket(s) are missing from the problem data!')

                model_nickname = data['model_name']
                assert self.model_nickname == model_nickname
                initial_data_id = data['context_info']['observation_for_bot'][
                    'initial_data_id']
                if model_nickname not in stat_counts:
                    stat_counts[model_nickname] = {}
                if model_nickname in complete_convos_per_model:
                    complete_convos_per_model[model_nickname].append(
                        initial_data_id)
                else:
                    complete_convos_per_model[model_nickname] = [
                        initial_data_id
                    ]

                # Extract non-message info
                info_dict = {
                    'read_folder_name': read_folder_name,
                    'file_name': file_name,
                    'worker': worker_id,
                    'model_nickname': model_nickname,
                    'bad_workers': ','.join(data['bad_workers']),
                    'hit_id': data['hit_ids'][0],
                    'assignment_id': assignment_id,
                    'is_incomplete': 'incomplete' in file_name,
                    'context_info': data['context_info'],
                    'bot_persona_strings': data['bot_persona_strings'],
                    'human_persona_strings': data['human_persona_strings'],
                    'initial_task_data': data['initial_task_data'],
                    'initial_data_id': initial_data_id,
                }

                # Check that the conversation consists of pairs of comments between
                # agents 0 and 1, with 1(bot) speaking first
                assert all([
                    utterance_data['agent_idx'] == (utterance_idx + 1) % 2 for
                    utterance_idx, utterance_data in enumerate(data['dialog'])
                ])

                # Determine whether the HIT contains unacceptable messages.
                # (We do this for every HIT, even if acceptability violation info
                # was already saved, because the violation criteria may have
                # changed since the HIT was collected.)
                messages_0 = [
                    utt for utt in data['dialog'] if utt['agent_idx'] == 0
                ]
                messages_1 = [
                    utt for utt in data['dialog'] if utt['agent_idx'] == 1
                ]
                assert len(messages_0) + len(messages_1) == len(data['dialog'])

                # Check the human utterances for safety
                utterances_0 = [m['text'] for m in messages_0]
                info_dict[
                    'acceptability_violations_0'] = self.acceptability_checker.check_messages(
                        messages=utterances_0,
                        is_worker_0=True,
                        violation_types=self.acceptability_checker.
                        ALL_VIOLATION_TYPES,
                    )

                # Compile personas and previous utterances
                df = pd.DataFrame(
                    [],
                    columns=[
                        'folder',
                        'file_name'
                        'worker_id',
                        'hit_id',
                        'is_incomplete',
                        'context_info',
                        'initial_data_id',
                        'acceptability_violations_0',
                        'model_nickname',
                        'conversation_idx',
                        'turn_idx',
                        'agent_idx',
                        'text',
                    ] + self.problem_buckets,
                )
                df = df.append(
                    {
                        'folder':
                        info_dict['read_folder_name'],
                        'file_name':
                        info_dict['file_name'],
                        'worker_id':
                        info_dict['worker'],
                        'hit_id':
                        info_dict['hit_id'],
                        'is_incomplete':
                        info_dict['is_incomplete'],
                        'context_info':
                        info_dict['context_info'],
                        'initial_data_id':
                        info_dict['initial_task_data'],
                        'acceptability_violations_0':
                        info_dict['acceptability_violations_0'],
                        'model_nickname':
                        model_nickname,
                        'conversation_idx':
                        conversation_idx,
                        'turn_idx':
                        -1,
                        'agent_idx':
                        0,
                        'text':
                        info_dict['context_info']['observation_for_bot']
                        ['text'],
                        **{bucket: ''
                           for bucket in self.problem_buckets},
                    },
                    ignore_index=True,
                )

                for utterance_idx, utt in enumerate(data['dialog']):

                    d = {
                        'folder':
                        info_dict['read_folder_name'],
                        'file_name':
                        info_dict['file_name'],
                        'worker_id':
                        info_dict['worker'],
                        'hit_id':
                        info_dict['hit_id'],
                        'is_incomplete':
                        info_dict['is_incomplete'],
                        'context_info':
                        info_dict['context_info'],
                        'initial_data_id':
                        info_dict['initial_task_data'],
                        'acceptability_violations_0':
                        info_dict['acceptability_violations_0'],
                        'model_nickname':
                        model_nickname,
                        'conversation_idx':
                        conversation_idx,
                        'turn_idx':
                        utterance_idx,
                        'agent_idx':
                        utt['agent_idx'],
                        'text':
                        utt['text'],
                        **{bucket: ''
                           for bucket in self.problem_buckets},
                    }

                    if utt['agent_idx'] == 1:
                        if 'problem_data' not in utt:
                            for bucket in self.problem_buckets:
                                d[bucket] = 'MALFORMED'
                            print(
                                f'Warning got MALFORMED utterance problem data inside complete convo: {utt}. Skipping.'
                            )
                            continue
                        else:
                            for bucket in self.regular_buckets:
                                d[bucket] = utt['problem_data'][bucket]
                            d['final_rating'] = (utt['final_rating']
                                                 if 'final_rating' in utt else
                                                 None)
                        for k in self.regular_buckets:
                            if k not in stat_counts[model_nickname]:
                                stat_counts[model_nickname][k] = 0
                            stat_counts[model_nickname][k] += d[k]

                        if 'total' not in stat_counts[model_nickname]:
                            stat_counts[model_nickname]['total'] = 0
                        if d['agent_idx'] == 1:
                            stat_counts[model_nickname]['total'] += 1
                        if d['final_rating'] is not None:
                            # Only one the last utterance (agent idx == 1)
                            if 'count_ratings' not in stat_counts[
                                    model_nickname]:
                                stat_counts[model_nickname][
                                    'count_ratings'] = 0
                            stat_counts[model_nickname]['count_ratings'] += 1
                            if 'ratings' not in stat_counts[model_nickname]:
                                stat_counts[model_nickname]['ratings'] = []
                            if 'pairwise_ratings' not in stat_counts[
                                    model_nickname]:
                                stat_counts[model_nickname][
                                    'pairwise_ratings'] = {}
                            stat_counts[model_nickname]['ratings'].append(
                                int(d['final_rating']))
                            stat_counts[model_nickname]['pairwise_ratings'][
                                info_dict['initial_data_id']] = int(
                                    d['final_rating'])

                        if 'bot_word_count' not in stat_counts[model_nickname]:
                            stat_counts[model_nickname]['bot_word_count'] = 0
                        stat_counts[model_nickname]['bot_word_count'] += len(
                            d['text'].strip().split(' '))
                    else:

                        # Counting some aspects of the human's utterances
                        if 'human_utterance_count' not in stat_counts[
                                model_nickname]:
                            stat_counts[model_nickname][
                                'human_utterance_count'] = 0
                        stat_counts[model_nickname][
                            'human_utterance_count'] += 1

                        if 'human_word_count' not in stat_counts[
                                model_nickname]:
                            stat_counts[model_nickname]['human_word_count'] = 0
                        stat_counts[model_nickname]['human_word_count'] += len(
                            d['text'].strip().split(' '))

                        if 'human_question_count' not in stat_counts[
                                model_nickname]:
                            stat_counts[model_nickname][
                                'human_question_count'] = 0
                        stat_counts[model_nickname][
                            'human_question_count'] += d['text'].count('?')

                # Only want to count bot utterances but human ones, while included,
                # won't be False
                if info_dict['worker'] not in worker_stats:
                    worker_stats[info_dict['worker']] = {'conversations': 0}
                worker_stats[info_dict['worker']]['conversations'] += 1

                # Logic for calculating percent of conversations that are clean
                if 'count_convos' not in stat_counts[model_nickname]:
                    stat_counts[model_nickname]['count_convos'] = 0
                stat_counts[model_nickname]['count_convos'] += 1

                # Adding the full conversation to the list of conversations
                conversation_dfs.append(df)
                conversation_idx += 1

        for m, conversations_completed in complete_convos_per_model.items():
            print(
                f'Got {len(conversations_completed)} complete conversations for model: {m}'
            )
            print(f"{m} completed: {conversations_completed}")

        print(f'{num_complete_convos:d} complete conversation(s) collected.')
        print(f'{len(bad_conversations):d} bad conversation(s).')
        num_approved_convos = num_complete_convos - len(bad_conversations)
        print(f'{num_approved_convos:d} approved conversation(s).')
        print(
            f'({num_incomplete_convos:d} incomplete conversation(s) collected.)'
        )
        for model_nickname, model_stats_dict in stat_counts.items():
            print(f'---{model_nickname}---')
            for p, v in model_stats_dict.items():
                if p == 'count_ratings' or p == 'pairwise_ratings':
                    continue
                if p == 'ratings':
                    print(
                        f'Average Engaging-ness Rating: {np.average(model_stats_dict["ratings"])} ({model_stats_dict["count_ratings"]} ratings)'
                    )
                    continue
                if p == 'human_word_count' or p == 'human_question_count':
                    print(
                        f'{p}: {v} ({v/model_stats_dict["human_utterance_count"]:.3})'
                    )
                elif p == 'bot_word_count':
                    print(f'{p}: {v} ({v/model_stats_dict["total"]:.3})')
                elif p == 'human_utterance_count':
                    print(f'{p}: {v}')
                elif p == 'count_convos':
                    print(f'{p}: {v}')
                else:
                    print(f'{p}: {v} ({v/model_stats_dict["total"]:.2%})')

        print('Printing worker IDs not already in block list to add...')
        for b in bad_conversations:
            worker_id = b['workers'][0]
            if worker_id not in self.worker_block_list:
                print(f"""'{worker_id}',""")
        print('Done printing bad workers.')

        worker_df = pd.DataFrame([], columns=['worker_id', 'conversations'])

        for worker_id, data in worker_stats.items():
            stat = {
                'worker_id': worker_id,
                'conversations': data['conversations']
            }
            worker_df = worker_df.append(stat, ignore_index=True)

        with open(self.completed_run_stats_path, 'r') as f:
            completed_run_stats = json.load(f)
        assert completed_run_stats['bot_model_name'] == self.model_nickname
        completed_run_stats['context_done_statistics'][
            self.model_nickname] = complete_convos_per_model[
                self.model_nickname]
        completed_run_stats['context_done_counts'] = len(
            complete_convos_per_model[self.model_nickname])
        with open(self.completed_run_stats_path, 'w') as fw:
            json.dump(completed_run_stats, fw)
        print(f'Wrote override opt to: {self.completed_run_stats_path}')

        rating_path = os.path.join(self.output_folder,
                                   f'pairwise_ratings.json')
        with open(rating_path, 'w') as fw:
            json.dump(stat_counts[self.model_nickname]['pairwise_ratings'], fw)
        print(f'Wrote pairwise ratings to: {rating_path}')

        # Save full results
        all_conversations_df = pd.DataFrame()
        for df in conversation_dfs:
            all_conversations_df = all_conversations_df.append(df)

        return all_conversations_df
Ejemplo n.º 16
0
class ModelChatResultsCompiler(AbstractTurnAnnotationResultsCompiler):
    """
    Compile and save results of human+model chats.

    Results will be saved on the level of specific conversations, as well as aggregated
    up the level of each worker as a whole.
    """
    @classmethod
    def setup_args(cls):
        parser = super().setup_args()
        parser.add_argument(
            '--start-date',
            type=str,
            default='',
            help='The earliest date to analyze results from',
        )
        parser.add_argument(
            '--max-convos-per-worker',
            type=int,
            default=100,
            help=
            'The most conversations to analyze from any one user. Set to -1 for no limit.',
        )
        parser.add_argument(
            '--min-word-count',
            type=int,
            default=4,
            help=
            'The minimum acceptable mean number of words per human utterance',
        )
        parser.add_argument(
            '--hit-block-list',
            type=str,
            default='',
            help='Comma-separated list of all hits to block',
        )
        parser.add_argument(
            '--worker-block-list',
            type=str,
            default='',
            help='Comma-separated list of all workers to block',
        )
        return parser

    def __init__(self, opt: Dict[str, Any]):

        super().__init__(opt)
        # Validate problem buckets
        if self.use_problem_buckets and 'none_all_good' not in self.problem_buckets:
            # The code relies on a catchall "none" category if the user selects no other
            # annotation bucket
            raise ValueError(
                'There must be a "none_all_good" category in self.problem_buckets!'
            )

        # Input args
        assert len(self.results_folders) > 0
        for folder in self.results_folders:
            assert os.path.isdir(folder), f'{folder} is not a valid folder!'
        os.makedirs(self.output_folder, exist_ok=True)
        self.start_date = opt['start_date']
        self.max_convos_per_worker = opt['max_convos_per_worker']
        self.min_word_count = opt['min_word_count']
        self.hit_block_list = opt['hit_block_list'].split(',')
        self.worker_block_list = opt['worker_block_list'].split(',')

        # Setting up problem buckets
        if self.use_problem_buckets:
            self.regular_buckets = [
                bucket for bucket in self.problem_buckets
                if bucket not in ['other', 'none_all_good']
            ]
            # Remove the buckets that are special cases

        self.acceptability_checker = AcceptabilityChecker()

    def get_results_path_base(self) -> str:
        now = datetime.now()
        return os.path.join(self.output_folder,
                            f'results_{now.strftime("%Y%m%d_%H%M%S")}')

    def compile_results(self) -> pd.DataFrame:

        read_folders = []
        date_strings = []
        for folder in self.results_folders:
            # Load paths
            date_strings = sorted([
                obj for obj in os.listdir(folder)
                if os.path.isdir(os.path.join(folder, obj))
                and re.fullmatch(r'\d\d\d\d_\d\d_\d\d', obj)
            ])
            if self.start_date != '':
                date_strings = [
                    str_ for str_ in date_strings if str_ >= self.start_date
                ]
            folders = [os.path.join(folder, str_) for str_ in date_strings]
            read_folders.extend(folders)
        print(f'Date folders: ' + ', '.join(date_strings))

        now = datetime.now()
        worker_results_file = os.path.join(
            self.output_folder,
            f'worker_results_{now.strftime("%Y%m%d_%H%M%S")}.csv')
        # Read in each file
        num_incomplete_convos = 0
        num_complete_convos = 0
        complete_convos_per_model = {}
        bad_conversations = []
        stat_counts = {}
        worker_stats = {}
        worker_conversation_counts = {}
        total_utterances = 0

        conversation_idx = 0
        conversation_dfs = []
        for read_folder in read_folders:
            read_folder_name = os.path.split(read_folder)[-1]
            for file_name in sorted(os.listdir(read_folder)):
                if file_name in self.hit_block_list:
                    continue

                if 'incomplete' in file_name:
                    num_incomplete_convos += 1
                    continue
                else:
                    num_complete_convos += 1

                # Read in file
                with open(os.path.join(read_folder, file_name), 'rb') as f:
                    data = json.load(f)

                # Only include the first max_convos_per_worker conversations from a
                # worker to avoid biasing
                worker_id = data['workers'][0]
                assignment_id = data['assignment_ids'][0]
                if worker_id in worker_conversation_counts:
                    conversations_so_far = worker_conversation_counts[
                        worker_id]
                else:
                    conversations_so_far = 0
                worker_conversation_counts[
                    worker_id] = conversations_so_far + 1
                if (self.max_convos_per_worker != -1 and
                        conversations_so_far >= self.max_convos_per_worker):
                    print(
                        f'Had {conversations_so_far} conversation(s) already from this worker {worker_id}. Skipping {assignment_id}.'
                    )
                    continue

                # Check if need to block the turker
                word_counts = [
                    len(d['text'].split(' ')) for d in data['dialog']
                    if d['agent_idx'] == 0
                ]
                utterances = [
                    d['text'] for d in data['dialog'] if d['agent_idx'] == 0
                ]
                if np.average(word_counts) < self.min_word_count:
                    bad_conversations.append(data)
                    print(
                        f'Bad complete conversation, words from human: {utterances}. Skipping.'
                    )
                    continue

                if self.use_problem_buckets:
                    if not all(bucket in data['dialog'][1]['problem_data']
                               for bucket in self.problem_buckets):
                        raise ValueError(
                            'Bucket(s) are missing from the problem data!')

                model_nickname = data['task_description']['model_nickname']
                if model_nickname not in stat_counts:
                    stat_counts[model_nickname] = {}
                if model_nickname in complete_convos_per_model:
                    complete_convos_per_model[model_nickname] += 1
                else:
                    complete_convos_per_model[model_nickname] = 1

                # Extract non-message info
                info_dict = {
                    'read_folder_name': read_folder_name,
                    'file_name': file_name,
                    'worker': worker_id,
                    'model_nickname': model_nickname,
                    'bad_workers': ','.join(data['bad_workers']),
                    'hit_id': data['hit_ids'][0],
                    'assignment_id': assignment_id,
                    'is_incomplete': 'incomplete' in file_name,
                    'context_dataset': data['context_dataset'],
                    'additional_context': data['additional_context'],
                }

                # Check that the conversation consists of pairs of comments between
                # agents 0 and 1, with 0 speaking first
                assert all([
                    utterance_data['agent_idx'] == utterance_idx % 2 for
                    utterance_idx, utterance_data in enumerate(data['dialog'])
                ])

                # Determine whether the HIT contains unacceptable messages.
                # (We do this for every HIT, even if acceptability violation info
                # was already saved, because the violation criteria may have
                # changed since the HIT was collected.)
                messages_0 = [
                    utt for utt in data['dialog'] if utt['agent_idx'] == 0
                ]
                messages_1 = [
                    utt for utt in data['dialog'] if utt['agent_idx'] == 1
                ]
                assert len(messages_0) + len(messages_1) == len(data['dialog'])

                # Check the human utterances for safety
                utterances_0 = [m['text'] for m in messages_0]
                info_dict[
                    'acceptability_violations_0'] = self.acceptability_checker.check_messages(
                        messages=utterances_0,
                        is_worker_0=True,
                        violation_types=self.acceptability_checker.
                        ALL_VIOLATION_TYPES,
                    )

                # Compile personas and previous utterances
                df = pd.DataFrame(
                    [],
                    columns=[
                        'folder',
                        'worker_id',
                        'hit_id',
                        'model_nickname',
                        'conversation_idx',
                        'turn_idx',
                        'agent_idx',
                        'text',
                    ] + self.problem_buckets,
                )
                text_parts = []
                if data['personas'] is not None and len(data['personas']) > 0:
                    text_parts += [
                        'your persona: ' + data['personas'][1][0],
                        'your persona: ' + data['personas'][1][1],
                    ]
                if (data['additional_context'] is not None
                        and len(data['additional_context']) > 0):
                    text_parts.append(data['additional_context'])
                df = df.append(
                    {
                        'folder': info_dict['read_folder_name'],
                        'worker_id': info_dict['worker'],
                        'hit_id': info_dict['hit_id'],
                        'model_nickname': model_nickname,
                        'conversation_idx': conversation_idx,
                        'turn_idx': -1,
                        'agent_idx': 1,
                        'text': '\n'.join(text_parts),
                        **{bucket: ''
                           for bucket in self.problem_buckets},
                    },
                    ignore_index=True,
                )

                total_utterances += len(
                    [d for d in data["dialog"] if d["agent_idx"] == 1])
                if len(data['dialog']) > 20:
                    print(
                        f'Got long dialogue of {len(data["dialog"])} utterances, hit id: {info_dict["hit_id"]}, model_nickname: {model_nickname}.'
                    )

                if self.use_problem_buckets:
                    dialog_has_problems = False
                for utterance_idx, utt in enumerate(data['dialog']):

                    d = {
                        'folder': info_dict['read_folder_name'],
                        'worker_id': info_dict['worker'],
                        'hit_id': info_dict['hit_id'],
                        'model_nickname': model_nickname,
                        'conversation_idx': conversation_idx,
                        'turn_idx': utterance_idx,
                        'agent_idx': utt['agent_idx'],
                        'text': utt['text'],
                        **{bucket: ''
                           for bucket in self.problem_buckets},
                    }

                    if utt['agent_idx'] == 1:

                        d['final_rating'] = utt.get('final_rating')

                        if self.use_problem_buckets:
                            if 'problem_data' not in utt:
                                for bucket in self.problem_buckets:
                                    d[bucket] = 'MALFORMED'
                                print(
                                    f'Warning got MALFORMED utterance problem data inside complete convo: {utt}. Skipping.'
                                )
                                continue
                            else:
                                for bucket in self.regular_buckets + [
                                        'none_all_good'
                                ]:
                                    d[bucket] = utt['problem_data'][bucket]
                            for k in self.regular_buckets + ['none_all_good']:
                                if k not in stat_counts[model_nickname]:
                                    stat_counts[model_nickname][k] = 0
                                stat_counts[model_nickname][k] += d[k]
                                if k != 'none_all_good' and d[k]:
                                    dialog_has_problems = True

                        if 'total' not in stat_counts[model_nickname]:
                            stat_counts[model_nickname]['total'] = 0
                        if d['agent_idx'] == 1:
                            stat_counts[model_nickname]['total'] += 1
                        if d['final_rating'] is not None:
                            # Only one the last utterance (agent idx == 1)
                            if 'count_ratings' not in stat_counts[
                                    model_nickname]:
                                stat_counts[model_nickname][
                                    'count_ratings'] = 0
                            stat_counts[model_nickname]['count_ratings'] += 1
                            if 'ratings' not in stat_counts[model_nickname]:
                                stat_counts[model_nickname]['ratings'] = []
                            stat_counts[model_nickname]['ratings'].append(
                                int(d['final_rating']))

                    else:

                        # Counting some aspects of the human's utterances
                        if 'human_utterance_count' not in stat_counts[
                                model_nickname]:
                            stat_counts[model_nickname][
                                'human_utterance_count'] = 0
                        stat_counts[model_nickname][
                            'human_utterance_count'] += 1

                        if 'human_word_count' not in stat_counts[
                                model_nickname]:
                            stat_counts[model_nickname]['human_word_count'] = 0
                        stat_counts[model_nickname]['human_word_count'] += len(
                            d['text'].strip().split(' '))

                        if 'human_question_count' not in stat_counts[
                                model_nickname]:
                            stat_counts[model_nickname][
                                'human_question_count'] = 0
                        stat_counts[model_nickname][
                            'human_question_count'] += d['text'].count('?')

                    d = self._add_additional_per_turn_stats(d=d, utt=utt)

                    df = df.append(d, ignore_index=True)

                if info_dict['worker'] not in worker_stats:
                    worker_stats[info_dict['worker']] = {'conversations': 0}
                    if self.use_problem_buckets:
                        worker_stats[info_dict['worker']]['problems_found'] = 0
                worker_stats[info_dict['worker']]['conversations'] += 1

                if self.use_problem_buckets:
                    # Count the number of problems the worker got
                    is_problem = ~df['none_all_good'].replace('', True)
                    # Only want to count bot utterances but human ones, while included,
                    # won't be False
                    count = is_problem.sum()
                    worker_stats[
                        info_dict['worker']]['problems_found'] += count

                # Logic for calculating percent of conversations that are clean
                if 'count_convos' not in stat_counts[model_nickname]:
                    stat_counts[model_nickname]['count_convos'] = 0
                stat_counts[model_nickname]['count_convos'] += 1

                if self.use_problem_buckets and not dialog_has_problems:
                    if 'convo_clean' not in stat_counts[model_nickname]:
                        stat_counts[model_nickname]['convo_clean'] = 0
                    stat_counts[model_nickname]['convo_clean'] += 1

                # Adding the full conversation to the list of conversations
                conversation_dfs.append(df)
                conversation_idx += 1

        for m, conversation_count in complete_convos_per_model.items():
            print(
                f'Got {conversation_count} complete conversation(s) for model: {m}'
            )

        print(f'{num_complete_convos:d} complete conversation(s) collected.')
        print(f'{len(bad_conversations):d} bad conversation(s).')
        num_approved_convos = num_complete_convos - len(bad_conversations)
        print(f'{num_approved_convos:d} approved conversation(s).')
        print(
            f'({num_incomplete_convos:d} incomplete conversation(s) collected.)'
        )
        for model_nickname, model_stats_dict in stat_counts.items():
            print(f'---{model_nickname}---')
            for p, v in model_stats_dict.items():
                if p == 'count_ratings':
                    continue
                if p == 'ratings':
                    print(
                        f'Average Engaging-ness Rating: {np.average(model_stats_dict["ratings"])} ({model_stats_dict["count_ratings"]} ratings)'
                    )
                    continue
                if p == 'human_word_count' or p == 'human_question_count':
                    print(
                        f'{p}: {v} ({v/model_stats_dict["human_utterance_count"]:.3})'
                    )
                elif p == 'human_utterance_count':
                    print(f'{p}: {v}')
                elif p == 'count_convos':
                    print(f'{p}: {v}')
                elif self.use_problem_buckets and p == 'convo_clean':
                    print(
                        f'{p}: {v} ({v/model_stats_dict["count_convos"]:.2%})')
                else:
                    print(f'{p}: {v} ({v/model_stats_dict["total"]:.2%})')

        print('Printing worker IDs not already in block list to add...')
        for b in bad_conversations:
            worker_id = b['workers'][0]
            if worker_id not in self.worker_block_list:
                print(f"""'{worker_id}',""")
        print('Done printing bad workers.')

        print('Worker stats:')
        worker_columns = ['worker_id', 'conversations']
        if self.use_problem_buckets:
            worker_columns += ['problems_found', 'avg_problems_per_convo']
        worker_df = pd.DataFrame([], columns=worker_columns)

        for worker_id, data in worker_stats.items():
            print(worker_id)

            stat = {
                'worker_id': worker_id,
                'conversations': data['conversations']
            }
            if self.use_problem_buckets:
                avg_problems_per_convo = data['problems_found'] / data[
                    'conversations']
                stat.update({
                    'problems_found': data['problems_found'],
                    'avg_problems_per_convo': avg_problems_per_convo,
                })
            worker_df = worker_df.append(stat, ignore_index=True)
        if self.use_problem_buckets:
            worker_df = worker_df.sort_values('avg_problems_per_convo',
                                              ascending=0)
        worker_df.to_csv(worker_results_file, index=False)
        print(worker_df)
        print(f'Wrote worker statistical results to: {worker_results_file}')

        # Save full results
        all_conversations_df = pd.DataFrame()
        for df in conversation_dfs:
            all_conversations_df = all_conversations_df.append(df)
        print(f'\nWorker conversation counts: {worker_conversation_counts}')

        return all_conversations_df

    def _add_additional_per_turn_stats(self, d: dict, utt: dict) -> dict:
        """
        Add in additional statistics on the level of each conversation turn.

        Useful for subclasses.
        """
        _ = utt  # utt is ignored in this passthrough method
        return d
Ejemplo n.º 17
0
class PerTurnEvalResultsCompiler(AbstractResultsCompiler):
    """
    Compile and save results of human+model chats.

    Results will be saved on the level of specific conversations, as well as aggregated
    up the level of each worker as a whole.
    """

    # TODO: deduplicate setup_args from ModelChatResultsCompiler
    @classmethod
    def setup_args(cls):
        parser = super().setup_args()
        parser.add_argument(
            '--worker-block-list',
            type=str,
            default='',
            help='Comma-separated list of all workers to block',
        )
        return parser

    def __init__(self, opt: Dict[str, Any]):
        # TODO: deduplicate init from ModelChatResultsCompiler

        super().__init__(opt)

        # Input args
        os.makedirs(self.output_folder, exist_ok=True)
        self.worker_block_list = opt['worker_block_list'].split(',')

        # Save paths
        self.worker_results_path = os.path.join(
            self.output_folder, 'worker_results.csv'
        )
        self.unacceptable_worker_ids_path = os.path.join(
            self.output_folder, 'unacceptable_worker_ids.txt'
        )
        self.win_rate_by_date_path = os.path.join(
            self.output_folder, 'win_rates_by_date.csv'
        )
        self.stat_mean_length_by_date_path = os.path.join(
            self.output_folder, 'stat_mean_length_by_date.csv'
        )
        self.completion_time_by_model_pair_path = os.path.join(
            self.output_folder, 'mean_completion_times.csv'
        )

        self.acceptability_checker = AcceptabilityChecker()

        # Set fields that should be empty strings if the relevant information is not
        # present
        blank_field_columns = [
            'human_text',
            'human_choice',
            'human_justification',
            'accepted_bot_text',
            'not_accepted_bot_text',
        ]
        self.blank_fields = {field: '' for field in blank_field_columns}

        # Results attributes
        self.stat_counts = {}
        self.mean_completion_time = None
        # Useful for subclasses, to compare with other eval techniques

    def get_results_path_base(self) -> str:
        return os.path.join(self.output_folder, 'results')

    def compile_results(self) -> pd.DataFrame:

        # Load task data
        logging.info('Retrieving task data from Mephisto.')
        task_units_data = self.get_task_data()
        logging.info(f'Data for {len(task_units_data)} units loaded successfully.')

        # Read in each file
        num_convos_with_no_save_data = 0
        num_wrong_status_convos = 0
        num_complete_convos = 0
        worker_stats = {}
        worker_conversation_counts = {}
        total_utterances = 0

        unacceptable_task_units = []
        unacceptable_worker_ids = []
        conversation_idx = 0
        conversation_dfs = []

        for task_unit in task_units_data:

            worker_id = task_unit['worker_id']
            assignment_id = task_unit['assignment_id']

            # Determining whether the task unit should be skipped
            # Extract out custom data
            if task_unit['data']['save_data'] is None:
                logging.info('Found a task unit with no save data! Skipping.')
                num_convos_with_no_save_data += 1
                continue
            elif task_unit['status'] not in ['completed', 'approved']:
                logging.info(
                    f'Found a HIT with the status "{task_unit["status"]}"!.'
                    f'Skipping.'
                )
                num_wrong_status_convos += 1
                continue
            else:
                num_complete_convos += 1

            # Check if the Turker is on the list of blocked Turkers
            if worker_id in self.worker_block_list:
                logging.info(
                    f'Found a HIT with the worker {worker_id}, on the blocklist. '
                    f'Skipping.'
                )
                continue

            if worker_id in worker_conversation_counts:
                conversations_so_far = worker_conversation_counts[worker_id]
            else:
                conversations_so_far = 0
            worker_conversation_counts[worker_id] = conversations_so_far + 1

            data = task_unit['data']['save_data']['custom_data']

            # Extract out information about this conversation

            model_1_nickname = data['task_description']['model_1_nickname']
            model_2_nickname = data['task_description']['model_2_nickname']

            # Since we have two models, we use the format model_1_name:model_2_name
            model_pair_nickname = f"{model_1_nickname}:{model_2_nickname}"
            if model_pair_nickname not in self.stat_counts:
                self.stat_counts[model_pair_nickname] = {
                    'per_turn': defaultdict(
                        lambda: {model_1_nickname: 0, model_2_nickname: 0}
                    )
                }

            # Extract non-message info
            mturk_worker_id_match = re.fullmatch(
                r'--NOT-MTURK-AGENT-(.*)', data['workers'][0]
            )
            # TODO: figure out why --NOT-MTURK-AGENT appears at the beginning of this
            #  field, and remove it; then, remove this re.fullmatch() call
            if mturk_worker_id_match is not None:
                mturk_worker_id = mturk_worker_id_match.group(1)
            else:
                mturk_worker_id = None
            task_start = datetime.utcfromtimestamp(task_unit['task_start'])
            task_end = datetime.utcfromtimestamp(task_unit['task_end'])
            single_convo_info_dict = {
                'worker': worker_id,
                'mturk_worker_id': mturk_worker_id,
                'model_pair_nickname': model_pair_nickname,
                'bad_workers': ','.join(data['bad_workers']),
                'hit_id': data['hit_ids'][0],
                'assignment_id': assignment_id,
                'context_dataset': data['context_dataset'],
                'additional_context': data['additional_context'],
                'date': task_start.strftime('%Y-%m-%d'),
                'task_start': task_start,
                'task_end': task_end,
                'completion_time': (task_end - task_start).total_seconds(),
            }
            # TODO: 'task_start' and 'task_end' assume that the original datetime floats
            #  are stored in UTC. Check this!

            # Check that the conversation consists of pairs of comments between
            # agents 0 and 1, with 0 speaking first
            assert all(
                [
                    utterance_data['agent_idx'] == utterance_idx % 2
                    for utterance_idx, utterance_data in enumerate(data['dialog'])
                ]
            )
            messages_0 = [utt for utt in data['dialog'] if utt['agent_idx'] == 0]
            messages_1 = [utt for utt in data['dialog'] if utt['agent_idx'] == 1]
            assert len(messages_0) + len(messages_1) == len(data['dialog'])

            # Determine whether the HIT contains unacceptable messages.
            # (We do this for every HIT, even if acceptability violation info
            # was already saved, because the violation criteria may have
            # changed since the HIT was collected.)
            utterances_0 = [m['text'] for m in messages_0]
            assert utterances_0[0] == 'Hi!', (
                'This script assumes that the first human message is "Hi!", which is '
                'set by default and cannot be changed by the crowdsourcing worker.'
            )
            acceptability_violations = self.acceptability_checker.check_messages(
                messages=utterances_0[1:],  # Don't use the initial "Hi!"
                is_worker_0=True,
                violation_types=self.acceptability_checker.ALL_VIOLATION_TYPES,
            )
            if acceptability_violations != '':
                logging.info(
                    f'Conversation fails acceptability checks with a violation of '
                    f'"{acceptability_violations}", given the following utterances: '
                    f'{utterances_0[1:]}. Skipping.'
                )
                unacceptable_task_units.append(task_unit)
                assert (
                    mturk_worker_id is not None
                ), "MTurk worker ID cannot be determined for this unacceptable conversation!"
                unacceptable_worker_ids.append(mturk_worker_id)
                continue
            single_convo_info_dict[
                'acceptability_violations'
            ] = acceptability_violations

            # Identify information to put in each line of the output DataFrame
            info_for_each_turn = {
                'worker_id': single_convo_info_dict['worker'],
                'mturk_worker_id': single_convo_info_dict['mturk_worker_id'],
                'hit_id': single_convo_info_dict['hit_id'],
                'model_pair_nickname': model_pair_nickname,
                'conversation_idx': conversation_idx,
                'date': single_convo_info_dict['date'],
                'completion_time': single_convo_info_dict['completion_time'],
            }

            single_turn_dicts = []

            # Compile personas and previous utterances
            text_parts = []
            if data['personas'] is not None and len(data['personas']) > 0:
                assert len(data['personas']) == 2
                text_parts += [
                    'human persona: ' + ' '.join(data['personas'][0]),
                    'bot persona: ' + ' '.join(data['personas'][1]),
                ]
            if (
                data['additional_context'] is not None
                and len(data['additional_context']) > 0
            ):
                text_parts.append(data['additional_context'])
            single_turn_dicts.append(
                {
                    **info_for_each_turn,
                    'turn_idx': -1,
                    'agent_idx': -1,
                    'context': '\n'.join(text_parts),
                    **self.blank_fields,
                }
            )

            total_utterances += len([d for d in data["dialog"] if d["agent_idx"] == 1])
            if len(data['dialog']) > 20:
                logging.info(
                    f'Got long dialogue of {len(data["dialog"])} utterances, hit id: '
                    f'{single_convo_info_dict["hit_id"]}, model_pair_nickname: '
                    f'{model_pair_nickname}.'
                )

            # # Loop over conversation turns

            for utterance_idx, utt in enumerate(data['dialog']):

                this_turn_dict = {
                    **info_for_each_turn,
                    'turn_idx': utterance_idx,
                    'agent_idx': utt['agent_idx'],
                    **self.blank_fields,
                }

                if utt['agent_idx'] == 1:

                    # This is a turn in which the bots have responded

                    human_turn_idx = int(utterance_idx / 2) + 1
                    # Turns are 1-indexed

                    # TODO: maybe clean up some of this logic
                    this_turn_dict['human_choice'] = utt['human_choice']
                    this_turn_dict['human_justification'] = utt['human_justification']
                    this_turn_dict['accepted_bot_text'] = (
                        utt['accepted_bot_data']['text']
                        .replace('\n', '__newline__')
                        .replace('\r', '__CR__')
                    )
                    this_turn_dict['not_accepted_bot_text'] = (
                        utt['not_accepted_bot_data']['text']
                        .replace('\n', '__newline__')
                        .replace('\r', '__CR__')
                    )

                    if 'total' not in self.stat_counts[model_pair_nickname]:
                        self.stat_counts[model_pair_nickname]['total'] = 0
                    if this_turn_dict['agent_idx'] == 1:
                        self.stat_counts[model_pair_nickname]['total'] += 1

                    # Calculating overall human choice statistics
                    if model_1_nickname not in self.stat_counts[model_pair_nickname]:
                        self.stat_counts[model_pair_nickname][model_1_nickname] = 0
                    if model_2_nickname not in self.stat_counts[model_pair_nickname]:
                        self.stat_counts[model_pair_nickname][model_2_nickname] = 0

                    # Calculating per-turn human choice statistics
                    if utt['human_choice'] == model_1_nickname:
                        self.stat_counts[model_pair_nickname][model_1_nickname] += 1
                        self.stat_counts[model_pair_nickname]['per_turn'][
                            human_turn_idx
                        ][model_1_nickname] += 1
                    elif utt['human_choice'] == model_2_nickname:
                        self.stat_counts[model_pair_nickname][model_2_nickname] += 1
                        self.stat_counts[model_pair_nickname]['per_turn'][
                            human_turn_idx
                        ][model_2_nickname] += 1
                    else:
                        raise Exception(
                            'Something wrong has occurred: human choice is not equal '
                            'to either of the two models!'
                        )

                else:

                    # This is a turn in which the human has responded

                    this_turn_dict['human_text'] = utt['text']

                    # Counting some aspects of the human's utterances
                    if (
                        'human_utterance_count'
                        not in self.stat_counts[model_pair_nickname]
                    ):
                        self.stat_counts[model_pair_nickname][
                            'human_utterance_count'
                        ] = 0
                    self.stat_counts[model_pair_nickname]['human_utterance_count'] += 1

                    if 'human_word_count' not in self.stat_counts[model_pair_nickname]:
                        self.stat_counts[model_pair_nickname]['human_word_count'] = 0
                    self.stat_counts[model_pair_nickname]['human_word_count'] += len(
                        this_turn_dict['human_text'].strip().split(' ')
                    )

                    if (
                        'human_question_count'
                        not in self.stat_counts[model_pair_nickname]
                    ):
                        self.stat_counts[model_pair_nickname][
                            'human_question_count'
                        ] = 0
                    self.stat_counts[model_pair_nickname][
                        'human_question_count'
                    ] += this_turn_dict['human_text'].count('?')

                single_turn_dicts.append(this_turn_dict)

            # Finish up collecting per-conversation stats

            if single_convo_info_dict['worker'] not in worker_stats:
                worker_stats[single_convo_info_dict['worker']] = {'conversations': 0}
            worker_stats[single_convo_info_dict['worker']]['conversations'] += 1

            # Logic for calculating percent of conversations that are clean
            if 'acceptable_convos' not in self.stat_counts[model_pair_nickname]:
                self.stat_counts[model_pair_nickname]['acceptable_convos'] = 0
            self.stat_counts[model_pair_nickname]['acceptable_convos'] += 1

            # Adding the full conversation to the list of conversations
            single_convo_df = pd.DataFrame(single_turn_dicts)
            conversation_dfs.append(single_convo_df)
            conversation_idx += 1

        # Print results
        # TODO: all of this would be cleaner if saved as CSVs, so we don't have to
        #  re-run to get the results

        logging.info(
            f'{num_convos_with_no_save_data:d} conversations found with no save data.'
        )
        logging.info(
            f'{num_wrong_status_convos:d} conversations found with the wrong status.'
        )
        logging.info(f'{num_complete_convos:d} complete conversations found:')
        logging.info(f'\t{len(unacceptable_task_units):d} unacceptable conversations.')
        logging.info(f'\t{len(conversation_dfs):d} acceptable conversations.')
        for model_pair_nickname, model_stats_dict in self.stat_counts.items():
            logging.info(f'---{model_pair_nickname}---')
            model_1_nickname = model_pair_nickname.split(":")[0]
            model_2_nickname = model_pair_nickname.split(":")[1]
            for p, v in model_stats_dict.items():
                if p == 'per_turn':
                    for human_turn_idx in model_stats_dict['per_turn']:
                        per_turn_model_1 = model_stats_dict['per_turn'][human_turn_idx][
                            model_1_nickname
                        ]
                        per_turn_model_2 = model_stats_dict['per_turn'][human_turn_idx][
                            model_2_nickname
                        ]
                        per_turn_model_total = per_turn_model_1 + per_turn_model_2
                        logging.info(
                            f"Turn {human_turn_idx}, {model_1_nickname}: {per_turn_model_1} "
                            f"({per_turn_model_1/per_turn_model_total:.2%})"
                            f", {model_2_nickname}: {per_turn_model_2} "
                            f"({per_turn_model_2/per_turn_model_total:.2%})"
                        )
                    continue
                if p == 'human_word_count' or p == 'human_question_count':
                    logging.info(
                        f'{p}: {v} ({v/model_stats_dict["human_utterance_count"]:.3})'
                    )
                elif p == 'human_utterance_count':
                    logging.info(f'{p}: {v}')
                elif p == 'acceptable_convos':
                    logging.info(f'{p}: {v}')
                else:
                    logging.info(f'{p}: {v} ({v/model_stats_dict["total"]:.2%})')

        logging.info('Printing worker IDs not already in block list to add...')
        for b in unacceptable_task_units:
            worker_id = b['worker_id']
            if worker_id not in self.worker_block_list:
                logging.info(f"""'{worker_id}',""")
        logging.info('Done printing bad workers.')

        logging.info(f'\nWorker conversation counts: {worker_conversation_counts}')

        # Compile full results

        all_conversations_df = pd.DataFrame()
        for single_convo_df in conversation_dfs:
            all_conversations_df = all_conversations_df.append(single_convo_df)
        for field in self.blank_fields.keys():
            assert all_conversations_df[field].isna().sum() == 0, (
                f'Some rows of the "{field}" column have NaNs in them, making them '
                f'hard to calculate statistics on!'
            )

        # Save analysis files

        logging.info(
            f'Saving worker statistical results to {self.worker_results_path}.'
        )
        worker_columns = ['worker_id', 'conversations']
        worker_df = pd.DataFrame([], columns=worker_columns)
        for worker_id, data in worker_stats.items():
            stat = {'worker_id': worker_id, 'conversations': data['conversations']}
            worker_df = worker_df.append(stat, ignore_index=True)
        worker_df.to_csv(self.worker_results_path, index=False)

        logging.info(
            f'Saving MTurk IDs of workers with unacceptable conversations to '
            f'{self.unacceptable_worker_ids_path}.'
        )
        with open(self.unacceptable_worker_ids_path, 'w') as f:
            for worker_id in unacceptable_worker_ids:
                f.write(worker_id + '\n')

        logging.info(f'Saving win rates cut by date to {self.win_rate_by_date_path}.')
        pivoted_win_rate_df = (
            all_conversations_df[lambda df: df['human_choice'].notna()]
            .assign(count=1)
            .groupby(['model_pair_nickname', 'date', 'human_choice'])
            .agg({'count': 'sum'})
            .reset_index()
            .pivot(
                index=['model_pair_nickname', 'date'],
                columns='human_choice',
                values='count',
            )
        )
        model_names = pivoted_win_rate_df.columns
        pivoted_win_rate_df.loc[:, 'total_count'] = pivoted_win_rate_df[
            model_names
        ].sum(axis=1)
        for model_name in model_names:
            pivoted_win_rate_df.loc[:, f'frac_{model_name}'] = (
                pivoted_win_rate_df[model_name] / pivoted_win_rate_df['total_count']
            )
        pivoted_win_rate_df.to_csv(self.win_rate_by_date_path)

        logging.info(
            f'Saving mean word count of different stats, cut by date, to '
            f'{self.stat_mean_length_by_date_path}.'
        )
        stats_to_calculate_mean_length_of = ['human_text', 'human_justification']
        assert (
            len(stats_to_calculate_mean_length_of) == 2
        ), 'This section of the code won\'t work with more than 2 stats!'
        stat_mean_length_dfs = []
        for stat in stats_to_calculate_mean_length_of:
            stat_mean_length_dfs.append(
                all_conversations_df[lambda df: df[stat] != '']
                .assign(word_count=lambda df: df[stat].str.split().str.len())
                .groupby(['model_pair_nickname', 'date'])['word_count']
                .mean()
                .to_frame(stat)
            )
        joined_stat_mean_length_df = stat_mean_length_dfs[0].join(
            stat_mean_length_dfs[1]
        )
        joined_stat_mean_length_df.to_csv(self.stat_mean_length_by_date_path)

        logging.info(
            f'Saving mean completion time stats to '
            f'{self.completion_time_by_model_pair_path}.'
        )
        completion_time_by_convo_df = all_conversations_df[
            ['model_pair_nickname', 'conversation_idx', 'completion_time']
        ].drop_duplicates()
        for model_pair_nickname in completion_time_by_convo_df[
            'model_pair_nickname'
        ].unique():
            assert (
                completion_time_by_convo_df[
                    lambda df: df['model_pair_nickname'] == model_pair_nickname
                ].index.size
                == self.stat_counts[model_pair_nickname]['acceptable_convos']
            ), (
                f"The count of convos for the model pair {model_pair_nickname} is "
                f"inconsistent!"
            )
        completion_time_by_model_pair_df = (
            completion_time_by_convo_df.groupby('model_pair_nickname')[
                'completion_time'
            ]
            .mean()
            .to_frame('mean_completion_time')
        )
        completion_time_by_model_pair_df.to_csv(self.completion_time_by_model_pair_path)

        return all_conversations_df
Ejemplo n.º 18
0
class TurnAnnotationsChatWorld(MultiAgentDialogWorld):
    def __init__(
        self,
        opt,
        agents=None,
        shared=None,
        num_turns=6,
        tag=None,
        max_resp_time=120,
        agent_timeout_shutdown=120,
        context_info: Optional[dict] = None,
    ):
        # 6 turns for a single side (so 12 total), and really it appears to be
        # 14 total b/c of the "Hi!" and first bot utterance

        self.agents = agents
        self.task_turn_idx = 0
        self.num_turns = num_turns

        self.dialog = []
        self.tag = tag
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        if context_info is not None:
            self.context_info = context_info
            self.personas = [
                self.context_info['persona_1_strings'],
                self.context_info['persona_2_strings'],
            ]
        else:
            self.context_info = {}
            self.personas = None
        self.check_acceptability = opt['check_acceptability']
        self.acceptability_checker = AcceptabilityChecker()

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        self.agent_timeout_shutdown = agent_timeout_shutdown
        print(
            f'Creating {self.__class__.__name__} for tag {tag} with {num_turns} turns.'
        )
        super().__init__(opt, agents, shared)

    def __add_problem_data_to_utterance(self, p):
        # Human has just responded. Problem data received
        # now will be from bot's prior utterance (turn_idx
        # is also present to be safe that data matches)
        is_fake_utterance = ('fake_start' in self.dialog[p['turn_idx']]
                             and self.dialog[p['turn_idx']]['fake_start'])
        annotations = []
        for a in self.opt['annotations_config']:
            annotations.append(p[a['value']])
        assert any(annotations) or is_fake_utterance
        self.dialog[p['turn_idx']]['problem_data'] = p

    def parley(self):
        control_msg = {
            'episode_done': False,
            'config': {
                'min_num_turns': self.num_turns,
                'annotations_config': self.opt['annotations_config'],
            },
            'left_pane_text': self.opt['left_pane_text'],
        }

        print(
            f'{self.__class__.__name__}:{self.tag}: is at turn {self.task_turn_idx}, with {self.num_turns} pairs of turns needed...'
        )

        if self.task_turn_idx == 0:

            for agent_idx, agent in enumerate(self.agents):
                if agent_idx == 1 and self.opt['include_persona']:
                    print('Including persona for the bot.')
                    # The Bot agent
                    # We add the personas and 1/3 of the time WoW topic as the
                    # first utterance in the history.
                    # Previously for BST task, we also had a big first utterance
                    # that gave instructions. Removing that for this task.
                    persona_strings = [
                        s.strip() for s in self.personas[agent_idx]
                    ]
                    persona_utterance = self._get_persona_utterance(
                        persona_strings=persona_strings,
                        context_dataset=self.context_info['context_dataset'],
                        additional_context=self.
                        context_info['additional_context'],
                        is_bot=(agent_idx == 1),
                    )
                    message = control_msg.copy()
                    message['text'] = persona_utterance
                    agent.observe(validate(message), increment_turn=False)
                    # The bot seeing its persona does not count as a "turn"
                    if agent_idx == 0:
                        time.sleep(3)

            if self.opt['conversation_start_mode'] == 'bst':

                print('[Displaying first utterances as per BST task.]')
                # Display the previous two utterances
                human_first_msg = {
                    'left_pane_text': self.opt['left_pane_text'],
                    'episode_done': False,
                    'id': self.agents[0].id,
                    'text': self.context_info['person1_seed_utterance'],
                    'fake_start': True,
                    'agent_idx': 0,
                }
                for k, v in control_msg.items():
                    human_first_msg[k] = v
                bot_first_msg = {
                    'episode_done': False,
                    'id': self.agents[1].id,
                    'text': self.context_info['person2_seed_utterance'],
                    'fake_start': True,
                    'agent_idx': 1,
                }
                print(
                    f'human_first_msg: {human_first_msg}, bot_first_msg: {bot_first_msg}'
                )

                self.dialog.append(human_first_msg)
                self.dialog.append(bot_first_msg)

                for agent in self.agents:
                    agent.observe(validate(human_first_msg))
                    agent.observe(validate(bot_first_msg))

            elif self.opt['conversation_start_mode'] == 'hi':

                print('[Displaying "Hi!" only as per Meena task.]')
                human_first_msg = {
                    'left_pane_text': self.opt['left_pane_text'],
                    'episode_done': False,
                    'id': self.agents[0].id,
                    'text': 'Hi!',
                    'fake_start': True,
                    'agent_idx': 0,
                }
                for k, v in control_msg.items():
                    human_first_msg[k] = v

                self.dialog.append(human_first_msg)
                self.agents[0].observe(validate(human_first_msg))
                self.agents[1].observe(validate(human_first_msg))

                first_bot_act = self.agents[1].act()
                first_bot_act = Compatibility.maybe_fix_act(first_bot_act)

                self.agents[0].observe(validate(first_bot_act))

                bot_utterance_data = {
                    'agent_idx': 1,
                    # Get rid of annotations HTML from bot response
                    'text': first_bot_act['text'].split('<br>')[0],
                    'id': first_bot_act['id'],
                }
                self.dialog.append(bot_utterance_data)

            else:

                raise ValueError(
                    f"Conversation start mode {self.opt['conversation_start_mode']} "
                    f"not recognized!")

            self.task_turn_idx += 1
            return
        """Otherwise, we proceed accordingly"""
        print(
            f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: {self.task_turn_idx}'
        )
        acts = [None, None]
        for idx, agent in enumerate(self.agents):
            if not self.chat_done:
                acts[idx] = agent.act(timeout=self.max_resp_time)
                acts[idx] = Compatibility.maybe_fix_act(acts[idx])
                print(
                    f'Got act for agent idx {idx}, act was: {acts[idx]} and self.task_turn_idx: {self.task_turn_idx}.'
                )

            if self.check_timeout(acts[idx]):
                return

            if acts[idx]['episode_done']:
                self.chat_done = True
                for ag in self.agents:
                    # if agent disconnected
                    if ag != agent and ag.some_agent_disconnected:
                        if idx == 0:
                            # Human
                            message = control_msg.copy()
                            message['text'] = (
                                'The other worker unexpectedly diconnected. '
                                'Please click "Done with this HIT" button below to finish this HIT.'
                            )
                            message['episode_done'] = True
                            ag.observe(validate(message))
                        return
                # agent ends chat after exceeding minimum number of turns
                if self.task_turn_idx > self.num_turns:
                    for ag in self.agents:
                        if idx == 0:
                            print(
                                'One of you ended the chat utterance coming.')
                            message = control_msg.copy()
                            message['text'] = (
                                'One of you ended the chat. Thanks for your '
                                'time! Please click "Done with this HIT"'
                                'button below to finish this HIT.')
                            message['episode_done'] = True
                            ag.observe(validate(message))
                            # Human has just responded. Problem data received
                            # now will be from bot's prior utterance (turn_idx
                            # is a also present to be safe that data matches)
                            p = acts[idx]['problem_data_for_prior_message']
                            self.__add_problem_data_to_utterance(p)
                return

            else:
                utterance_data = {
                    'agent_idx': idx,
                    # Get rid of annotations HTML if it's the bot response
                    'text': acts[idx]['text'].split('<br>')[0],
                    'id': acts[idx]['id'] if 'id' in acts[idx] else
                    'NULL_ID',  # Person1 or Polyencoder
                }
                self.dialog.append(utterance_data)
                if idx == 0:
                    # Human has just responded. Problem data received
                    # now will be from bot's prior utterance (turn_idx
                    # is a also present to be safe that data matches)
                    p = acts[idx]['problem_data_for_prior_message']
                    self.__add_problem_data_to_utterance(p)

                for other_agent in self.agents:
                    if other_agent != agent:
                        other_agent.observe(validate(acts[idx]))

                print(
                    f'[agent {idx}] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: {self.dialog}'
                )
                self.task_turn_idx += 1

    def shutdown(self):
        global shutdown_agent

        def shutdown_agent(mturk_agent):
            mturk_agent.shutdown()

        Parallel(n_jobs=len(self.agents),
                 backend='threading')(delayed(shutdown_agent)(agent)
                                      for agent in self.agents)

    def episode_done(self):
        return self.chat_done

    def _get_persona_utterance(
        self,
        persona_strings: Optional[List[str]] = None,
        context_dataset: Optional[str] = None,
        additional_context: Optional[str] = None,
        is_bot: bool = False,
    ):
        if is_bot:
            # Pass back the original context
            persona_pieces = [
                f"your persona: {str_}" for str_ in persona_strings
            ]
            if context_dataset == 'wizard_of_wikipedia':
                additional_context_pieces = [additional_context]
            else:
                additional_context_pieces = []
            full_context = '\n'.join(persona_pieces +
                                     additional_context_pieces)
            print(f'FULL CONTEXT: {full_context}')
            return full_context
        else:
            if context_dataset == 'convai2':
                last_sentence = 'Pretend that the conversation has already begun.'
            elif context_dataset == 'empathetic_dialogues':
                last_sentence = (
                    f'Pretend that the conversation has already begun, and that you '
                    f'had been talking about the following situation: '
                    f'<b>"{additional_context}"</b>')
            elif context_dataset == 'wizard_of_wikipedia':
                last_sentence = (
                    f'Pretend that the conversation has already begun, and that you '
                    f'had been talking about <b>{additional_context}</b>.')
            else:
                raise ValueError('Context dataset unrecognized!')
            joined_personas = '\n'.join(persona_strings)
            return (
                f'\nSuccessfully matched with another user! Now let\'s get to know '
                f'each other through the chat. You need to finish at least '
                f'<b>{self.num_turns} chat turns</b>, and after that you can click the '
                f'"Done" button to end the chat.\n\n'
                f'<b>Your character description is:\n<span style="color:blue">{joined_personas}</span></b> '
                '\n\n<b>Remember that you can get to know each '
                'other as your characters, talk about any topic, or talk about a '
                'situation that might have happened to your character.</b>'
                '\n<b>Do not trivially copy the '
                'character descriptions into the message.</b><br><br>'
                f'{last_sentence}')

    def save_data(self):
        convo_finished = True
        bad_workers = []
        for ag in self.agents:
            if (ag.hit_is_abandoned or ag.hit_is_returned or ag.disconnected
                    or ag.hit_is_expired):
                bad_workers.append(ag.worker_id)
                convo_finished = False
                ag.not_approve = True

        if self.check_acceptability:
            human_texts = [
                message['text'] for message in self.dialog
                if message['agent_idx'] == 0
            ]
            violation_types = [
                'min_words', 'all_caps', 'exact_match', 'safety'
            ]
            if self.opt['conversation_start_mode'] == 'bst':
                # The BST mode starts the conversation with two previous utterances, so
                # there should be no new greeting
                violation_types.append('penalize_greetings')

            violations_agent_0 = self.acceptability_checker.check_messages(
                messages=human_texts,
                is_worker_0=False,
                violation_types=violation_types)
        else:
            violations_agent_0 = None

        time_string = time.strftime('%Y%m%d_%H%M%S')
        data_path = self.opt['save_folder']
        if convo_finished:
            filename = os.path.join(
                data_path,
                '{}_{}_{}.json'.format(time_string, np.random.randint(0, 1000),
                                       self.task_type),
            )
        else:
            filename = os.path.join(
                data_path,
                '{}_{}_{}_incomplete.json'.format(time_string,
                                                  np.random.randint(0, 1000),
                                                  self.task_type),
            )
        with open(os.path.join(filename), 'w+') as f_json:
            data = {
                'personas':
                self.personas,
                'context_dataset':
                self.context_info.get('context_dataset'),
                'person1_seed_utterance':
                self.context_info.get('person1_seed_utterance'),
                'person2_seed_utterance':
                self.context_info.get('person2_seed_utterance'),
                'additional_context':
                self.context_info.get('additional_context'),
                'dialog':
                self.dialog,
                'workers': [ag.worker_id for ag in self.agents],
                'bad_workers':
                bad_workers,
                'acceptability_violations': (violations_agent_0, ),
                'hit_ids': [ag.hit_id for ag in self.agents],
                'assignment_ids': [ag.assignment_id for ag in self.agents],
                'task_description': {
                    'annotations_config': self.opt['annotations_config'],
                    'model_nickname': self.agents[1].worker_id,
                    'model_file':
                    self.agents[1].model_agent.opt.get('model_file'),
                    'model_opt': self.agents[1].model_agent.opt,
                },
            }
            if self.check_acceptability:
                data['acceptability_violations'] = (violations_agent_0, )
                # Make a tuple for compatibility with a human/human conversation in
                # which we check both sides for acceptability
            data_str = json.dumps(data)
            f_json.write(data_str)
        print(
            f'{self.__class__.__name__}:{self.tag}: Data successfully saved at '
            f'{filename} for model: {self.agents[1].worker_id}.')
        if self.check_acceptability:
            print(f'Acceptability violations for agent 0: '
                  f'{violations_agent_0}')
            return self.agents[
                1].worker_id, violations_agent_0 != '', convo_finished
        else:
            return self.agents[1].worker_id, False, convo_finished

    def check_timeout(self, act):
        if act['text'] == '[TIMEOUT]' and act['episode_done']:
            control_msg = {'episode_done': True}
            control_msg['id'] = 'SYSTEM'
            control_msg[
                'text'] = 'HIT has timed out. Please click the "Done with this HIT" button below to exit this HIT. No rejections.'
            for ag in self.agents:
                if ag.id != act['id']:
                    if ag.id != AGENT_1:
                        ag.observe(validate(control_msg))
            self.chat_done = True
            return True
        else:
            return False

    def review_work(self):
        pass