Ejemplo n.º 1
0
    def test_xlm_token_tensorizer(self):
        vocab = self._mock_vocab()

        xlm = ScriptXLMTensorizer(
            tokenizer=ScriptDoNothingTokenizer(),
            token_vocab=vocab,
            language_vocab=ScriptVocabulary(["ar", "cn", "en"]),
            max_seq_len=256,
            default_language="en",
        )
        rand_tokens = [
            [str(random.randint(100, 200)) for i in range(20)],
            [str(random.randint(100, 200)) for i in range(10)],
        ]

        tokens, pad_masks, languages, positions = xlm.tensorize(
            tokens=squeeze_2d(rand_tokens))
        tokens = tokens.tolist()
        # eos token
        self.assertEqual(tokens[0][0], 202)
        self.assertEqual(tokens[0][-1], 202)
        # pad token
        self.assertEqual(tokens[1][12:], [200] * 10)

        languages = languages.tolist()
        self.assertEqual(languages[0], [2] * len(tokens[0]))
        self.assertEqual(languages[1][12:], [0] * 10)

        tokens, pad_masks, languages, positions = xlm.tensorize(
            tokens=squeeze_2d(rand_tokens), languages=squeeze_1d(["cn", "en"]))
        languages = languages.tolist()
        self.assertEqual(languages[0][:], [1] * len(tokens[0]))
        self.assertEqual(languages[1][:12], [2] * 12)
Ejemplo n.º 2
0
    def test_xlm_tensorizer_seq_padding_size_exceeds_max_seq_len(self):
        vocab = self._mock_vocab()

        xlm = ScriptXLMTensorizer(
            tokenizer=ScriptDoNothingTokenizer(),
            token_vocab=vocab,
            language_vocab=ScriptVocabulary(["ar", "cn", "en"]),
            max_seq_len=20,
            default_language="en",
        )

        seq_padding_control = [0, 32, 256]
        xlm.set_padding_control("sequence_length", seq_padding_control)

        rand_tokens = [
            [str(random.randint(100, 200)) for i in range(30)],
            [str(random.randint(100, 200)) for i in range(20)],
            [str(random.randint(100, 200)) for i in range(10)],
        ]

        tokens, pad_masks, languages, positions = xlm.tensorize(
            tokens=squeeze_2d(rand_tokens), )

        token_count = [len(t) + 2 for t in rand_tokens]
        expected_batch_size = len(rand_tokens)
        expected_token_size = min(
            max(max(token_count), seq_padding_control[1]), xlm.max_seq_len)
        expected_padding_count = [
            max(0, expected_token_size - cnt) for cnt in token_count
        ]
        token_count = [
            expected_token_size - cnt for cnt in expected_padding_count
        ]

        # verify tensorized tokens padding
        tokens = tokens.tolist()
        self.assertEqual(len(tokens), expected_batch_size)
        self.assertEqual(
            max(len(t) for t in tokens),
            min(len(t) for t in tokens),
            expected_token_size,
        )
        for i in range(expected_batch_size):
            self.assertEqual(tokens[i][token_count[i]:],
                             [200] * expected_padding_count[i])

        # verify tensorized languages
        languages = languages.tolist()
        self.assertEqual(len(languages), expected_batch_size)
        for i in range(expected_batch_size):
            self.assertEqual(languages[i][:token_count[i]],
                             [2] * token_count[i])
            self.assertEqual(languages[i][token_count[i]:],
                             [0] * expected_padding_count[i])

        # verify tensorized postions
        positions = positions.tolist()
        self.assertEqual(len(positions), expected_batch_size)
        for i in range(expected_batch_size):
            self.assertEqual(positions[i][token_count[i]:],
                             [0] * expected_padding_count[i])

        # verify pad_masks
        pad_masks = pad_masks.tolist()
        self.assertEqual(len(pad_masks), expected_batch_size)
        for i in range(expected_batch_size):
            self.assertEqual(pad_masks[i][:token_count[i]],
                             [1] * token_count[i])
            self.assertEqual(pad_masks[i][token_count[i]:],
                             [0] * expected_padding_count[i])