コード例 #1
0
 def test_parse_multiling(self):
     prepend_dataset = data.InMemoryNumpyDataset()
     append_dataset = data.InMemoryNumpyDataset()
     corpora = [
         data.MultilingualCorpusConfig(
             dialect_id=10, data_file=self.trg_txt, dict=self.d, oversampling=1
         ),
         data.MultilingualCorpusConfig(
             dialect_id=11, data_file=self.trg_txt, dict=self.d, oversampling=1
         ),
     ]
     lang1 = corpora[0].dialect_id
     lang2 = corpora[1].dialect_id
     prepend_dataset.parse_multilingual(
         corpora, reverse_order=False, append_eos=False, prepend_language_id=True
     )
     append_dataset.parse_multilingual(
         corpora, reverse_order=False, append_eos=False, prepend_language_id=False
     )
     self.assertEqual(2 * self.num_sentences, len(prepend_dataset))
     self.assertEqual(2 * self.num_sentences, len(append_dataset))
     for i in range(self.num_sentences):
         self.assertListEqual([lang1] + self.trg_ref[i], prepend_dataset[i].tolist())
         self.assertListEqual(self.trg_ref[i] + [lang1], append_dataset[i].tolist())
         self.assertListEqual(
             [lang2] + self.trg_ref[i],
             prepend_dataset[i + self.num_sentences].tolist(),
         )
         self.assertListEqual(
             self.trg_ref[i] + [lang2],
             append_dataset[i + self.num_sentences].tolist(),
         )
コード例 #2
0
ファイル: preprocess.py プロジェクト: gokmonk/translate
def make_multiling_corpus_configs(
    language_ids,
    text_files,
    dictionaries,
    char_dictionaries=None,
    oversampling_rates=None,
):
    if not oversampling_rates:
        oversampling_rates = [1] * len(language_ids)
    if char_dictionaries is None:
        char_dictionaries = [None] * len(language_ids)
    assert len(language_ids) == len(text_files)
    assert len(language_ids) == len(dictionaries)
    assert len(language_ids) == len(oversampling_rates)
    return [
        pytorch_translate_data.MultilingualCorpusConfig(
            dialect_id=None
            if i is None
            else i + pytorch_translate_data.MULTILING_DIALECT_ID_OFFSET,
            data_file=p,
            dict=d,
            char_dict=cd,
            oversampling=o,
        )
        for i, p, d, cd, o in zip(
            language_ids,
            text_files,
            dictionaries,
            char_dictionaries,
            oversampling_rates,
        )
    ]
コード例 #3
0
 def test_parse_oversampling(self):
     dataset = data.InMemoryNumpyDataset()
     factors = [(1, 0), (3, 2), (4, 4)]
     for o1, o2 in factors:
         corpora = [
             data.MultilingualCorpusConfig(
                 dialect_id=None,
                 data_file=self.trg_txt,
                 dict=self.d,
                 oversampling=o1,
             ),
             data.MultilingualCorpusConfig(
                 dialect_id=None,
                 data_file=self.trg_txt,
                 dict=self.d,
                 oversampling=o2,
             ),
         ]
         dataset.parse_multilingual(corpora)
         self.assertEqual((o1 + o2) * self.num_sentences, len(dataset))