Esempio n. 1
0
    def indices_to_tokens(self, indexed_tokens: IndexedTokenList,
                          vocabulary: Vocabulary) -> List[Token]:
        token_ids = indexed_tokens["token_ids"]
        type_ids = indexed_tokens.get("type_ids")

        return [
            Token(
                text=vocabulary.get_token_from_index(token_ids[i],
                                                     self._namespace),
                text_id=token_ids[i],
                type_id=type_ids[i] if type_ids is not None else None,
            ) for i in range(len(token_ids))
        ]
    def _postprocess_output(self, output: IndexedTokenList) -> IndexedTokenList:
        """
        Takes an IndexedTokenList about to be returned by `tokens_to_indices()` and adds any
        necessary postprocessing, e.g. long sequence splitting.

        The input should have a `"token_ids"` key corresponding to the token indices. They should
        have special tokens already inserted.
        """
        if self._max_length is not None:
            # We prepare long indices by converting them to (assuming max_length == 5)
            # [CLS] A B C [SEP] [CLS] D E F [SEP] ...
            # Embedder is responsible for folding this 1-d sequence to 2-d and feed to the
            # transformer model.
            # TODO(zhaofengw): we aren't respecting word boundaries when segmenting wordpieces.

            indices = output["token_ids"]
            type_ids = output.get("type_ids", [0] * len(indices))

            # Strips original special tokens
            indices = indices[
                self._num_added_start_tokens : len(indices) - self._num_added_end_tokens
            ]
            type_ids = type_ids[
                self._num_added_start_tokens : len(type_ids) - self._num_added_end_tokens
            ]

            # Folds indices
            folded_indices = [
                indices[i : i + self._effective_max_length]
                for i in range(0, len(indices), self._effective_max_length)
            ]
            folded_type_ids = [
                type_ids[i : i + self._effective_max_length]
                for i in range(0, len(type_ids), self._effective_max_length)
            ]

            # Adds special tokens to each segment
            folded_indices = [
                self._tokenizer.build_inputs_with_special_tokens(segment)
                for segment in folded_indices
            ]
            single_sequence_start_type_ids = [
                t.type_id for t in self._allennlp_tokenizer.single_sequence_start_tokens
            ]
            single_sequence_end_type_ids = [
                t.type_id for t in self._allennlp_tokenizer.single_sequence_end_tokens
            ]
            folded_type_ids = [
                single_sequence_start_type_ids + segment + single_sequence_end_type_ids
                for segment in folded_type_ids
            ]
            assert all(
                len(segment_indices) == len(segment_type_ids)
                for segment_indices, segment_type_ids in zip(folded_indices, folded_type_ids)
            )

            # Flattens
            indices = [i for segment in folded_indices for i in segment]
            type_ids = [i for segment in folded_type_ids for i in segment]

            output["token_ids"] = indices
            output["type_ids"] = type_ids
            output["segment_concat_mask"] = [True] * len(indices)

        return output