Beispiel #1
0
    def _setup_data(self, opt):
        counts = {
            'partner': {
                gend_utils.UNKNOWN: 0,
                gend_utils.FEM: 0,
                gend_utils.MASC: 0
            },
            'self': {
                gend_utils.UNKNOWN: 0,
                gend_utils.FEM: 0,
                gend_utils.MASC: 0
            },
        }

        dt = opt['datatype'].split(':')[0]
        if dt == 'test':
            warn_once('No test set; switching to valid')
            dt = 'valid'

        # build data
        print('[ Building data ... ]')
        new_eps = []
        orig_teacher = OrigConvai2Teacher(opt)
        total_exs = orig_teacher.num_examples()
        num_exs = 0
        while num_exs < total_exs:
            current_episode = []
            episode_done = False

            while not episode_done:
                # TODO: eventually all teachers should return Messages, so
                # we should assert this
                action = Message(orig_teacher.act())
                current_episode.append(action)
                episode_done = action.get('episode_done', False)
                num_exs += 1

            # now we have the entire episode,... do something
            first_ex = current_episode[0]
            first_ex_text = []
            partner_persona = []
            your_persona = []
            for line in first_ex['text'].split('\n'):
                # NOTE: we flip "your" and "partner" here since we are taking the 'text'
                # field instead of the 'label'
                if 'partner\'s persona: ' in line:
                    your_persona.append(line.split('partner\'s persona: ')[1])
                elif 'your persona: ' in line:
                    partner_persona.append(line.split('your persona: ')[1])
                else:
                    first_ex_text.append(line)

            your, your_prob, partner, partner_prob = self.get_genders(
                your_persona, partner_persona)

            for i, ex in enumerate(current_episode):
                counts['self'][your] += 1
                counts['partner'][partner] += 1
                if i == 0:
                    text = '\n'.join(first_ex_text)
                else:
                    text = ex['text']
                new_ex = {
                    'text': text,
                    'episode_done': True,
                    'your_persona': '\n'.join(your_persona),
                    'partner_persona': '\n'.join(partner_persona),
                    'id': 'ConvAI2 Gender',
                }
                if not self.use_probably:
                    new_ex['partner_prob'] = partner_prob
                    new_ex['your_prob'] = your_prob

                if your is not None and self.labels_to_use != 'partner':
                    # Get the your task
                    labels = [f'SELF:{your}']
                    your_ex = deepcopy(new_ex)
                    your_ex['labels'] = labels
                    your_ex['class_type'] = 'self'
                    new_eps.append(your_ex)

                if partner is not None and self.labels_to_use != 'self':
                    # Get the partner task
                    labels = [f'PARTNER:{partner}']
                    partner_ex = deepcopy(new_ex)
                    partner_ex['labels'] = labels
                    partner_ex['class_type'] = 'partner'
                    new_eps.append(partner_ex)

        if self.labels_to_use == 'all' and self.add_unknown_classes:
            # load about data
            all_about_data = gend_utils.get_inferred_about_data(
                self.opt['task'], self.opt)
            sample_rate = self.opt['unknown_temp']
            if sample_rate < 1.0:
                to_samp = int(sample_rate * len(all_about_data))
                sampled = random.sample(all_about_data, to_samp)
                new_eps += sampled
            else:
                new_eps += all_about_data

        if self.is_train:
            random.shuffle(new_eps)

        self.data = new_eps
        print(f'Missing cnt: {self.missing_cnt} / {len(self.data) * 2}')
        for x in ['self', 'partner']:
            print(f'Totals for {x}:')
            subtot = sum(counts[x].values())
            for k, v in counts[x].items():
                print(f'\t{k}: {v} ({v / subtot})')
Beispiel #2
0
 def act(self):
     reply = Message()
     reply['id'] = self.getID()
     reply['text'] = self.searchInput
     return reply
Beispiel #3
0
 def setup_data(self, datafile):
     for _ in range(3):
         for j in range(1, 4):
             yield Message({'text': str(j), 'label': str(j * 2)}), j == 1
        def test_base_task(self):

            with testing_utils.tempdir() as tmpdir:

                # Paths
                expected_states_folder = os.path.join(
                    os.path.dirname(os.path.abspath(__file__)),
                    'expected_states')
                expected_chat_data_path = os.path.join(
                    expected_states_folder, 'final_chat_data__image_chat.json')
                expected_state_path = os.path.join(expected_states_folder,
                                                   'state__image_chat.json')
                parlai_data_folder = os.path.join(tmpdir, 'parlai_data')
                chat_data_folder = os.path.join(tmpdir, 'final_chat_data')
                sample_image_path = os.path.join(
                    os.path.dirname(os.path.abspath(__file__)),
                    'test_image_stack',
                    'sample_image.jpg',
                )
                image_context_path = os.path.join(tmpdir, 'image_contexts')
                stack_folder = os.path.join(tmpdir, 'image_stack')

                # Save image context: instead of downloading images, just save a pickle
                # file with all of the image act
                image_context = [{
                    'image_act':
                    Message({
                        'text':
                        'Obsessive',
                        'image_id':
                        '2923e28b6f588aff2d469ab2cccfac57',
                        'episode_done':
                        False,
                        'label_candidates': [
                            "I must learn that bird's name!",
                            "My, aren't you a pretty bird?",
                        ],
                        'image':
                        Image.open(sample_image_path),
                        'id':
                        'image_chat',
                        'eval_labels': ["I must learn that bird's name!"],
                    })
                }]
                with open(image_context_path, 'wb') as f:
                    pickle.dump(image_context, f)

                # Download the Transresnet Multimodal model
                download_transresnet(parlai_data_folder)

                # Set up the config and database
                num_convos = 1
                args = ModelImageChatBlueprintArgs()
                overrides = [
                    f'+mephisto.blueprint.{key}={val}'
                    for key, val in args.__dict__.items() if key in [
                        'evals_per_image_model_combo',
                        'max_resp_time',
                        'override_opt',
                        'random_seed',
                        'world_file',
                    ]
                ] + [
                    'mephisto.blueprint.annotations_config_path=""',
                    f'mephisto.blueprint.chat_data_folder={chat_data_folder}',
                    f'+mephisto.blueprint.image_context_path={image_context_path}',
                    '+mephisto.blueprint.left_pane_text_path=${task_dir}/task_config/left_pane_text.html',
                    '+mephisto.blueprint.max_concurrent_responses=1',
                    '+mephisto.blueprint.model_opt_path=${task_dir}/task_config/image_model_opts.yaml',
                    f'+mephisto.blueprint.num_conversations={num_convos:d}',
                    f'+mephisto.blueprint.stack_folder={stack_folder}',
                    '+mephisto.blueprint.task_description_file=${task_dir}/task_config/task_description.html',
                    'mephisto.blueprint.task_model_parallel=False',
                ]
                # TODO: remove all of these params once Hydra 1.1 is released with
                #  support for recursive defaults
                self._set_up_config(
                    blueprint_type=IMAGE_CHAT_BLUEPRINT_TYPE,
                    task_directory=TASK_DIRECTORY,
                    overrides=overrides,
                )

                # Set up the operator and server
                shared_state = SharedModelImageChatTaskState(
                    world_module=world_module)
                self._set_up_server(shared_state=shared_state)

                # Check that the agent states are as they should be
                self._get_channel_info(
                ).job.task_runner.task_run.get_blueprint().use_onboarding = (
                    False)
                # Don't require onboarding for this test agent
                with open(expected_state_path) as f:
                    expected_state = json.load(f)
                self._test_agent_states(
                    num_agents=1,
                    agent_display_ids=AGENT_DISPLAY_IDS,
                    agent_messages=AGENT_MESSAGES,
                    form_messages=FORM_MESSAGES,
                    form_task_data=FORM_TASK_DATA,
                    expected_states=(expected_state, ),
                )

                # Check that the contents of the chat data file are as expected
                with open(expected_chat_data_path) as f:
                    expected_chat_data = json.load(f)
                results_path = list(
                    glob.glob(
                        os.path.join(chat_data_folder,
                                     '*_*_*_sandbox.json')))[0]
                with open(results_path) as f:
                    actual_chat_data = json.load(f)
                self._check_final_chat_data(actual_value=actual_chat_data,
                                            expected_value=expected_chat_data)
Beispiel #5
0
    def update_parley(self, k, Dk_tr, Dk_val):

        teacher, student = self.agents
        observations_tr = []
        observations_val = []

        for i in range(len(Dk_tr)):
            # todo later? batchify all the turns from all the dialogs in the meta-batch together.

            for a in teacher.get_episode(Dk_tr[i]):
                observations_tr.extend([student.observe(a)])
                student.self_observe(observations_tr[-1])

            for a in teacher.get_episode(Dk_val[i]):
                observations_val.extend([student.observe(a)])
                student.self_observe(observations_val[-1])

#         logging.info('%s obs in Dk_tr' % len(observations_tr))
        batch_reply = [
            Message({
                'id': self.getID(),
                'episode_done': False
            }) for _ in observations_tr
        ]
        #             batch_reply_val = [
        #                 Message({'id': self.getID(), 'episode_done': False}) for _ in observations_val
        #             ]

        # check if there are any labels available, if so we will train on them
        self.is_training = True  #any('labels' in obs for obs in observations)

        #         print('Train obs to batchify: ', observations_tr); sys.exit()

        # create a batch from the vectors
        batch = student.batchify(observations_tr)
        batch_val = student.batchify(observations_val)

        #         logging.info('Tr batch size: %s' % batch.batchsize)
        #         logging.info('Val batch size: %s' % batch_val.batchsize)
        student._init_cuda_buffer(self.opt['batchsize'], student.label_truncate
                                  or 256)
        student.model.train()
        student.zero_grad()

        init_model_params = copy.deepcopy(student.model.state_dict())
        param_names = list(init_model_params.keys())

        #             import torch; torch.set_printoptions(precision=6)
        #             print('Init: ', student.model.state_dict()[param_names[3]][0,:10])

        student._control_local_metrics(
            disabled=True)  # turn off local metric computation
        #         print(student.__local_metrics_enabled); sys.exit()
        loss = student.compute_loss(batch)

        student.backward(loss)
        student.update_params()
        #             print('grad: ', list(student.model.parameters())[0].grad)

        #             print('Updated: ', student.model.state_dict()[param_names[3]][0,:10])

        loss = student.compute_loss(batch_val)
        student.backward(loss)
        #             print('second eval:', student.model.state_dict()[param_names[3]][0,:10])

        student.model.load_state_dict(init_model_params)
        #             print('back to init: ', student.model.state_dict()[param_names[3]][0,:10])

        student.update_params()
        #             print('final updated: ', student.model.state_dict()[param_names[3]][0,:10])

        #             else:
        #                 with torch.no_grad():
        #                     # save memory and compute by disabling autograd.
        #                     # use `with torch.enable_grad()` to gain back gradients.
        #                     output = self.eval_step(batch)
        #
        #             if output is not None:
        #                 # local metrics are automatically matched up
        #                 self.match_batch(batch_reply, batch.valid_indices, output)
        #
        #             # broadcast the metrics back
        #             for k, values in self._local_metrics.items():
        #                 if len(values) != len(batch.valid_indices):
        #                     raise IndexError(
        #                         f"Batchsize mismatch on metric {k} (got {len(values)}, "
        #                         f"expected {len(batch.valid_indices)}"
        #                     )
        #                 for i, value in zip(batch.valid_indices, values):
        #                     if 'metrics' not in batch_reply[i]:
        #                         batch_reply[i]['metrics'] = {}
        #                     batch_reply[i]['metrics'][k] = value
        #
        #             # Make sure we push all the metrics to main thread in hogwild/workers
        #             self.global_metrics.flush()

        return batch_reply
Beispiel #6
0
    def parley(self):
        self.turn_idx += 1
        print(self.world_tag + ' is at turn {}...'.format(self.turn_idx))
        """If at first turn, we need to give each agent their persona"""
        if self.turn_idx == 1:
            for idx, agent in enumerate(self.agents):
                persona_text = ''
                for s in self.personas[idx]:
                    persona_text += ('<b><span style="color:blue">'
                                     '{}\n</span></b>'.format(s.strip()))
                control_msg = self.get_control_msg()
                control_msg['persona_text'] = persona_text
                control_msg['text'] = self.get_instruction(tag='start',
                                                           agent_id=agent.id)
                # TODO: check that get instruction actually exists?
                agent.observe(validate(control_msg))
                if idx == 0:
                    time.sleep(3)
        """If we get to the min turns, inform turker that they can end if they
        want.
        """
        if self.turn_idx == self.n_turn + 1:
            for idx, agent in enumerate(self.agents):
                control_msg = self.get_control_msg()
                control_msg['text'] = self.get_instruction(
                    idx, tag='exceed_min_turns')
                control_msg['exceed_min_turns'] = True
                agent.observe(validate(control_msg))
        """Otherwise, we proceed accordingly."""
        # Other agent first
        if self.other_first and self.turn_idx == 1:
            if self.model_agent is not None:
                # Model must observe its persona
                persona_act = {
                    'text': '\n'.join([self.model_persona_text,
                                       '__SILENCE__']),
                    'episode_done': False,
                }
                self.model_agent.observe(persona_act)
                self.bot_seen_persona = True
                model_act = copy.deepcopy(self.model_agent.act())
                model_act.force_set('text',
                                    self.format_model_reply(model_act['text']))
                model_act.force_set('id', 'PERSON_2')
                self.dialog.append((1, model_act.get('text')))
                _random_delay()
                self.eval_agent.observe(_strip_tensors(model_act))
            else:
                act = self.get_human_agent_act(self.other_agent)
                timeout = self.check_timeout(act)
                if timeout:
                    # eval agent early disconnect
                    control_msg = self.get_control_msg()
                    control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                    self.eval_agent.observe(validate(control_msg))
                    return
                else:
                    self.dialog.append((1, act.get('text')))
                    act = copy.deepcopy(act)
                    act.force_set('text', self.format_model_reply(act['text']))
                    self.eval_agent.observe(act)

        # Eval agent turn
        act = Message(self.get_human_agent_act(self.eval_agent))
        timeout = self.check_timeout(act)
        if timeout:
            if self.model_agent is None:
                control_msg = self.get_control_msg()
                control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                self.other_agent.observe(validate(control_msg))
            return

        if act['episode_done']:
            if self.turn_idx >= self.n_turn:
                if not self.other_first:
                    self.dialog_list = [
                        '\n'.join([self.dialog[i][1], self.dialog[i + 1][1]])
                        for i in range(0, len(self.dialog), 2)
                    ]
                else:
                    self.dialog_list = [' \n' + self.dialog[0][1]] + [
                        '\n'.join([self.dialog[i][1], self.dialog[i + 1][1]])
                        for i in range(1,
                                       len(self.dialog) - 1, 2)
                    ]
                self.parallel_eval_mode()

                self.chat_done = True
                for ag in self.agents:
                    control_msg = self.get_control_msg()
                    control_msg['text'] = CHAT_ENDED_MSG
                    ag.observe(validate(control_msg))
            return

        self.dialog.append((0, act['text']))

        if not self.bot_seen_persona and self.model_agent is not None:
            # Add persona for model to observe
            act.force_set('text',
                          '\n'.join([self.model_persona_text, act['text']]))
            self.bot_seen_persona = True
        if self.model_agent is not None:
            self.model_agent.observe(act)
        else:
            act = copy.deepcopy(act)
            act.force_set('text', self.format_model_reply(act['text']))
            self.other_agent.observe(act)

        # Model_agent turn
        if not self.other_first or self.turn_idx < self.n_turn:
            if self.model_agent is not None:
                _random_delay()
                act = _strip_tensors(copy.deepcopy(self.model_agent.act()))
                act.force_set('text', self.format_model_reply(act['text']))
                act.force_set('id', 'PERSON_2')
                # NOTE: your model may or may not need to observe itself here
                # If it does, call model_observes_itself or some other specialized
                # function
            else:
                act = self.get_human_agent_act(self.other_agent)
                timeout = self.check_timeout(act)
                if timeout:
                    # eval agent early disconnect
                    control_msg = self.get_control_msg()
                    control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                    self.eval_agent.observe(validate(control_msg))
                    return

            self.dialog.append((1, act.get('text')))
            act = copy.deepcopy(act)
            act.force_set('text', self.format_model_reply(act['text']))
            self.eval_agent.observe(act)
Beispiel #7
0
    def _build_rpa_episodes(self, ep: List[Message]) -> List[Message]:
        """
        Construct new episodes from old, LIGHT ones.

        enumerate over all possible start and label positions

        :param ep:
            episode to explode and build into new eps

        :return episodes:
            return a list of episodes after enumerating over possible labels.
        """
        episodes = []
        context, characters, utterances = self._explode_episode(
            ep, self.exclude_from_context, self.use_speech_prefix)
        candidates = (self.candidates if self.inline_candidate_type == 'all'
                      else list(characters.values()))

        # determine initial start and end indices
        num_utts = self.num_utterances
        if num_utts < 0:
            num_utts = len(utterances) - 1

        start_idx, end_idx = (0, num_utts - 1)

        # Enumerate over all possible start, end positions
        while end_idx < len(utterances) - 1:
            # Step 0: (maybe) annotate the prior utterances of dialogue
            prev_utts = [
                maybe_annotate(
                    *utt[:-1],
                    self.annotate_speaker,
                    self.speaker_separator,
                    self.speaker_annotation_position,
                ) for utt in utterances[start_idx:end_idx]
            ]
            if self.include_light_context:
                prev_utts = [context] + prev_utts

            # Step 1: enumerate over each successive utterance
            for speaker, label, listener in utterances[end_idx:]:
                # Step 2: determine the label control / task type
                if self.classifier_label_type == 'character':
                    speaker_label = (speaker if self.speaker_label_type
                                     == 'speaker' else listener)
                    label_control = (WHO_AM_I if self.speaker_label_type
                                     == 'speaker' else WHO_ARE_YOU)
                else:
                    speaker_label = (SELF if speaker
                                     == characters['_self_name'] else PARTNER)
                    label_control = WHO_IS_THIS
                # Step 3: Determine what label candidates to use
                if (self.num_train_inline_candidates > 0
                        and DatatypeHelper.is_training(self.datatype)
                        and self.inline_candidate_type == 'all'):
                    label_cands = [speaker, listener]
                    while speaker in label_cands and listener in label_cands:
                        label_cands = random.sample(
                            candidates, self.num_train_inline_candidates - 2)
                    label_cands += [speaker, listener]
                    random.shuffle(label_cands)
                else:
                    label_cands = candidates

                # Step 4: Build the Message
                if self.left_to_right:
                    label_words = label.split(' ')
                    for i in range(1, len(label_words) + 1):
                        if self.delimit_label_control:
                            text = self.delimiter.join(
                                prev_utts +
                                [label_control, ' '.join(label_words[:i])])
                        else:
                            text = self.delimiter.join(
                                prev_utts[:-1] +
                                [f"{prev_utts[-1]} {label_control}"] +
                                [' '.join(label_words[:i])])
                        message = Message({
                            'text': text,
                            'labels': [speaker_label],
                            'label_candidates': label_cands,
                            'episode_done': True,
                        })
                        episodes.append([message])
                else:
                    if self.delimit_label_control:
                        text = self.delimiter.join(prev_utts +
                                                   [label_control, label])
                    else:
                        text = self.delimiter.join(
                            prev_utts[:-1] +
                            [f"{prev_utts[-1]} {label_control}"] + [label])
                    message = Message({
                        'text': text,
                        'labels': [speaker_label],
                        'label_candidates': label_cands,
                        'episode_done': True,
                    })
                    episodes.append([message])

            if start_idx == end_idx:
                # edge case where num_utterances == 1
                break
            else:
                start_idx += 1
                end_idx += 1

        return episodes
Beispiel #8
0
 def act(self):
     return Message({
         'id': self.getID(),
         'text': self.fixed_response,
         'episode_done': False
     })
Beispiel #9
0
    def get(self, episode_idx, entry_idx=0):
        d = self.data[episode_idx]
        episode_done = entry_idx == (self.len_episode(episode_idx) - 1)

        wizard_first = 'Wizard' in d['dialog'][0]['speaker']
        idx = entry_idx * 2 if wizard_first else (entry_idx * 2) + 1

        # first, get knowledge
        apprentice_ret_passages = wizard_ret_passages = {}

        if not wizard_first or idx != 0:
            apprentice_entry = d['dialog'][idx - 1]
            apprentice_ret_passages = apprentice_entry['retrieved_passages']
        if idx - 2 >= 0:
            wizard_prev_entry = d['dialog'][idx - 2]
            wizard_ret_passages = wizard_prev_entry['retrieved_passages']

        chosen_topic = d.get('chosen_topic', '')
        chosen_topic_passages = d['chosen_topic_passage']
        chosen_topic = d.get('chosen_topic', '')

        knowledge_dict = {chosen_topic: chosen_topic_passages}
        for ret_passes in [apprentice_ret_passages, wizard_ret_passages]:
            for passage in ret_passes:
                for k, v in passage.items():
                    if k not in knowledge_dict.keys():
                        knowledge_dict[k] = v

        # then, get text
        if idx == 0:
            # first message - only have the chosen topic
            text = chosen_topic
        elif idx == 1:
            # first response - only have the first message
            text = (
                f"{chosen_topic}{self.chosen_topic_delimiter}{apprentice_entry['text']}"
            )
        else:
            text = ''
            if self.label_type == 'chosen_sent':
                # if chosen_sent, add wizard response to dialog history
                text += '{}\n'.format(wizard_prev_entry['text'])
            text += apprentice_entry['text']

        # next, get label
        wizard_entry = d['dialog'][idx]
        if self.label_type == 'response':
            labels = [wizard_entry['text']]
        else:
            title, sentence = _get_chosen_title_and_sent(
                wizard_entry, knowledge_dict)
            if self.knowledge_separator and title != TOKEN_NOCHOSEN:
                labels = ['{} {} {}'.format(title, TOKEN_KNOWLEDGE, sentence)]
            else:
                labels = ['{} {}'.format(title, sentence)]

        # finally, get label_candidates
        label_cands = ['{} {}'.format(TOKEN_NOCHOSEN, TOKEN_NOCHOSEN)]
        knowledge_str = ''
        for title, passage in knowledge_dict.items():
            for p in passage:
                if self.knowledge_separator:
                    cand = '{} {} {}'.format(title, TOKEN_KNOWLEDGE, p)
                else:
                    cand = '{} {}'.format(title, p)
                knowledge_str += cand + '\n'
                label_cands.append(cand)
        if self.label_type == 'response':
            if 'train' in self.datatype:
                label_cands = []
            else:
                label_cands = wizard_entry.get('candidate_responses', [])

        action = Message({
            'id': 'WizardDialogKnowledgeTeacher',
            'text': text,
            'labels': labels,
            'chosen_topic': chosen_topic,
            'episode_done': episode_done,
            'label_candidates': label_cands,
        })
        if self.include_knowledge:
            action['knowledge'] = knowledge_str
        if self.include_checked_sentence:
            title, sentence = _get_chosen_title_and_sent(
                wizard_entry, knowledge_dict)
            action['title'] = title
            action['checked_sentence'] = sentence
        return action
Beispiel #10
0
    def test_custom_eval(self):
        """
        Test whether custom evaluation works.
        """
        with testing_utils.capture_output():
            parser = setup_args()
            opt = parser.parse_args(
                [
                    '--task',
                    'wizard_of_wikipedia',
                    '--datatype',
                    'valid',
                    '--label-type',
                    'chosen_sent',
                ]
            )
            teacher = create_task_agent_from_taskname(opt)[0]

        title = 'Gardening'
        cands = list('four')

        text = "Gardening\nI like Gardening, even when I've only been doing it for a short time."
        response = 'I live on a farm, we garden all year long, it is very relaxing.'
        checked_sent = (
            'Gardening is considered by many people to be a relaxing activity.'
        )
        checked_sent_label = f'{title}{TOKEN_KNOWLEDGE}{checked_sent}'

        retrieval_metric_keys = ['passage_r@1', 'passage_r@5', 'title_r@1', 'title_r@5']

        chosen_sent_teacher_action = Message(
            {
                'text': text,
                'labels': [checked_sent_label],
                'title': [title],
                'checked_sentence': [checked_sent],
            }
        )
        correct_chosen_sent_response = Message(
            {
                'text': checked_sent_label,
                'title_candidates': [title] + cands,
                'text_candidates': [checked_sent_label] + cands,
            }
        )
        top5_chosen_sent_response = Message(
            {
                'text': f'hello{TOKEN_KNOWLEDGE}goodbye',
                'title_candidates': cands + [title],
                'text_candidates': cands + [checked_sent_label],
            }
        )
        incorrect_chosen_sent_response = Message(
            {
                'text': f'hello{TOKEN_KNOWLEDGE}goodbye',
                'title_candidates': cands,
                'text_candidates': cands,
            }
        )

        response_teacher_action = Message(
            {'text': text, 'labels': [response], 'checked_sentence': checked_sent}
        )
        high_f1_response = Message({'text': checked_sent})
        low_f1_response = Message({'text': 'incorrect'})

        # 1) Test with correct top sentence
        teacher.reset_metrics()
        teacher.custom_evaluation(
            chosen_sent_teacher_action,
            [checked_sent_label],
            correct_chosen_sent_response,
        )
        report = teacher.report()
        for k in retrieval_metric_keys:
            assert k in report
            assert report[k] == AverageMetric(1)

        # 2) Test with top sentence in top 5
        teacher.reset_metrics()
        teacher.custom_evaluation(
            chosen_sent_teacher_action, [checked_sent_label], top5_chosen_sent_response
        )
        report = teacher.report()
        for k in retrieval_metric_keys:
            assert k in report
            assert report[k] == AverageMetric(1) if '5' in k else AverageMetric(0)

        # 3) Test with no top sentences
        teacher.reset_metrics()
        teacher.custom_evaluation(
            chosen_sent_teacher_action,
            [checked_sent_label],
            incorrect_chosen_sent_response,
        )
        report = teacher.report()
        for k in retrieval_metric_keys:
            assert k in report
            assert report[k] == AverageMetric(0)

        # 4) Test knowledge f1 with high f1
        teacher.label_type = 'response'
        teacher.reset_metrics()
        teacher.custom_evaluation(response_teacher_action, [response], high_f1_response)
        report = teacher.report()
        assert 'knowledge_f1' in report
        assert report['knowledge_f1'] == F1Metric(1)

        # 5) Test knowledge f1 with low f1
        teacher.reset_metrics()
        teacher.custom_evaluation(response_teacher_action, [response], low_f1_response)
        report = teacher.report()
        assert 'knowledge_f1' in report
        assert report['knowledge_f1'] == F1Metric(0)
Beispiel #11
0
    def _setup_data(self, opt: Opt) -> List[List[Message]]:
        """
        Flatten and classify the normal task data.

        Save/load where applicable.

        :param opt:
            options dict.
        """
        # create save directory, if it does not already exist
        self.original_task_name = ':'.join(opt['task'].split(':')[2:])
        self.save_dir = self._get_save_path(opt['datapath'],
                                            str(datetime.datetime.today()))
        os.makedirs(self.save_dir, exist_ok=True)

        fname = f"{opt['datatype'].split(':')[0]}.json"
        self.save_path = os.path.join(self.save_dir, fname)

        data = self.load_data(opt, fname)
        if data is not None:
            # successfully load data
            return data

        # build the original teacher
        original_task_module = get_original_task_module(opt)
        teacher_opt = deepcopy(opt)
        teacher_opt['task'] = self.original_task_name
        teacher = original_task_module(teacher_opt)

        total_exs = teacher.num_examples()
        if self.opt['max_examples'] > 0:
            total_exs = min(self.opt['max_examples'], total_exs)

        progress_bar = tqdm(total=total_exs,
                            unit='ex',
                            unit_scale=True,
                            desc='Building flattened data')

        all_episodes = []
        num_exs = 0
        while num_exs < total_exs:
            current_episode = []
            episode_done = False

            while not episode_done:
                action = Message(teacher.act())
                current_episode.append(action)
                episode_done = action.get('episode_done', False)
                num_exs += 1

            # flatten the episode into 1-example episodes with context
            flattened_ep = flatten_and_classify(
                current_episode,
                opt['flatten_max_context_length'],
                include_labels=opt['flatten_include_labels'],
                delimiter=opt['flatten_delimiter'],
                word_lists=self.word_lists,
            )
            all_episodes += flattened_ep

            progress_bar.update(len(flattened_ep))

        # save data for future use
        self.save_data(all_episodes)

        return all_episodes
Beispiel #12
0
 def _setup_data_with_context(self):
     yield Message(self._add_context(EXAMPLE1))
     yield Message(EXAMPLE2)
     yield Message(self._add_context(EXAMPLE3))
     yield Message(EXAMPLE4)
Beispiel #13
0
 def _setup_data(self):
     yield Message(EXAMPLE1)
     yield Message(EXAMPLE2)
     yield Message(EXAMPLE3)
     yield Message(EXAMPLE4)
Beispiel #14
0
 def get(self, episode_idx, entry_idx=0):
     ep = self.data[episode_idx]
     ep['label_candidates'] = self.label_candidates[ep['class_type']]
     ep['id'] = 'Yelp Gender'
     ep['episode_done'] = True
     return Message(ep)
Beispiel #15
0
 def _to_tuple(self, msg: Message) -> Tuple:
     # turned into an indexable object
     keys = ['text', 'labels', 'eval_labels']
     return tuple(self._val(msg.get(k)) for k in keys)
Beispiel #16
0
    def parley(self):
        print(
            f'{self.__class__.__name__}:{self.tag}: is at turn {self.task_turn_idx}, with {self.num_turns} pairs of turns needed...'
        )

        if self.task_turn_idx == 0:
            self._run_initial_turn()
            self.task_turn_idx += 1
            return
        """Otherwise, we proceed accordingly"""
        print(
            f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: {self.task_turn_idx}'
        )
        acts = [None, None]
        for idx, agent in enumerate([self.agent, self.bot]):
            if not self.chat_done:
                acts[idx] = agent.act(timeout=self.max_resp_time)
                acts[idx] = Message(Compatibility.maybe_fix_act(
                    acts[idx])).json_safe_payload()
                print(
                    f'Got act for agent idx {idx}, act was: {acts[idx]} and self.task_turn_idx: {self.task_turn_idx}.'
                )

            if acts[idx].get('task_data', {}).get('final_rating') is not None:

                self.chat_done = True
                # agent ends chat after exceeding minimum number of turns

                # Human has just responded. Any problem data received now will be
                # regarding the bot's prior utterance
                turn_idx = -1
                # Attach the problem data and final rating to the last utterance, since
                # the human hasn't said anything since then
                p = acts[idx]['task_data'].get(
                    'problem_data_for_prior_message')
                if p is not None:
                    self.__add_problem_data_to_utterance(p, turn_idx=turn_idx)
                self.dialog[turn_idx]['final_rating'] = acts[idx]['task_data'][
                    'final_rating']

                # Save the final chat data
                date_folder = time.strftime('%Y_%m_%d')
                time_string = time.strftime('%Y%m%d_%H%M%S')
                chat_data_subfolder = os.path.join(
                    self.opt['chat_data_folder'], date_folder)
                os.makedirs(chat_data_subfolder, exist_ok=True)
                chat_data_path = os.path.join(
                    chat_data_subfolder,
                    f'{time_string}_{np.random.randint(0, 1000)}_{self.task_type}.json',
                )
                final_chat_data = self.get_final_chat_data()
                self.agent.mephisto_agent.state.messages.append(
                    {'final_chat_data': final_chat_data})
                # Append the chat data directly to the agent state's message list in
                # order to prevent the worker from seeing a new text response in the UI
                with open(chat_data_path, 'w+') as f_json:
                    data_str = json.dumps(final_chat_data)
                    f_json.write(data_str)
                print(f'{self.__class__.__name__}:{self.tag}: Data saved at '
                      f'{chat_data_path} for model: {self.bot.worker_id}.')

                # Soft-block the worker if there were acceptability violations
                acceptability_violations = final_chat_data[
                    'acceptability_violations'][0]
                if (acceptability_violations is not None
                        and acceptability_violations != ''):
                    print(
                        f'**NOTE** Acceptability violations detected: {acceptability_violations}'
                    )
                    # Grant the failed qualification
                    self.agent.mephisto_agent.get_worker().grant_qualification(
                        self.block_qualification, 1)

                return

            else:
                utterance_data = {
                    'agent_idx': idx,
                    # Get rid of annotations HTML if it's the bot response
                    'text': acts[idx]['text'].split('<br>')[0],
                    'id':
                    acts[idx].get('id',
                                  'NULL_ID'),  # In case model doesn't set id
                }
                self.dialog.append(utterance_data)
                if idx == 0:
                    # Human has just responded. Any problem data received now will be
                    # regarding the bot's prior utterance
                    p = acts[idx]['task_data'].get(
                        'problem_data_for_prior_message')
                    if p is not None:
                        turn_idx = -2
                        # Attach the problem data to the second-to-last utterance, since
                        # the last utterance is what the human just said
                        self.__add_problem_data_to_utterance(p,
                                                             turn_idx=turn_idx)

                self._postprocess_acts(acts=acts, agent_idx=idx)
                for other_agent in [self.agent, self.bot]:
                    if other_agent != agent:
                        other_agent.observe(validate(acts[idx]))

                print(
                    f'[agent {idx}] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: {self.dialog}'
                )
                self.task_turn_idx += 1
Beispiel #17
0
 def batch_act(self, observations):
     self._padding_counter.update(
         ['val' for o in observations if o.is_padding()])
     self._counter.update(
         [self._to_tuple(o) for o in observations if not o.is_padding()])
     return [Message() for o in observations]
 def get(self, episode_idx, entry_idx):
     return Message(self.data[episode_idx])
Beispiel #19
0
    def setup_data(self, datafile):
        print('loading: ' + datafile)
        if self.datatype.startswith('train'):
            path_to_open = os.path.join(datafile, 'train.txt')
        elif self.datatype.startswith('valid'):
            path_to_open = os.path.join(datafile, 'valid.txt')
        else:
            path_to_open = os.path.join(datafile, 'test.txt')

        with PathManager.open(path_to_open) as f:
            raw_data = [json.loads(line.strip()) for line in f]

        data = []
        label_speaker_id_range = {}
        predicted_summary_dict = {}
        if self.use_predicted_summary:
            is_session_level = not ('utt_' in self.previous_persona_type)
            predsum_path = get_predicted_summary_path(self.msc_dpath, is_session_level)
            logger.warning(f"use the predicted summary from {predsum_path}")
            with PathManager.open(predsum_path) as jsonfile:
                predicted_summary_dict = json.load(jsonfile)

        def _get_time_gap(time_num, time_unit, time_token=""):
            time_gap = str(time_num) + ' ' + time_unit
            return f'{time_token} {time_gap}' if len(time_token) > 0 else time_gap

        def _compile_persona_dialog_input(
            dialog, personas, previous_dialogs, label_speaker_id
        ):
            new_dialog = copy.deepcopy(dialog)
            new_previous_dialogs = copy.deepcopy(previous_dialogs)
            your_persona = ""
            partner_persona = ""
            if label_speaker_id == 'self':
                your_persona = '\n'.join([f'your persona: {x}' for x in personas[1]])
                partner_persona = '\n'.join(
                    [f"partner's persona: {x}" for x in personas[0]]
                )
            elif label_speaker_id == 'their':
                your_persona = '\n'.join([f'your persona: {x}' for x in personas[0]])
                partner_persona = '\n'.join(
                    [f"partner's persona: {x}" for x in personas[1]]
                )
                for prev_dialog in new_previous_dialogs:
                    prev_dialog['dialog'].insert(0, {"text": DUMMY_TEXT})
                    if len(prev_dialog['dialog']) % 2 == 1 and (
                        self.history_person_tokens is None
                    ):
                        prev_dialog['dialog'].append({"text": DUMMY_TEXT})
                new_dialog.insert(0, {"text": DUMMY_TEXT})

            return your_persona, partner_persona, new_dialog, new_previous_dialogs

        for dialog_dict in raw_data:
            initial_data_id = dialog_dict['metadata']['initial_data_id']
            if self.label_speaker_id == 'both':
                label_speaker_id_range = ['their', 'self']
            else:
                label_speaker_id_range = [self.label_speaker_id]

            for label_speaker_id in label_speaker_id_range:
                if self.use_predicted_summary:
                    personas_to_complie = predicted_summary_dict[
                        str(self.session_id - 1)
                    ][initial_data_id]
                elif self.previous_persona_type.startswith('init'):
                    personas_to_complie = dialog_dict['init_personas']
                else:
                    personas_to_complie = dialog_dict['personas']

                (
                    your_persona,
                    partner_persona,
                    new_dialog,
                    new_previous_dialogs,
                ) = _compile_persona_dialog_input(
                    dialog_dict['dialog'],
                    personas_to_complie,
                    dialog_dict['previous_dialogs'],
                    label_speaker_id,
                )
                previous_sessions_msgs = []
                if self.previous_persona_type == 'raw_history':
                    for d_id in range(len(new_previous_dialogs)):
                        previous_dialog_msg = [
                            x['text'] for x in new_previous_dialogs[d_id]['dialog']
                        ]
                        if self.history_person_tokens:
                            previous_dialog_msg = [
                                self.history_person_tokens[i % 2] + ' ' + text
                                for i, text in enumerate(previous_dialog_msg)
                                if text != DUMMY_TEXT
                            ]
                        if self.history_time_gaps_token:
                            time_gap_i = _get_time_gap(
                                new_previous_dialogs[d_id]['time_num'],
                                new_previous_dialogs[d_id]['time_unit'],
                                time_token=self.history_time_gaps_token,
                            )
                            previous_sessions_msgs.append(
                                '\n'.join(previous_dialog_msg + [time_gap_i])
                            )
                        else:
                            previous_sessions_msgs.append(
                                '\n'.join(previous_dialog_msg)
                            )

                if self.previous_session_delimiter is not None:
                    previous_sessions_msgs = [
                        val
                        for pair in zip(
                            previous_sessions_msgs,
                            [self.previous_session_delimiter]
                            * len(previous_sessions_msgs),
                        )
                        for val in pair
                    ]
                previous_sessions_msgs = '\n'.join(previous_sessions_msgs)

                episode = []
                for i in range(0, len(new_dialog) - 1, 2):
                    text = new_dialog[i]['text']
                    partner_persona_one_line = partner_persona.replace('\n', '').split(
                        "partner's persona: "
                    )
                    your_persona_one_line = your_persona.replace('\n', '').split(
                        "your persona: "
                    )
                    action = {
                        'id': self.id,
                        'text': self.normalize_replies(text),
                        'labels': [self.normalize_replies(new_dialog[i + 1]['text'])],
                        'session_id': self.session_id,
                        'initial_data_id': initial_data_id,
                        'personas': f'{partner_persona}\n{your_persona}',
                        'personas_one_line': f"partner's persona: {' '.join(partner_persona_one_line)}\nyour persona: {' '.join(your_persona_one_line)}",
                    }
                    episode.append(action)
                    if self.session_openning:
                        break

                persona_context_str = ""
                if 'self' in self.previous_persona_type:
                    persona_context_str = your_persona
                elif 'their' in self.previous_persona_type:
                    persona_context_str = partner_persona
                elif 'both' in self.previous_persona_type:
                    if self.your_persona_first:
                        persona_context_str = (
                            (your_persona + '\n') if len(your_persona) > 0 else ""
                        ) + partner_persona
                    else:
                        persona_context_str = (
                            (partner_persona + '\n') if len(partner_persona) > 0 else ""
                        ) + your_persona
                elif self.previous_persona_type == 'raw_history':
                    persona_context_str = previous_sessions_msgs

                if self.include_last_time_gap:
                    time_gap = _get_time_gap(
                        dialog_dict['previous_dialogs'][-1]['time_num'],
                        dialog_dict['previous_dialogs'][-1]['time_unit'],
                    )
                    persona_context_str = (
                        (persona_context_str + '\n')
                        if len(persona_context_str) > 0
                        else ""
                    ) + f'[{time_gap}]'

                if persona_context_str and len(persona_context_str) > 0:
                    episode[0]['text'] = persona_context_str + '\n' + episode[0]['text']

                data.append(episode)

        for episode in data:
            start_idx = 0
            for i, turn in enumerate(episode):
                yield Message(turn), i == start_idx
Beispiel #20
0
    def parley(self):
        """
        Agent 0 goes first.

        Alternate between the two agents.
        """

        if self.turn_cnt == 0:
            self.p1, self.p2 = self.get_contexts()

        acts = self.acts
        print("acts", acts)
        agents = self.agents

        if self.turn_cnt == 0 and self.p1 != '':
            # add the context on to the first message to agent 0
            context_act = Message({
                'id': 'context',
                'text': self.p1,
                'episode_done': True
            })
            agents[0].observe(validate(context_act))
        try:
            act = deepcopy(agents[0].act())
        except StopIteration:
            self.reset()
            self.finalize_episode()
            self.turn_cnt = 0
            return
        acts[0] = act
        if self.turn_cnt == 0 and self.p2 != '':
            # add the context on to the first message to agent 1
            context_act = Message({
                'id': 'context',
                'text': self.p2,
                'episode_done': False
            })
            agents[1].observe(validate(context_act))
        try:
            try:
                cache = dict()  # caching to save the user_id

                def get_user_id(user_id):
                    print("getting user id")
                    if user_id not in cache:
                        agents_clone_human = deepcopy(agents[0].clone())
                        agents_clone_bot = deepcopy(agents[1].clone())
                        agents[0] = agents_clone_human
                        agents[1] = agents_clone_bot
                        print("new agent created")

                        cache[user_id] = user_id
                    return cache[user_id]

                file = open("input.txt", "r").read()
                text = file.split("_")
                text = text[1]
                print("user_id_value", text)
                userid = get_user_id(text)

            except:
                print("None")

            agents[1].observe(validate(act))
            if os.path.isfile("input.txt"):
                os.remove("input.txt")

            print("client:", act)
            acts[1] = agents[1].act()
            print("agent:", acts[1])
            print("text only:", acts[1]['text'])
            file = open("output.txt", "w").write(str(acts[1]['text']))

            agents[0].observe(validate(acts[1]))
            self.update_counters()
            self.turn_cnt += 1
            if act['episode_done']:

                self.finalize_episode()
                self.turn_cnt = 0
        except:
            #agents[1].observe(validate(act))

            if os.path.isfile("input.txt"):
                os.remove("input.txt")
            print("client:", act)
            acts[1] = {
                'id': 'ImageSeq2seq',
                'episode_done': False,
                'text': "I didn't get that"
            }
            print("agent:", acts[1])
            print("text only:", acts[1]['text'])

            agents[0].observe(validate(acts[1]))
            self.update_counters()
            self.turn_cnt += 1
            act = {
                'id': 'localHuman',
                'episode_done': False,
                'label_candidates': None,
                'text': 'hi'
            }
            if act['episode_done']:

                self.finalize_episode()
                self.turn_cnt = 0
Beispiel #21
0
def _strip_tensors(act):
    """
    Remove all tensor objects from an act to ensure we don't try to serialize them.
    """
    return Message({k: v for k, v in act.items() if not torch.is_tensor(v)})