def __init__(self, args, source_dictionaries, target_dictionaries): self.source_dictionaries = source_dictionaries self.target_dictionaries = target_dictionaries # Mapping from language IDs to language codes. During training # this list is fully populated. During generation we typically # have only a single source/target dictionary, thus it is important to # call set_encoder/decoder_langs to properly populate these. self.encoder_langs = list(source_dictionaries.keys()) self.decoder_langs = list(target_dictionaries.keys()) self.src_dict = pytorch_translate_dictionary.MaxVocabDictionary() for d in source_dictionaries.values(): self.src_dict.push(d) self.tgt_dict = pytorch_translate_dictionary.MaxVocabDictionary() for d in target_dictionaries.values(): self.tgt_dict.push(d) super().__init__(args, self.src_dict, self.tgt_dict)
def test_push(self): max_vocab_dict = dictionary.MaxVocabDictionary() src_txt, trg_txt = test_utils.create_test_text_files() tmp_prefix = test_utils.make_temp_file() src_dict = dictionary.Dictionary.build_vocab_file( corpus_files=[src_txt], vocab_file=f"{tmp_prefix}.src", max_vocab_size=1000) srctrg_dict = dictionary.Dictionary.build_vocab_file( corpus_files=[src_txt, trg_txt], vocab_file=f"{tmp_prefix}.srctrg", max_vocab_size=1000, ) self.assertEqual(len(max_vocab_dict), max_vocab_dict.nspecial) max_vocab_dict.push(src_dict) self.assertEqual(len(max_vocab_dict), len(src_dict)) max_vocab_dict.push(srctrg_dict) self.assertEqual(len(max_vocab_dict), len(srctrg_dict)) max_vocab_dict.push(src_dict) self.assertEqual(len(max_vocab_dict), len(srctrg_dict)) os.remove(f"{tmp_prefix}.src") os.remove(f"{tmp_prefix}.srctrg") os.remove(src_txt) os.remove(trg_txt)