예제 #1
0
파일: model.py 프로젝트: stefensa/XLM_NER
class XLMForTokenClassification(nn.Module):
    def __init__(self, config, num_labels, params, dico, reloaded):
        super().__init__()
        self.config = config
        self.num_labels = num_labels
        self.xlm = TransformerModel(params, dico, True, True)
        self.xlm.eval()
        self.xlm.load_state_dict(reloaded['model'])

        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(1024, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self, word_ids, lengths, langs=None, causal=False):
        sequence_output = self.xlm('fwd',
                                   x=word_ids,
                                   lengths=lengths,
                                   causal=False).contiguous()
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return logits

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
예제 #2
0
파일: model.py 프로젝트: stefensa/XLM_NER
    def __init__(self, config, num_labels, params, dico, reloaded):
        super().__init__()
        self.config = config
        self.num_labels = num_labels
        self.xlm = TransformerModel(params, dico, True, True)
        self.xlm.eval()
        self.xlm.load_state_dict(reloaded['model'])

        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(1024, num_labels)
        self.apply(self.init_bert_weights)
예제 #3
0
def initialize_model():
    """
    """
    print('launching model')

    chemin = getcwd()
    curPath = chemin if "xlm" in chemin else (os.path.join(getcwd(), 'xlm'))

    onlyfiles = [f for f in listdir(chemin) if isfile(join(chemin, f))]
    print(onlyfiles)

    print(os.path.normpath(os.path.join(getcwd(),
                                        './mlm_tlm_xnli15_1024.pth')))
    model_path = os.path.normpath(
        os.path.join(getcwd(), './mlm_tlm_xnli15_1024.pth'))
    reloaded = torch.load(model_path)

    #     print('allez le model')
    #     response = requests.get(url)
    #     print('response downloaded')
    #     f = io.BytesIO(response.content)
    #     reloaded = torch.load(f)
    #     print('file downloaded')

    #    reloaded = Reloaded.serve()

    params = AttrDict(reloaded['params'])
    print("Supported languages: %s" % ", ".join(params.lang2id.keys()))

    # build dictionary / update parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build model / reload weights
    model = TransformerModel(params, dico, True, True)
    model.load_state_dict(reloaded['model'])

    #    bpe = fastBPE.fastBPE(
    #            path.normpath(path.join(curPath, "./codes_xnli_15") ),
    #            path.normpath(path.join(curPath, "./vocab_xnli_15") )  )
    print('fin lecture')

    return model, params, dico
예제 #4
0
파일: model.py 프로젝트: stefensa/XLM_NER
class XLM_BiLSTM_CRF(nn.Module):
    def __init__(self, config, num_labels, params, dico, reloaded):
        super().__init__()
        self.config = config
        self.num_labels = num_labels
        self.batch_size = config.batch_size
        self.hidden_dim = config.hidden_dim

        self.xlm = TransformerModel(params, dico, True, True)
        self.xlm.eval()
        self.xlm.load_state_dict(reloaded['model'])

        self.lstm = nn.LSTM(config.embedding_dim,
                            config.hidden_dim // 2,
                            num_layers=1,
                            bidirectional=True)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_dim, config.num_class)
        self.apply(self.init_bert_weights)
        self.crf = CRF(config.num_class)

    def forward(self, word_ids, lengths, langs=None, causal=False):
        sequence_output = self.xlm('fwd',
                                   x=word_ids,
                                   lengths=lengths,
                                   causal=False).contiguous()
        sequence_output, _ = self.lstm(sequence_output)
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return self.crf.decode(logits)

    def log_likelihood(self, word_ids, lengths, tags):
        sequence_output = self.xlm('fwd',
                                   x=word_ids,
                                   lengths=lengths,
                                   causal=False).contiguous()
        sequence_output, _ = self.lstm(sequence_output)
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return -self.crf(logits, tags.transpose(0, 1))

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
예제 #5
0
def train(rank, args):
    print(f"Running basic DDP example on rank {rank} {args.master_port}.")
    setup(rank, args.world_size, args.master_port)
    args.local_rank = rank
    torch.manual_seed(args.seed)
    torch.cuda.set_device(rank)
    src_vocab = Dictionary.read_vocab(args.vocab_src)
    tgt_vocab = Dictionary.read_vocab(args.vocab_tgt)

    # model init
    model = TransformerModel(d_model=args.d_model,
                             nhead=args.nhead,
                             num_encoder_layers=args.num_encoder_layers,
                             num_decoder_layers=args.num_decoder_layers,
                             dropout=args.dropout,
                             attention_dropout=args.attn_dropout,
                             src_dictionary=src_vocab,
                             tgt_dictionary=tgt_vocab)
    model.to(rank)
    model = DDP(model, device_ids=[rank])

    if rank == 0:
        print(model)

    # data load
    train_loader, sampler = dataloader.get_train_parallel_loader(
        args.train_src,
        args.train_tgt,
        src_vocab,
        tgt_vocab,
        batch_size=args.batch_size,
        world_size=args.world_size,
        rank=rank)
    valid_loader = dataloader.get_valid_parallel_loader(
        args.valid_src,
        args.train_tgt,
        src_vocab,
        tgt_vocab,
        batch_size=args.batch_size)

    data = {'dataloader': {'train': train_loader, 'valid': valid_loader}}

    trainer = Trainer(model, data, args)
    for epoch in range(1, args.epoch_size):
        trainer.mt_step()
        trainer.evaluate(epoch)
        trainer.save_checkpoint(epoch)
        sampler.set_epoch(epoch)
예제 #6
0
    def __init__(self, model_path, tgt_lang, src_lang,dump_path = "./dumped/", exp_name="translate", exp_id="test", batch_size=32):
        
        # parse parameters
        parser = argparse.ArgumentParser(description="Translate sentences")
        
        # main parameters
        parser.add_argument("--dump_path", type=str, default=dump_path, help="Experiment dump path")
        parser.add_argument("--exp_name", type=str, default=exp_name, help="Experiment name")
        parser.add_argument("--exp_id", type=str, default=exp_id, help="Experiment ID")
        parser.add_argument("--batch_size", type=int, default=batch_size, help="Number of sentences per batch")
        # model / output paths
        parser.add_argument("--model_path", type=str, default=model_path, help="Model path")
        # parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
        # parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")
        # source language / target language
        parser.add_argument("--src_lang", type=str, default=src_lang, help="Source language")
        parser.add_argument("--tgt_lang", type=str, default=tgt_lang, help="Target language")

        params = parser.parse_args()
        assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang

        # initialize the experiment
        logger = initialize_exp(params)
        
        # On a pas de GPU
        #reloaded = torch.load(params.model_path)
        reloaded = torch.load(params.model_path, map_location=torch.device('cpu'))
        model_params = AttrDict(reloaded['params'])
        self.supported_languages = model_params.lang2id.keys() 
        logger.info("Supported languages: %s" % ", ".join(self.supported_languages))

        # update dictionary parameters
        for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
            setattr(params, name, getattr(model_params, name))

        # build dictionary / build encoder / build decoder / reload weights
        self.dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
        #self.encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
        self.encoder = TransformerModel(model_params, self.dico, is_encoder=True, with_output=True).eval()
        #self.decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
        self.decoder = TransformerModel(model_params, self.dico, is_encoder=False, with_output=True).eval()
        self.encoder.load_state_dict(reloaded['encoder'])
        self.decoder.load_state_dict(reloaded['decoder'])
        params.src_id = model_params.lang2id[params.src_lang]
        params.tgt_id = model_params.lang2id[params.tgt_lang]
        self.model_params = model_params
        self.params = params
예제 #7
0
def translate(args):
    batch_size = args.batch_size

    src_vocab = Dictionary.read_vocab(args.vocab_src)
    tgt_vocab = Dictionary.read_vocab(args.vocab_tgt)
    data = torch.load(args.reload_path, map_location='cpu')
    model = TransformerModel(src_dictionary=src_vocab,
                             tgt_dictionary=tgt_vocab)
    model.load_state_dict({k: data['module'][k] for k in data['module']})
    model.cuda()
    model.eval()

    if 'epoch' in data:
        print(f"Loading model from epoch_{data['epoch']}....")

    src_sent = open(args.src, "r").readlines()
    for i in range(0, len(src_sent), batch_size):
        word_ids = [
            torch.LongTensor([src_vocab.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(src_vocab.pad_index)
        batch[0] = src_vocab.bos_index

        for j, s in enumerate(word_ids):
            if lengths[j] > 2:
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = src_vocab.eos_index

        batch = batch.cuda()
        encoder_out = model.encoder(batch)

        with torch.no_grad():
            if args.beam == 1:
                generated = model.decoder.generate_greedy(encoder_out)
            else:
                generated = model.decoder.generate_beam(encoder_out,
                                                        beam_size=5)

        for j, s in enumerate(src_sent[i:i + batch_size]):
            print(f"Source_{i+j}: {s.strip()}")
            hypo = []
            for w in generated[j][1:]:
                if tgt_vocab[w.item()] == '</s>':
                    break
                hypo.append(tgt_vocab[w.item()])
            hypo = " ".join(hypo)
            print(f"Target_{i+j}: {hypo}\n")
예제 #8
0
 def on_init(self, params, p_params):
     dump_path = os.path.join(params.dump_path, "debias")
     checkpoint_path = os.path.join(dump_path, "checkpoint.pth")
     if os.path.isfile(checkpoint_path):
         self.params.dump_path = dump_path
         self.checkpoint_path = checkpoint_path
         self.from_deb = True
     else:
         self.checkpoint_path = os.path.join(params.dump_path,
                                             "checkpoint.pth")
         self.from_deb = False
     deb = TransformerModel(p_params,
                            self.model.dico,
                            is_encoder=True,
                            with_output=False,
                            with_emb=False)
     #deb = LinearDeb(p_params)
     self.deb = deb.to(params.device)
예제 #9
0
파일: model.py 프로젝트: stefensa/XLM_NER
    def __init__(self, config, num_labels, params, dico, reloaded):
        super().__init__()
        self.config = config
        self.num_labels = num_labels
        self.batch_size = config.batch_size
        self.hidden_dim = config.hidden_dim

        self.xlm = TransformerModel(params, dico, True, True)
        self.xlm.eval()
        self.xlm.load_state_dict(reloaded['model'])

        self.lstm = nn.LSTM(config.embedding_dim,
                            config.hidden_dim // 2,
                            num_layers=1,
                            bidirectional=True)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_dim, config.num_class)
        self.apply(self.init_bert_weights)
        self.crf = CRF(config.num_class)
예제 #10
0
def reload_ar_checkpoint(path):
    """ Reload autoregressive params, dictionary, model from a given path """
    # Load dictionary/model/datasets first
    reloaded = torch.load(path)
    params = AttrDict(reloaded['params'])

    # build dictionary / update parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.n_langs = 1
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build Transformer model
    model = TransformerModel(params, is_encoder=False, with_output=True)
    model.load_state_dict(reloaded['model'])
    return params, dico, model
예제 #11
0
파일: mymodel.py 프로젝트: RunxinXu/XLM
    def reload(path, params):
        """
        Create a sentence embedder from a pretrained model.
        """
        # reload model
        reloaded = torch.load(path)
        state_dict = reloaded['model']

        # handle models from multi-GPU checkpoints
        if 'checkpoint' in path:
            state_dict = {(k[7:] if k.startswith('module.') else k): v
                          for k, v in state_dict.items()}

        # reload dictionary and model parameters
        dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                          reloaded['dico_counts'])
        pretrain_params = AttrDict(reloaded['params'])
        pretrain_params.n_words = len(dico)
        pretrain_params.bos_index = dico.index(BOS_WORD)
        pretrain_params.eos_index = dico.index(EOS_WORD)
        pretrain_params.pad_index = dico.index(PAD_WORD)
        pretrain_params.unk_index = dico.index(UNK_WORD)
        pretrain_params.mask_index = dico.index(MASK_WORD)

        # build model and reload weights
        model = TransformerModel(pretrain_params, dico, True, True)
        model.load_state_dict(state_dict)
        model.eval()

        # adding missing parameters
        params.max_batch_size = 0

        return MyModel(model, dico, pretrain_params, params)
예제 #12
0
파일: bli.py 프로젝트: 15091444119/MASS
def load_xlm_embeddings(path, model_name="model"):
    """
    Load all xlm embeddings
    Params:
        path:
        model_name: model name in the reloaded path, "model" for pretrained xlm encoder; "encoder" for encoder of translation model "decoder" for decoder of translation model
    """
    reloaded = torch.load(path)

    assert model_name in ["model", "encoder", "decoder"]
    state_dict = reloaded[model_name]

    # handle models from multi-GPU checkpoints
    state_dict = {(k[7:] if k.startswith('module.') else k): v
                  for k, v in state_dict.items()}

    # reload dictionary and model parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    pretrain_params = AttrDict(reloaded['params'])
    pretrain_params.n_words = len(dico)
    pretrain_params.bos_index = dico.index(BOS_WORD)
    pretrain_params.eos_index = dico.index(EOS_WORD)
    pretrain_params.pad_index = dico.index(PAD_WORD)
    pretrain_params.unk_index = dico.index(UNK_WORD)
    pretrain_params.mask_index = dico.index(MASK_WORD)

    # build model and reload weights
    if model_name != "decoder":
        model = TransformerModel(pretrain_params, dico, True, True)
    else:
        model = TransformerModel(pretrain_params, dico, False, True)
    model.load_state_dict(state_dict)

    return model.embeddings.weight.data, dico
예제 #13
0
def reload_checkpoint(path):
    """ Reload params, dictionary, model from a given path """
    # Load dictionary/model/datasets first
    reloaded = torch.load(path)
    params = AttrDict(reloaded['params'])
    print("Supported languages: %s" % ", ".join(params.lang2id.keys()))

    # build dictionary / update parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build model / reload weights
    model = TransformerModel(params, dico, True, True)
    model.load_state_dict(reloaded['model'])

    return params, dico, model
예제 #14
0
파일: qg.py 프로젝트: PythaGorilla/QAQG
def load_model(params):
    # check parameters
    assert os.path.isdir(params.data_path)
    assert os.path.isfile(params.model_path)
    reloaded = torch.load(params.model_path)

    encoder_model_params = AttrDict(reloaded['enc_params'])
    decoder_model_params = AttrDict(reloaded['dec_params'])

    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])

    params.n_langs = encoder_model_params['n_langs']
    params.id2lang = encoder_model_params['id2lang']
    params.lang2id = encoder_model_params['lang2id']
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    encoder = TransformerModel(encoder_model_params,
                               dico,
                               is_encoder=True,
                               with_output=False)
    decoder = TransformerModel(decoder_model_params,
                               dico,
                               is_encoder=False,
                               with_output=True)

    def _process_state_dict(state_dict):
        return {(k[7:] if k.startswith('module.') else k): v
                for k, v in state_dict.items()}

    encoder.load_state_dict(_process_state_dict(reloaded['encoder']))
    decoder.load_state_dict(_process_state_dict(reloaded['decoder']))

    return encoder, decoder, dico
예제 #15
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
    encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
    decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
    encoder.load_state_dict(reloaded['encoder'])
    decoder.load_state_dict(reloaded['decoder'])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # float16
    if params.fp16:
        assert torch.backends.cudnn.enabled
        encoder = network_to_half(encoder)
        decoder = network_to_half(decoder)

    input_data = torch.load(params.input)
    eval_dataset = Dataset(input_data["sentences"], input_data["positions"], params)

    if params.subset_start is not None:
        assert params.subset_end
        eval_dataset.select_data(params.subset_start, params.subset_end)

    eval_dataset.remove_empty_sentences()
    eval_dataset.remove_long_sentences(params.max_len)

    n_batch = 0

    out = io.open(params.output_path, "w", encoding="utf-8")
    inp_dump = io.open(os.path.join(params.dump_path, "input.txt"), "w", encoding="utf-8")
    logger.info("logging to {}".format(os.path.join(params.dump_path, 'input.txt')))

    with open(params.output_path, "w", encoding="utf-8") as out:

        for batch in eval_dataset.get_iterator(shuffle=False):
            n_batch += 1

            (x1, len1) = batch
            input_text = convert_to_text(x1, len1, input_data["dico"], params)
            inp_dump.write("\n".join(input_text))
            inp_dump.write("\n")

            langs1 = x1.clone().fill_(params.src_id)

            # cuda
            x1, len1, langs1 = to_cuda(x1, len1, langs1)

            # encode source sentence
            enc1 = encoder("fwd", x=x1, lengths=len1, langs=langs1, causal=False)
            enc1 = enc1.transpose(0, 1)

            # generate translation - translate / convert to text
            max_len = int(1.5 * len1.max().item() + 10)
            if params.beam_size == 1:
                generated, lengths = decoder.generate(enc1, len1, params.tgt_id, max_len=max_len)
            else:
                generated, lengths = decoder.generate_beam(
                    enc1, len1, params.tgt_id, beam_size=params.beam_size,
                    length_penalty=params.length_penalty,
                    early_stopping=params.early_stopping,
                    max_len=max_len)

            hypotheses_batch = convert_to_text(generated, lengths, input_data["dico"], params)

            out.write("\n".join(hypotheses_batch))
            out.write("\n")

            if n_batch % 100 == 0:
                logger.info("{} batches processed".format(n_batch))

    out.close()
    inp_dump.close()
예제 #16
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded["params"])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            "n_words",
            "bos_index",
            "eos_index",
            "pad_index",
            "unk_index",
            "mask_index",
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded["dico_id2word"], reloaded["dico_word2id"],
                      reloaded["dico_counts"])
    encoder = (TransformerModel(model_params,
                                dico,
                                is_encoder=True,
                                with_output=True).cuda().eval())
    decoder = (TransformerModel(model_params,
                                dico,
                                is_encoder=False,
                                with_output=True).cuda().eval())
    encoder.load_state_dict(reloaded["encoder"])
    decoder.load_state_dict(reloaded["decoder"])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    f = io.open(params.output_path, "w", encoding="utf-8")

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder(
            "fwd",
            x=batch.cuda(),
            lengths=lengths.cuda(),
            langs=langs.cuda(),
            causal=False,
        )
        encoded = encoded.transpose(0, 1)
        decoded, dec_lengths = decoder.generate(
            encoded,
            lengths.cuda(),
            params.tgt_id,
            max_len=int(1.5 * lengths.max().item() + 10),
        )

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
예제 #17
0
파일: train.py 프로젝트: michael-snower/XLM
def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment
    logger = initialize_exp(params)

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    # load data
    data = load_data(params)

    # load checkpoint
    if params.model_path != "":
        reloaded = torch.load(params.model_path)
        model_params = AttrDict(reloaded['params'])
        dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                          reloaded['dico_counts'])
        encoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=True,
                                   with_output=True).cuda().eval()
        decoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=False,
                                   with_output=True).cuda().eval()
        encoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=True,
                                   with_output=True).cuda().eval()
        decoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=False,
                                   with_output=True).cuda().eval()
        encoder.load_state_dict(reloaded['encoder'])
        decoder.load_state_dict(reloaded['decoder'])
        logger.info("Supported languages: %s" %
                    ", ".join(model_params.lang2id.keys()))
    else:
        # build model
        if params.encoder_only:
            model = build_model(params, data['dico'])
        else:
            encoder, decoder = build_model(params, data['dico'])

    # build trainer, reload potential checkpoints / build evaluator
    if params.encoder_only:
        trainer = SingleTrainer(model, data, params)
        evaluator = SingleEvaluator(trainer, data, params)
    else:
        trainer = EncDecTrainer(encoder, decoder, data, params)
        evaluator = EncDecEvaluator(trainer, data, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer)
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # set sampling probabilities for training
    set_sampling_probs(data, params)

    # language model training
    for _ in range(params.max_epoch):

        logger.info("============ Starting epoch %i ... ============" %
                    trainer.epoch)

        trainer.n_sentences = 0

        while trainer.n_sentences < trainer.epoch_size:

            # CLM steps
            for lang1, lang2 in shuf_order(params.clm_steps, params):
                trainer.clm_step(lang1, lang2, params.lambda_clm)

            # MLM steps (also includes TLM if lang2 is not None)
            for lang1, lang2 in shuf_order(params.mlm_steps, params):
                trainer.mlm_step(lang1, lang2, params.lambda_mlm)

            # parallel classification steps
            for lang1, lang2 in shuf_order(params.pc_steps, params):
                trainer.pc_step(lang1, lang2, params.lambda_pc)

            # denoising auto-encoder steps
            for lang in shuf_order(params.ae_steps):
                trainer.mt_step(lang, lang, params.lambda_ae)

            # machine translation steps
            for lang1, lang2 in shuf_order(params.mt_steps, params):
                trainer.mt_step(lang1, lang2, params.lambda_mt)

            # back-translation steps
            for lang1, lang2, lang3 in shuf_order(params.bt_steps):
                trainer.bt_step(lang1, lang2, lang3, params.lambda_bt)

            trainer.iter()

        logger.info("============ End of epoch %i ============" %
                    trainer.epoch)

        # evaluate perplexity
        scores = evaluator.run_all_evals(trainer)

        # print / JSON log
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)
예제 #18
0
class Translate():
    def __init__(self,
                 model_path,
                 tgt_lang,
                 src_lang,
                 dump_path="./dumped/",
                 exp_name="translate",
                 exp_id="test",
                 batch_size=32):

        # parse parameters
        parser = argparse.ArgumentParser(description="Translate sentences")

        # main parameters
        parser.add_argument("--dump_path",
                            type=str,
                            default=dump_path,
                            help="Experiment dump path")
        parser.add_argument("--exp_name",
                            type=str,
                            default=exp_name,
                            help="Experiment name")
        parser.add_argument("--exp_id",
                            type=str,
                            default=exp_id,
                            help="Experiment ID")
        parser.add_argument("--batch_size",
                            type=int,
                            default=batch_size,
                            help="Number of sentences per batch")
        # model / output paths
        parser.add_argument("--model_path",
                            type=str,
                            default=model_path,
                            help="Model path")
        # parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
        # parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")
        # source language / target language
        parser.add_argument("--src_lang",
                            type=str,
                            default=src_lang,
                            help="Source language")
        parser.add_argument("--tgt_lang",
                            type=str,
                            default=tgt_lang,
                            help="Target language")
        parser.add_argument('-d',
                            "--text",
                            type=str,
                            default="",
                            nargs='+',
                            help="Text to be translated")

        params = parser.parse_args()
        assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang

        # initialize the experiment
        logger = initialize_exp(params)

        # On a pas de GPU
        #reloaded = torch.load(params.model_path)
        reloaded = torch.load(params.model_path,
                              map_location=torch.device('cpu'))
        model_params = AttrDict(reloaded['params'])
        self.supported_languages = model_params.lang2id.keys()
        logger.info("Supported languages: %s" %
                    ", ".join(self.supported_languages))

        # update dictionary parameters
        for name in [
                'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
                'mask_index'
        ]:
            try:
                setattr(params, name, getattr(model_params, name))
            except AttributeError:
                key = list(model_params.meta_params.keys())[0]
                attr = getattr(model_params.meta_params[key], name)
                setattr(params, name, attr)
                setattr(model_params, name, attr)

        # build dictionary / build encoder / build decoder / reload weights
        self.dico = Dictionary(reloaded['dico_id2word'],
                               reloaded['dico_word2id'],
                               reloaded['dico_counts'])
        #self.encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
        self.encoder = TransformerModel(model_params,
                                        self.dico,
                                        is_encoder=True,
                                        with_output=True).eval()
        #self.decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
        self.decoder = TransformerModel(model_params,
                                        self.dico,
                                        is_encoder=False,
                                        with_output=True).eval()
        self.encoder.load_state_dict(reloaded['encoder'])
        self.decoder.load_state_dict(reloaded['decoder'])
        params.src_id = model_params.lang2id[params.src_lang]
        params.tgt_id = model_params.lang2id[params.tgt_lang]
        self.model_params = model_params
        self.params = params

    def translate(self, src_sent=[]):
        flag = False
        if type(src_sent) == str:
            src_sent = [src_sent]
            flag = True
        tgt_sent = []
        for i in range(0, len(src_sent), self.params.batch_size):
            # prepare batch
            word_ids = [
                torch.LongTensor(
                    [self.dico.index(w) for w in s.strip().split()])
                for s in src_sent[i:i + self.params.batch_size]
            ]
            lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
            batch = torch.LongTensor(lengths.max().item(),
                                     lengths.size(0)).fill_(
                                         self.params.pad_index)
            batch[0] = self.params.eos_index
            for j, s in enumerate(word_ids):
                if lengths[j] > 2:  # if sentence not empty
                    batch[1:lengths[j] - 1, j].copy_(s)
                batch[lengths[j] - 1, j] = self.params.eos_index
            langs = batch.clone().fill_(self.params.src_id)

            # encode source batch and translate it
            #encoded = self.encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
            encoded = self.encoder('fwd',
                                   x=batch,
                                   lengths=lengths,
                                   langs=langs,
                                   causal=False)
            encoded = encoded.transpose(0, 1)
            #decoded, dec_lengths = self.decoder.generate(encoded, lengths.cuda(), self.params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
            decoded, dec_lengths = self.decoder.generate(
                encoded,
                lengths,
                self.params.tgt_id,
                max_len=int(1.5 * lengths.max().item() + 10))

            # convert sentences to words
            for j in range(decoded.size(1)):

                # remove delimiters
                sent = decoded[:, j]
                delimiters = (sent == self.params.eos_index).nonzero().view(-1)
                assert len(delimiters) >= 1 and delimiters[0].item() == 0
                sent = sent[1:] if len(
                    delimiters) == 1 else sent[1:delimiters[1]]

                # output translation
                source = src_sent[i + j].strip()
                target = " ".join(
                    [self.dico[sent[k].item()] for k in range(len(sent))])
                sys.stderr.write("%i / %i: %s -> %s\n" %
                                 (i + j, len(src_sent), source, target))
                tgt_sent.append(target)

        if flag:
            return tgt_sent[0]
        return tgt_sent
예제 #19
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=True).cuda().eval()
    encoder.load_state_dict(reloaded['encoder'])
    decoder = None
    #    decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
    #    decoder.load_state_dict(reloaded['decoder'])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    all_encodings = []
    # For each sentence...
    for i in range(0, len(src_sent), params.batch_size):
        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it, deal with padding
        encodings = encoderouts(encoder, batch, lengths, langs)

        # batch is actually in original order, append each sent to all_encodings
        for idx in encodings:
            all_encodings.append(idx.cpu().numpy())

    # Save all encodings to npy
    np.save(params.output_path, np.stack(all_encodings))
예제 #20
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=True).cuda().eval()
    decoder = TransformerModel(model_params,
                               dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()
    encoder.load_state_dict(reloaded['encoder'])
    decoder.load_state_dict(reloaded['decoder'])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    with open(params.sentences_path, 'r') as file1:
        for line in file1:
            if 0 < len(line.strip().split()) < 100:
                src_sent.append(line)
    #print(len(src_sent))
    logger.info(
        "Read %i sentences from sentences file.Writing them to a src file. Translating ..."
        % len(src_sent))
    f = io.open(params.output_path + '.src_sent', 'w', encoding='utf-8')
    for sentence in src_sent:
        f.write(sentence)
    f.close()
    logger.info("Wrote them to a src file")
    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded, _ = encoder('fwd',
                             x=batch.cuda(),
                             lengths=lengths.cuda(),
                             langs=langs.cuda(),
                             causal=False)
        encoded = encoded.transpose(0, 1)
        #decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
        decoded, dec_lengths = decoder.generate_beam(
            encoded,
            lengths.cuda(),
            params.tgt_id,
            beam_size=params.beam_size,
            length_penalty=params.length_penalty,
            early_stopping=params.early_stopping,
            max_len=int(1.5 * lengths.cuda().max().item() + 10))
        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
            if (i + j) % 100 == 0:
                logger.info(
                    "Translation of %i / %i:\n Source sentence: %s \n Translation: %s\n"
                    % (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
예제 #21
0
파일: xnlg-ft.py 프로젝트: shiqing1234/XNLG
def run_xnlg():
    params = get_params()

    # initialize the experiment / build sentence embedder
    logger = initialize_exp(params)

    if params.tokens_per_batch > -1:
        params.group_by_size = True

    # check parameters
    assert os.path.isdir(params.data_path)
    assert os.path.isfile(params.model_path)

    # tasks
    params.transfer_tasks = params.transfer_tasks.split(',')
    assert len(params.transfer_tasks) > 0
    assert all([task in TASKS for task in params.transfer_tasks])

    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))
    params.n_langs = model_params['n_langs']
    params.id2lang = model_params['id2lang']
    params.lang2id = model_params['lang2id']

    if "enc_params" in reloaded:
        encoder_model_params = AttrDict(reloaded["enc_params"])
    elif params.n_enc_layers == model_params.n_layers or params.n_enc_layers == 0:
        encoder_model_params = model_params
    else:
        encoder_model_params = AttrDict(reloaded['params'])
        encoder_model_params.n_layers = params.n_enc_layers
        assert model_params.n_layers is not encoder_model_params.n_layers

    if "dec_params" in reloaded:
        decoder_model_params = AttrDict(reloaded["dec_params"])
    elif params.n_dec_layers == model_params.n_layers or params.n_dec_layers == 0:
        decoder_model_params = model_params
    else:
        decoder_model_params = AttrDict(reloaded['params'])
        decoder_model_params.n_layers = params.n_dec_layers
        assert model_params.n_layers is not decoder_model_params.n_layers

    params.encoder_model_params = encoder_model_params
    params.decoder_model_params = decoder_model_params

    if params.emb_dim != -1:
        encoder_model_params.emb_dim = params.emb_dim
        decoder_model_params.emb_dim = params.emb_dim

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])

    for p in [params, encoder_model_params, decoder_model_params]:
        p.n_words = len(dico)
        p.bos_index = dico.index(BOS_WORD)
        p.eos_index = dico.index(EOS_WORD)
        p.pad_index = dico.index(PAD_WORD)
        p.unk_index = dico.index(UNK_WORD)
        p.mask_index = dico.index(MASK_WORD)

    encoder = TransformerModel(encoder_model_params,
                               dico,
                               is_encoder=True,
                               with_output=False)
    decoder = TransformerModel(decoder_model_params,
                               dico,
                               is_encoder=False,
                               with_output=True)

    def _process_state_dict(state_dict):
        return {(k[7:] if k.startswith('module.') else k): v
                for k, v in state_dict.items()}

    if params.no_init == "all":
        logger.info("All Models will not load state dict.!!!")
    elif params.reload_emb != "":
        logger.info("Reloading embedding from %s ..." % params.reload_emb)
        word2id, embeddings = read_txt_embeddings(logger, params.reload_emb)
        set_pretrain_emb(logger, encoder, dico, word2id, embeddings)
        set_pretrain_emb(logger, decoder, dico, word2id, embeddings)
    else:
        if "model" in reloaded:
            if params.no_init != "encoder":
                encoder.load_state_dict(_process_state_dict(reloaded['model']),
                                        strict=False)
            if params.no_init != "decoder":
                decoder.load_state_dict(_process_state_dict(reloaded['model']),
                                        strict=False)
        else:
            if params.no_init != "encoder":
                encoder.load_state_dict(_process_state_dict(
                    reloaded['encoder']),
                                        strict=False)
            if params.no_init != "decoder":
                decoder.load_state_dict(
                    _process_state_dict(reloaded['decoder']))

    scores = {}

    # run
    for task in params.transfer_tasks:
        if task == "XQG":
            XQG_v3(encoder, decoder, scores, dico, params).run()
        elif task == "XSumm":
            XSumm(encoder, decoder, scores, dico, params).run()
예제 #22
0
def main():
    parser.add_argument("--input", type=str, default="", help="input file")
    parser.add_argument("--model", type=str, default="", help="model path")
    parser.add_argument("--spm_model",
                        type=str,
                        default="",
                        help="spm model path")
    parser.add_argument("--batch_size",
                        type=int,
                        default=64,
                        help="batch size")
    parser.add_argument("--max_words", type=int, default=100, help="max words")
    parser.add_argument("--cuda", type=str, default="True", help="use cuda")
    parser.add_argument("--output", type=str, default="", help="output file")
    args = parser.parse_args()

    # Reload a pretrained model
    reloaded = torch.load(args.model)
    params = AttrDict(reloaded['params'])

    # Reload the SPM model
    spm_model = spm.SentencePieceProcessor()
    spm_model.Load(args.spm_model)

    # cuda
    assert args.cuda in ["True", "False"]
    args.cuda = eval(args.cuda)

    # build dictionary / update parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build model / reload weights
    model = TransformerModel(params, dico, True, True)
    reloaded['model'] = OrderedDict({
        key.replace('module.', ''): reloaded['model'][key]
        for key in reloaded['model']
    })
    model.load_state_dict(reloaded['model'])
    model.eval()

    if args.cuda:
        model.cuda()

    # load sentences
    sentences = []
    with open(args.input) as f:
        for line in f:
            line = spm_model.EncodeAsPieces(line.rstrip())
            line = line[:args.max_words - 1]
            sentences.append(line)

    # encode sentences
    embs = []
    for i in range(0, len(sentences), args.batch_size):
        batch = sentences[i:i + args.batch_size]
        lengths = torch.LongTensor([len(s) + 1 for s in batch])
        bs, slen = len(batch), lengths.max().item()
        assert slen <= args.max_words

        x = torch.LongTensor(slen, bs).fill_(params.pad_index)
        for k in range(bs):
            sent = torch.LongTensor([params.eos_index] +
                                    [dico.index(w) for w in batch[k]])
            x[:len(sent), k] = sent

        if args.cuda:
            x = x.cuda()
            lengths = lengths.cuda()

        with torch.no_grad():
            embedding = model('fwd',
                              x=x,
                              lengths=lengths,
                              langs=None,
                              causal=False).contiguous()[0].cpu()

        embs.append(embedding)

    # save embeddings
    torch.save(torch.cat(embs, dim=0).squeeze(0), args.output)
예제 #23
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    model_params.add_pred = ""
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    src_dico = load_binarized(params.src_data)
    tgt_dico = load_binarized(params.tgt_data)

    encoder = TransformerModel(model_params,
                               src_dico,
                               is_encoder=True,
                               with_output=False).cuda().eval()
    decoder = TransformerModel(model_params,
                               tgt_dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()

    if all([k.startswith('module.') for k in reloaded['encoder'].keys()]):
        reloaded['encoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['encoder'].items()
        }
        reloaded['decoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['decoder'].items()
        }

    encoder.load_state_dict(reloaded['encoder'], strict=False)
    decoder.load_state_dict(reloaded['decoder'], strict=False)

    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # # float16

    # # read sentences from stdin
    src_sent = []
    input_f = open(params.input_path, 'r')
    for line in input_f:
        line = line.strip()
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([src_dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd',
                          x=batch.cuda(),
                          lengths=lengths.cuda(),
                          langs=langs.cuda(),
                          causal=False)
        encoded = [enc.transpose(0, 1) for enc in encoded]
        decoded, dec_lengths = decoder.generate(
            encoded,
            lengths.cuda(),
            params.tgt_id,
            max_len=int(1.5 * lengths.max().item() + 10))

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join(
                [tgt_dico[sent[k].item()] for k in range(len(sent))])
            #sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
예제 #24
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    model_params['mnmt'] = params.mnmt
    logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    if model_params.share_word_embeddings or model_params.share_all_embeddings:
        dico = Dictionary(reloaded['dico_id2word'],
                          reloaded['dico_word2id'],
                          reloaded['dico_counts'])
    else:
        dico = {}
        for lang in [params.src_lang, params.tgt_lang]:
            dico[lang] = Dictionary(reloaded[lang]['dico_id2word'],
                                    reloaded[lang]['dico_word2id'],
                                    reloaded[lang]['dico_counts'])


    if model_params.share_word_embeddings or model_params.share_all_embeddings:
        encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=False).cuda().eval()
        decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
    else:
        src_dico = dico[params.src_lang]
        tgt_dico = dico[params.tgt_lang]
        encoder = TransformerModel(model_params, src_dico, is_encoder=True, with_output=False).cuda().eval()
        decoder = TransformerModel(model_params, tgt_dico, is_encoder=False, with_output=True).cuda().eval()

    try:
        encoder.load_state_dict(reloaded['encoder'])
        decoder.load_state_dict(reloaded['decoder'])
    except RuntimeError:
        enc_reload = reloaded['encoder']
        if all([k.startswith('module.') for k in enc_reload.keys()]):
            enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}

        dec_reload = reloaded['decoder']
        if all(k.startswith('module.') for k in dec_reload.keys()):
            dec_reload = {k[len('moduls.'):]: v for k, v in dec_reload.items()}

        encoder.load_state_dict(enc_reload)
        decoder.load_state_dict(dec_reload)

    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent))

    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        word_ids = [torch.LongTensor([src_dico.index(w) for w in s.strip().split()])
                        for s in src_sent[i:i + params.batch_size]]

        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
        encoded = encoded.transpose(0, 1)
        if params.beam_size > 1:
            decoded, dec_lengths = decoder.generate_beam(encoded, lengths.cuda(), params.tgt_id,
                                                         beam_size=params.beam_size,
                                                         length_penalty=params.lenpen,
                                                         early_stopping=params.early_stopping,
                                                         max_len=int(1.5 * lengths.max().item() + 10))
        else:
            decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id,
                                                    max_len=int(1.5 * lengths.max().item() + 10))

        # convert sentences to words
        for j in range(decoded.size(1)):
            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join([tgt_dico[sent[k].item()] for k in range(len(sent))])
            sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
            f.write(target + "\n")
    f.close()
예제 #25
0
#%% [markdown]
# ## Build dictionary / update parameters / build model

#%%
# build dictionary / update parameters
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                  reloaded['dico_counts'])
assert params.n_words == len(dico)
assert params.bos_index == dico.index(BOS_WORD)
assert params.eos_index == dico.index(EOS_WORD)
assert params.pad_index == dico.index(PAD_WORD)
assert params.unk_index == dico.index(UNK_WORD)
assert params.mask_index == dico.index(MASK_WORD)

# build model / reload weights
model = TransformerModel(params, dico, True, True)
model.load_state_dict(reloaded['model'])
model.cuda()
model.eval()

#%%

#%%
FASTBPE_PATH = '/private/home/guismay/tools/fastBPE/fast'
TOKENIZER_PATH = '/private/home/guismay/tools/mosesdecoder/scripts/tokenizer/tokenizer.perl'
DETOKENIZER_PATH = '/private/home/guismay/tools/mosesdecoder/scripts/tokenizer/detokenizer.perl'
BPE_CODES = '/checkpoint/guismay/ccclean/60000/codes.60000'


#%%
def apply_bpe(txt):
예제 #26
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)
    parser = get_parser()
    params = parser.parse_args()
    models_path = params.model_path.split(',')

    # generate parser / parse parameters
    models_reloaded = []
    for model_path in models_path:
        models_reloaded.append(torch.load(model_path))
    model_params = AttrDict(models_reloaded[0]['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(models_reloaded[0]['dico_id2word'],
                      models_reloaded[0]['dico_word2id'],
                      models_reloaded[0]['dico_counts'])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    encoders = []
    decoders = []

    def package_module(modules):
        state_dict = OrderedDict()
        for k, v in modules.items():
            if k.startswith('module.'):
                state_dict[k[7:]] = v
            else:
                state_dict[k] = v
        return state_dict

    for reloaded in models_reloaded:
        encoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=True,
                                   with_output=True).to(params.device).eval()
        decoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=False,
                                   with_output=True).to(params.device).eval()
        encoder.load_state_dict(package_module(reloaded['encoder']))
        decoder.load_state_dict(package_module(reloaded['decoder']))

        # float16
        if params.fp16:
            assert torch.backends.cudnn.enabled
            encoder = network_to_half(encoder)
            decoder = network_to_half(decoder)

        encoders.append(encoder)
        decoders.append(decoder)

    #src_sent = ['Poly@@ gam@@ ie statt Demokratie .']
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)

    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encodeds = []
        for encoder in encoders:
            encoded = encoder('fwd',
                              x=batch.to(params.device),
                              lengths=lengths.to(params.device),
                              langs=langs.to(params.device),
                              causal=False)
            encoded = encoded.transpose(0, 1)
            encodeds.append(encoded)

            assert encoded.size(0) == lengths.size(0)

        decoded, dec_lengths = generate_beam(
            decoders,
            encodeds,
            lengths.to(params.device),
            params.tgt_id,
            beam_size=params.beam,
            length_penalty=params.length_penalty,
            early_stopping=False,
            max_len=int(1.5 * lengths.max().item() + 10),
            params=params)

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
예제 #27
0
def main():

    # Load pre-trained model
    model_path = './models/mlm_tlm_xnli15_1024.pth'
    reloaded = torch.load(model_path)
    params = AttrDict(reloaded['params'])

    # build dictionary / update parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build model / reload weights
    model = TransformerModel(params, dico, True, True)
    #model.cuda() #if using GPU
    model.load_state_dict(reloaded['model'])
    """ """

    with open(args.filename, "r") as f:
        sentence_list = f.readlines()[args.sn[0]:args.sn[1]]

    # remove new line symbols
    for i in range(0, len(sentence_list)):
        sentence_list[i] = sentence_list[i].replace("\n", "")

    # save as dataframe and add language tokens
    sentence_df = pd.DataFrame(sentence_list)
    sentence_df.columns = ['sentence']
    sentence_df['language'] = 'en'

    # match xlm format
    sentences = list(zip(sentence_df.sentence, sentence_df.language))(sentence,
                                                                      language)
    """ from XLM repo """
    # add </s> sentence delimiters
    sentences = [(('</s> %s </s>' % sent.strip()).split(), lang)
                 for sent, lang in sentences]

    # Create batch
    bs = len(sentences)
    slen = max([len(sent) for sent, _ in sentences])

    word_ids = torch.LongTensor(slen, bs).fill_(params.pad_index)
    for i in range(len(sentences)):
        sent = torch.LongTensor([dico.index(w) for w in sentences[i][0]])
        word_ids[:len(sent), i] = sent

    lengths = torch.LongTensor([len(sent) for sent, _ in sentences])
    langs = torch.LongTensor([params.lang2id[lang] for _, lang in sentences
                              ]).unsqueeze(0).expand(slen, bs)

    #if using GPU:
    #word_ids=word_ids.cuda()
    #lengths=lengths.cuda()
    #langs=langs.cuda()

    tensor = model('fwd',
                   x=word_ids,
                   lengths=lengths,
                   langs=langs,
                   causal=False).contiguous()
    print(tensor.size())

    # The variable tensor is of shape (sequence_length, batch_size, model_dimension).
    # tensor[0] is a tensor of shape (batch_size, model_dimension) that corresponds to the first hidden state of the last layer of each sentence.
    # This is this vector that we use to finetune on the GLUE and XNLI tasks.
    """ """

    torch.save(tensor[0], args.o)
예제 #28
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    torch.manual_seed(
        params.seed
    )  # Set random seed. NB: Multi-GPU also needs torch.cuda.manual_seed_all(params.seed)
    assert (params.sample_temperature
            == 0) or (params.beam_size == 1), 'Cannot sample with beam search.'
    assert params.amp <= 1, f'params.amp == {params.amp} not yet supported.'
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=False).cuda().eval()
    decoder = TransformerModel(model_params,
                               dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()
    if all([k.startswith('module.') for k in reloaded['encoder'].keys()]):
        reloaded['encoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['encoder'].items()
        }
    encoder.load_state_dict(reloaded['encoder'])
    if all([k.startswith('module.') for k in reloaded['decoder'].keys()]):
        reloaded['decoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['decoder'].items()
        }
    decoder.load_state_dict(reloaded['decoder'])

    if params.amp != 0:
        models = apex.amp.initialize([encoder, decoder],
                                     opt_level=('O%i' % params.amp))
        encoder, decoder = models

    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    # f = io.open(params.output_path, 'w', encoding='utf-8')

    hypothesis = [[] for _ in range(params.beam_size)]
    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd',
                          x=batch.cuda(),
                          lengths=lengths.cuda(),
                          langs=langs.cuda(),
                          causal=False)
        encoded = encoded.transpose(0, 1)
        max_len = int(1.5 * lengths.max().item() + 10)
        if params.beam_size == 1:
            decoded, dec_lengths = decoder.generate(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                max_len=max_len,
                sample_temperature=(None if params.sample_temperature == 0 else
                                    params.sample_temperature))
        else:
            decoded, dec_lengths, all_hyp_strs = decoder.generate_beam(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                beam_size=params.beam_size,
                length_penalty=params.length_penalty,
                early_stopping=params.early_stopping,
                max_len=max_len,
                output_all_hyps=True)
        # hypothesis.extend(convert_to_text(decoded, dec_lengths, dico, params))

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip().replace('<unk>', '<<unk>>')
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))
                               ]).replace('<unk>', '<<unk>>')
            if params.beam_size == 1:
                hypothesis[0].append(target)
            else:
                for hyp_rank in range(params.beam_size):
                    print(
                        all_hyp_strs[j]
                        [hyp_rank if hyp_rank < len(all_hyp_strs[j]) else -1])
                    hypothesis[hyp_rank].append(
                        all_hyp_strs[j]
                        [hyp_rank if hyp_rank < len(all_hyp_strs[j]) else -1])

            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source.replace(
                                 '@@ ', ''), target.replace('@@ ', '')))
            # f.write(target + "\n")

    # f.close()

    # export sentences to reference and hypothesis files / restore BPE segmentation
    save_dir, split = params.output_path.rsplit('/', 1)
    for hyp_rank in range(len(hypothesis)):
        hyp_name = f'hyp.st={params.sample_temperature}.bs={params.beam_size}.lp={params.length_penalty}.es={params.early_stopping}.seed={params.seed if (len(hypothesis) == 1) else str(hyp_rank)}.{params.src_lang}-{params.tgt_lang}.{split}.txt'
        hyp_path = os.path.join(save_dir, hyp_name)
        with open(hyp_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(hypothesis[hyp_rank]) + '\n')
        restore_segmentation(hyp_path)

        # evaluate BLEU score
        if params.ref_path:
            bleu = eval_moses_bleu(params.ref_path, hyp_path)
            logger.info("BLEU %s %s : %f" % (hyp_path, params.ref_path, bleu))
예제 #29
0
def main(params):
    # setup random seeds
    set_seed(params.seed)
    params.ar = True

    exp_path = os.path.join(params.dump_path, params.exp_name)
    # create exp path if it doesn't exist
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)
    # create logger
    logger = create_logger(os.path.join(exp_path, 'train.log'), 0)
    logger.info("============ Initialized logger ============")
    logger.info("Random seed is {}".format(params.seed))
    logger.info("\n".join("%s: %s" % (k, str(v))
                          for k, v in sorted(dict(vars(params)).items())))
    logger.info("The experiment will be stored in %s\n" % exp_path)
    logger.info("Running command: %s" % 'python ' + ' '.join(sys.argv))
    logger.info("")
    # load data
    data, loader = load_smiles_data(params)
    if params.data_type == 'ChEMBL':
        all_smiles_mols = open(os.path.join(params.data_path, 'guacamol_v1_all.smiles'), 'r').readlines()
    else:
        all_smiles_mols = open(os.path.join(params.data_path, 'QM9_all.smiles'), 'r').readlines()
    train_data, val_data = data['train'], data['valid']
    dico = data['dico']
    logger.info ('train_data len is {}'.format(len(train_data)))
    logger.info ('val_data len is {}'.format(len(val_data)))

    # keep cycling through train_loader forever
    # stop when max iters is reached
    def rcycle(iterable):
        saved = []                 # In-memory cache
        for element in iterable:
            yield element
            saved.append(element)
        while saved:
            random.shuffle(saved)  # Shuffle every batch
            for element in saved:
                  yield element
    train_loader = rcycle(train_data.get_iterator(shuffle=True, group_by_size=True, n_sentences=-1))

    # extra param names for transformermodel
    params.n_langs = 1
    # build Transformer model
    model = TransformerModel(params, is_encoder=False, with_output=True)

    if params.local_cpu is False:
        model = model.cuda()
    opt = get_optimizer(model.parameters(), params.optimizer)
    scores = {'ppl': np.float('inf'), 'acc': 0}

    if params.load_path:
        reloaded_iter, scores = load_model(params, model, opt, logger)

    for total_iter, train_batch in enumerate(train_loader):
        if params.load_path is not None:
            total_iter += reloaded_iter + 1

        epoch = total_iter // params.epoch_size
        if total_iter == params.max_steps:
            logger.info("============ Done training ... ============")
            break
        elif total_iter % params.epoch_size == 0:
            logger.info("============ Starting epoch %i ... ============" % epoch)
        model.train()
        opt.zero_grad()
        train_loss = calculate_loss(model, train_batch, params)
        train_loss.backward()
        if params.clip_grad_norm > 0:
            clip_grad_norm_(model.parameters(), params.clip_grad_norm)
        opt.step()
        if total_iter % params.print_after == 0:
            logger.info("Step {} ; Loss = {}".format(total_iter, train_loss))

        if total_iter > 0 and total_iter % params.epoch_size == (params.epoch_size - 1):
            # run eval step (calculate validation loss)
            model.eval()
            n_chars = 0
            xe_loss = 0
            n_valid = 0
            logger.info("============ Evaluating ... ============")
            val_loader = val_data.get_iterator(shuffle=True)
            for val_iter, val_batch in enumerate(val_loader):
                with torch.no_grad():
                    val_scores, val_loss, val_y = calculate_loss(model, val_batch, params, get_scores=True)
                # update stats
                n_chars += val_y.size(0)
                xe_loss += val_loss.item() * len(val_y)
                n_valid += (val_scores.max(1)[1] == val_y).sum().item()

            ppl = np.exp(xe_loss / n_chars)
            acc = 100. * n_valid / n_chars
            logger.info("Acc={}, PPL={}".format(acc, ppl))
            if acc > scores['acc']:
                scores['acc'] = acc
                scores['ppl'] = ppl
                save_model(params, data, model, opt, dico, logger, 'best_model', epoch, total_iter, scores)
                logger.info('Saving new best_model {}'.format(epoch))
                logger.info("Best Acc={}, PPL={}".format(scores['acc'], scores['ppl']))

            logger.info("============ Generating ... ============")
            number_samples = 100
            gen_smiles = generate_smiles(params, model, dico, number_samples)
            generator = ARMockGenerator(gen_smiles)

            try:
                benchmark = ValidityBenchmark(number_samples=number_samples)
                validity_score = benchmark.assess_model(generator).score
            except:
                validity_score = -1
            try:
                benchmark = UniquenessBenchmark(number_samples=number_samples)
                uniqueness_score = benchmark.assess_model(generator).score
            except:
                uniqueness_score = -1

            try:
                benchmark = KLDivBenchmark(number_samples=number_samples, training_set=all_smiles_mols)
                kldiv_score = benchmark.assess_model(generator).score
            except:
                kldiv_score = -1
            logger.info('Validity Score={}, Uniqueness Score={}, KlDiv Score={}'.format(validity_score, uniqueness_score, kldiv_score))
            save_model(params, data, model, opt, dico, logger, 'model', epoch, total_iter, {'ppl': ppl, 'acc': acc})