Ejemplo n.º 1
0
    def __init__(self, model_path, tgt_lang, src_lang,dump_path = "./dumped/", exp_name="translate", exp_id="test", batch_size=32):
        
        params = Param(dump_path, exp_name, exp_id, batch_size, model_path, tgt_lang, src_lang)
        
        assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang

        # initialize the experiment
        logger = initialize_exp(params)
        
        # On a pas de GPU
        #reloaded = torch.load(params.model_path)
        reloaded = torch.load(params.model_path, map_location=torch.device('cpu'))
        model_params = AttrDict(reloaded['params'])
        self.supported_languages = model_params.lang2id.keys() 
        logger.info("Supported languages: %s" % ", ".join(self.supported_languages))

        # update dictionary parameters
        for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
            setattr(params, name, getattr(model_params, name))

        # build dictionary / build encoder / build decoder / reload weights
        self.dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
        #self.encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
        self.encoder = TransformerModel(model_params, self.dico, is_encoder=True, with_output=True).eval()
        #self.decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
        self.decoder = TransformerModel(model_params, self.dico, is_encoder=False, with_output=True).eval()
        self.encoder.load_state_dict(reloaded['encoder'])
        self.decoder.load_state_dict(reloaded['decoder'])
        params.src_id = model_params.lang2id[params.src_lang]
        params.tgt_id = model_params.lang2id[params.tgt_lang]
        self.model_params = model_params
        self.params = params
def load_facebook_xml_model():

    print('loading facebook-XLM model..')
    # load pretrained model
    model_path = 'XLM/models/mlm_tlm_xnli15_1024.pth'
    reloaded = torch.load(model_path)
    params = AttrDict(reloaded['params'])
    #print("Supported languages: %s" % ", ".join(params.lang2id.keys()))

    # Build dictionary / update parameters / build model
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build model / reload weights
    model = TransformerModel(params, dico, True, True)
    model.load_state_dict(reloaded['model'])

    # get bpe
    bpe = get_bpe()

    return model, params, dico, bpe
Ejemplo n.º 3
0
    def __init__(self, params):
        reloaded = torch.load(params.model_path, map_location='cpu')
#        print(reloaded['dico_word2id']['while'])
#        print(reloaded['dico_word2id']['return'])
#        print(reloaded['dico_word2id']['if'])
        
#        print(reloaded['encoder'].keys())
#        print(reloaded['decoder'].keys())
        reloaded['encoder'] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in
                               reloaded['encoder'].items()}
        assert 'decoder' in reloaded or (
            'decoder_0' in reloaded and 'decoder_1' in reloaded)
        if 'decoder' in reloaded:
            decoders_names = ['decoder']
        else:
            decoders_names = ['decoder_0', 'decoder_1']
        for decoder_name in decoders_names:
            reloaded[decoder_name] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in
                                      reloaded[decoder_name].items()}

        self.reloaded_params = AttrDict(reloaded['params'])

        # build dictionary / update parameters
        self.dico = Dictionary(
            reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
        assert self.reloaded_params.n_words == len(self.dico)
        assert self.reloaded_params.bos_index == self.dico.index(BOS_WORD)
        assert self.reloaded_params.eos_index == self.dico.index(EOS_WORD)
        assert self.reloaded_params.pad_index == self.dico.index(PAD_WORD)
        assert self.reloaded_params.unk_index == self.dico.index(UNK_WORD)
        assert self.reloaded_params.mask_index == self.dico.index(MASK_WORD)

        # build model / reload weights
        self.reloaded_params['reload_model'] = ','.join([params.model_path] * 2)
        encoder, decoder = build_model(self.reloaded_params, self.dico)

        self.encoder = encoder[0]
        self.encoder.load_state_dict(reloaded['encoder'])
        assert len(reloaded['encoder'].keys()) == len(
            list(p for p, _ in self.encoder.state_dict().items()))

        self.decoder = decoder[0]
        self.decoder.load_state_dict(reloaded['decoder'])
        assert len(reloaded['decoder'].keys()) == len(
            list(p for p, _ in self.decoder.state_dict().items()))

        #self.encoder.to('cpu') #cuda()
        #self.decoder.to('cpu') #cuda()
        self.encoder.cuda()
        self.decoder.cuda()

        self.encoder.eval()
        self.decoder.eval()
        self.bpe_model = fastBPE.fastBPE(os.path.abspath(params.BPE_path))
Ejemplo n.º 4
0
    def __init__(self, src_lang, tgt_lang):
        model_path = TranscoderClient.get_model_path(src_lang, tgt_lang)
        reloaded = torch.load(model_path, map_location='cpu')
        reloaded['encoder'] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in
                               reloaded['encoder'].items()}
        assert 'decoder' in reloaded or (
            'decoder_0' in reloaded and 'decoder_1' in reloaded)
        if 'decoder' in reloaded:
            decoders_names = ['decoder']
        else:
            decoders_names = ['decoder_0', 'decoder_1']
        for decoder_name in decoders_names:
            reloaded[decoder_name] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in
                                      reloaded[decoder_name].items()}

        self.reloaded_params = AttrDict(reloaded['params'])

        # build dictionary / update parameters
        self.dico = Dictionary(
            reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
        assert self.reloaded_params.n_words == len(self.dico)
        assert self.reloaded_params.bos_index == self.dico.index(BOS_WORD)
        assert self.reloaded_params.eos_index == self.dico.index(EOS_WORD)
        assert self.reloaded_params.pad_index == self.dico.index(PAD_WORD)
        assert self.reloaded_params.unk_index == self.dico.index(UNK_WORD)
        assert self.reloaded_params.mask_index == self.dico.index(MASK_WORD)

        # build model / reload weights
        self.reloaded_params['reload_model'] = ','.join([model_path] * 2)
        encoder, decoder = build_model(self.reloaded_params, self.dico)

        self.encoder = encoder[0]
        self.encoder.load_state_dict(reloaded['encoder'])
        assert len(reloaded['encoder'].keys()) == len(
            list(p for p, _ in self.encoder.state_dict().items()))

        self.decoder = decoder[0]
        self.decoder.load_state_dict(reloaded['decoder'])
        assert len(reloaded['decoder'].keys()) == len(
            list(p for p, _ in self.decoder.state_dict().items()))

        self.encoder.cuda()
        self.decoder.cuda()

        self.encoder.eval()
        self.decoder.eval()
        self.bpe_model = fastBPE.fastBPE(os.path.abspath(BPE_PATH))
        self.allowed_languages = [lang.value for lang in Languages]
Ejemplo n.º 5
0
class Translate():
    def __init__(self,
                 model_path,
                 tgt_lang,
                 src_lang,
                 dump_path="./dumped/",
                 exp_name="translate",
                 exp_id="test",
                 batch_size=32):

        params = Param(dump_path, exp_name, exp_id, batch_size, model_path,
                       tgt_lang, src_lang)

        assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang

        # initialize the experiment
        logger = initialize_exp(params)

        # On a pas de GPU
        #reloaded = torch.load(params.model_path)
        reloaded = torch.load(params.model_path,
                              map_location=torch.device('cpu'))
        model_params = AttrDict(reloaded['params'])
        self.supported_languages = model_params.lang2id.keys()
        logger.info("Supported languages: %s" %
                    ", ".join(self.supported_languages))

        # update dictionary parameters
        for name in [
                'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
                'mask_index'
        ]:
            setattr(params, name, getattr(model_params, name))

        print("params = ", params)

        # build dictionary / build encoder / build decoder / reload weights
        self.dico = Dictionary(reloaded['dico_id2word'],
                               reloaded['dico_word2id'],
                               reloaded['dico_counts'])
        #self.encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
        self.encoder = TransformerModel(model_params,
                                        self.dico,
                                        is_encoder=True,
                                        with_output=True).eval()
        #self.decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
        self.decoder = TransformerModel(model_params,
                                        self.dico,
                                        is_encoder=False,
                                        with_output=True).eval()
        self.encoder.load_state_dict(reloaded['encoder'])
        self.decoder.load_state_dict(reloaded['decoder'])
        params.src_id = model_params.lang2id[params.src_lang]
        params.tgt_id = model_params.lang2id[params.tgt_lang]
        self.model_params = model_params
        self.params = params

    def translate(self, src_sent=[]):
        flag = False
        if type(src_sent) == str:
            src_sent = [src_sent]
            flag = True
        tgt_sent = []
        for i in range(0, len(src_sent), self.params.batch_size):
            # prepare batch
            word_ids = [
                torch.LongTensor(
                    [self.dico.index(w) for w in s.strip().split()])
                for s in src_sent[i:i + self.params.batch_size]
            ]
            lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
            batch = torch.LongTensor(lengths.max().item(),
                                     lengths.size(0)).fill_(
                                         self.params.pad_index)
            batch[0] = self.params.eos_index
            for j, s in enumerate(word_ids):
                if lengths[j] > 2:  # if sentence not empty
                    batch[1:lengths[j] - 1, j].copy_(s)
                batch[lengths[j] - 1, j] = self.params.eos_index
            langs = batch.clone().fill_(self.params.src_id)

            # encode source batch and translate it
            #encoded = self.encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
            encoded = self.encoder('fwd',
                                   x=batch,
                                   lengths=lengths,
                                   langs=langs,
                                   causal=False)
            encoded = encoded.transpose(0, 1)
            #decoded, dec_lengths = self.decoder.generate(encoded, lengths.cuda(), self.params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
            decoded, dec_lengths = self.decoder.generate(
                encoded,
                lengths,
                self.params.tgt_id,
                max_len=int(1.5 * lengths.max().item() + 10))

            # convert sentences to words
            for j in range(decoded.size(1)):

                # remove delimiters
                sent = decoded[:, j]
                delimiters = (sent == self.params.eos_index).nonzero().view(-1)
                assert len(delimiters) >= 1 and delimiters[0].item() == 0
                sent = sent[1:] if len(
                    delimiters) == 1 else sent[1:delimiters[1]]

                # output translation
                source = src_sent[i + j].strip()
                target = " ".join(
                    [self.dico[sent[k].item()] for k in range(len(sent))])
                sys.stderr.write("%i / %i: %s -> %s\n" %
                                 (i + j, len(src_sent), source, target))
                tgt_sent.append(target)

        if flag:
            return tgt_sent[0]
        return tgt_sent
Ejemplo n.º 6
0
import torch
import global_variables as glob
from XLM.src.utils import AttrDict
from XLM.src.data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD
from XLM.src.model.transformer import TransformerModel
"""_________________________________________ XML ________________________________________________"""

model_path = 'pre-trained_embeddings/mlm_enfr_1024.pth'
reloaded = torch.load(model_path)
params = AttrDict(reloaded['params'])

# build dictionary / update parameters
dicoXLM = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                     reloaded['dico_counts'])
params.n_words = len(dicoXLM)
params.bos_index = dicoXLM.index(BOS_WORD)
params.eos_index = dicoXLM.index(EOS_WORD)
params.pad_index = dicoXLM.index(PAD_WORD)
params.unk_index = dicoXLM.index(UNK_WORD)
params.mask_index = dicoXLM.index(MASK_WORD)

# build model / reload weights
XLMmodel = TransformerModel(params, dicoXLM, True, True)
XLMmodel.load_state_dict(reloaded['model'])


def sen_list_to_xlm_sen_list(sentences):
    sen_list = []
    for s in sentences:
        sen_list.append((s.as_str(), 'fr'))
    return sen_list
Ejemplo n.º 7
0
class TranscoderClient:
    def __init__(self, src_lang, tgt_lang):
        model_path = TranscoderClient.get_model_path(src_lang, tgt_lang)
        reloaded = torch.load(model_path, map_location='cpu')
        reloaded['encoder'] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in
                               reloaded['encoder'].items()}
        assert 'decoder' in reloaded or (
            'decoder_0' in reloaded and 'decoder_1' in reloaded)
        if 'decoder' in reloaded:
            decoders_names = ['decoder']
        else:
            decoders_names = ['decoder_0', 'decoder_1']
        for decoder_name in decoders_names:
            reloaded[decoder_name] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in
                                      reloaded[decoder_name].items()}

        self.reloaded_params = AttrDict(reloaded['params'])

        # build dictionary / update parameters
        self.dico = Dictionary(
            reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
        assert self.reloaded_params.n_words == len(self.dico)
        assert self.reloaded_params.bos_index == self.dico.index(BOS_WORD)
        assert self.reloaded_params.eos_index == self.dico.index(EOS_WORD)
        assert self.reloaded_params.pad_index == self.dico.index(PAD_WORD)
        assert self.reloaded_params.unk_index == self.dico.index(UNK_WORD)
        assert self.reloaded_params.mask_index == self.dico.index(MASK_WORD)

        # build model / reload weights
        self.reloaded_params['reload_model'] = ','.join([model_path] * 2)
        encoder, decoder = build_model(self.reloaded_params, self.dico)

        self.encoder = encoder[0]
        self.encoder.load_state_dict(reloaded['encoder'])
        assert len(reloaded['encoder'].keys()) == len(
            list(p for p, _ in self.encoder.state_dict().items()))

        self.decoder = decoder[0]
        self.decoder.load_state_dict(reloaded['decoder'])
        assert len(reloaded['decoder'].keys()) == len(
            list(p for p, _ in self.decoder.state_dict().items()))

        self.encoder.cuda()
        self.decoder.cuda()

        self.encoder.eval()
        self.decoder.eval()
        self.bpe_model = fastBPE.fastBPE(os.path.abspath(BPE_PATH))
        self.allowed_languages = [lang.value for lang in Languages]

    def translate(self, input, lang1, lang2, n=1, beam_size=1, sample_temperature=None):
        with torch.no_grad():
            assert lang1 in self.allowed_languages, lang1
            assert lang2 in self.allowed_languages, lang2

            tokenizer = getattr(code_tokenizer, f'tokenize_{lang1}')
            detokenizer = getattr(code_tokenizer, f'detokenize_{lang2}')
            lang1 += '_sa'
            lang2 += '_sa'

            lang1_id = self.reloaded_params.lang2id[lang1]
            lang2_id = self.reloaded_params.lang2id[lang2]

            tokens = [t for t in tokenizer(input)]
            tokens = self.bpe_model.apply(tokens)
            tokens = ['</s>'] + tokens + ['</s>']
            input = " ".join(tokens)
            # create batch
            len1 = len(input.split())
            len1 = torch.LongTensor(1).fill_(len1).to(DEVICE)

            x1 = torch.LongTensor([self.dico.index(w)
                                   for w in input.split()]).to(DEVICE)[:, None]
            langs1 = x1.clone().fill_(lang1_id)

            enc1 = self.encoder('fwd', x=x1, lengths=len1,
                                langs=langs1, causal=False)
            enc1 = enc1.transpose(0, 1)
            if n > 1:
                enc1 = enc1.repeat(n, 1, 1)
                len1 = len1.expand(n)

            x2 = self._decode_solution(enc1, len1, lang2_id, sample_temperature, beam_size)
            tok = []
            for i in range(x2.shape[1]):
                wid = [self.dico[x2[j, i].item()] for j in range(len(x2))][1:]
                wid = wid[:wid.index(EOS_WORD)] if EOS_WORD in wid else wid
                tok.append(" ".join(wid).replace("@@ ", ""))

            results = []
            for t in tok:
                results.append(detokenizer(t))
            return results

    def _decode_solution(self, enc1, len1, lang2_id, sample_temperature, beam_size):
        if beam_size == 1:
            x2, _ = self.decoder.generate(
                enc1,
                len1,
                lang2_id,
                max_len=int(min(self.reloaded_params.max_len, 3 * len1.max().item() + 10)),
                sample_temperature=sample_temperature
            )
        else:
            x2, _ = self.decoder.generate_beam(
                enc1,
                len1,
                lang2_id,
                max_len=int(min(self.reloaded_params.max_len, 3 * len1.max().item() + 10)),
                early_stopping=False,
                length_penalty=1.0,
                beam_size=beam_size
            )

        return x2


    @staticmethod
    def get_model_path(source_lang, target_lang):
        assert source_lang != target_lang

        if (source_lang == Languages.JAVA.value or (source_lang == Languages.CPP.value and target_lang == Languages.JAVA.value)):
            return MODELS_PATH + '/model_1.pth'
        else:
            return MODELS_PATH + '/model_2.pth'
Ejemplo n.º 8
0
class Translator:
    def __init__(self, params):
        reloaded = torch.load(params.model_path, map_location='cpu')
#        print(reloaded['dico_word2id']['while'])
#        print(reloaded['dico_word2id']['return'])
#        print(reloaded['dico_word2id']['if'])
        
#        print(reloaded['encoder'].keys())
#        print(reloaded['decoder'].keys())
        reloaded['encoder'] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in
                               reloaded['encoder'].items()}
        assert 'decoder' in reloaded or (
            'decoder_0' in reloaded and 'decoder_1' in reloaded)
        if 'decoder' in reloaded:
            decoders_names = ['decoder']
        else:
            decoders_names = ['decoder_0', 'decoder_1']
        for decoder_name in decoders_names:
            reloaded[decoder_name] = {(k[len('module.'):] if k.startswith('module.') else k): v for k, v in
                                      reloaded[decoder_name].items()}

        self.reloaded_params = AttrDict(reloaded['params'])

        # build dictionary / update parameters
        self.dico = Dictionary(
            reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
        assert self.reloaded_params.n_words == len(self.dico)
        assert self.reloaded_params.bos_index == self.dico.index(BOS_WORD)
        assert self.reloaded_params.eos_index == self.dico.index(EOS_WORD)
        assert self.reloaded_params.pad_index == self.dico.index(PAD_WORD)
        assert self.reloaded_params.unk_index == self.dico.index(UNK_WORD)
        assert self.reloaded_params.mask_index == self.dico.index(MASK_WORD)

        # build model / reload weights
        self.reloaded_params['reload_model'] = ','.join([params.model_path] * 2)
        encoder, decoder = build_model(self.reloaded_params, self.dico)

        self.encoder = encoder[0]
        self.encoder.load_state_dict(reloaded['encoder'])
        assert len(reloaded['encoder'].keys()) == len(
            list(p for p, _ in self.encoder.state_dict().items()))

        self.decoder = decoder[0]
        self.decoder.load_state_dict(reloaded['decoder'])
        assert len(reloaded['decoder'].keys()) == len(
            list(p for p, _ in self.decoder.state_dict().items()))

        #self.encoder.to('cpu') #cuda()
        #self.decoder.to('cpu') #cuda()
        self.encoder.cuda()
        self.decoder.cuda()

        self.encoder.eval()
        self.decoder.eval()
        self.bpe_model = fastBPE.fastBPE(os.path.abspath(params.BPE_path))

    def translate(self, input, lang1, lang2, n=1, beam_size=1, sample_temperature=None, device='cuda:0'):
    #def translate(self, input, lang1, lang2, n=1, beam_size=1, sample_temperature=None, device='cpu'):
        with torch.no_grad():
            assert lang1 in {'python', 'java', 'cpp'}, lang1
            assert lang2 in {'python', 'java', 'cpp'}, lang2

            DEVICE = device
            tokenizer = getattr(code_tokenizer, f'tokenize_{lang1}')
            detokenizer = getattr(code_tokenizer, f'detokenize_{lang2}')
            lang1 += '_sa'
            lang2 += '_sa'

            lang1_id = self.reloaded_params.lang2id[lang1]
            lang2_id = self.reloaded_params.lang2id[lang2]

            tokens = [t for t in tokenizer(input)]
            tokens = self.bpe_model.apply(tokens)
            tokens = ['</s>'] + tokens + ['</s>']
            input = " ".join(tokens)
            # create batch
            len1 = len(input.split())
            len1 = torch.LongTensor(1).fill_(len1).to(DEVICE)

            x1 = torch.LongTensor([self.dico.index(w)
                                   for w in input.split()]).to(DEVICE)[:, None]
            langs1 = x1.clone().fill_(lang1_id)

            enc1 = self.encoder('fwd', x=x1, lengths=len1,
                                langs=langs1, causal=False)
            enc1 = enc1.transpose(0, 1)
            if n > 1:
                enc1 = enc1.repeat(n, 1, 1)
                len1 = len1.expand(n)

            if beam_size == 1:
                x2, len2 = self.decoder.generate(enc1, len1, lang2_id,
                                                 max_len=int(
                                                     min(self.reloaded_params.max_len, 3 * len1.max().item() + 10)),
                                                 sample_temperature=sample_temperature)
            else:
                x2, len2 = self.decoder.generate_beam(enc1, len1, lang2_id,
                                                      max_len=int(
                                                          min(self.reloaded_params.max_len, 3 * len1.max().item() + 10)),
                                                      early_stopping=False, length_penalty=1.0, beam_size=beam_size)
            tok = []
            for i in range(x2.shape[1]):
                wid = [self.dico[x2[j, i].item()] for j in range(len(x2))][1:]
                wid = wid[:wid.index(EOS_WORD)] if EOS_WORD in wid else wid
                tok.append(" ".join(wid).replace("@@ ", ""))

            results = []
            for t in tok:
                results.append(detokenizer(t))
            return results