Ejemplo n.º 1
0
def span_tokens_to_text(source_text, tokens, span_start, span_end):
    text_with_sentiment_tokens = tokens
    predicted_start = span_start
    predicted_end = span_end

    while (predicted_start >= 0
           and text_with_sentiment_tokens[predicted_start].idx is None):
        predicted_start -= 1
    if predicted_start < 0:
        character_start = 0
    else:
        character_start = text_with_sentiment_tokens[predicted_start].idx

    while (predicted_end < len(text_with_sentiment_tokens)
           and text_with_sentiment_tokens[predicted_end].idx is None):
        predicted_end -= 1

    if predicted_end >= len(text_with_sentiment_tokens):
        print(text_with_sentiment_tokens)
        print(len(text_with_sentiment_tokens))
        print(span_end)
        print(predicted_end)
        character_end = len(source_text)
    else:
        end_token = text_with_sentiment_tokens[predicted_end]
        if end_token.idx == 0:
            character_end = (end_token.idx +
                             len(sanitize_wordpiece(end_token.text)) + 1)
        else:
            character_end = end_token.idx + len(
                sanitize_wordpiece(end_token.text))

    best_span_string = source_text[character_start:character_end].strip()
    return best_span_string
Ejemplo n.º 2
0
    def span_tokens_to_text(source_text, tokens, span_start, span_end):
        text_with_sentiment_tokens = tokens
        predicted_start = span_start
        predicted_end = span_end

        while (predicted_start >= 0
               and text_with_sentiment_tokens[predicted_start].idx is None):
            predicted_start -= 1
        if predicted_start < 0:
            logger.warning(
                f"Could not map the token '{text_with_sentiment_tokens[span_start].text}' at index "
                f"'{span_start}' to an offset in the original text.")
            character_start = 0
        else:
            character_start = text_with_sentiment_tokens[predicted_start].idx

        while (predicted_end < len(text_with_sentiment_tokens)
               and text_with_sentiment_tokens[predicted_end].idx is None):
            predicted_end -= 1

        if predicted_end >= len(text_with_sentiment_tokens):
            print(text_with_sentiment_tokens)
            print(len(text_with_sentiment_tokens))
            print(span_end)
            print(predicted_end)
            logger.warning(
                f"Could not map the token '{text_with_sentiment_tokens[span_end].text}' at index "
                f"'{span_end}' to an offset in the original text.")
            character_end = len(source_text)
        else:
            end_token = text_with_sentiment_tokens[predicted_end]
            if end_token.idx == 0:
                character_end = (end_token.idx +
                                 len(sanitize_wordpiece(end_token.text)) + 1)
            else:
                character_end = end_token.idx + len(
                    sanitize_wordpiece(end_token.text))

        best_span_string = source_text[character_start:character_end].strip()
        return best_span_string
Ejemplo n.º 3
0
    def _collect_best_span_strings(
        self,
        best_spans: torch.Tensor,
        context_span: torch.IntTensor,
        metadata: List[Dict[str, Any]],
        cls_index: Optional[torch.LongTensor],
    ) -> Tuple[List[str], torch.Tensor]:
        """
        Collect the string of the best predicted span from the context metadata and
        update `self._per_instance_metrics`, which in the case of SQuAD v1.1 / v2.0
        includes the EM and F1 score.

        This returns a `Tuple[List[str], torch.Tensor]`, where the `List[str]` is the
        predicted answer for each instance in the batch, and the tensor is just the input
        tensor `best_spans` after adjustments so that each answer span corresponds to the
        context tokens only, and not the question tokens. Spans that correspond to the
        `[CLS]` token, i.e. the question was predicted to be impossible, will be set
        to `(-1, -1)`.
        """
        _best_spans = best_spans.detach().cpu().numpy()

        best_span_strings = []
        for (metadata_entry, best_span, cspan,
             cls_ind) in zip(metadata, _best_spans, context_span, cls_index
                             or (0 for _ in range(len(metadata)))):
            context_tokens_for_question = metadata_entry["context_tokens"]

            if best_span[0] == cls_ind:
                # Predicting [CLS] is interpreted as predicting the question as unanswerable.
                best_span_string = ""
                # NOTE: even though we've "detached" 'best_spans' above, this still
                # modifies the original tensor in-place.
                best_span[0], best_span[1] = -1, -1
            else:
                best_span -= int(cspan[0])
                assert np.all(best_span >= 0)

                predicted_start, predicted_end = tuple(best_span)

                while (predicted_start >= 0
                       and context_tokens_for_question[predicted_start].idx is
                       None):
                    predicted_start -= 1
                if predicted_start < 0:
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index "
                        f"'{best_span[0]}' to an offset in the original text.")
                    character_start = 0
                else:
                    character_start = context_tokens_for_question[
                        predicted_start].idx

                while (predicted_end < len(context_tokens_for_question) and
                       context_tokens_for_question[predicted_end].idx is None):
                    predicted_end += 1
                if predicted_end >= len(context_tokens_for_question):
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index "
                        f"'{best_span[1]}' to an offset in the original text.")
                    character_end = len(metadata_entry["context"])
                else:
                    end_token = context_tokens_for_question[predicted_end]
                    character_end = end_token.idx + len(
                        sanitize_wordpiece(end_token.text))

                best_span_string = metadata_entry["context"][
                    character_start:character_end]

            best_span_strings.append(best_span_string)

            answers = metadata_entry.get("answers")
            if answers:
                self._per_instance_metrics(best_span_string, answers)

        return best_span_strings, best_spans
    def _tokenize(self, sentence_1: str, sentence_2: str = None):
        """
        This method works on both sentence and sentence pair.
        """

        encoded_tokens = self._tokenizer.encode_plus(
            text=sentence_1,
            text_pair=sentence_2,
            add_special_tokens=self._add_special_tokens,
            max_length=self._max_length,
            stride=self._stride,
            truncation_strategy=self._truncation_strategy,
            return_tensors=None,
        )
        # token_ids contains a final list with ids for both regular and special tokens
        token_ids, token_type_ids = encoded_tokens["input_ids"], encoded_tokens["token_type_ids"]

        tokens = []
        for token_id, token_type_id in zip(token_ids, token_type_ids):
            token_str = self._tokenizer.convert_ids_to_tokens(token_id, skip_special_tokens=False)
            tokens.append(Token(text=token_str, text_id=token_id, type_id=token_type_id))

        if self._calculate_character_offsets:
            # The huggingface tokenizers produce tokens that may or may not be slices from the original text.
            # Differences arise from lowercasing, Unicode normalization, and other kinds of normalization, as well
            # as special characters that are included to denote various situations, such as "##" in BERT for word
            # pieces from the middle of a word, or "Ġ" in RoBERTa for the beginning of words not at the start of a
            # sentence.
            # This code attempts to calculate character offsets while being tolerant to these differences. It
            # scans through the text and the tokens in parallel, trying to match up positions in both. If it
            # gets out of sync, it backs off to not adding any token indices, and attempts to catch back up
            # afterwards. This procedure is approximate. Don't rely on precise results, especially in non-English
            # languages that are far more affected by Unicode normalization.

            whole_text = sentence_1
            if sentence_2 is not None:
                whole_text += sentence_2  # Calculating character offsets with sentence pairs is sketchy at best.
            if self._tokenizer_lowercases:
                whole_text = whole_text.lower()

            min_allowed_skipped_whitespace = 3
            allowed_skipped_whitespace = min_allowed_skipped_whitespace

            text_index = 0
            token_index = 0
            while text_index < len(whole_text) and token_index < len(tokens):
                token_text = tokens[token_index].text
                if self._tokenizer_lowercases:
                    token_text = token_text.lower()
                token_text = sanitize_wordpiece(token_text)
                token_start_index = whole_text.find(token_text, text_index)

                # Did we not find it at all?
                if token_start_index < 0:
                    token_index += 1
                    # When we skip a token, we increase our tolerance, so we have a chance of catching back up.
                    allowed_skipped_whitespace += 1 + min_allowed_skipped_whitespace
                    continue

                # Did we jump too far?
                non_whitespace_chars_skipped = sum(
                    1 for c in whole_text[text_index:token_start_index] if not c.isspace()
                )
                if non_whitespace_chars_skipped > allowed_skipped_whitespace:
                    # Too many skipped characters. Something is wrong. Ignore this token.
                    token_index += 1
                    # When we skip a token, we increase our tolerance, so we have a chance of catching back up.
                    allowed_skipped_whitespace += 1 + min_allowed_skipped_whitespace
                    continue
                allowed_skipped_whitespace = min_allowed_skipped_whitespace

                tokens[token_index] = tokens[token_index]._replace(idx=token_start_index)
                text_index = token_start_index + len(token_text)
                token_index += 1

        return tokens
    def _estimate_character_indices(
            self, text: str,
            token_ids: List[int]) -> List[Optional[Tuple[int, int]]]:
        """
        The huggingface tokenizers produce tokens that may or may not be slices from the
        original text.  Differences arise from lowercasing, Unicode normalization, and other
        kinds of normalization, as well as special characters that are included to denote
        various situations, such as "##" in BERT for word pieces from the middle of a word, or
        "Ġ" in RoBERTa for the beginning of words not at the start of a sentence.

        This code attempts to calculate character offsets while being tolerant to these
        differences. It scans through the text and the tokens in parallel, trying to match up
        positions in both. If it gets out of sync, it backs off to not adding any token
        indices, and attempts to catch back up afterwards. This procedure is approximate.
        Don't rely on precise results, especially in non-English languages that are far more
        affected by Unicode normalization.
        """

        token_texts = [
            sanitize_wordpiece(t)
            for t in self.tokenizer.convert_ids_to_tokens(token_ids)
        ]
        token_offsets: List[Optional[Tuple[int,
                                           int]]] = [None] * len(token_ids)
        if self._tokenizer_lowercases:
            text = text.lower()
            token_texts = [t.lower() for t in token_texts]

        min_allowed_skipped_whitespace = 3
        allowed_skipped_whitespace = min_allowed_skipped_whitespace

        text_index = 0
        token_index = 0
        while text_index < len(text) and token_index < len(token_ids):
            token_text = token_texts[token_index]
            token_start_index = text.find(token_text, text_index)

            # Did we not find it at all?
            if token_start_index < 0:
                token_index += 1
                # When we skip a token, we increase our tolerance, so we have a chance of catching back up.
                allowed_skipped_whitespace += 1 + min_allowed_skipped_whitespace
                continue

            # Did we jump too far?
            non_whitespace_chars_skipped = sum(
                1 for c in text[text_index:token_start_index]
                if not c.isspace())
            if non_whitespace_chars_skipped > allowed_skipped_whitespace:
                # Too many skipped characters. Something is wrong. Ignore this token.
                token_index += 1
                # When we skip a token, we increase our tolerance, so we have a chance of catching back up.
                allowed_skipped_whitespace += 1 + min_allowed_skipped_whitespace
                continue
            allowed_skipped_whitespace = min_allowed_skipped_whitespace

            token_offsets[token_index] = (
                token_start_index,
                token_start_index + len(token_text),
            )
            text_index = token_start_index + len(token_text)
            token_index += 1
        return token_offsets
Ejemplo n.º 6
0
    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        answer_span: Optional[torch.IntTensor] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question_with_context : Dict[str, torch.LongTensor]
            From a ``TextField``. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and
            [SEP]) has type id 1.
        context_span : ``torch.IntTensor``
            From a ``SpanField``. This marks the span of word pieces in ``question`` from which answers can come.
        answer_span : ``torch.IntTensor``, optional
            From a ``SpanField``. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the ``metadata`` list should be the
            batch size, and each dictionary should have the keys ``id``, ``question``, ``context``,
            ``question_tokens``, ``context_tokens``, and ``answers``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        best_span_scores : torch.FloatTensor
            The score for each of the best spans.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._text_field_embedder(question_with_context)
        logits = self._linear_layer(embedded_question)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        span_start_logits = span_start_logits.squeeze(-1)
        span_end_logits = span_end_logits.squeeze(-1)

        possible_answer_mask = torch.zeros_like(
            get_token_ids_from_text_field_tensors(question_with_context))
        for i, (start, end) in enumerate(context_span):
            possible_answer_mask[i, start:end + 1] = 1

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       possible_answer_mask,
                                                       -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     possible_answer_mask,
                                                     -1e32)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)
        best_spans = get_best_span(span_start_logits, span_end_logits)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_spans,
            "best_span_scores": best_span_scores,
        }

        # Compute the loss for training.
        if answer_span is not None:
            span_start = answer_span[:, 0]
            span_end = answer_span[:, 1]
            span_mask = span_start != -1
            self._span_accuracy(best_spans, answer_span,
                                span_mask.unsqueeze(-1).expand_as(best_spans))

            start_loss = cross_entropy(span_start_logits,
                                       span_start,
                                       ignore_index=-1)
            if torch.any(start_loss > 1e9):
                logger.critical("Start loss too high (%r)", start_loss)
                logger.critical("span_start_logits: %r", span_start_logits)
                logger.critical("span_start: %r", span_start)
                assert False

            end_loss = cross_entropy(span_end_logits,
                                     span_end,
                                     ignore_index=-1)
            if torch.any(end_loss > 1e9):
                logger.critical("End loss too high (%r)", end_loss)
                logger.critical("span_end_logits: %r", span_end_logits)
                logger.critical("span_end: %r", span_end)
                assert False

            loss = (start_loss + end_loss) / 2

            self._span_start_accuracy(span_start_logits, span_start, span_mask)
            self._span_end_accuracy(span_end_logits, span_end, span_mask)

            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            best_spans = best_spans.detach().cpu().numpy()

            output_dict["best_span_str"] = []
            context_tokens = []
            for metadata_entry, best_span in zip(metadata, best_spans):
                context_tokens_for_question = metadata_entry["context_tokens"]
                context_tokens.append(context_tokens_for_question)

                best_span -= 1 + len(metadata_entry["question_tokens"]) + 2
                assert np.all(best_span >= 0)

                predicted_start, predicted_end = tuple(best_span)

                while (predicted_start >= 0
                       and context_tokens_for_question[predicted_start].idx is
                       None):
                    predicted_start -= 1
                if predicted_start < 0:
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index "
                        f"'{best_span[0]}' to an offset in the original text.")
                    character_start = 0
                else:
                    character_start = context_tokens_for_question[
                        predicted_start].idx

                while (predicted_end < len(context_tokens_for_question) and
                       context_tokens_for_question[predicted_end].idx is None):
                    predicted_end += 1
                if predicted_end >= len(context_tokens_for_question):
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index "
                        f"'{best_span[1]}' to an offset in the original text.")
                    character_end = len(metadata_entry["context"])
                else:
                    end_token = context_tokens_for_question[predicted_end]
                    character_end = end_token.idx + len(
                        sanitize_wordpiece(end_token.text))

                best_span_string = metadata_entry["context"][
                    character_start:character_end]
                output_dict["best_span_str"].append(best_span_string)

                answers = metadata_entry.get("answers")
                if len(answers) > 0:
                    self._per_instance_metrics(best_span_string, answers)
            output_dict["context_tokens"] = context_tokens
        return output_dict
Ejemplo n.º 7
0
    def text_to_instance(
        self,
        text: str,
        sentiment: str,
        selected_text: Optional[str] = None,
    ) -> Instance:
        fields = {}
        text_tokens = self._tokenizer.tokenize(text)
        sentiment_tokens = self._tokenizer.tokenize(sentiment)
        # add special tokens

        text_with_sentiment_tokens = self._tokenizer.add_special_tokens(
            text_tokens, sentiment_tokens)
        tokens_field = TextField(text_with_sentiment_tokens,
                                 {"tokens": self._tokenindexer})
        fields["tokens"] = tokens_field

        additional_metadata = {}
        if selected_text is not None:
            context = text
            answer = selected_text
            additional_metadata["selected_text"] = selected_text
            first_answer_offset = context.find(answer)

            def tokenize_slice(start: int, end: int) -> Iterable[Token]:
                text_to_tokenize = context[start:end]
                if start - 1 >= 0 and context[start - 1].isspace():
                    prefix = (
                        "a "
                    )  # must end in a space, and be short so we can be sure it becomes only one token
                    wordpieces = self._tokenizer.tokenize(prefix +
                                                          text_to_tokenize)
                    for wordpiece in wordpieces:
                        if wordpiece.idx is not None:
                            wordpiece.idx -= len(prefix)
                    return wordpieces[1:]
                else:
                    return self._tokenizer.tokenize(text_to_tokenize)

            tokenized_context = []
            token_start = 0
            for i, c in enumerate(context):
                if c.isspace():
                    for wordpiece in tokenize_slice(token_start, i):
                        if wordpiece.idx is not None:
                            wordpiece.idx += token_start
                        tokenized_context.append(wordpiece)
                    token_start = i + 1
            for wordpiece in tokenize_slice(token_start, len(context)):
                if wordpiece.idx is not None:
                    wordpiece.idx += token_start
                tokenized_context.append(wordpiece)

            if first_answer_offset is None:
                (token_answer_span_start, token_answer_span_end) = (-1, -1)
            else:
                (
                    token_answer_span_start,
                    token_answer_span_end,
                ), _ = char_span_to_token_span(
                    [(t.idx, t.idx + len(sanitize_wordpiece(t.text)))
                     if t.idx is not None else None
                     for t in tokenized_context],
                    (first_answer_offset, first_answer_offset + len(answer)),
                )
            tags = ["O"] * len(tokens_field)
            for i in range(token_answer_span_start, token_answer_span_end + 1):
                tags[i] = "I"
            fields["tags"] = SequenceLabelField(tags, tokens_field)

        # make the metadata
        metadata = {
            "text": text,
            "sentiment": sentiment,
            "words": text,
            "text_with_sentiment_tokens": text_with_sentiment_tokens,
        }
        if additional_metadata:
            metadata.update(additional_metadata)
        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)
Ejemplo n.º 8
0
    def make_instances(
        self,
        qid: str,
        question: str,
        answers: List[str],
        context: str,
        first_answer_offset: Optional[int],
        always_add_answer_span: bool = False,
    ) -> Iterable[Instance]:
        """
        Create training instances from a SQuAD example.
        """
        # tokenize context by spaces first, and then with the wordpiece tokenizer
        # For RoBERTa, this produces a bug where every token is marked as beginning-of-sentence. To fix it, we
        # detect whether a space comes before a word, and if so, add "a " in front of the word.
        def tokenize_slice(start: int, end: int) -> Iterable[Token]:
            text_to_tokenize = context[start:end]
            if start - 1 >= 0 and context[start - 1].isspace():
                prefix = "a "  # must end in a space, and be short so we can be sure it becomes only one token
                wordpieces = self._tokenizer.tokenize(prefix + text_to_tokenize)
                for wordpiece in wordpieces:
                    if wordpiece.idx is not None:
                        wordpiece.idx -= len(prefix)
                return wordpieces[1:]
            else:
                return self._tokenizer.tokenize(text_to_tokenize)

        tokenized_context = []
        token_start = 0
        for i, c in enumerate(context):
            if c.isspace():
                for wordpiece in tokenize_slice(token_start, i):
                    if wordpiece.idx is not None:
                        wordpiece.idx += token_start
                    tokenized_context.append(wordpiece)
                token_start = i + 1
        for wordpiece in tokenize_slice(token_start, len(context)):
            if wordpiece.idx is not None:
                wordpiece.idx += token_start
            tokenized_context.append(wordpiece)

        if first_answer_offset is None:
            (token_answer_span_start, token_answer_span_end) = (-1, -1)
        else:
            (token_answer_span_start, token_answer_span_end), _ = char_span_to_token_span(
                [
                    (t.idx, t.idx + len(sanitize_wordpiece(t.text))) if t.idx is not None else None
                    for t in tokenized_context
                ],
                (first_answer_offset, first_answer_offset + len(answers[0])),
            )

        # Tokenize the question.
        tokenized_question = self._tokenizer.tokenize(question)
        tokenized_question = tokenized_question[: self.max_query_length]

        # Stride over the context, making instances.
        space_for_context = (
            self.length_limit
            - len(tokenized_question)
            - len(self._tokenizer.sequence_pair_start_tokens)
            - len(self._tokenizer.sequence_pair_mid_tokens)
            - len(self._tokenizer.sequence_pair_end_tokens)
        )
        stride_start = 0
        while True:
            tokenized_context_window = tokenized_context[stride_start:]
            tokenized_context_window = tokenized_context_window[:space_for_context]

            window_token_answer_span = (
                token_answer_span_start - stride_start,
                token_answer_span_end - stride_start,
            )
            if any(i < 0 or i >= len(tokenized_context_window) for i in window_token_answer_span):
                # The answer is not contained in the window.
                window_token_answer_span = None

            if not self.skip_impossible_questions or window_token_answer_span is not None:
                additional_metadata = {"id": qid}
                instance = self.text_to_instance(
                    question,
                    tokenized_question,
                    context,
                    tokenized_context_window,
                    answers=answers,
                    token_answer_span=window_token_answer_span,
                    additional_metadata=additional_metadata,
                    always_add_answer_span=always_add_answer_span,
                )
                yield instance

            stride_start += space_for_context
            if stride_start >= len(tokenized_context):
                break
            stride_start -= self.stride
Ejemplo n.º 9
0
    def make_instances(
        self,
        qid: str,
        question: str,
        answers: List[str],
        context: str,
        first_answer_offset: Optional[int],
        always_add_answer_span: bool = False,
        is_training: bool = False,
        cached_tokenized_context: Optional[List[Token]] = None,
    ) -> Iterable[Instance]:
        """
        Create training instances from a SQuAD example.
        """
        if cached_tokenized_context is not None:
            # In training, we will use the same context in multiple instances, so we use
            # cached_tokenized_context to avoid duplicate tokenization
            tokenized_context = cached_tokenized_context
        else:
            # In prediction, no cached_tokenized_context is provided, so we tokenize context here
            tokenized_context = self._tokenize_context(context)

        if first_answer_offset is None:
            (token_answer_span_start, token_answer_span_end) = (-1, -1)
        else:
            (token_answer_span_start,
             token_answer_span_end), _ = char_span_to_token_span(
                 [(t.idx, t.idx + len(sanitize_wordpiece(t.text)))
                  if t.idx is not None else None for t in tokenized_context],
                 (first_answer_offset, first_answer_offset + len(answers[0])),
             )

        # Tokenize the question.
        tokenized_question = self._tokenizer.tokenize(question)
        tokenized_question = tokenized_question[:self.max_query_length]

        # Stride over the context, making instances.
        space_for_context = (self.length_limit - len(tokenized_question) -
                             len(self._tokenizer.sequence_pair_start_tokens) -
                             len(self._tokenizer.sequence_pair_mid_tokens) -
                             len(self._tokenizer.sequence_pair_end_tokens))
        stride_start = 0
        while True:
            tokenized_context_window = tokenized_context[stride_start:]
            tokenized_context_window = tokenized_context_window[:
                                                                space_for_context]

            window_token_answer_span = (
                token_answer_span_start - stride_start,
                token_answer_span_end - stride_start,
            )
            if any(i < 0 or i >= len(tokenized_context_window)
                   for i in window_token_answer_span):
                # The answer is not contained in the window.
                window_token_answer_span = None

            if (not is_training or not self.skip_impossible_questions
                    or window_token_answer_span is not None):
                additional_metadata = {"id": qid}
                instance = self.text_to_instance(
                    question,
                    tokenized_question,
                    context,
                    tokenized_context_window,
                    answers=answers,
                    token_answer_span=window_token_answer_span,
                    additional_metadata=additional_metadata,
                    always_add_answer_span=always_add_answer_span,
                )
                yield instance

            stride_start += space_for_context
            if stride_start >= len(tokenized_context):
                break
            stride_start -= self.stride
Ejemplo n.º 10
0
    def make_instances(
        self,
        qid: str,
        question: str,
        answers: List[str],
        context: str,
        first_answer_offset: Optional[int],
    ) -> Iterable[Instance]:
        # tokenize context by spaces first, and then with the wordpiece tokenizer
        # For RoBERTa, this produces a bug where every token is marked as beginning-of-sentence. To fix it, we
        # detect whether a space comes before a word, and if so, add "a " in front of the word.
        def tokenize_slice(start: int, end: int) -> Iterable[Token]:
            text_to_tokenize = context[start:end]
            if start - 1 >= 0 and context[start - 1].isspace():
                prefix = "a "  # must end in a space, and be short so we can be sure it becomes only one token
                wordpieces = self._tokenizer.tokenize(prefix +
                                                      text_to_tokenize)
                for wordpiece in wordpieces:
                    if wordpiece.idx is not None:
                        wordpiece.idx -= len(prefix)
                return wordpieces[1:]
            else:
                return self._tokenizer.tokenize(text_to_tokenize)

        tokenized_context = []
        token_start = 0
        for i, c in enumerate(context):
            if c.isspace():
                for wordpiece in tokenize_slice(token_start, i):
                    if wordpiece.idx is not None:
                        wordpiece.idx += token_start
                    tokenized_context.append(wordpiece)
                token_start = i + 1
        for wordpiece in tokenize_slice(token_start, len(context)):
            if wordpiece.idx is not None:
                wordpiece.idx += token_start
            tokenized_context.append(wordpiece)

        if first_answer_offset is None:
            (token_answer_span_start, token_answer_span_end) = (-1, -1)
        else:
            (token_answer_span_start,
             token_answer_span_end), _ = char_span_to_token_span(
                 [(t.idx, t.idx + len(sanitize_wordpiece(t.text)))
                  if t.idx is not None else None for t in tokenized_context],
                 (first_answer_offset, first_answer_offset + len(answers[0])),
             )

        # Tokenize the question
        tokenized_question = self._tokenizer.tokenize(question)
        tokenized_question = tokenized_question[:self.max_query_length]
        for token in tokenized_question:
            token.type_id = self.non_content_type_id
            token.idx = None

        # Stride over the context, making instances
        # Sequences are [CLS] question [SEP] [SEP] context [SEP], hence the - 4 for four special tokens.
        # This is technically not correct for anything but RoBERTa, but it does not affect the scores.
        space_for_context = self.length_limit - len(tokenized_question) - 4
        stride_start = 0
        while True:
            tokenized_context_window = tokenized_context[stride_start:]
            tokenized_context_window = tokenized_context_window[:
                                                                space_for_context]

            window_token_answer_span = (
                token_answer_span_start - stride_start,
                token_answer_span_end - stride_start,
            )
            if any(i < 0 or i >= len(tokenized_context_window)
                   for i in window_token_answer_span):
                # The answer is not contained in the window.
                window_token_answer_span = None

            if not self.skip_invalid_examples or window_token_answer_span is not None:
                additional_metadata = {"id": qid}
                instance = self.text_to_instance(
                    question,
                    tokenized_question,
                    context,
                    tokenized_context_window,
                    answers,
                    window_token_answer_span,
                    additional_metadata,
                )
                yield instance

            stride_start += space_for_context
            if stride_start >= len(tokenized_context):
                break
            stride_start -= self.stride
Ejemplo n.º 11
0
    def get_instances_from_example(
            self,
            example: Dict,
            always_add_answer_span: bool = False) -> Iterable[Instance]:
        """
        Helper function to get instances from an example.

        Much of this comes from `transformer_squad.make_instances`

        # Parameters

        example: `Dict[str,Any]`
            The example dict.

        # Returns:

        `Iterable[Instance]` The instances for each example
        """
        # Get the passage dict from the example, it has text and
        # entities
        example_id: str = example["id"]
        passage_dict: Dict = example["passage"]
        passage_text: str = passage_dict["text"]

        # Tokenize the passage
        tokenized_passage: List[Token] = self.tokenize_str(passage_text)

        # TODO: Determine what to do with entities. Superglue marks them
        #   explicitly as input (https://arxiv.org/pdf/1905.00537.pdf)

        # Get the queries from the example dict
        queries: List = example["qas"]
        logger.debug(f"{len(queries)} queries for example {example_id}")

        # Tokenize and get the context windows for each queries
        for query in queries:

            # Create the additional metadata dict that will be passed w/ extra
            # data for each query. We store the question & query ids, all
            # answers, and other data following `transformer_qa`.
            additional_metadata = {
                "id": query["id"],
                "example_id": example_id,
            }
            instances_yielded = 0
            # Tokenize, and truncate, the query based on the max set in
            # `__init__`
            tokenized_query = self.tokenize_str(
                query["query"])[:self._query_len_limit]

            # Calculate where the context needs to start and how many tokens we have
            # for it. This is due to the limit on the number of tokens that a
            # transformer can use because they have quadratic memory usage. But if
            # you are reading this code, you probably know that.
            space_for_context = (
                self._length_limit - len(list(tokenized_query))
                # Used getattr so I can test without having to load a
                # transformer model.
                -
                len(getattr(self._tokenizer, "sequence_pair_start_tokens",
                            [])) -
                len(getattr(self._tokenizer, "sequence_pair_mid_tokens", [])) -
                len(getattr(self._tokenizer, "sequence_pair_end_tokens", [])))

            # Check if answers exist for this query. We assume that there are no
            # answers for this query, and set the start and end index for the
            # answer span to -1.
            answers = query.get("answers", [])
            if not answers:
                logger.warning(f"Skipping {query['id']}, no answers")
                continue

            # Create the arguments needed for `char_span_to_token_span`
            token_offsets = [(t.idx, t.idx + len(sanitize_wordpiece(t.text)))
                             if t.idx is not None else None
                             for t in tokenized_passage]

            # Get the token offsets for the answers for this current passage.
            answer_token_start, answer_token_end = (-1, -1)
            for answer in answers:

                # Try to find the offsets.
                offsets, _ = char_span_to_token_span(
                    token_offsets, (answer["start"], answer["end"]))

                # If offsets for an answer were found, it means the answer is in
                # the passage, and thus we can stop looking.
                if offsets != (-1, -1):
                    answer_token_start, answer_token_end = offsets
                    break

            # Go through the context and find the window that has the answer in it.
            stride_start = 0

            while True:
                tokenized_context_window = tokenized_passage[stride_start:]
                tokenized_context_window = tokenized_context_window[:
                                                                    space_for_context]

                # Get the token offsets w.r.t the current window.
                window_token_answer_span = (
                    answer_token_start - stride_start,
                    answer_token_end - stride_start,
                )
                if any(i < 0 or i >= len(tokenized_context_window)
                       for i in window_token_answer_span):
                    # The answer is not contained in the window.
                    window_token_answer_span = None

                if (
                        # not self.skip_impossible_questions
                        window_token_answer_span is not None):
                    # The answer WAS found in the context window, and thus we
                    # can make an instance for the answer.
                    instance = self.text_to_instance(
                        query["query"],
                        tokenized_query,
                        passage_text,
                        tokenized_context_window,
                        answers=[answer["text"] for answer in answers],
                        token_answer_span=window_token_answer_span,
                        additional_metadata=additional_metadata,
                        always_add_answer_span=always_add_answer_span,
                    )
                    yield instance
                    instances_yielded += 1

                if instances_yielded == 1 and self._one_instance_per_query:
                    break

                stride_start += space_for_context

                # If we have reached the end of the passage, stop.
                if stride_start >= len(tokenized_passage):
                    break

                # I am not sure what this does...but it is here?
                stride_start -= self._stride
Ejemplo n.º 12
0
 def _str_compare_tokens(a: Token, b: Token):
     return sanitize_wordpiece(a.text) == sanitize_wordpiece(b.text)