def indices_to_tokens(self, indexed_tokens: IndexedTokenList, vocabulary: Vocabulary) -> List[Token]: token_ids = indexed_tokens["token_ids"] type_ids = indexed_tokens.get("type_ids") return [ Token( text=vocabulary.get_token_from_index(token_ids[i], self._namespace), text_id=token_ids[i], type_id=type_ids[i] if type_ids is not None else None, ) for i in range(len(token_ids)) ]
def _postprocess_output(self, output: IndexedTokenList) -> IndexedTokenList: """ Takes an IndexedTokenList about to be returned by `tokens_to_indices()` and adds any necessary postprocessing, e.g. long sequence splitting. The input should have a `"token_ids"` key corresponding to the token indices. They should have special tokens already inserted. """ if self._max_length is not None: # We prepare long indices by converting them to (assuming max_length == 5) # [CLS] A B C [SEP] [CLS] D E F [SEP] ... # Embedder is responsible for folding this 1-d sequence to 2-d and feed to the # transformer model. # TODO(zhaofengw): we aren't respecting word boundaries when segmenting wordpieces. indices = output["token_ids"] type_ids = output.get("type_ids", [0] * len(indices)) # Strips original special tokens indices = indices[ self._num_added_start_tokens : len(indices) - self._num_added_end_tokens ] type_ids = type_ids[ self._num_added_start_tokens : len(type_ids) - self._num_added_end_tokens ] # Folds indices folded_indices = [ indices[i : i + self._effective_max_length] for i in range(0, len(indices), self._effective_max_length) ] folded_type_ids = [ type_ids[i : i + self._effective_max_length] for i in range(0, len(type_ids), self._effective_max_length) ] # Adds special tokens to each segment folded_indices = [ self._tokenizer.build_inputs_with_special_tokens(segment) for segment in folded_indices ] single_sequence_start_type_ids = [ t.type_id for t in self._allennlp_tokenizer.single_sequence_start_tokens ] single_sequence_end_type_ids = [ t.type_id for t in self._allennlp_tokenizer.single_sequence_end_tokens ] folded_type_ids = [ single_sequence_start_type_ids + segment + single_sequence_end_type_ids for segment in folded_type_ids ] assert all( len(segment_indices) == len(segment_type_ids) for segment_indices, segment_type_ids in zip(folded_indices, folded_type_ids) ) # Flattens indices = [i for segment in folded_indices for i in segment] type_ids = [i for segment in folded_type_ids for i in segment] output["token_ids"] = indices output["type_ids"] = type_ids output["segment_concat_mask"] = [True] * len(indices) return output