def test_parse_multiling(self): prepend_dataset = data.InMemoryNumpyDataset() append_dataset = data.InMemoryNumpyDataset() corpora = [ data.MultilingualCorpusConfig( dialect_id=10, data_file=self.trg_txt, dict=self.d, oversampling=1 ), data.MultilingualCorpusConfig( dialect_id=11, data_file=self.trg_txt, dict=self.d, oversampling=1 ), ] lang1 = corpora[0].dialect_id lang2 = corpora[1].dialect_id prepend_dataset.parse_multilingual( corpora, reverse_order=False, append_eos=False, prepend_language_id=True ) append_dataset.parse_multilingual( corpora, reverse_order=False, append_eos=False, prepend_language_id=False ) self.assertEqual(2 * self.num_sentences, len(prepend_dataset)) self.assertEqual(2 * self.num_sentences, len(append_dataset)) for i in range(self.num_sentences): self.assertListEqual([lang1] + self.trg_ref[i], prepend_dataset[i].tolist()) self.assertListEqual(self.trg_ref[i] + [lang1], append_dataset[i].tolist()) self.assertListEqual( [lang2] + self.trg_ref[i], prepend_dataset[i + self.num_sentences].tolist(), ) self.assertListEqual( self.trg_ref[i] + [lang2], append_dataset[i + self.num_sentences].tolist(), )
def make_multiling_corpus_configs( language_ids, text_files, dictionaries, char_dictionaries=None, oversampling_rates=None, ): if not oversampling_rates: oversampling_rates = [1] * len(language_ids) if char_dictionaries is None: char_dictionaries = [None] * len(language_ids) assert len(language_ids) == len(text_files) assert len(language_ids) == len(dictionaries) assert len(language_ids) == len(oversampling_rates) return [ pytorch_translate_data.MultilingualCorpusConfig( dialect_id=None if i is None else i + pytorch_translate_data.MULTILING_DIALECT_ID_OFFSET, data_file=p, dict=d, char_dict=cd, oversampling=o, ) for i, p, d, cd, o in zip( language_ids, text_files, dictionaries, char_dictionaries, oversampling_rates, ) ]
def test_parse_oversampling(self): dataset = data.InMemoryNumpyDataset() factors = [(1, 0), (3, 2), (4, 4)] for o1, o2 in factors: corpora = [ data.MultilingualCorpusConfig( dialect_id=None, data_file=self.trg_txt, dict=self.d, oversampling=o1, ), data.MultilingualCorpusConfig( dialect_id=None, data_file=self.trg_txt, dict=self.d, oversampling=o2, ), ] dataset.parse_multilingual(corpora) self.assertEqual((o1 + o2) * self.num_sentences, len(dataset))