예제 #1
0
    def __init__(self, args, source_dictionaries, target_dictionaries):
        self.source_dictionaries = source_dictionaries
        self.target_dictionaries = target_dictionaries

        # Mapping from language IDs to language codes. During training
        # this list is fully populated. During generation we typically
        # have only a single source/target dictionary, thus it is important to
        # call set_encoder/decoder_langs to properly populate these.
        self.encoder_langs = list(source_dictionaries.keys())
        self.decoder_langs = list(target_dictionaries.keys())

        self.src_dict = pytorch_translate_dictionary.MaxVocabDictionary()
        for d in source_dictionaries.values():
            self.src_dict.push(d)
        self.tgt_dict = pytorch_translate_dictionary.MaxVocabDictionary()
        for d in target_dictionaries.values():
            self.tgt_dict.push(d)

        super().__init__(args, self.src_dict, self.tgt_dict)
예제 #2
0
 def test_push(self):
     max_vocab_dict = dictionary.MaxVocabDictionary()
     src_txt, trg_txt = test_utils.create_test_text_files()
     tmp_prefix = test_utils.make_temp_file()
     src_dict = dictionary.Dictionary.build_vocab_file(
         corpus_files=[src_txt],
         vocab_file=f"{tmp_prefix}.src",
         max_vocab_size=1000)
     srctrg_dict = dictionary.Dictionary.build_vocab_file(
         corpus_files=[src_txt, trg_txt],
         vocab_file=f"{tmp_prefix}.srctrg",
         max_vocab_size=1000,
     )
     self.assertEqual(len(max_vocab_dict), max_vocab_dict.nspecial)
     max_vocab_dict.push(src_dict)
     self.assertEqual(len(max_vocab_dict), len(src_dict))
     max_vocab_dict.push(srctrg_dict)
     self.assertEqual(len(max_vocab_dict), len(srctrg_dict))
     max_vocab_dict.push(src_dict)
     self.assertEqual(len(max_vocab_dict), len(srctrg_dict))
     os.remove(f"{tmp_prefix}.src")
     os.remove(f"{tmp_prefix}.srctrg")
     os.remove(src_txt)
     os.remove(trg_txt)