def test_dataset_kwargs(self, tok_name):
     tokenizer = AutoTokenizer.from_pretrained(tok_name)
     if tok_name == MBART_TINY:
         train_dataset = Seq2SeqDataset(
             tokenizer,
             data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
             type_path="train",
             max_source_length=4,
             max_target_length=8,
             src_lang="EN",
             tgt_lang="FR",
         )
         kwargs = train_dataset.dataset_kwargs
         assert "src_lang" in kwargs and "tgt_lang" in kwargs
     else:
         train_dataset = Seq2SeqDataset(
             tokenizer,
             data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
             type_path="train",
             max_source_length=4,
             max_target_length=8,
         )
         kwargs = train_dataset.dataset_kwargs
         assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
         assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
    def test_pack_dataset(self):
        tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

        tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
        orig_examples = tmp_dir.joinpath("train.source").open().readlines()
        save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
        pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
        orig_paths = {x.name for x in tmp_dir.iterdir()}
        new_paths = {x.name for x in save_dir.iterdir()}
        packed_examples = save_dir.joinpath("train.source").open().readlines()
        # orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
        # desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
        assert len(packed_examples) < len(orig_examples)
        assert len(packed_examples) == 1
        assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
        assert orig_paths == new_paths
Exemple #3
0
 def test_legacy_dataset_truncation(self, tok):
     tokenizer = AutoTokenizer.from_pretrained(tok)
     tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
     max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
     max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
     trunc_target = 4
     train_dataset = LegacySeq2SeqDataset(
         tokenizer,
         data_dir=tmp_dir,
         type_path="train",
         max_source_length=20,
         max_target_length=trunc_target,
     )
     dataloader = DataLoader(train_dataset,
                             batch_size=2,
                             collate_fn=train_dataset.collate_fn)
     for batch in dataloader:
         assert batch["attention_mask"].shape == batch["input_ids"].shape
         # show that articles were trimmed.
         assert batch["input_ids"].shape[1] == max_len_source
         assert 20 >= batch["input_ids"].shape[1]  # trimmed significantly
         # show that targets were truncated
         assert batch["labels"].shape[1] == trunc_target  # Truncated
         assert max_len_target > trunc_target  # Truncated
         break  # No need to test every batch
Exemple #4
0
    def test_seq2seq_dataset_truncation(self, tok_name):
        tokenizer = AutoTokenizer.from_pretrained(tok_name)
        tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
        max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
        max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
        max_src_len = 4
        max_tgt_len = 8
        assert max_len_target > max_src_len  # Will be truncated
        assert max_len_source > max_src_len  # Will be truncated
        src_lang, tgt_lang = "ro_RO", "de_DE"  # ignored for all but mbart, but never causes error.
        train_dataset = Seq2SeqDataset(
            tokenizer,
            data_dir=tmp_dir,
            type_path="train",
            max_source_length=max_src_len,
            max_target_length=max_tgt_len,  # ignored
            src_lang=src_lang,
            tgt_lang=tgt_lang,
        )
        dataloader = DataLoader(train_dataset,
                                batch_size=2,
                                collate_fn=train_dataset.collate_fn)
        for batch in dataloader:
            assert isinstance(batch, dict)
            assert batch["attention_mask"].shape == batch["input_ids"].shape
            # show that articles were trimmed.
            assert batch["input_ids"].shape[1] == max_src_len
            # show that targets are the same len
            assert batch["labels"].shape[1] == max_tgt_len
            if tok_name != MBART_TINY:
                continue
            # check language codes in correct place
            batch["decoder_input_ids"] = shift_tokens_right(
                batch["labels"], tokenizer.pad_token_id)
            assert batch["decoder_input_ids"][
                0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
            assert batch["decoder_input_ids"][
                0, -1].item() == tokenizer.eos_token_id
            assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
            assert batch["input_ids"][
                0, -1].item() == tokenizer.lang_code_to_id[src_lang]

            break  # No need to test every batch