示例#1
0
    def test_python_en_tokenizer_prepare_batch(self):
        batch = self.tokenizer(self.src_text,
                               padding=True,
                               truncation=True,
                               max_length=len(self.expected_src_tokens),
                               return_tensors="pt")
        with self.tokenizer.as_target_tokenizer():
            targets = self.tokenizer(
                self.tgt_text,
                padding=True,
                truncation=True,
                max_length=len(self.expected_src_tokens),
                return_tensors="pt",
            )
        labels = targets["input_ids"]
        batch["decoder_input_ids"] = shift_tokens_right(
            labels, self.tokenizer.pad_token_id)

        self.assertIsInstance(batch, BatchEncoding)

        self.assertEqual((2, 26), batch.input_ids.shape)
        self.assertEqual((2, 26), batch.attention_mask.shape)
        result = batch.input_ids.tolist()[0]
        self.assertListEqual(self.expected_src_tokens, result)
        self.assertEqual(2, batch.decoder_input_ids[0, -1])  # EOS
        # Test that special tokens are reset
        self.assertEqual(self.tokenizer.prefix_tokens, [])
        self.assertEqual(self.tokenizer.suffix_tokens,
                         [self.tokenizer.eos_token_id, PYTHON_CODE])
示例#2
0
    def test_batch_fairseq_parity(self):
        batch = self.tokenizer(self.src_text, padding=True)
        with self.tokenizer.as_target_tokenizer():
            targets = self.tokenizer(self.tgt_text,
                                     padding=True,
                                     return_tensors="pt")
        labels = targets["input_ids"]
        batch["decoder_input_ids"] = shift_tokens_right(
            labels, self.tokenizer.pad_token_id).tolist()

        # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
        self.assertEqual(batch.input_ids[1][-2:], [2, PYTHON_CODE])
        self.assertEqual(batch.decoder_input_ids[1][0], EN_CODE)
        self.assertEqual(batch.decoder_input_ids[1][-1], 2)
        self.assertEqual(labels[1][-2:].tolist(), [2, EN_CODE])
示例#3
0
    def test_seq2seq_max_length(self):
        batch = self.tokenizer(self.src_text,
                               padding=True,
                               truncation=True,
                               max_length=3,
                               return_tensors="pt")
        with self.tokenizer.as_target_tokenizer():
            targets = self.tokenizer(self.tgt_text,
                                     padding=True,
                                     truncation=True,
                                     max_length=10,
                                     return_tensors="pt")
        labels = targets["input_ids"]
        batch["decoder_input_ids"] = shift_tokens_right(
            labels, self.tokenizer.pad_token_id)

        self.assertEqual(batch.input_ids.shape[1], 3)
        self.assertEqual(batch.decoder_input_ids.shape[1], 10)