コード例 #1
0
 def test_collate_char_dataset(self):
     src_dataset = char_data.InMemoryNumpyWordCharDataset()
     src_dataset.parse(
         self.src_txt, self.d, self.char_dict, reverse_order=True, append_eos=True
     )
     tgt_dataset = char_data.InMemoryNumpyWordCharDataset()
     tgt_dataset.parse(
         self.trg_txt, self.d, self.char_dict, reverse_order=True, append_eos=True
     )
     char_dataset = char_data.LanguagePairCharDataset(
         src=src_dataset,
         src_sizes=src_dataset.sizes,
         src_dict=self.d,
         tgt=tgt_dataset,
         tgt_sizes=tgt_dataset.sizes,
         tgt_dict=self.d,
     )
     samples = [char_dataset[i] for i in range(len(char_dataset))]
     collate_data = char_dataset.collater(samples)
     ids = collate_data["id"]
     ntokens = collate_data["ntokens"]
     assert len(ids) == 4
     assert ntokens == 32
     net_input = collate_data["net_input"]
     assert net_input["char_inds"].size() == torch.Size([4, 11, 4])
     assert net_input["prev_output_chars"].size() == torch.Size([4, 12, 4])
     assert collate_data["target_char_inds"].size() == torch.Size([4, 11, 4])
     assert net_input["prev_output_word_lengths"].size() == torch.Size([4, 12])
     for i in range(net_input["prev_output_chars"].size()[0]):
         assert net_input["prev_output_chars"][i, 0, 0] == self.d.eos_index
         # Asseting that the generated word before the first was is only eos.
         assert net_input["prev_output_word_lengths"][i][0] == 1
コード例 #2
0
    def test_collate_char_dataset_w_unk(self):
        """
        We intentionally intorduce a text that has unknown words in it.
        """

        # srct and trgt are unknown words.
        src_txt = test_utils.write_lines_to_temp_file(
            ["srcA srct srcA srcA", "srcA srcA srcA srcA"])
        trg_txt = test_utils.write_lines_to_temp_file(
            ["trgA trgA trgt trgB", "trgA trgB trgC trgD"])
        src_dataset = char_data.InMemoryNumpyWordCharDataset()
        src_dataset.parse(src_txt,
                          self.d,
                          self.char_dict,
                          reverse_order=True,
                          append_eos=True)
        tgt_dataset = char_data.InMemoryNumpyWordCharDataset(
            ignore_chars_for_unks=True)
        tgt_dataset.parse(trg_txt,
                          self.d,
                          self.char_dict,
                          reverse_order=True,
                          append_eos=True)

        # Confirming that the third word in an unknown.
        assert tgt_dataset.char_offsets[1] + 1 == tgt_dataset.char_offsets[2]
        assert (tgt_dataset.char_buffer[tgt_dataset.char_offsets[1]] ==
                self.char_dict.eos())

        char_dataset = char_data.LanguagePairCharDataset(
            src=src_dataset,
            src_sizes=src_dataset.sizes,
            src_dict=self.d,
            tgt=tgt_dataset,
            tgt_sizes=tgt_dataset.sizes,
            tgt_dict=self.d,
        )
        samples = [char_dataset[i] for i in range(len(char_dataset))]
        collate_data = char_dataset.collater(samples)
        ids = collate_data["id"]
        assert len(ids) == 2
        net_input = collate_data["net_input"]

        assert net_input["prev_output_word_lengths"][0][2] == 1
        assert torch.equal(net_input["prev_output_chars"][0][2],
                           torch.tensor([2, 0, 0, 0]))
        assert torch.equal(collate_data["target_char_inds"][0][1],
                           torch.tensor([2, 0, 0, 0]))
コード例 #3
0
    def test_subsample_char_dataset(self):
        """
        Test the InMemoryNumpyWordCharDataset.subsample() method, ensuring that
        the examples produced by the dataset are correctly permuted according to
        the indices argument.
        """
        src_dataset = char_data.InMemoryNumpyWordCharDataset()
        src_dataset.parse(self.src_txt,
                          self.d,
                          self.char_dict,
                          reverse_order=True,
                          append_eos=False)

        indices = np.random.permutation(len(src_dataset))[:2]
        token_samples = [src_dataset.get_tokens(i) for i in indices]
        char_samples = [src_dataset.get_chars_list(i) for i in indices]
        src_dataset.subsample(indices)
        for i in range(2):
            assert all(
                src_dataset.get_tokens(i).numpy() == token_samples[i].numpy())
            orig_chars_list = char_samples[i]
            sampled_chars_list = src_dataset.get_chars_list(i)
            assert len(sampled_chars_list) == len(orig_chars_list)
            for sampled_chars, orig_chars in zip(sampled_chars_list,
                                                 orig_chars_list):
                assert all(sampled_chars.numpy() == orig_chars.numpy())
コード例 #4
0
ファイル: preprocess.py プロジェクト: gokmonk/translate
def binarize_text_file(
    text_file: str,
    dictionary: Dictionary,
    output_path: str,
    append_eos: bool,
    reverse_order: bool,
    use_char_data: bool = False,
    embed_bytes: bool = False,
    char_dictionary: Optional[Dictionary] = None,
    already_numberized: bool = False,
) -> str:
    output_path = maybe_generate_temp_file_path(output_path)
    if use_char_data:
        dataset = char_data.InMemoryNumpyWordCharDataset()
        dataset.parse(
            path=text_file,
            word_dict=dictionary,
            char_dict=char_dictionary,
            embed_bytes=embed_bytes,
            reverse_order=reverse_order,
            append_eos=append_eos,
        )
    else:
        dataset = pytorch_translate_data.InMemoryNumpyDataset()
        dataset.parse(
            path=text_file,
            dictionary=dictionary,
            reverse_order=reverse_order,
            append_eos=append_eos,
            already_numberized=already_numberized,
        )
    dataset.save(output_path)
    return output_path
コード例 #5
0
ファイル: preprocess.py プロジェクト: gokmonk/translate
def binarize_text_file_multilingual(
    corpus_configs: List[pytorch_translate_data.MultilingualCorpusConfig],
    output_path: str,
    append_eos: bool,
    reverse_order: bool,
    prepend_language_id: bool,
    use_char_data: bool = False,
    embed_bytes: bool = False,
    already_numberized: bool = False,
) -> str:
    output_path = maybe_generate_temp_file_path(output_path)
    if use_char_data:
        dataset = char_data.InMemoryNumpyWordCharDataset()
        dataset.parse_multilingual(
            corpus_configs,
            reverse_order=reverse_order,
            append_eos=append_eos,
            embed_bytes=embed_bytes,
            prepend_language_id=prepend_language_id,
            already_numberized=already_numberized,
        )
    else:
        dataset = pytorch_translate_data.InMemoryNumpyDataset()
        dataset.parse_multilingual(
            corpus_configs,
            append_eos=append_eos,
            reverse_order=reverse_order,
            prepend_language_id=prepend_language_id,
            already_numberized=already_numberized,
        )
    dataset.save(output_path)
    return output_path
コード例 #6
0
    def load_dataset_from_text(
        self,
        split: str,
        source_text_file: str,
        target_text_file: str,
        append_eos: Optional[bool] = False,
        reverse_source: Optional[bool] = True,
    ):
        dst_dataset = data.IndexedRawTextDataset(
            path=target_text_file,
            dictionary=self.target_dictionary,
            # We always append EOS to the target sentence since we still want
            # the model to output an indication the sentence has finished, even
            # if we don't append the EOS symbol to the source sentence
            # (to prevent the model from misaligning UNKs or other words
            # to the frequently occurring EOS).
            append_eos=True,
            # We don't reverse the order of the target sentence, since
            # even if the source sentence is fed to the model backwards,
            # we still want the model to start outputting from the first word.
            reverse_order=False,
        )

        if self.char_source_dict is not None:
            src_dataset = char_data.InMemoryNumpyWordCharDataset()
            src_dataset.parse(
                path=source_text_file,
                word_dict=self.source_dictionary,
                char_dict=self.char_source_dict,
                reverse_order=reverse_source,
                append_eos=append_eos,
            )
            self.datasets[split] = char_data.LanguagePairSourceCharDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
            )
        else:
            src_dataset = data.IndexedRawTextDataset(
                path=source_text_file,
                dictionary=self.source_dictionary,
                append_eos=append_eos,
                reverse_order=reverse_source,
            )
            self.datasets[split] = data.LanguagePairDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.source_dictionary,
                tgt=dst_dataset,
                tgt_sizes=dst_dataset.sizes,
                tgt_dict=self.target_dictionary,
                left_pad_source=False,
            )

        print(f"| {split} {len(self.datasets[split])} examples")
コード例 #7
0
 def test_collate_char_dataset_no_tgt(self):
     src_dataset = char_data.InMemoryNumpyWordCharDataset()
     src_dataset.parse(
         self.src_txt, self.d, self.char_dict, reverse_order=True, append_eos=True
     )
     char_dataset = char_data.LanguagePairCharDataset(
         src=src_dataset, src_sizes=src_dataset.sizes, src_dict=self.d
     )
     samples = [char_dataset[i] for i in range(len(char_dataset))]
     collate_data = char_dataset.collater(samples)
     ids = collate_data["id"]
     ntokens = collate_data["ntokens"]
     assert len(ids) == 4
     assert ntokens is None
     net_input = collate_data["net_input"]
     assert net_input["char_inds"].size() == torch.Size([4, 11, 4])
     assert net_input["prev_output_chars"] is None
     assert collate_data["target_char_inds"] is None
     assert net_input["prev_output_word_lengths"] is None