コード例 #1
0
 def test_seq2seq_max_target_length(self):
     batch = self.tokenizer.prepare_seq2seq_batch(
         self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
     )
     batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
     self.assertEqual(batch.input_ids.shape[1], 3)
     self.assertEqual(batch.decoder_input_ids.shape[1], 10)
     # max_target_length will default to max_length if not specified
     batch = self.tokenizer.prepare_seq2seq_batch(
         self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
     )
     batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
     self.assertEqual(batch.input_ids.shape[1], 3)
     self.assertEqual(batch.decoder_input_ids.shape[1], 3)
コード例 #2
0
    def test_enro_tokenizer_prepare_batch(self):
        batch = self.tokenizer(
            self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
        )
        with self.tokenizer.as_target_tokenizer():
            targets = self.tokenizer(
                self.tgt_text,
                padding=True,
                truncation=True,
                max_length=len(self.expected_src_tokens),
                return_tensors="pt",
            )
        labels = targets["input_ids"]
        batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)

        self.assertIsInstance(batch, BatchEncoding)

        self.assertEqual((2, 14), batch.input_ids.shape)
        self.assertEqual((2, 14), batch.attention_mask.shape)
        result = batch.input_ids.tolist()[0]
        self.assertListEqual(self.expected_src_tokens, result)
        self.assertEqual(2, batch.decoder_input_ids[0, -1])  # EOS
        # Test that special tokens are reset
        self.assertEqual(self.tokenizer.prefix_tokens, [])
        self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
コード例 #3
0
    def test_seq2seq_max_length(self):
        batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
        with self.tokenizer.as_target_tokenizer():
            targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
        labels = targets["input_ids"]
        batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)

        self.assertEqual(batch.input_ids.shape[1], 3)
        self.assertEqual(batch.decoder_input_ids.shape[1], 10)
コード例 #4
0
    def test_batch_fairseq_parity(self):
        batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
        batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)

        # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
        assert batch.input_ids[1][-2:].tolist() == [2, EN_CODE]
        assert batch.decoder_input_ids[1][0].tolist() == RO_CODE
        assert batch.decoder_input_ids[1][-1] == 2
        assert batch.labels[1][-2:].tolist() == [2, RO_CODE]
コード例 #5
0
    def test_batch_fairseq_parity(self):
        batch = self.tokenizer(self.src_text, padding=True)
        with self.tokenizer.as_target_tokenizer():
            targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
        labels = targets["input_ids"]
        batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
        labels = labels.tolist()

        # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
        assert batch.input_ids[1][0] == EN_CODE
        assert batch.input_ids[1][-1] == 2
        assert labels[1][0] == RO_CODE
        assert labels[1][-1] == 2
        assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
コード例 #6
0
    def test_tokenizer_prepare_seq2seq_batch(self):
        batch = self.tokenizer.prepare_seq2seq_batch(
            self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
        )
        batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
        self.assertIsInstance(batch, BatchEncoding)

        self.assertEqual((2, 14), batch.input_ids.shape)
        self.assertEqual((2, 14), batch.attention_mask.shape)
        result = batch.input_ids.tolist()[0]
        self.assertListEqual(self.expected_src_tokens, result)
        self.assertEqual(2, batch.decoder_input_ids[0, 0])  # decoder_start_token_id
        # Test that special tokens are reset
        self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE])
        self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
コード例 #7
0
    def test_batch_fairseq_parity(self):
        batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
            self.src_text, tgt_texts=self.tgt_text, return_tensors="pt")
        batch["decoder_input_ids"] = shift_tokens_right(
            batch.labels, self.tokenizer.pad_token_id)

        for k in batch:
            batch[k] = batch[k].tolist()
        # batch = {k: v.tolist() for k,v in batch.items()}
        # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
        # batch.decoder_inputs_ids[0][0] ==
        assert batch.input_ids[1][-2:] == [2, EN_CODE]
        assert batch.decoder_input_ids[1][0] == RO_CODE
        assert batch.decoder_input_ids[1][-1] == 2
        assert batch.labels[1][-2:] == [2, RO_CODE]
コード例 #8
0
    def test_seq2seq_dataset_truncation(self, tok_name):
        tokenizer = AutoTokenizer.from_pretrained(tok_name)
        tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
        max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
        max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
        max_src_len = 4
        max_tgt_len = 8
        assert max_len_target > max_src_len  # Will be truncated
        assert max_len_source > max_src_len  # Will be truncated
        src_lang, tgt_lang = "ro_RO", "de_DE"  # ignored for all but mbart, but never causes error.
        train_dataset = Seq2SeqDataset(
            tokenizer,
            data_dir=tmp_dir,
            type_path="train",
            max_source_length=max_src_len,
            max_target_length=max_tgt_len,  # ignored
            src_lang=src_lang,
            tgt_lang=tgt_lang,
        )
        dataloader = DataLoader(train_dataset,
                                batch_size=2,
                                collate_fn=train_dataset.collate_fn)
        for batch in dataloader:
            assert isinstance(batch, dict)
            assert batch["attention_mask"].shape == batch["input_ids"].shape
            # show that articles were trimmed.
            assert batch["input_ids"].shape[1] == max_src_len
            # show that targets are the same len
            assert batch["labels"].shape[1] == max_tgt_len
            if tok_name != MBART_TINY:
                continue
            # check language codes in correct place
            batch["decoder_input_ids"] = shift_tokens_right(
                batch["labels"], tokenizer.pad_token_id)
            assert batch["decoder_input_ids"][
                0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
            assert batch["decoder_input_ids"][
                0, -1].item() == tokenizer.eos_token_id
            assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
            assert batch["input_ids"][
                0, -1].item() == tokenizer.lang_code_to_id[src_lang]

            break  # No need to test every batch