def test_dataset_kwargs(self, tok_name): tokenizer = AutoTokenizer.from_pretrained(tok_name) if tok_name == MBART_TINY: train_dataset = Seq2SeqDataset( tokenizer, data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()), type_path="train", max_source_length=4, max_target_length=8, src_lang="EN", tgt_lang="FR", ) kwargs = train_dataset.dataset_kwargs assert "src_lang" in kwargs and "tgt_lang" in kwargs else: train_dataset = Seq2SeqDataset( tokenizer, data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()), type_path="train", max_source_length=4, max_target_length=8, ) kwargs = train_dataset.dataset_kwargs assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
def test_pack_dataset(self): tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())) orig_examples = tmp_dir.joinpath("train.source").open().readlines() save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())) pack_data_dir(tokenizer, tmp_dir, 128, save_dir) orig_paths = {x.name for x in tmp_dir.iterdir()} new_paths = {x.name for x in save_dir.iterdir()} packed_examples = save_dir.joinpath("train.source").open().readlines() # orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.'] # desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.'] assert len(packed_examples) < len(orig_examples) assert len(packed_examples) == 1 assert len(packed_examples[0]) == sum(len(x) for x in orig_examples) assert orig_paths == new_paths
def test_legacy_dataset_truncation(self, tok): tokenizer = AutoTokenizer.from_pretrained(tok) tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) trunc_target = 4 train_dataset = LegacySeq2SeqDataset( tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target, ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: assert batch["attention_mask"].shape == batch["input_ids"].shape # show that articles were trimmed. assert batch["input_ids"].shape[1] == max_len_source assert 20 >= batch["input_ids"].shape[1] # trimmed significantly # show that targets were truncated assert batch["labels"].shape[1] == trunc_target # Truncated assert max_len_target > trunc_target # Truncated break # No need to test every batch
def test_seq2seq_dataset_truncation(self, tok_name): tokenizer = AutoTokenizer.from_pretrained(tok_name) tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_src_len = 4 max_tgt_len = 8 assert max_len_target > max_src_len # Will be truncated assert max_len_source > max_src_len # Will be truncated src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error. train_dataset = Seq2SeqDataset( tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=max_src_len, max_target_length=max_tgt_len, # ignored src_lang=src_lang, tgt_lang=tgt_lang, ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: assert isinstance(batch, dict) assert batch["attention_mask"].shape == batch["input_ids"].shape # show that articles were trimmed. assert batch["input_ids"].shape[1] == max_src_len # show that targets are the same len assert batch["labels"].shape[1] == max_tgt_len if tok_name != MBART_TINY: continue # check language codes in correct place batch["decoder_input_ids"] = shift_tokens_right( batch["labels"], tokenizer.pad_token_id) assert batch["decoder_input_ids"][ 0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] assert batch["decoder_input_ids"][ 0, -1].item() == tokenizer.eos_token_id assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id assert batch["input_ids"][ 0, -1].item() == tokenizer.lang_code_to_id[src_lang] break # No need to test every batch