Beispiel #1
0
    def test_batch_fairseq_parity(self):
        self.tokenizer.src_lang = "en"
        self.tokenizer.tgt_lang = "fr"

        batch = self.tokenizer(self.src_text,
                               padding=True,
                               return_tensors="pt")
        with self.tokenizer.as_target_tokenizer():
            batch["labels"] = self.tokenizer(self.tgt_text,
                                             padding=True,
                                             return_tensors="pt").input_ids

        batch["decoder_input_ids"] = shift_tokens_right(
            batch["labels"], self.tokenizer.pad_token_id,
            self.tokenizer.eos_token_id)

        for k in batch:
            batch[k] = batch[k].tolist()
        # batch = {k: v.tolist() for k,v in batch.items()}
        # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
        # batch.decoder_inputs_ids[0][0] ==
        assert batch.input_ids[1][0] == EN_CODE
        assert batch.input_ids[1][-1] == 2
        assert batch.labels[1][0] == FR_CODE
        assert batch.labels[1][-1] == 2
        assert batch.decoder_input_ids[1][:2] == [2, FR_CODE]
    def test_seq2seq_max_length(self):
        batch = self.tokenizer(self.src_text,
                               padding=True,
                               truncation=True,
                               max_length=3,
                               return_tensors="pt")
        targets = self.tokenizer(text_target=self.tgt_text,
                                 padding=True,
                                 truncation=True,
                                 max_length=10,
                                 return_tensors="pt")
        labels = targets["input_ids"]
        batch["decoder_input_ids"] = shift_tokens_right(
            labels,
            self.tokenizer.pad_token_id,
            decoder_start_token_id=self.tokenizer.lang_code_to_id[
                self.tokenizer.tgt_lang],
        )

        self.assertEqual(batch.input_ids.shape[1], 3)
        self.assertEqual(batch.decoder_input_ids.shape[1], 10)
    def test_enro_tokenizer_prepare_batch(self):
        batch = self.tokenizer(
            self.src_text,
            text_target=self.tgt_text,
            padding=True,
            truncation=True,
            max_length=len(self.expected_src_tokens),
            return_tensors="pt",
        )
        batch["decoder_input_ids"] = shift_tokens_right(
            batch["labels"], self.tokenizer.pad_token_id,
            self.tokenizer.lang_code_to_id["ron_Latn"])

        self.assertIsInstance(batch, BatchEncoding)

        self.assertEqual((2, 15), batch.input_ids.shape)
        self.assertEqual((2, 15), batch.attention_mask.shape)
        result = batch.input_ids.tolist()[0]
        self.assertListEqual(self.expected_src_tokens, result)
        self.assertEqual(2, batch.decoder_input_ids[0, -1])  # EOS
        # Test that special tokens are reset
        self.assertEqual(self.tokenizer.prefix_tokens, [])
        self.assertEqual(self.tokenizer.suffix_tokens,
                         [self.tokenizer.eos_token_id, EN_CODE])