示例#1
0
 def test_collate_char_dataset(self):
     src_dataset = char_data.InMemoryNumpyWordCharDataset()
     src_dataset.parse(
         self.src_txt, self.d, self.char_dict, reverse_order=True, append_eos=True
     )
     tgt_dataset = char_data.InMemoryNumpyWordCharDataset()
     tgt_dataset.parse(
         self.trg_txt, self.d, self.char_dict, reverse_order=True, append_eos=True
     )
     char_dataset = char_data.LanguagePairCharDataset(
         src=src_dataset,
         src_sizes=src_dataset.sizes,
         src_dict=self.d,
         tgt=tgt_dataset,
         tgt_sizes=tgt_dataset.sizes,
         tgt_dict=self.d,
     )
     samples = [char_dataset[i] for i in range(len(char_dataset))]
     collate_data = char_dataset.collater(samples)
     ids = collate_data["id"]
     ntokens = collate_data["ntokens"]
     assert len(ids) == 4
     assert ntokens == 32
     net_input = collate_data["net_input"]
     assert net_input["char_inds"].size() == torch.Size([4, 11, 4])
     assert net_input["prev_output_chars"].size() == torch.Size([4, 12, 4])
     assert collate_data["target_char_inds"].size() == torch.Size([4, 11, 4])
     assert net_input["prev_output_word_lengths"].size() == torch.Size([4, 12])
     for i in range(net_input["prev_output_chars"].size()[0]):
         assert net_input["prev_output_chars"][i, 0, 0] == self.d.eos_index
         # Asseting that the generated word before the first was is only eos.
         assert net_input["prev_output_word_lengths"][i][0] == 1
示例#2
0
    def test_collate_char_dataset_w_unk(self):
        """
        We intentionally intorduce a text that has unknown words in it.
        """

        # srct and trgt are unknown words.
        src_txt = test_utils.write_lines_to_temp_file(
            ["srcA srct srcA srcA", "srcA srcA srcA srcA"])
        trg_txt = test_utils.write_lines_to_temp_file(
            ["trgA trgA trgt trgB", "trgA trgB trgC trgD"])
        src_dataset = char_data.InMemoryNumpyWordCharDataset()
        src_dataset.parse(src_txt,
                          self.d,
                          self.char_dict,
                          reverse_order=True,
                          append_eos=True)
        tgt_dataset = char_data.InMemoryNumpyWordCharDataset(
            ignore_chars_for_unks=True)
        tgt_dataset.parse(trg_txt,
                          self.d,
                          self.char_dict,
                          reverse_order=True,
                          append_eos=True)

        # Confirming that the third word in an unknown.
        assert tgt_dataset.char_offsets[1] + 1 == tgt_dataset.char_offsets[2]
        assert (tgt_dataset.char_buffer[tgt_dataset.char_offsets[1]] ==
                self.char_dict.eos())

        char_dataset = char_data.LanguagePairCharDataset(
            src=src_dataset,
            src_sizes=src_dataset.sizes,
            src_dict=self.d,
            tgt=tgt_dataset,
            tgt_sizes=tgt_dataset.sizes,
            tgt_dict=self.d,
        )
        samples = [char_dataset[i] for i in range(len(char_dataset))]
        collate_data = char_dataset.collater(samples)
        ids = collate_data["id"]
        assert len(ids) == 2
        net_input = collate_data["net_input"]

        assert net_input["prev_output_word_lengths"][0][2] == 1
        assert torch.equal(net_input["prev_output_chars"][0][2],
                           torch.tensor([2, 0, 0, 0]))
        assert torch.equal(collate_data["target_char_inds"][0][1],
                           torch.tensor([2, 0, 0, 0]))
示例#3
0
 def test_collate_char_dataset_no_tgt(self):
     src_dataset = char_data.InMemoryNumpyWordCharDataset()
     src_dataset.parse(
         self.src_txt, self.d, self.char_dict, reverse_order=True, append_eos=True
     )
     char_dataset = char_data.LanguagePairCharDataset(
         src=src_dataset, src_sizes=src_dataset.sizes, src_dict=self.d
     )
     samples = [char_dataset[i] for i in range(len(char_dataset))]
     collate_data = char_dataset.collater(samples)
     ids = collate_data["id"]
     ntokens = collate_data["ntokens"]
     assert len(ids) == 4
     assert ntokens is None
     net_input = collate_data["net_input"]
     assert net_input["char_inds"].size() == torch.Size([4, 11, 4])
     assert net_input["prev_output_chars"] is None
     assert collate_data["target_char_inds"] is None
     assert net_input["prev_output_word_lengths"] is None