def test_truncate_seq_pair(self): tokens_a = [1, 2, 3] tokens_b = [4, 5, 6] utils.truncate_seq_pair(tokens_a, tokens_b, 4) self.assertListEqual(tokens_a, [1, 2]) self.assertListEqual(tokens_b, [4, 5]) tokens_a = [1] tokens_b = [2, 3, 4, 5] utils.truncate_seq_pair(tokens_a, tokens_b, 3) self.assertListEqual(tokens_a, [1]) self.assertListEqual(tokens_b, [2, 3])
def encode_text(self, # type: ignore text_a: str, text_b: Optional[str] = None, max_seq_length: Optional[int] = None) -> \ Tuple[List[int], List[int]]: r"""Adds special tokens to a sequence or sequence pair and computes the corresponding input mask for RoBERTa specific tasks. The sequence will be truncated if its length is larger than ``max_seq_length``. A RoBERTa sequence has the following format: `[cls_token]` X `[sep_token]` A RoBERTa sequence pair has the following format: `[cls_token]` A `[spe_token]` `[sep_token]` B `[sep_token]` Args: text_a: The first input text. text_b: The second input text. max_seq_length: Maximum sequence length. Returns: A tuple of `(input_ids, segment_ids, input_mask)`, where - ``input_ids``: A list of input token ids with added special token ids. - ``input_mask``: A list of mask ids. The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. """ if max_seq_length is None: max_seq_length = self.max_len cls_token_id = self._map_token_to_id(self.cls_token) sep_token_id = self._map_token_to_id(self.sep_token) token_ids_a = self.map_text_to_id(text_a) assert isinstance(token_ids_a, list) token_ids_b = None if text_b: token_ids_b = self.map_text_to_id(text_b) if token_ids_b: assert isinstance(token_ids_b, list) # Modifies `token_ids_a` and `token_ids_b` in place so that the # total length is less than the specified length. # Account for <s>, </s>, </s>, </s> with "- 4" truncate_seq_pair(token_ids_a, token_ids_b, max_seq_length - 4) input_ids = ([cls_token_id] + token_ids_a + [sep_token_id] + [sep_token_id] + token_ids_b + [sep_token_id]) else: # Account for <s> and </s> with "- 2" token_ids_a = token_ids_a[:max_seq_length - 2] input_ids = [cls_token_id] + token_ids_a + [sep_token_id] input_mask = [1] * len(input_ids) # Zero-pad up to the maximum sequence length. input_ids = input_ids + [0] * (max_seq_length - len(input_ids)) input_mask = input_mask + [0] * (max_seq_length - len(input_mask)) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length return input_ids, input_mask
def encode_text(self, text_a: str, text_b: Optional[str] = None, max_seq_length: Optional[int] = None) -> \ Tuple[List[int], List[int], List[int]]: r"""Adds special tokens to a sequence or sequence pair and computes the corresponding segment ids and input mask for XLNet specific tasks. The sequence will be truncated if its length is larger than ``max_seq_length``. A XLNet sequence has the following format: X `[sep_token]` `[cls_token]` A XLNet sequence pair has the following format: `[cls_token]` A `[sep_token]` B `[sep_token]` Args: text_a: The first input text. text_b: The second input text. max_seq_length: Maximum sequence length. Returns: A tuple of `(input_ids, segment_ids, input_mask)`, where - ``input_ids``: A list of input token ids with added special token ids. - ``segment_ids``: A list of segment ids. - ``input_mask``: A list of mask ids. The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. """ if max_seq_length is None: max_seq_length = self.max_len cls_token_id = self._map_token_to_id(self.cls_token) sep_token_id = self._map_token_to_id(self.sep_token) token_ids_a = self.map_text_to_id(text_a) assert isinstance(token_ids_a, list) token_ids_b = None if text_b: token_ids_b = self.map_text_to_id(text_b) if token_ids_b: assert isinstance(token_ids_b, list) # Modifies `token_ids_a` and `token_ids_b` in place so that the # total length is less than the specified length. # Account for [CLS], [SEP], [SEP] with "- 3" truncate_seq_pair(token_ids_a, token_ids_b, max_seq_length - 3) input_ids = (token_ids_a + [sep_token_id] + token_ids_b + [sep_token_id] + [cls_token_id]) segment_ids = [SEG_ID_A] * (len(token_ids_a) + 1) + \ [SEG_ID_B] * (len(token_ids_b) + 1) + [SEG_ID_CLS] else: # Account for [CLS] and [SEP] with "- 2" token_ids = token_ids_a[:max_seq_length - 2] input_ids = token_ids + [sep_token_id] + [cls_token_id] segment_ids = [SEG_ID_A] * (len(input_ids) - 1) + [SEG_ID_CLS] input_mask = [0] * len(input_ids) # Zero-pad up to the maximum sequence length. input_ids = [0] * (max_seq_length - len(input_ids)) + input_ids input_mask = [1] * (max_seq_length - len(input_mask)) + input_mask segment_ids = ([SEG_ID_PAD] * (max_seq_length - len(segment_ids)) + segment_ids) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length return input_ids, segment_ids, input_mask
def encode_text(self, text_a: str, text_b: Optional[str] = None, max_seq_length: Optional[int] = None) -> \ Tuple[List[int], List[int], List[int], List[int], int]: r"""Adds special tokens to a sequence or sequence pair and computes the corresponding segment ids and input mask for BERT specific tasks. The sequence will be truncated if its length is larger than ``max_seq_length``. A BERT sequence has the following format: `[cls_token]` X `[sep_token]` A BERT sequence pair has the following format: `[cls_token]` A `[sep_token]` B `[sep_token]` Args: text_a: The first input text. text_b: The second input text. max_seq_length: Maximum sequence length. Returns: A tuple of `(input_ids, segment_ids, input_mask)`, where - ``input_ids``: A list of input token ids with added special token ids. - ``segment_ids``: A list of segment ids. - ``input_mask``: A list of mask ids. The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. """ if max_seq_length is None: max_seq_length = self.max_len cls_token_id = self._map_token_to_id(self.cls_token) sep_token_id = self._map_token_to_id(self.sep_token) token_ids_a = self.map_text_to_id(text_a) assert isinstance(token_ids_a, list) token_ids_b = None if text_b: token_ids_b = self.map_text_to_id(text_b) if token_ids_b: assert isinstance(token_ids_b, list) # Modifies `token_ids_a` and `token_ids_b` in place so that the # total length is less than the specified length. # Account for [CLS], [SEP], [SEP] with "- 3" truncate_seq_pair(token_ids_a, token_ids_b, max_seq_length - 3) input_ids = ([cls_token_id] + token_ids_a + [sep_token_id] + token_ids_b + [sep_token_id]) segment_ids = [0] * (len(token_ids_a) + 2) + \ [1] * (len(token_ids_b) + 1) else: # Account for [CLS] and [SEP] with "- 2" token_ids_a = token_ids_a[:max_seq_length - 2] input_ids = [cls_token_id] + token_ids_a + [sep_token_id] segment_ids = [0] * len(input_ids) input_mask = [1] * len(input_ids) sentence_ids, cur_id = [-1], 0 eos_id = self._map_token_to_id('.') end_id = self._map_token_to_id('[SEP]') for id in input_ids[1:]: if id == end_id: break sentence_ids.append(cur_id) if id == eos_id: cur_id += 1 sentence_ids = sentence_ids + [-1 ] * (max_seq_length - len(sentence_ids)) sentence_num = cur_id # Zero-pad up to the maximum sequence length. input_ids = input_ids + [0] * (max_seq_length - len(input_ids)) segment_ids = segment_ids + [0] * (max_seq_length - len(segment_ids)) input_mask = input_mask + [0] * (max_seq_length - len(input_mask)) assert len(input_ids) == max_seq_length assert len(segment_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(sentence_ids) == max_seq_length # print("original_text:\n", text_a) # print("input_ids:\n", input_ids) # print("sentence_ids:\n", sentence_ids) # print("sentence_num: %d" % sentence_num) # print("segment_ids:\n", segment_ids) # print("input_mask:\n", input_mask, "\n") return input_ids, segment_ids, input_mask, sentence_ids, sentence_num