Example #1
0
    def message_mutation(self, message: Message) -> Message:
        if CONST.RETRIEVED_DOCS not in message:
            return message
        new_message = message.copy()
        docs = message.get(CONST.RETRIEVED_DOCS)
        new_docs = []
        max_chunks = self.opt.get('woi_doc_max_chunks', 100)

        keep = torch.randperm(len(docs))[0:max_chunks]
        remove = torch.ones(len(docs))
        remove[keep] = 0

        for i in range(len(docs)):
            if remove[i] == 0:
                new_docs.append(docs[i])
            else:
                # We may still keep the doc if it contains the gold checked sentence(s).
                checked_sentences = message.get(
                    CONST.SELECTED_SENTENCES,
                    message.get('labels', [CONST.NO_SELECTED_SENTENCES_TOKEN]),
                )
                d = docs[i]
                found = False
                if ' '.join(checked_sentences) != CONST.NO_SELECTED_SENTENCES_TOKEN:
                    for sent in checked_sentences:
                        s = sent.lstrip(' ').rstrip(' ')
                        if s in d:
                            found = True
                if found:
                    new_docs.append(docs[i])

        new_message.force_set(CONST.RETRIEVED_DOCS, new_docs)
        return new_message
Example #2
0
    def act(self):
        """
        Send new dialog message.
        """
        if not hasattr(self, 'epochDone'):
            # reset if haven't yet
            self.reset()

        # get next example, action is episode_done dict if already out of exs
        action, self.epochDone = self.next_example()
        # TODO: all teachers should eventually create messages
        # while setting up the data, so this won't be necessary
        action = Message(action)
        action.force_set('id', self.getID())

        # remember correct answer if available
        self.lastY = action.get('labels_1', action.get('eval_labels_1', None))
        if (not self.datatype.startswith('train')
                or 'evalmode' in self.datatype) and 'labels' in action:
            # move labels to eval field so not used for training
            # but this way the model can use the labels for perplexity or loss
            action = action.copy()
            labels = action.pop('labels')
            if not self.opt.get('hide_labels', False):
                action['eval_labels'] = labels

        return action
Example #3
0
 def message_mutation(self, message: Message) -> Message:
     new_message = message.copy()
     new_message[CONST.RETRIEVED_DOCS] = [message['checked_sentence']]
     new_message[CONST.RETRIEVED_DOCS_TITLES] = [message['title']]
     new_message[CONST.RETRIEVED_DOCS_URLS] = ['']
     new_message.pop('checked_sentence')
     if message.get('history', '') != '':
         text = message.get('history', '') + '\n' + message.get('text', '')
         new_message.force_set('text', text)
     return new_message
Example #4
0
 def get_temp_history(self, observation: Message) -> tp.Optional[str]:
     """
     If use_knowledge and not add knowledge to history, knowledge will be
     set as temp history in agent's history and vectorize function will run
     accordingly.
     """
     use_knowledge = self.opt.get('use_knowledge', False)
     add_knowledge_to_history = self.opt.get('add_knowledge_to_history',
                                             False)
     if use_knowledge and not add_knowledge_to_history:
         if self.opt.get('chosen_sentence', True):
             return observation.get('checked_sentence', None)
         return observation.get('knowledge', None)
     return None
Example #5
0
    def self_observe(self, self_message: Message):
        """
        Override TA.self_observe.

        Make sure that knowledge agent and other agents have the same history.

        This eliminates unnecessary copies of the previous knowledge in the history.
        """
        self.knowledge_agent.self_observe(self_message)
        self.knowledge_responses.append(self_message.get('knowledge_response', ''))
        observation = {'text': self.knowledge_agent.history.get_history_str()}
        for agent in self.dialogue_agent_clones:
            agent.reset()
            agent.history.update_history(
                observation,
                temp_history=self.dialogue_agent.get_temp_history(observation),
            )
        if (
            self.search_decision_agent
            and self.search_decision is SearchDecision.COMPUTE
        ):
            self.search_decision_agent.self_observe(self_message)
            self.search_decision_agent.history.reset()
            self.search_decision_agent.history.update_history(
                observation,
                temp_history=self.search_decision_agent.get_temp_history(observation),
            )
        if self.search_query_agent:
            self.search_query_agent.self_observe(self_message)
            self.search_query_agent.history.reset()
            self.search_query_agent.history.update_history(
                observation,
                temp_history=self.search_query_agent.get_temp_history(observation),
            )
Example #6
0
    def custom_evaluation(self, teacher_action: Message, labels,
                          model_response: Message):
        resp = model_response.get('text')
        if not resp:
            return

        if teacher_action['type'] == 'apicall' and resp.startswith(
                'apicall: '):
            gold = teacher_action['slots']
            slot_strs = resp[9:].split(' ; ')
            parsed = {}
            for slot_str in slot_strs:
                if ' = ' not in slot_str:
                    if slot_str != '':
                        # syntactically invalid generations should count against us
                        self.metrics.add('slot_p', AverageMetric(0))
                    continue
                name, value = slot_str.split(' = ')
                parsed[name] = value

            # slot precision
            for k, v in parsed.items():
                self.metrics.add('slot_p', AverageMetric(v == gold.get(k)))
            # slot recall
            for k, v in gold.items():
                self.metrics.add('slot_r', AverageMetric(v == parsed.get(k)))
        elif teacher_action['type'] == 'apiresp':
            delex_resp = self._delex(resp, teacher_action['slots'])
            delex_label = self._delex(labels[0], teacher_action['slots'])
            self.metrics.add('delex_bleu',
                             BleuMetric.compute(delex_resp, [delex_label]))
Example #7
0
 def message_mutation(self, message: Message) -> Message:
     if not message.get('available_knowledge_text'):
         return message
     context = message.pop('text')
     knowledge = f'{TOKEN_KNOWLEDGE} {message["available_knowledge_text"]} {TOKEN_END_KNOWLEDGE}'
     delimiter = self.opt.get('delimiter', '\n')
     message['text'] = (knowledge if context == SILENCE else
                        f'{knowledge}{delimiter}{context}')
     return message
Example #8
0
 def observe(self, observation: Message) -> Message:
     """
     Before general observe, if use_knowledge and add knowledge to history,
     knowledge will be added to agent's along with text.
     """
     use_knowledge = self.opt.get('use_knowledge', False)
     add_knowledge_to_history = self.opt.get('add_knowledge_to_history',
                                             False)
     if use_knowledge and add_knowledge_to_history:
         if self.opt.get('chosen_sentence', True):
             add_text = observation.get('checked_sentence', None)
         else:
             add_text = observation.get('knowledge', None)
         if isinstance(add_text, str) and add_text != '':
             observation.force_set(
                 'text',
                 observation['text'] + self.history.delimiter + add_text)
     return super().observe(observation)
Example #9
0
    def evaluate_response(self, observation: Message,
                          labels: List[str]) -> None:
        """
        Compute all required text-based metrics based on an observation and labels.
        """
        prediction = observation.get('text', None)

        self.add('exs', SumMetric(1))

        if prediction is not None:
            self.add('accuracy', ExactMatchMetric.compute(prediction, labels))
            self.add('f1', F1Metric.compute(prediction, labels))

            for k in range(1, 5):  # 1..4
                if f'bleu-{k}' in self._metrics_list:
                    self.add(f'bleu-{k}',
                             BleuMetric.compute(prediction, labels, k))
            # if any of the rouges are in the list
            if self._metrics_list & ROUGE_METRICS:
                r1, r2, rL = RougeMetric.compute_many(prediction, labels)
                if 'rouge-1' in self._metrics_list and r1:
                    self.add('rouge_1', r1)
                if 'rouge-2' in self._metrics_list and r2:
                    self.add('rouge_2', r2)
                if 'rouge-L' in self._metrics_list and rL:
                    self.add('rouge_L', rL)
            # compute distinct-k
            for k in [1, 2]:
                if f'interdistinct-{k}' in self._metrics_list:
                    self.add(f'interdistinct-{k}',
                             InterDistinctMetric.compute(prediction, k))
                if f'intradistinct-{k}' in self._metrics_list:
                    self.add(f'intradistinct-{k}',
                             IntraDistinctMetric.compute(prediction, k))

        # Ranking metrics.
        self._update_ranking_metrics(observation, labels)

        # User-reported metrics
        if 'metrics' in observation:
            for uk, v in observation['metrics'].items():
                if uk in ALL_METRICS:
                    # don't let the user override our metrics
                    uk = f'USER_{uk}'
                assert isinstance(uk, str), type(k)
                if not isinstance(v, Metric):
                    warn_once(
                        f'Metric {uk} is assumed to be averaged per example.')
                    v = AverageMetric(v)
                assert isinstance(v, Metric)
                self.add(uk, v)
Example #10
0
 def get_temp_history(self, observation: Message) -> Optional[str]:
     """
     Conditionally return a style-token string to temporarily insert into history.
     """
     use_style_rand = random.random()
     if use_style_rand < self.use_style_frac:
         # Use the style
         style = observation.get('personality')
         # This key name is dependent on Image-Chat and will change for other tasks.
         # If obs does not contain 'personality' (i.e. at the end of an epoch during
         # validation), there will be no style
     else:
         style = ''
     if style is not None and style != '':
         return STYLE_SEP_TOKEN + style
Example #11
0
    def custom_evaluation(
        self,
        teacher_action: Message,
        labels: Optional[Tuple[str]],
        model_response: Message,
    ) -> None:
        if model_response.is_padding() or (not model_response.get(
                'text', None)):
            return

        expected_graph = break_knowledge_graph(labels[0].lower())
        predicted_graph = break_knowledge_graph(model_response['text'].lower())

        # Encoding the graph edges/mutation operations into ints for readily use of F1Metric
        expected_graph_enc, predicted_graph_enc = encode_set_elements(
            expected_graph, predicted_graph)
        self.metrics.add(
            'response_elements_f1',
            F1Metric.compute(
                guess=' '.join(predicted_graph_enc),
                answers=[' '.join(expected_graph_enc)],
            ),
        )

        # Subject, Relation F1
        # Changind "(MUT) < you , in , house >"   --into-->   "(MUT) < you , in "
        # This is to check F1 for the predicted subject and relation overlap.
        ekg_sub_rel = set([e.rsplit(',', 1)[0] for e in expected_graph])
        pkg_sub_rel = set([e.rsplit(',', 1)[0] for e in predicted_graph])
        ekg_sub_rel_ids, pkg_sub_rel_ids = encode_set_elements(
            ekg_sub_rel, pkg_sub_rel)
        self.metrics.add(
            'graph_subject_relation_f1',
            F1Metric.compute(guess=' '.join(pkg_sub_rel_ids),
                             answers=[' '.join(ekg_sub_rel_ids)]),
        )

        # Subject F1
        # Changind "(MUT) < you , in " (produced above)   --into-->   "(MUT) < you "
        # This is to check F1 for the predicted subject overlap.
        ekg_sub = set([e.split(',')[0] for e in ekg_sub_rel])
        pkg_sub = set([e.split(',')[0] for e in pkg_sub_rel])
        ekg_sub_ids, pkg_sub_ids = encode_set_elements(ekg_sub, pkg_sub)
        self.metrics.add(
            'graph_subject_f1',
            F1Metric.compute(guess=' '.join(pkg_sub_ids),
                             answers=[' '.join(ekg_sub_ids)]),
        )
Example #12
0
    def get_retrieved_knowledge(self, message: Message):

        retrieved_docs = []
        if not message.get(consts.RETRIEVED_DOCS):
            return retrieved_docs

        # First adding the docs with selected sentences.
        selected_sentences = message[consts.SELECTED_SENTENCES]
        n_docs_in_message = len(message[consts.RETRIEVED_DOCS])
        already_added_doc_idx = []

        if ' '.join(selected_sentences) == consts.NO_SELECTED_SENTENCES_TOKEN:
            return retrieved_docs  # `retrieved_docs` is empty at this point

        for doc_idx in range(n_docs_in_message):
            doc_content = message[consts.RETRIEVED_DOCS][doc_idx]
            for sel_sentc in selected_sentences:
                if sel_sentc in doc_content:
                    retrieved_docs.append(
                        self._extract_doc_from_message(message, doc_idx)
                    )
                    already_added_doc_idx.append(doc_idx)
                    break
            if len(retrieved_docs) == self._n_docs and doc_idx != (self._n_docs - 1):
                logging.warning(
                    f'More than {self._n_docs} documents have selected sentences. Trimming them to the first {self._n_docs}'
                )
                break

        # Then adding other (filler) docs.
        # We add them by iterating forward in the __retrieved-docs__ list for repeatability,
        # but we shuffle the order of the final retruned docs, to make sure model doesn't cheat.
        for doc_idx in range(n_docs_in_message):
            if len(retrieved_docs) == self._n_docs:
                break

            if doc_idx in already_added_doc_idx:
                continue

            retrieved_docs.append(self._extract_doc_from_message(message, doc_idx))

        if n_docs_in_message > len(retrieved_docs):
            logging.debug(
                f'Trimmed retrieved docs from {n_docs_in_message} to {len(retrieved_docs)}'
            )
        random.shuffle(retrieved_docs)
        return retrieved_docs
Example #13
0
    def custom_evaluation(self, teacher_action: Message, labels,
                          model_response: Message):
        resp = model_response.get("text")
        if not resp:
            return
        if teacher_action["type"] == tod.STANDARD_RESP:
            if resp.startswith(tod.STANDARD_RESP):
                resp = resp[len(tod.STANDARD_RESP):]
            predicted = SerializationHelpers.str_to_api_dict(resp)

            metrics = SlotMetrics(teacher_action["slots"], predicted).report()
            for key, value in metrics.items():
                self.metrics.add(key, value)

        elif teacher_action["type"] == tod.STANDARD_USER_UTTERANCE:
            metrics = NlgMetrics(resp, labels).report()
            for key, value in metrics.items():
                self.metrics.add(key, value)
Example #14
0
    def custom_evaluation(self, teacher_action: Message, labels,
                          model_response: Message):
        super().custom_evaluation(teacher_action, labels, model_response)
        resp = model_response.get("text")
        if not resp:
            return

        if (teacher_action["type"] == tod.STANDARD_CALL
                and tod.STANDARD_API_NAME_SLOT in teacher_action["slots"]
                and teacher_action["slots"][tod.STANDARD_API_NAME_SLOT]
                in VALID_OUT_DOMAIN_API_NAMES):
            if resp.startswith(tod.STANDARD_CALL):
                resp = resp[len(tod.STANDARD_CALL):]
            predicted = tod.SerializationHelpers.str_to_api_dict(resp)
            self.metrics.add(
                f"OutDomainOnlyApis/jga",
                AverageMetric(teacher_action["slots"] == predicted),
            )
Example #15
0
    def evaluate_response(self, observation: Message,
                          labels: List[str]) -> None:
        """
        Compute all required text-based metrics based on an observation and labels.
        """
        prediction = observation.get('text', None)

        self.add('exs', SumMetric(1))

        if prediction is not None:
            self.add('accuracy', ExactMatchMetric.compute(prediction, labels))
            self.add('f1', F1Metric.compute(prediction, labels))

            for k in range(1, 5):  # 1..4
                if f'bleu-{k}' in self._metrics_list:
                    self.add(f'bleu-{k}',
                             BleuMetric.compute(prediction, labels, k))
            # if any of the rouges are in the list
            if self._metrics_list & ROUGE_METRICS:
                r1, r2, rL = RougeMetric.compute_many(prediction, labels)
                if 'rouge-1' in self._metrics_list and r1:
                    self.add('rouge_1', r1)
                if 'rouge-2' in self._metrics_list and r2:
                    self.add('rouge_2', r2)
                if 'rouge-L' in self._metrics_list and rL:
                    self.add('rouge_L', rL)
            # compute distinct-k
            for k in [1, 2]:
                if f'interdistinct-{k}' in self._metrics_list:
                    self.add(f'interdistinct-{k}',
                             InterDistinctMetric.compute(prediction, k))
                if f'intradistinct-{k}' in self._metrics_list:
                    self.add(f'intradistinct-{k}',
                             IntraDistinctMetric.compute(prediction, k))

        # Ranking metrics.
        self._update_ranking_metrics(observation, labels)

        self._consume_user_metrics(observation)
Example #16
0
    def custom_evaluation(
        self,
        teacher_action: Message,
        labels: Optional[Tuple[str]],
        model_response: Message,
    ) -> None:
        """
        Compute RPA for a model response.

        :param teacher_action:
            The message last sent from this teacher.
        :param labels:
            The previous correct labels
        :param model_response:
            The raw response from the model
        """
        if not model_response or not model_response.get('text'):
            return
        self.context.append(teacher_action['text'])
        context = self.delimiter.join(self.context)
        characters = extract_characters(context)
        correct_character = characters['_self_name']
        model_text = model_response['text']
        classifier_act = self.classifier.classify(context, model_text)
        predicted_character = classifier_act['text']
        correct_prediction = int(predicted_character == correct_character)
        self.metrics.add('character_accuracy',
                         AverageMetric(correct_prediction))
        scores = F.softmax(classifier_act['sorted_scores'].float(), dim=0)
        if teacher_action['episode_done']:
            self.context = []
        else:
            assert labels
            self.context.append(labels[0])

        return predicted_character == correct_character
Example #17
0
    def custom_evaluation(self, teacher_action: Message, labels,
                          model_response: Message):
        resp = model_response.get("text")
        if not resp:
            return
        if teacher_action["type"] == tod.STANDARD_CALL:
            if resp.startswith(tod.STANDARD_CALL):
                resp = resp[len(tod.STANDARD_CALL):]
            predicted = SerializationHelpers.str_to_api_dict(resp)
            domains = ([teacher_action["domain"]]
                       if self.opt["domain_jga_record"] else [])

            metrics = SlotMetrics(
                teacher_slots=teacher_action["slots"],
                predicted_slots=predicted,
                prefixes=domains,
            ).report()
            for key, value in metrics.items():
                self.metrics.add(key, value)

            if self.opt["api_jga_record"] and len(teacher_action["slots"]) > 0:
                teacher = teacher_action["slots"]
                slots = list(teacher.keys())
                slots.remove(tod.STANDARD_API_NAME_SLOT)
                api_here = ("api-" + teacher[tod.STANDARD_API_NAME_SLOT] +
                            "--" + "-".join(slots))
                self.metrics.add(f"{api_here}/jga",
                                 AverageMetric(teacher == predicted))

        elif teacher_action["type"] == tod.STANDARD_SYSTEM_UTTERANCE:
            domains = ([teacher_action["domain"]]
                       if self.opt["domain_nlg_record"] else [])
            metrics = NlgMetrics(guess=resp, labels=labels,
                                 prefixes=domains).report()
            for key, value in metrics.items():
                self.metrics.add(key, value)
    def update_history(self, obs: Message, temp_history: Optional[str] = None):
        """
        Update the history with the given observation.

        :param obs:
            Observation used to update the history.
        :param temp_history:
            Optional temporary string. If it is not None, this string will be
            appended to the end of the history. It will not be in the history
            on the next dialogue turn. Set to None to stop adding to the
            history.
        """
        if "text" in obs and obs["text"] is not None:
            log_prob = obs.get('log_prob', None)
            reward = obs.get('reward_items', None)
            text = obs['text']
            cache = self.get_cache(obs)
            if not self.context and obs.get('context', None):
                self.context = obs['context']
            if not self.background and obs.get('background', None):
                self.background = obs['background']
            if not self.title and obs.get('title', None):
                self.title = obs['title']
            if not self.section_title and obs.get('section_title', None):
                self.section_title = obs['section_title']
            self._update_raw_strings(text)
            if self.add_person_tokens:
                text = self._add_person_tokens(obs[self.field], self.p1_token,
                                               self.add_p1_after_newln)
            # update history string
            self._update_strings(text)
            # update history dialogues
            self._update_dialogues(text, log_prob=log_prob, cache=cache)
            # update history vecs
            self._update_vecs(text)
        self.temp_history = temp_history
Example #19
0
    def _setup_data(self, opt: Opt) -> List[List[Message]]:
        """
        Flatten and classify the normal task data.

        Save/load where applicable.

        :param opt:
            options dict.
        """
        # create save directory, if it does not already exist
        self.original_task_name = ':'.join(opt['task'].split(':')[2:])
        self.save_dir = self._get_save_path(
            opt['datapath'], str(datetime.datetime.today())
        )
        os.makedirs(self.save_dir, exist_ok=True)

        fname = f"{opt['datatype'].split(':')[0]}.json"
        self.save_path = os.path.join(self.save_dir, fname)

        data = self.load_data(opt, fname)
        if data is not None:
            # successfully load data
            return data

        # build the original teacher
        original_task_module = get_original_task_module(opt)
        teacher_opt = deepcopy(opt)
        teacher_opt['task'] = self.original_task_name
        teacher = original_task_module(teacher_opt)

        total_exs = teacher.num_examples()
        if self.opt['max_examples'] > 0:
            total_exs = min(self.opt['max_examples'], total_exs)

        progress_bar = tqdm(
            total=total_exs, unit='ex', unit_scale=True, desc='Building flattened data'
        )

        all_episodes = []
        num_exs = 0
        while num_exs < total_exs:
            current_episode = []
            episode_done = False

            while not episode_done:
                action = Message(teacher.act())
                current_episode.append(action)
                episode_done = action.get('episode_done', False)
                num_exs += 1

            # flatten the episode into 1-example episodes with context
            flattened_ep = flatten_and_classify(
                current_episode,
                opt['flatten_max_context_length'],
                include_labels=opt['flatten_include_labels'],
                delimiter=opt['flatten_delimiter'],
                word_lists=self.word_lists,
            )
            all_episodes += flattened_ep

            progress_bar.update(len(flattened_ep))

        # save data for future use
        self.save_data(all_episodes)

        return all_episodes
Example #20
0
 def _to_tuple(self, msg: Message) -> Tuple:
     # turned into an indexable object
     keys = ['text', 'labels', 'eval_labels']
     return tuple(self._val(msg.get(k)) for k in keys)
Example #21
0
    def custom_evaluation(
        self,
        teacher_action: Message,
        labels: Optional[Tuple[str]],
        model_response: Message,
    ) -> None:
        """
        Various F1 metrics for the generated model response.
        """
        if not model_response.get('text'):
            # No response generated by model.
            return

        resp = model_response['text']
        # F1 metric over the *selected* knowledge.
        self.metrics.add(
            'knowledge_f1_docs',
            F1Metric.compute(resp, teacher_action[CONST.SELECTED_DOCS]),
        )
        self.metrics.add(
            'knowledge_f1_sentences',
            F1Metric.compute(resp, teacher_action[CONST.SELECTED_SENTENCES]),
        )

        # F1 Metrics over the *retrieved* docs.
        self.metrics.add(
            'f1_retrieved_docs',
            F1Metric.compute(resp,
                             ' '.join(teacher_action[CONST.RETRIEVED_DOCS])),
        )
        self.metrics.add(
            'max_f1_retrieved_docs',
            F1Metric.compute(resp, teacher_action[CONST.RETRIEVED_DOCS]),
        )

        selected_doc_senetences = teacher_action[CONST.SELECTED_DOCS][0].split(
            '\n')
        all_doc_senetences = []
        for doc in teacher_action[CONST.RETRIEVED_DOCS]:
            all_doc_senetences.extend(doc.split('\n'))

        self.metrics.add('exact_copied_sentences',
                         ExactMatchMetric.compute(resp, all_doc_senetences))
        self.metrics.add(
            'max_substring_copied_sentences',
            CopiedSubstringMetric.compute(resp, all_doc_senetences),
        )
        self.metrics.add(
            'max_substring_copied_docs',
            CopiedSubstringMetric.compute(
                resp, teacher_action[CONST.RETRIEVED_DOCS]),
        )
        self.metrics.add(
            'substring_copied_docs',
            CopiedSubstringMetric.compute(
                resp, [''.join(teacher_action[CONST.RETRIEVED_DOCS])]),
        )
        self.metrics.add(
            'max_f1_selected_docs_senetences',
            F1Metric.compute(resp, selected_doc_senetences),
        )
        self.metrics.add('max_f1_docs_senetences',
                         F1Metric.compute(resp, all_doc_senetences))

        # N-gram matching metrics
        for k in range(1, 5):  # 1..4
            self.metrics.add(
                f'max_bleu_selected_docs_senetences-{k}',
                BleuMetric.compute(resp, selected_doc_senetences, k),
            )

        r1, r2, rL = RougeMetric.compute_many(resp, selected_doc_senetences)
        self.metrics.add('max_rouge_selected_docs_senetences_1', r1)
        self.metrics.add('max_rouge_selected_docs_senetences_2', r2)
        self.metrics.add('max_rouge_selected_docs_senetences_L', rL)
Example #22
0
    def _setup_data(self, opt):
        counts = {
            'partner': {
                gend_utils.UNKNOWN: 0,
                gend_utils.FEM: 0,
                gend_utils.MASC: 0
            },
            'self': {
                gend_utils.UNKNOWN: 0,
                gend_utils.FEM: 0,
                gend_utils.MASC: 0
            },
        }

        dt = opt['datatype'].split(':')[0]
        if dt == 'test':
            warn_once('No test set; switching to valid')
            dt = 'valid'

        # build data
        print('[ Building data ... ]')
        new_eps = []
        orig_teacher = OrigConvai2Teacher(opt)
        total_exs = orig_teacher.num_examples()
        num_exs = 0
        while num_exs < total_exs:
            current_episode = []
            episode_done = False

            while not episode_done:
                # TODO: eventually all teachers should return Messages, so
                # we should assert this
                action = Message(orig_teacher.act())
                current_episode.append(action)
                episode_done = action.get('episode_done', False)
                num_exs += 1

            # now we have the entire episode,... do something
            first_ex = current_episode[0]
            first_ex_text = []
            partner_persona = []
            your_persona = []
            for line in first_ex['text'].split('\n'):
                # NOTE: we flip "your" and "partner" here since we are taking the 'text'
                # field instead of the 'label'
                if 'partner\'s persona: ' in line:
                    your_persona.append(line.split('partner\'s persona: ')[1])
                elif 'your persona: ' in line:
                    partner_persona.append(line.split('your persona: ')[1])
                else:
                    first_ex_text.append(line)

            your, your_prob, partner, partner_prob = self.get_genders(
                your_persona, partner_persona)

            for i, ex in enumerate(current_episode):
                counts['self'][your] += 1
                counts['partner'][partner] += 1
                if i == 0:
                    text = '\n'.join(first_ex_text)
                else:
                    text = ex['text']
                new_ex = {
                    'text': text,
                    'episode_done': True,
                    'your_persona': '\n'.join(your_persona),
                    'partner_persona': '\n'.join(partner_persona),
                    'id': 'ConvAI2 Gender',
                }
                if not self.use_probably:
                    new_ex['partner_prob'] = partner_prob
                    new_ex['your_prob'] = your_prob

                if your is not None and self.labels_to_use != 'partner':
                    # Get the your task
                    labels = [f'SELF:{your}']
                    your_ex = deepcopy(new_ex)
                    your_ex['labels'] = labels
                    your_ex['class_type'] = 'self'
                    new_eps.append(your_ex)

                if partner is not None and self.labels_to_use != 'self':
                    # Get the partner task
                    labels = [f'PARTNER:{partner}']
                    partner_ex = deepcopy(new_ex)
                    partner_ex['labels'] = labels
                    partner_ex['class_type'] = 'partner'
                    new_eps.append(partner_ex)

        if self.labels_to_use == 'all' and self.add_unknown_classes:
            # load about data
            all_about_data = gend_utils.get_inferred_about_data(
                self.opt['task'], self.opt)
            sample_rate = self.opt['unknown_temp']
            if sample_rate < 1.0:
                to_samp = int(sample_rate * len(all_about_data))
                sampled = random.sample(all_about_data, to_samp)
                new_eps += sampled
            else:
                new_eps += all_about_data

        if self.is_train:
            random.shuffle(new_eps)

        self.data = new_eps
        print(f'Missing cnt: {self.missing_cnt} / {len(self.data) * 2}')
        for x in ['self', 'partner']:
            print(f'Totals for {x}:')
            subtot = sum(counts[x].values())
            for k, v in counts[x].items():
                print(f'\t{k}: {v} ({v / subtot})')
Example #23
0
def build_data(opt):
    if not opt.get('model', False):
        opt['model'] = 'repeat_label'
    preprocess = opt.get('pytorch_preprocess', True)
    opt['dict_file'] = get_pyt_dict_file(opt)
    dictionary = None
    if 'dict_maxexs' in opt:
        # Note: only build dictionary if dict loop args specified
        dictionary = build_dict(opt, skip_if_built=True)
    agent = create_agent(opt)
    # If build teacher not specified, we are simply looking for the file
    if not opt.get('pytorch_teacher_task', None):
        df = opt.get('pytorch_datapath')
        # check if the user set a datafile
        if not df:
            raise Exception(
                'Tried to find data but `--pytorch-datapath` is not set')
        # check if the user provided the already built file
        if 'pytorch' not in df:
            df += '.pytorch' + (agent.getID() if opt.get(
                'pytorch_preprocess', True) else '')
        if not os.path.isfile(df):
            raise Exception('Tried to find data but it is not built, please'
                            'specify `--pytorch-teacher-task`')
        else:
            return df

    ordered_opt = copy.deepcopy(opt)
    # we use streaming to build the data
    dt = opt['datatype'].split(':')[0]
    ordered_opt['datatype'] = dt + ':ordered:stream'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    ordered_opt['task'] = ordered_opt['pytorch_teacher_task']
    ordered_opt.pop('pytorch_teacher_dataset')
    ordered_opt['no_cuda'] = True
    world_data = create_task(ordered_opt, agent)
    teacher = world_data.get_task_agent()
    agent = world_data.agents[1]
    datapath = os.path.join(
        opt.get('datapath', '.'),
        '{}_pyt_data'.format(ordered_opt['task'].replace(':', '_')),
        dt,
    )
    if preprocess:
        datapath += '_{}_preprocess'.format(agent.getID().replace(':', '_'))
    if os.path.isdir(datapath) and 'data_length' in os.listdir(datapath):
        # Data already built
        print("[ pytorch data already built, at {}. ]".format(datapath))
        return datapath
    print('----------\n[ setting up pytorch data, saving to {}/ ]\n----------'.
          format(datapath))
    os.makedirs(datapath, exist_ok=True)
    num_eps = 0
    num_exs = 0
    current = []
    episode_done = False
    include_labels = opt.get('pytorch_include_labels', True)
    context_length = opt.get('pytorch_context_length', -1)
    context = deque(maxlen=context_length if context_length > 0 else None)
    total_exs = world_data.num_examples()
    pbar = tqdm.tqdm(total=total_exs,
                     unit='ex',
                     unit_scale=True,
                     desc='Building pytorch data')
    idx_to_char = []
    cumulative_char_len = 0
    # pass examples to dictionary
    with open(os.path.join(datapath, 'data'), 'w') as pytorch_data:
        while num_exs < total_exs:
            while not episode_done:
                # TODO: eventually all teachers should return Messages, so
                # we should assert this
                action = Message(teacher.act())
                current.append(action)
                episode_done = action.get('episode_done', False)

            # build separate episodes
            for ex in current:
                context.append(ex.get('text', ''))
                if len(context) > 1:
                    ex.force_set('text', '\n'.join(context))
                ex.force_set('episode_done', True)
                # print("ex:{}",format(ex['labels']))
                labels = ex.get('labels', ex.get('eval_labels', None))
                # print("labels:{}".format(labels))
                if labels is not None and include_labels:
                    context.append(random.choice(labels))
                # generate observation from new example
                if preprocess:
                    ex = agent.observe(ex)
                    ex.pop('label_candidates', '')
                    ex['preprocessed'] = True
                    if hasattr(agent, 'self_observe') and 'labels' in ex:
                        # Lie to the agent and tell it that it spoke the gold label
                        agent.self_observe(
                            Message({'text': random.choice(ex['labels'])}))
                num_eps += 1
                num_exs += 1
                pbar.update(1)
                ex_len = pytorch_data.write(
                    json.dumps(make_serializable(ex)) + "\n")
                idx_to_char.append(cumulative_char_len)
                cumulative_char_len += ex_len
            # reset
            episode_done = False
            current.clear()
            context.clear()
            agent.reset()
    pbar.close()
    with open(os.path.join(datapath, 'char_index'), 'w') as char_index:
        json.dump(idx_to_char, char_index)
    with open(os.path.join(datapath, 'data_length'), 'w') as pytorch_data_len:
        pytorch_data_len.write(
            json.dumps({
                'num_eps': num_eps,
                'num_exs': num_exs
            }))
    if dictionary:
        dictionary.save(get_pyt_dict_file(opt), sort=True)

    print('[ pytorch data built. ]')
    return datapath
Example #24
0
 def get_label(self, message: Message) -> str:
     checked_sentence = message.get('checked_sentence')
     return DO_NOT_SEARCH if checked_sentence == 'no_passages_used' else DO_SEARCH
Example #25
0
 def get_retrieved_knowledge(self, message: Message) -> List[Document]:
     if message.get('skip_retrieval'):
         return []
     return super().get_retrieved_knowledge(message)
Example #26
0
    def parley(self):
        self.turn_idx += 1
        print(self.world_tag + ' is at turn {}...'.format(self.turn_idx))
        """If at first turn, we need to give each agent their persona"""
        if self.turn_idx == 1:
            for idx, agent in enumerate(self.agents):
                persona_text = ''
                for s in self.personas[idx]:
                    persona_text += ('<b><span style="color:blue">'
                                     '{}\n</span></b>'.format(s.strip()))
                control_msg = self.get_control_msg()
                control_msg['persona_text'] = persona_text
                control_msg['text'] = self.get_instruction(tag='start',
                                                           agent_id=agent.id)
                # TODO: check that get instruction actually exists?
                agent.observe(validate(control_msg))
                if idx == 0:
                    time.sleep(3)
        """If we get to the min turns, inform turker that they can end if they
        want.
        """
        if self.turn_idx == self.n_turn + 1:
            for idx, agent in enumerate(self.agents):
                control_msg = self.get_control_msg()
                control_msg['text'] = self.get_instruction(
                    idx, tag='exceed_min_turns')
                control_msg['exceed_min_turns'] = True
                agent.observe(validate(control_msg))
        """Otherwise, we proceed accordingly."""
        # Other agent first
        if self.other_first and self.turn_idx == 1:
            if self.model_agent is not None:
                # Model must observe its persona
                persona_act = {
                    'text': '\n'.join([self.model_persona_text,
                                       '__SILENCE__']),
                    'episode_done': False,
                }
                self.model_agent.observe(persona_act)
                self.bot_seen_persona = True
                model_act = copy.deepcopy(self.model_agent.act())
                model_act.force_set('text', normalize_reply(model_act['text']))
                model_act.force_set('id', 'PERSON_2')
                self.dialog.append((1, model_act.get('text')))
                _random_delay()
                self.eval_agent.observe(_strip_tensors(model_act))
            else:
                act = self.get_human_agent_act(self.other_agent)
                timeout = self.check_timeout(act)
                if timeout:
                    # eval agent early disconnect
                    control_msg = self.get_control_msg()
                    control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                    self.eval_agent.observe(validate(control_msg))
                    return
                else:
                    self.dialog.append((1, act.get('text')))
                    act = copy.deepcopy(act)
                    act.force_set('text', normalize_reply(act['text']))
                    self.eval_agent.observe(act)

        # Eval agent turn
        act = Message(self.get_human_agent_act(self.eval_agent))
        timeout = self.check_timeout(act)
        if timeout:
            if self.model_agent is None:
                control_msg = self.get_control_msg()
                control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                self.other_agent.observe(validate(control_msg))
            return

        if act['episode_done']:
            if self.turn_idx >= self.n_turn:
                if not self.other_first:
                    self.dialog_list = [
                        '\n'.join([self.dialog[i][1], self.dialog[i + 1][1]])
                        for i in range(0, len(self.dialog), 2)
                    ]
                else:
                    self.dialog_list = [' \n' + self.dialog[0][1]] + [
                        '\n'.join([self.dialog[i][1], self.dialog[i + 1][1]])
                        for i in range(1,
                                       len(self.dialog) - 1, 2)
                    ]
                self.parallel_eval_mode()

                self.chat_done = True
                for ag in self.agents:
                    control_msg = self.get_control_msg()
                    control_msg['text'] = CHAT_ENDED_MSG
                    ag.observe(validate(control_msg))
            return

        self.dialog.append((0, act['text']))

        if not self.bot_seen_persona and self.model_agent is not None:
            # Add persona for model to observe
            act.force_set('text',
                          '\n'.join([self.model_persona_text, act['text']]))
            self.bot_seen_persona = True
        if self.model_agent is not None:
            self.model_agent.observe(act)
        else:
            act = copy.deepcopy(act)
            act.force_set('text', normalize_reply(act['text']))
            self.other_agent.observe(act)

        # Model_agent turn
        if not self.other_first or self.turn_idx < self.n_turn:
            if self.model_agent is not None:
                _random_delay()
                act = _strip_tensors(copy.deepcopy(self.model_agent.act()))
                act.force_set('text', normalize_reply(act['text']))
                act.force_set('id', 'PERSON_2')
                # NOTE: your model may or may not need to observe itself here
                # If it does, call model_observes_itself or some other specialized
                # function
            else:
                act = self.get_human_agent_act(self.other_agent)
                timeout = self.check_timeout(act)
                if timeout:
                    # eval agent early disconnect
                    control_msg = self.get_control_msg()
                    control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                    self.eval_agent.observe(validate(control_msg))
                    return

            self.dialog.append((1, act.get('text')))
            act = copy.deepcopy(act)
            act.force_set('text', normalize_reply(act['text']))
            self.eval_agent.observe(act)
Example #27
0
>>>>>>> 4f6b99642d60aff1a41b9eae8bd2ccd9e40ebba4
>>>>>>> ef574cebef2a8d5aa38b73176b1e71a919d6670f
                # NOTE: your model may or may not need to observe itself here
                # If it does, call model_observes_itself or some other specialized
                # function
            else:
                act = self.get_human_agent_act(self.other_agent)
                timeout = self.check_timeout(act)
                if timeout:
                    # eval agent early disconnect
                    control_msg = self.get_control_msg()
                    control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                    self.eval_agent.observe(validate(control_msg))
                    return

            self.dialog.append((1, act.get('text')))
            act = copy.deepcopy(act)
<<<<<<< HEAD
            act.force_set('text', self.format_model_reply(act['text']))
=======
<<<<<<< HEAD
<<<<<<< HEAD
            act.force_set('text', self.format_model_reply(act['text']))
=======
            act['text'] = self.format_model_reply(act['text'])
>>>>>>> 4f6b99642d60aff1a41b9eae8bd2ccd9e40ebba4
>>>>>>> origin/master
=======
            act['text'] = self.format_model_reply(act['text'])
>>>>>>> 4f6b99642d60aff1a41b9eae8bd2ccd9e40ebba4
>>>>>>> ef574cebef2a8d5aa38b73176b1e71a919d6670f
Example #28
0
    def observe(self, observation: Message) -> Dict[str, Message]:
        """
        Observe in 3 out of the 4 modules.

        :param observation:
            incoming message

        :return self.observation:
            returned observation is actually a dictionary mapping
            agent module name to the corresponding observation
        """
        if not isinstance(observation, Message):
            observation = Message(observation)
        for key in ['label_candidates', 'knowledge']:
            # Delete unnecessarily large keys
            observation.pop(key, '')
        observation['knowledge_response'] = observation.get('checked_sentence', '')

        raw_observation = copy.deepcopy(observation)
        # This part is *specifically* for document chunking.
        if self.krm_mutators:
            observation = observation.copy()
            for mutator in self.krm_mutators:
                assert isinstance(mutator, MessageMutator), "not message mutator"
                observation = next(mutator([observation]))

        knowledge_observation = self.knowledge_agent.observe(observation)
        knowledge_observation['prior_knowledge_responses'] = ' '.join(
            self.knowledge_responses
        )
        if observation.get('episode_done'):
            self.knowledge_responses = ['__SILENCE__']
        search_query_observation = None
        if self.search_query_agent:
            sqm_obs = copy.deepcopy(observation)
            if self.opt['search_query_control_token']:
                sqm_obs.force_set(
                    'temp_history', f" {self.opt['search_query_control_token']}"
                )
            sqm_obs.force_set('skip_retrieval', True)
            search_query_observation = self.search_query_agent.observe(sqm_obs)

        search_decision_observation = None
        if (
            self.search_decision_agent
            and self.search_decision is SearchDecision.COMPUTE
        ):
            assert (
                self.search_decision_agent.history.size == 1
            ), "wrong history size! set --sdm-history-size 1"
            sdm_obs = copy.deepcopy(observation)
            if self.opt['search_decision_control_token']:
                sdm_obs.force_set(
                    'temp_history', f" {self.opt['search_decision_control_token']}"
                )
            sdm_obs.force_set('skip_retrieval', True)
            search_decision_observation = self.search_decision_agent.observe(sdm_obs)

        observations = {
            'raw': raw_observation,
            'knowledge_agent': knowledge_observation,
            'search_query_agent': search_query_observation,
            'search_decision_agent': search_decision_observation,
        }
        self.observations = observations
        return observations
Example #29
0
    def parley(self):
        """
        The main function that controls the logic of the task. Uses self.task_turn_idx
        to control the sequence of the conversation.

        Specifically, when self.task_turn_idx is even, we know that the bots just gave
        their potential responses, and that it is the human's turn to choose one of the
        responses and give a justification value.

        When self.task_turn_idx is odd, we know that the human just chose one of the
        bots' responses, and now needs to respond to that response.

        self.task_turn_idx is initially 0, and during _run_initial_turn() the UI is
        redrawn to have the human select between the bots' responses. Then,
        self.task_turn_idx is incremented to 1.

        During self.agent.observe(), the UI is redrawn for the following human input,
        and during self.agent.act(), the code awaits the human input.
        """

        logging.info(
            f'{self.__class__.__name__}:{self.tag}: is at task_turn_idx '
            f'{self.task_turn_idx}, with {self.num_turns} pairs of turns needed...'
        )

        if self.task_turn_idx == 0:
            self._run_initial_turn()
            self.task_turn_idx += 1
            return

        logging.info(
            f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: '
            f'{self.task_turn_idx}')

        # At this point, we know that the human now needs to respond to the bot's
        # response that the human just chose

        # We retrieve information regarding the human's choice and justification using
        # self.agent.act()

        human_choose_bot_response_act = self.agent.act(
            timeout=self.max_resp_time)
        human_choose_bot_response_act = Message(
            Compatibility.maybe_fix_act(
                human_choose_bot_response_act)).json_safe_payload()

        logging.info(
            f'Got act for human, act was: {human_choose_bot_response_act} and '
            f'self.task_turn_idx: {self.task_turn_idx}.')

        accepted_bot_response = human_choose_bot_response_act['task_data'][
            'accepted_bot_response']
        accepted_bot_id = human_choose_bot_response_act['task_data'][
            'accepted_bot_id']
        accepted_bot_justification_value = human_choose_bot_response_act[
            'task_data']['justification_value']

        not_accepted_bot_response = human_choose_bot_response_act['task_data'][
            'not_accepted_bot_response']
        not_accepted_bot_id = human_choose_bot_response_act['task_data'][
            'not_accepted_bot_id']

        # We have both bots observe the accepted bot's response so that the conversation
        # history stays the same

        self.bots[0].observe(accepted_bot_response)
        self.bots[1].observe(accepted_bot_response)

        task_data = {}

        accepted_bot_utterance_data = {
            'text': accepted_bot_response['text'].split('<br>')[0],
            'id': accepted_bot_id,
        }
        not_accepted_bot_utterance_data = {
            'text': not_accepted_bot_response['text'].split('<br>')[0],
            'id': not_accepted_bot_id,
        }
        bot_utterance_data = {
            'agent_idx': 1,
            'accepted_bot_data': accepted_bot_utterance_data,
            'not_accepted_bot_data': not_accepted_bot_utterance_data,
            'human_choice': accepted_bot_id,
            'human_justification': accepted_bot_justification_value,
        }
        self.dialog.append(bot_utterance_data)

        self._postprocess_acts(acts=None, agent_idx=0)

        # All logic and processing for this step has now been done, so we do
        # self.agent.observe() to send the accepted response back to the frontend to
        # display and update task turn index, as well as await for the next action,
        # which is the human typing their response

        task_data['task_turn_idx'] = self.task_turn_idx
        # The UI will ask the human to respond to the chosen bot response
        self.agent.observe({
            'text': accepted_bot_response['text'],
            'task_data': task_data
        })

        # Make self.task_turn_idx even now
        self.task_turn_idx += 1

        # Check for whether 6 pairs of turns has been done, since the last message of a
        # conversation should always be the bot's response

        if (human_choose_bot_response_act is not None
                and human_choose_bot_response_act.get(
                    'task_data', {}).get('finished') is not None):
            self.chat_done = True
            # agent ends chat after exceeding minimum number of turns

            # Bot has just responded. Any problem data received now will be
            # regarding this bot's utterance

            # Get the final chat data
            self.final_chat_data = self.get_final_chat_data()

            # Soft-block the worker if there were acceptability violations
            acceptability_violations = self.final_chat_data[
                'acceptability_violations'][0]
            if acceptability_violations is not None and acceptability_violations != '':
                logging.info(f'**NOTE** Acceptability violations detected: '
                             f'{acceptability_violations}')
                # Grant the failed qualification
                self.agent.mephisto_agent.get_worker().grant_qualification(
                    self.block_qualification, 1)

            return

        logging.info(
            f'[human agent] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: '
            f'{self.dialog}')

        logging.info(
            f'Got act for human, act was: {human_choose_bot_response_act} and '
            f'self.task_turn_idx: {self.task_turn_idx}.')

        # At this point, we know that the human now needs to respond to the bot's
        # response that the human just chose

        # We retrieve information regarding the human's response using self.agent.act()

        human_response_act = self.agent.act(timeout=self.max_resp_time)

        # Again, we have both bots observe the human response so that the conversation
        # history stays the same
        self.bots[0].observe(validate(human_response_act))
        self.bots[1].observe(validate(human_response_act))

        # Check that the models' conversation histories are the same
        bot_1_history = self.bots[0].model_agent.history.history_strings
        bot_2_history = self.bots[1].model_agent.history.history_strings

        assert (
            bot_1_history == bot_2_history
        ), f"The two bots' conversation histories are different.\nBot 1 history: {bot_1_history}\nBot 2 history: {bot_2_history}"

        # After the bots have observed the human response, it's time for them to produce
        # their response to the human using self.bots.act()

        bot_1_response = self.bots[0].act()
        bot_1_response = Compatibility.maybe_fix_act(bot_1_response)

        bot_2_response = self.bots[1].act()
        bot_2_response = Compatibility.maybe_fix_act(bot_2_response)

        # We display the result to the frontend randomly so there is no selection bias.
        # Also, we attach our result to task_data to send arbitrary data to the frontend

        if random.random() > 0.5:
            task_data = {
                'top_bot_data': {
                    'top_bot_id': self.bots[0].worker_id,
                    'top_bot_response': bot_1_response,
                },
                'bottom_bot_data': {
                    'bottom_bot_id': self.bots[1].worker_id,
                    'bottom_bot_response': bot_2_response,
                },
                'task_turn_idx': self.task_turn_idx,
            }
        else:
            task_data = {
                'top_bot_data': {
                    'top_bot_id': self.bots[1].worker_id,
                    'top_bot_response': bot_2_response,
                },
                'bottom_bot_data': {
                    'bottom_bot_id': self.bots[0].worker_id,
                    'bottom_bot_response': bot_1_response,
                },
                'task_turn_idx': self.task_turn_idx,
            }

        human_utterance_data = {
            'agent_idx':
            0,
            # Get rid of annotations HTML if it's the bot response
            'text':
            human_response_act['text'].split('<br>')[0],
            'id':
            human_response_act['id'] if 'id' in human_response_act else
            'NULL_ID',  # Person1 or Polyencoder
        }

        self.dialog.append(human_utterance_data)

        # Human has just responded. Any problem data received now will be regarding the
        # bot's prior utterance
        p = human_response_act['task_data'].get(
            'problem_data_for_prior_message')
        if p is not None:
            turn_idx = -2
            # Attach the problem data to the second-to-last utterance, since the last
            # utterance is what the human just said
            self.__add_problem_data_to_utterance(p, turn_idx=turn_idx)

        self._postprocess_acts(acts=None, agent_idx=0)

        task_data['task_turn_idx'] = self.task_turn_idx

        # All logic and processing for this step has now been done, so we do
        # self.agent.observe() to send the two bots' responses back to the frontend to
        # display and update task turn index, as well as await for the next action,
        # which is the human choosing from the two responses and providing a
        # justification value

        # The UI will ask the human to choose between two bot responses and give a
        # justification
        logging.info(f'*** self.task_turn_idx: {self.task_turn_idx} ***')
        self.agent.observe({'text': '', 'task_data': task_data})

        # Make self.task_turn_idx odd now
        self.task_turn_idx += 1

        logging.info(
            f'[bot agent] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: '
            f'{self.dialog}')
Example #30
0
 def get_label(self, message: Message) -> str:
     checked_sentences = ' '.join(message.get(CONST.SELECTED_SENTENCES))
     return (DO_NOT_SEARCH if checked_sentences
             == CONST.NO_SELECTED_SENTENCES_TOKEN else DO_SEARCH)