Ejemplo n.º 1
0
    def test_truncate_seq_pair(self):

        tokens_a = [1, 2, 3]
        tokens_b = [4, 5, 6]
        utils.truncate_seq_pair(tokens_a, tokens_b, 4)
        self.assertListEqual(tokens_a, [1, 2])
        self.assertListEqual(tokens_b, [4, 5])

        tokens_a = [1]
        tokens_b = [2, 3, 4, 5]
        utils.truncate_seq_pair(tokens_a, tokens_b, 3)
        self.assertListEqual(tokens_a, [1])
        self.assertListEqual(tokens_b, [2, 3])
    def encode_text(self,  # type: ignore
                    text_a: str,
                    text_b: Optional[str] = None,
                    max_seq_length: Optional[int] = None) -> \
            Tuple[List[int], List[int]]:
        r"""Adds special tokens to a sequence or sequence pair and computes the
        corresponding input mask for RoBERTa specific tasks.
        The sequence will be truncated if its length is larger than
        ``max_seq_length``.

        A RoBERTa sequence has the following format:
        `[cls_token]` X `[sep_token]`

        A RoBERTa sequence pair has the following format:
        `[cls_token]` A `[spe_token]` `[sep_token]` B `[sep_token]`

        Args:
            text_a: The first input text.
            text_b: The second input text.
            max_seq_length: Maximum sequence length.

        Returns:
            A tuple of `(input_ids, segment_ids, input_mask)`, where

            - ``input_ids``: A list of input token ids with added
              special token ids.
            - ``input_mask``: A list of mask ids. The mask has 1 for real
              tokens and 0 for padding tokens. Only real tokens are
              attended to.
        """
        if max_seq_length is None:
            max_seq_length = self.max_len

        cls_token_id = self._map_token_to_id(self.cls_token)
        sep_token_id = self._map_token_to_id(self.sep_token)

        token_ids_a = self.map_text_to_id(text_a)
        assert isinstance(token_ids_a, list)

        token_ids_b = None
        if text_b:
            token_ids_b = self.map_text_to_id(text_b)

        if token_ids_b:
            assert isinstance(token_ids_b, list)
            # Modifies `token_ids_a` and `token_ids_b` in place so that the
            # total length is less than the specified length.
            # Account for <s>, </s>, </s>, </s> with "- 4"
            truncate_seq_pair(token_ids_a, token_ids_b, max_seq_length - 4)

            input_ids = ([cls_token_id] + token_ids_a + [sep_token_id] +
                         [sep_token_id] + token_ids_b + [sep_token_id])
        else:
            # Account for <s> and </s> with "- 2"
            token_ids_a = token_ids_a[:max_seq_length - 2]

            input_ids = [cls_token_id] + token_ids_a + [sep_token_id]

        input_mask = [1] * len(input_ids)

        # Zero-pad up to the maximum sequence length.
        input_ids = input_ids + [0] * (max_seq_length - len(input_ids))
        input_mask = input_mask + [0] * (max_seq_length - len(input_mask))

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length

        return input_ids, input_mask
Ejemplo n.º 3
0
    def encode_text(self,
                    text_a: str,
                    text_b: Optional[str] = None,
                    max_seq_length: Optional[int] = None) -> \
            Tuple[List[int], List[int], List[int]]:
        r"""Adds special tokens to a sequence or sequence pair and computes the
        corresponding segment ids and input mask for XLNet specific tasks.
        The sequence will be truncated if its length is larger than
        ``max_seq_length``.

        A XLNet sequence has the following format:
        X `[sep_token]` `[cls_token]`

        A XLNet sequence pair has the following format:
        `[cls_token]` A `[sep_token]` B `[sep_token]`

        Args:
            text_a: The first input text.
            text_b: The second input text.
            max_seq_length: Maximum sequence length.

        Returns:
            A tuple of `(input_ids, segment_ids, input_mask)`, where

            - ``input_ids``: A list of input token ids with added
              special token ids.
            - ``segment_ids``: A list of segment ids.
            - ``input_mask``: A list of mask ids. The mask has 1 for real
              tokens and 0 for padding tokens. Only real tokens are
              attended to.
        """
        if max_seq_length is None:
            max_seq_length = self.max_len

        cls_token_id = self._map_token_to_id(self.cls_token)
        sep_token_id = self._map_token_to_id(self.sep_token)

        token_ids_a = self.map_text_to_id(text_a)
        assert isinstance(token_ids_a, list)

        token_ids_b = None
        if text_b:
            token_ids_b = self.map_text_to_id(text_b)

        if token_ids_b:
            assert isinstance(token_ids_b, list)
            # Modifies `token_ids_a` and `token_ids_b` in place so that the
            # total length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            truncate_seq_pair(token_ids_a, token_ids_b, max_seq_length - 3)

            input_ids = (token_ids_a + [sep_token_id] + token_ids_b +
                         [sep_token_id] + [cls_token_id])
            segment_ids = [SEG_ID_A] * (len(token_ids_a) + 1) + \
                          [SEG_ID_B] * (len(token_ids_b) + 1) + [SEG_ID_CLS]
        else:
            # Account for [CLS] and [SEP] with "- 2"
            token_ids = token_ids_a[:max_seq_length - 2]

            input_ids = token_ids + [sep_token_id] + [cls_token_id]
            segment_ids = [SEG_ID_A] * (len(input_ids) - 1) + [SEG_ID_CLS]

        input_mask = [0] * len(input_ids)

        # Zero-pad up to the maximum sequence length.
        input_ids = [0] * (max_seq_length - len(input_ids)) + input_ids
        input_mask = [1] * (max_seq_length - len(input_mask)) + input_mask
        segment_ids = ([SEG_ID_PAD] * (max_seq_length - len(segment_ids)) +
                       segment_ids)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        return input_ids, segment_ids, input_mask
Ejemplo n.º 4
0
    def encode_text(self,
                    text_a: str,
                    text_b: Optional[str] = None,
                    max_seq_length: Optional[int] = None) -> \
            Tuple[List[int], List[int], List[int], List[int], int]:
        r"""Adds special tokens to a sequence or sequence pair and computes the
        corresponding segment ids and input mask for BERT specific tasks.
        The sequence will be truncated if its length is larger than
        ``max_seq_length``.

        A BERT sequence has the following format:
        `[cls_token]` X `[sep_token]`

        A BERT sequence pair has the following format:
        `[cls_token]` A `[sep_token]` B `[sep_token]`

        Args:
            text_a: The first input text.
            text_b: The second input text.
            max_seq_length: Maximum sequence length.

        Returns:
            A tuple of `(input_ids, segment_ids, input_mask)`, where

            - ``input_ids``: A list of input token ids with added
              special token ids.
            - ``segment_ids``: A list of segment ids.
            - ``input_mask``: A list of mask ids. The mask has 1 for real
              tokens and 0 for padding tokens. Only real tokens are
              attended to.
        """
        if max_seq_length is None:
            max_seq_length = self.max_len

        cls_token_id = self._map_token_to_id(self.cls_token)
        sep_token_id = self._map_token_to_id(self.sep_token)

        token_ids_a = self.map_text_to_id(text_a)
        assert isinstance(token_ids_a, list)

        token_ids_b = None
        if text_b:
            token_ids_b = self.map_text_to_id(text_b)

        if token_ids_b:
            assert isinstance(token_ids_b, list)
            # Modifies `token_ids_a` and `token_ids_b` in place so that the
            # total length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            truncate_seq_pair(token_ids_a, token_ids_b, max_seq_length - 3)

            input_ids = ([cls_token_id] + token_ids_a + [sep_token_id] +
                         token_ids_b + [sep_token_id])
            segment_ids = [0] * (len(token_ids_a) + 2) + \
                          [1] * (len(token_ids_b) + 1)
        else:
            # Account for [CLS] and [SEP] with "- 2"
            token_ids_a = token_ids_a[:max_seq_length - 2]

            input_ids = [cls_token_id] + token_ids_a + [sep_token_id]
            segment_ids = [0] * len(input_ids)

        input_mask = [1] * len(input_ids)

        sentence_ids, cur_id = [-1], 0
        eos_id = self._map_token_to_id('.')
        end_id = self._map_token_to_id('[SEP]')
        for id in input_ids[1:]:
            if id == end_id:
                break
            sentence_ids.append(cur_id)
            if id == eos_id:
                cur_id += 1
        sentence_ids = sentence_ids + [-1
                                       ] * (max_seq_length - len(sentence_ids))
        sentence_num = cur_id

        # Zero-pad up to the maximum sequence length.
        input_ids = input_ids + [0] * (max_seq_length - len(input_ids))
        segment_ids = segment_ids + [0] * (max_seq_length - len(segment_ids))
        input_mask = input_mask + [0] * (max_seq_length - len(input_mask))

        assert len(input_ids) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(sentence_ids) == max_seq_length

        # print("original_text:\n", text_a)
        # print("input_ids:\n", input_ids)
        # print("sentence_ids:\n", sentence_ids)
        # print("sentence_num: %d" % sentence_num)
        # print("segment_ids:\n", segment_ids)
        # print("input_mask:\n", input_mask, "\n")
        return input_ids, segment_ids, input_mask, sentence_ids, sentence_num