def test_collate_char_dataset(self): src_dataset = char_data.InMemoryNumpyWordCharDataset() src_dataset.parse( self.src_txt, self.d, self.char_dict, reverse_order=True, append_eos=True ) tgt_dataset = char_data.InMemoryNumpyWordCharDataset() tgt_dataset.parse( self.trg_txt, self.d, self.char_dict, reverse_order=True, append_eos=True ) char_dataset = char_data.LanguagePairCharDataset( src=src_dataset, src_sizes=src_dataset.sizes, src_dict=self.d, tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=self.d, ) samples = [char_dataset[i] for i in range(len(char_dataset))] collate_data = char_dataset.collater(samples) ids = collate_data["id"] ntokens = collate_data["ntokens"] assert len(ids) == 4 assert ntokens == 32 net_input = collate_data["net_input"] assert net_input["char_inds"].size() == torch.Size([4, 11, 4]) assert net_input["prev_output_chars"].size() == torch.Size([4, 12, 4]) assert collate_data["target_char_inds"].size() == torch.Size([4, 11, 4]) assert net_input["prev_output_word_lengths"].size() == torch.Size([4, 12]) for i in range(net_input["prev_output_chars"].size()[0]): assert net_input["prev_output_chars"][i, 0, 0] == self.d.eos_index # Asseting that the generated word before the first was is only eos. assert net_input["prev_output_word_lengths"][i][0] == 1
def test_collate_char_dataset_w_unk(self): """ We intentionally intorduce a text that has unknown words in it. """ # srct and trgt are unknown words. src_txt = test_utils.write_lines_to_temp_file( ["srcA srct srcA srcA", "srcA srcA srcA srcA"]) trg_txt = test_utils.write_lines_to_temp_file( ["trgA trgA trgt trgB", "trgA trgB trgC trgD"]) src_dataset = char_data.InMemoryNumpyWordCharDataset() src_dataset.parse(src_txt, self.d, self.char_dict, reverse_order=True, append_eos=True) tgt_dataset = char_data.InMemoryNumpyWordCharDataset( ignore_chars_for_unks=True) tgt_dataset.parse(trg_txt, self.d, self.char_dict, reverse_order=True, append_eos=True) # Confirming that the third word in an unknown. assert tgt_dataset.char_offsets[1] + 1 == tgt_dataset.char_offsets[2] assert (tgt_dataset.char_buffer[tgt_dataset.char_offsets[1]] == self.char_dict.eos()) char_dataset = char_data.LanguagePairCharDataset( src=src_dataset, src_sizes=src_dataset.sizes, src_dict=self.d, tgt=tgt_dataset, tgt_sizes=tgt_dataset.sizes, tgt_dict=self.d, ) samples = [char_dataset[i] for i in range(len(char_dataset))] collate_data = char_dataset.collater(samples) ids = collate_data["id"] assert len(ids) == 2 net_input = collate_data["net_input"] assert net_input["prev_output_word_lengths"][0][2] == 1 assert torch.equal(net_input["prev_output_chars"][0][2], torch.tensor([2, 0, 0, 0])) assert torch.equal(collate_data["target_char_inds"][0][1], torch.tensor([2, 0, 0, 0]))
def test_collate_char_dataset_no_tgt(self): src_dataset = char_data.InMemoryNumpyWordCharDataset() src_dataset.parse( self.src_txt, self.d, self.char_dict, reverse_order=True, append_eos=True ) char_dataset = char_data.LanguagePairCharDataset( src=src_dataset, src_sizes=src_dataset.sizes, src_dict=self.d ) samples = [char_dataset[i] for i in range(len(char_dataset))] collate_data = char_dataset.collater(samples) ids = collate_data["id"] ntokens = collate_data["ntokens"] assert len(ids) == 4 assert ntokens is None net_input = collate_data["net_input"] assert net_input["char_inds"].size() == torch.Size([4, 11, 4]) assert net_input["prev_output_chars"] is None assert collate_data["target_char_inds"] is None assert net_input["prev_output_word_lengths"] is None