def load_facebook_xml_model(): print('loading facebook-XLM model..') # load pretrained model model_path = 'XLM/models/mlm_tlm_xnli15_1024.pth' reloaded = torch.load(model_path) params = AttrDict(reloaded['params']) #print("Supported languages: %s" % ", ".join(params.lang2id.keys())) # Build dictionary / update parameters / build model dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) params.n_words = len(dico) params.bos_index = dico.index(BOS_WORD) params.eos_index = dico.index(EOS_WORD) params.pad_index = dico.index(PAD_WORD) params.unk_index = dico.index(UNK_WORD) params.mask_index = dico.index(MASK_WORD) # build model / reload weights model = TransformerModel(params, dico, True, True) model.load_state_dict(reloaded['model']) # get bpe bpe = get_bpe() return model, params, dico, bpe
def __init__(self, model_path, tgt_lang, src_lang,dump_path = "./dumped/", exp_name="translate", exp_id="test", batch_size=32): params = Param(dump_path, exp_name, exp_id, batch_size, model_path, tgt_lang, src_lang) assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang # initialize the experiment logger = initialize_exp(params) # On a pas de GPU #reloaded = torch.load(params.model_path) reloaded = torch.load(params.model_path, map_location=torch.device('cpu')) model_params = AttrDict(reloaded['params']) self.supported_languages = model_params.lang2id.keys() logger.info("Supported languages: %s" % ", ".join(self.supported_languages)) # update dictionary parameters for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']: setattr(params, name, getattr(model_params, name)) # build dictionary / build encoder / build decoder / reload weights self.dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) #self.encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval() self.encoder = TransformerModel(model_params, self.dico, is_encoder=True, with_output=True).eval() #self.decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval() self.decoder = TransformerModel(model_params, self.dico, is_encoder=False, with_output=True).eval() self.encoder.load_state_dict(reloaded['encoder']) self.decoder.load_state_dict(reloaded['decoder']) params.src_id = model_params.lang2id[params.src_lang] params.tgt_id = model_params.lang2id[params.tgt_lang] self.model_params = model_params self.params = params
def __init__(self, params): reloaded = torch.load(params.model_path, map_location='cpu') # print(reloaded['dico_word2id']['while']) # print(reloaded['dico_word2id']['return']) # print(reloaded['dico_word2id']['if']) # print(reloaded['encoder'].keys()) # print(reloaded['decoder'].keys()) reloaded['encoder'] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in reloaded['encoder'].items()} assert 'decoder' in reloaded or ( 'decoder_0' in reloaded and 'decoder_1' in reloaded) if 'decoder' in reloaded: decoders_names = ['decoder'] else: decoders_names = ['decoder_0', 'decoder_1'] for decoder_name in decoders_names: reloaded[decoder_name] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in reloaded[decoder_name].items()} self.reloaded_params = AttrDict(reloaded['params']) # build dictionary / update parameters self.dico = Dictionary( reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) assert self.reloaded_params.n_words == len(self.dico) assert self.reloaded_params.bos_index == self.dico.index(BOS_WORD) assert self.reloaded_params.eos_index == self.dico.index(EOS_WORD) assert self.reloaded_params.pad_index == self.dico.index(PAD_WORD) assert self.reloaded_params.unk_index == self.dico.index(UNK_WORD) assert self.reloaded_params.mask_index == self.dico.index(MASK_WORD) # build model / reload weights self.reloaded_params['reload_model'] = ','.join([params.model_path] * 2) encoder, decoder = build_model(self.reloaded_params, self.dico) self.encoder = encoder[0] self.encoder.load_state_dict(reloaded['encoder']) assert len(reloaded['encoder'].keys()) == len( list(p for p, _ in self.encoder.state_dict().items())) self.decoder = decoder[0] self.decoder.load_state_dict(reloaded['decoder']) assert len(reloaded['decoder'].keys()) == len( list(p for p, _ in self.decoder.state_dict().items())) #self.encoder.to('cpu') #cuda() #self.decoder.to('cpu') #cuda() self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() self.bpe_model = fastBPE.fastBPE(os.path.abspath(params.BPE_path))
def __init__(self, src_lang, tgt_lang): model_path = TranscoderClient.get_model_path(src_lang, tgt_lang) reloaded = torch.load(model_path, map_location='cpu') reloaded['encoder'] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in reloaded['encoder'].items()} assert 'decoder' in reloaded or ( 'decoder_0' in reloaded and 'decoder_1' in reloaded) if 'decoder' in reloaded: decoders_names = ['decoder'] else: decoders_names = ['decoder_0', 'decoder_1'] for decoder_name in decoders_names: reloaded[decoder_name] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in reloaded[decoder_name].items()} self.reloaded_params = AttrDict(reloaded['params']) # build dictionary / update parameters self.dico = Dictionary( reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) assert self.reloaded_params.n_words == len(self.dico) assert self.reloaded_params.bos_index == self.dico.index(BOS_WORD) assert self.reloaded_params.eos_index == self.dico.index(EOS_WORD) assert self.reloaded_params.pad_index == self.dico.index(PAD_WORD) assert self.reloaded_params.unk_index == self.dico.index(UNK_WORD) assert self.reloaded_params.mask_index == self.dico.index(MASK_WORD) # build model / reload weights self.reloaded_params['reload_model'] = ','.join([model_path] * 2) encoder, decoder = build_model(self.reloaded_params, self.dico) self.encoder = encoder[0] self.encoder.load_state_dict(reloaded['encoder']) assert len(reloaded['encoder'].keys()) == len( list(p for p, _ in self.encoder.state_dict().items())) self.decoder = decoder[0] self.decoder.load_state_dict(reloaded['decoder']) assert len(reloaded['decoder'].keys()) == len( list(p for p, _ in self.decoder.state_dict().items())) self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() self.bpe_model = fastBPE.fastBPE(os.path.abspath(BPE_PATH)) self.allowed_languages = [lang.value for lang in Languages]
import torch import global_variables as glob from XLM.src.utils import AttrDict from XLM.src.data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD from XLM.src.model.transformer import TransformerModel """_________________________________________ XML ________________________________________________""" model_path = 'pre-trained_embeddings/mlm_enfr_1024.pth' reloaded = torch.load(model_path) params = AttrDict(reloaded['params']) # build dictionary / update parameters dicoXLM = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) params.n_words = len(dicoXLM) params.bos_index = dicoXLM.index(BOS_WORD) params.eos_index = dicoXLM.index(EOS_WORD) params.pad_index = dicoXLM.index(PAD_WORD) params.unk_index = dicoXLM.index(UNK_WORD) params.mask_index = dicoXLM.index(MASK_WORD) # build model / reload weights XLMmodel = TransformerModel(params, dicoXLM, True, True) XLMmodel.load_state_dict(reloaded['model']) def sen_list_to_xlm_sen_list(sentences): sen_list = [] for s in sentences: sen_list.append((s.as_str(), 'fr')) return sen_list