Esempio n. 1
0
    def reload(path, params):
        """
        Create a sentence embedder from a pretrained model.
        """
        # reload model
        reloaded = torch.load(path)
        state_dict = reloaded['model']

        # handle models from multi-GPU checkpoints
        if 'checkpoint' in path:
            state_dict = {(k[7:] if k.startswith('module.') else k): v
                          for k, v in state_dict.items()}

        # reload dictionary and model parameters
        dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                          reloaded['dico_counts'])
        pretrain_params = AttrDict(reloaded['params'])
        pretrain_params.n_words = len(dico)
        pretrain_params.bos_index = dico.index(BOS_WORD)
        pretrain_params.eos_index = dico.index(EOS_WORD)
        pretrain_params.pad_index = dico.index(PAD_WORD)
        pretrain_params.unk_index = dico.index(UNK_WORD)
        pretrain_params.mask_index = dico.index(MASK_WORD)

        # build model and reload weights
        model = TransformerModel(pretrain_params, dico, True, True)
        model.load_state_dict(state_dict)
        model.eval()

        # adding missing parameters
        params.max_batch_size = 0

        return MyModel(model, dico, pretrain_params, params)
Esempio n. 2
0
class XLMForTokenClassification(nn.Module):
    def __init__(self, config, num_labels, params, dico, reloaded):
        super().__init__()
        self.config = config
        self.num_labels = num_labels
        self.xlm = TransformerModel(params, dico, True, True)
        self.xlm.eval()
        self.xlm.load_state_dict(reloaded['model'])

        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(1024, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self, word_ids, lengths, langs=None, causal=False):
        sequence_output = self.xlm('fwd',
                                   x=word_ids,
                                   lengths=lengths,
                                   causal=False).contiguous()
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return logits

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
Esempio n. 3
0
def translate(args):
    batch_size = args.batch_size

    src_vocab = Dictionary.read_vocab(args.vocab_src)
    tgt_vocab = Dictionary.read_vocab(args.vocab_tgt)
    data = torch.load(args.reload_path, map_location='cpu')
    model = TransformerModel(src_dictionary=src_vocab,
                             tgt_dictionary=tgt_vocab)
    model.load_state_dict({k: data['module'][k] for k in data['module']})
    model.cuda()
    model.eval()

    if 'epoch' in data:
        print(f"Loading model from epoch_{data['epoch']}....")

    src_sent = open(args.src, "r").readlines()
    for i in range(0, len(src_sent), batch_size):
        word_ids = [
            torch.LongTensor([src_vocab.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(src_vocab.pad_index)
        batch[0] = src_vocab.bos_index

        for j, s in enumerate(word_ids):
            if lengths[j] > 2:
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = src_vocab.eos_index

        batch = batch.cuda()
        encoder_out = model.encoder(batch)

        with torch.no_grad():
            if args.beam == 1:
                generated = model.decoder.generate_greedy(encoder_out)
            else:
                generated = model.decoder.generate_beam(encoder_out,
                                                        beam_size=5)

        for j, s in enumerate(src_sent[i:i + batch_size]):
            print(f"Source_{i+j}: {s.strip()}")
            hypo = []
            for w in generated[j][1:]:
                if tgt_vocab[w.item()] == '</s>':
                    break
                hypo.append(tgt_vocab[w.item()])
            hypo = " ".join(hypo)
            print(f"Target_{i+j}: {hypo}\n")
Esempio n. 4
0
class XLM_BiLSTM_CRF(nn.Module):
    def __init__(self, config, num_labels, params, dico, reloaded):
        super().__init__()
        self.config = config
        self.num_labels = num_labels
        self.batch_size = config.batch_size
        self.hidden_dim = config.hidden_dim

        self.xlm = TransformerModel(params, dico, True, True)
        self.xlm.eval()
        self.xlm.load_state_dict(reloaded['model'])

        self.lstm = nn.LSTM(config.embedding_dim,
                            config.hidden_dim // 2,
                            num_layers=1,
                            bidirectional=True)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_dim, config.num_class)
        self.apply(self.init_bert_weights)
        self.crf = CRF(config.num_class)

    def forward(self, word_ids, lengths, langs=None, causal=False):
        sequence_output = self.xlm('fwd',
                                   x=word_ids,
                                   lengths=lengths,
                                   causal=False).contiguous()
        sequence_output, _ = self.lstm(sequence_output)
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return self.crf.decode(logits)

    def log_likelihood(self, word_ids, lengths, tags):
        sequence_output = self.xlm('fwd',
                                   x=word_ids,
                                   lengths=lengths,
                                   causal=False).contiguous()
        sequence_output, _ = self.lstm(sequence_output)
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return -self.crf(logits, tags.transpose(0, 1))

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
Esempio n. 5
0
#%%
# build dictionary / update parameters
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                  reloaded['dico_counts'])
assert params.n_words == len(dico)
assert params.bos_index == dico.index(BOS_WORD)
assert params.eos_index == dico.index(EOS_WORD)
assert params.pad_index == dico.index(PAD_WORD)
assert params.unk_index == dico.index(UNK_WORD)
assert params.mask_index == dico.index(MASK_WORD)

# build model / reload weights
model = TransformerModel(params, dico, True, True)
model.load_state_dict(reloaded['model'])
model.cuda()
model.eval()

#%%

#%%
FASTBPE_PATH = '/private/home/guismay/tools/fastBPE/fast'
TOKENIZER_PATH = '/private/home/guismay/tools/mosesdecoder/scripts/tokenizer/tokenizer.perl'
DETOKENIZER_PATH = '/private/home/guismay/tools/mosesdecoder/scripts/tokenizer/detokenizer.perl'
BPE_CODES = '/checkpoint/guismay/ccclean/60000/codes.60000'


#%%
def apply_bpe(txt):
    temp1_path = '/tmp/xxx1'
    temp2_path = '/tmp/xxx2'
    with open(temp1_path, 'w', encoding='utf-8') as f:
Esempio n. 6
0
def main(params):
    # setup random seeds
    set_seed(params.seed)
    params.ar = True

    exp_path = os.path.join(params.dump_path, params.exp_name)
    # create exp path if it doesn't exist
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)
    # create logger
    logger = create_logger(os.path.join(exp_path, 'train.log'), 0)
    logger.info("============ Initialized logger ============")
    logger.info("Random seed is {}".format(params.seed))
    logger.info("\n".join("%s: %s" % (k, str(v))
                          for k, v in sorted(dict(vars(params)).items())))
    logger.info("The experiment will be stored in %s\n" % exp_path)
    logger.info("Running command: %s" % 'python ' + ' '.join(sys.argv))
    logger.info("")
    # load data
    data, loader = load_smiles_data(params)
    if params.data_type == 'ChEMBL':
        all_smiles_mols = open(os.path.join(params.data_path, 'guacamol_v1_all.smiles'), 'r').readlines()
    else:
        all_smiles_mols = open(os.path.join(params.data_path, 'QM9_all.smiles'), 'r').readlines()
    train_data, val_data = data['train'], data['valid']
    dico = data['dico']
    logger.info ('train_data len is {}'.format(len(train_data)))
    logger.info ('val_data len is {}'.format(len(val_data)))

    # keep cycling through train_loader forever
    # stop when max iters is reached
    def rcycle(iterable):
        saved = []                 # In-memory cache
        for element in iterable:
            yield element
            saved.append(element)
        while saved:
            random.shuffle(saved)  # Shuffle every batch
            for element in saved:
                  yield element
    train_loader = rcycle(train_data.get_iterator(shuffle=True, group_by_size=True, n_sentences=-1))

    # extra param names for transformermodel
    params.n_langs = 1
    # build Transformer model
    model = TransformerModel(params, is_encoder=False, with_output=True)

    if params.local_cpu is False:
        model = model.cuda()
    opt = get_optimizer(model.parameters(), params.optimizer)
    scores = {'ppl': np.float('inf'), 'acc': 0}

    if params.load_path:
        reloaded_iter, scores = load_model(params, model, opt, logger)

    for total_iter, train_batch in enumerate(train_loader):
        if params.load_path is not None:
            total_iter += reloaded_iter + 1

        epoch = total_iter // params.epoch_size
        if total_iter == params.max_steps:
            logger.info("============ Done training ... ============")
            break
        elif total_iter % params.epoch_size == 0:
            logger.info("============ Starting epoch %i ... ============" % epoch)
        model.train()
        opt.zero_grad()
        train_loss = calculate_loss(model, train_batch, params)
        train_loss.backward()
        if params.clip_grad_norm > 0:
            clip_grad_norm_(model.parameters(), params.clip_grad_norm)
        opt.step()
        if total_iter % params.print_after == 0:
            logger.info("Step {} ; Loss = {}".format(total_iter, train_loss))

        if total_iter > 0 and total_iter % params.epoch_size == (params.epoch_size - 1):
            # run eval step (calculate validation loss)
            model.eval()
            n_chars = 0
            xe_loss = 0
            n_valid = 0
            logger.info("============ Evaluating ... ============")
            val_loader = val_data.get_iterator(shuffle=True)
            for val_iter, val_batch in enumerate(val_loader):
                with torch.no_grad():
                    val_scores, val_loss, val_y = calculate_loss(model, val_batch, params, get_scores=True)
                # update stats
                n_chars += val_y.size(0)
                xe_loss += val_loss.item() * len(val_y)
                n_valid += (val_scores.max(1)[1] == val_y).sum().item()

            ppl = np.exp(xe_loss / n_chars)
            acc = 100. * n_valid / n_chars
            logger.info("Acc={}, PPL={}".format(acc, ppl))
            if acc > scores['acc']:
                scores['acc'] = acc
                scores['ppl'] = ppl
                save_model(params, data, model, opt, dico, logger, 'best_model', epoch, total_iter, scores)
                logger.info('Saving new best_model {}'.format(epoch))
                logger.info("Best Acc={}, PPL={}".format(scores['acc'], scores['ppl']))

            logger.info("============ Generating ... ============")
            number_samples = 100
            gen_smiles = generate_smiles(params, model, dico, number_samples)
            generator = ARMockGenerator(gen_smiles)

            try:
                benchmark = ValidityBenchmark(number_samples=number_samples)
                validity_score = benchmark.assess_model(generator).score
            except:
                validity_score = -1
            try:
                benchmark = UniquenessBenchmark(number_samples=number_samples)
                uniqueness_score = benchmark.assess_model(generator).score
            except:
                uniqueness_score = -1

            try:
                benchmark = KLDivBenchmark(number_samples=number_samples, training_set=all_smiles_mols)
                kldiv_score = benchmark.assess_model(generator).score
            except:
                kldiv_score = -1
            logger.info('Validity Score={}, Uniqueness Score={}, KlDiv Score={}'.format(validity_score, uniqueness_score, kldiv_score))
            save_model(params, data, model, opt, dico, logger, 'model', epoch, total_iter, {'ppl': ppl, 'acc': acc})
Esempio n. 7
0
def main():
    parser.add_argument("--input", type=str, default="", help="input file")
    parser.add_argument("--model", type=str, default="", help="model path")
    parser.add_argument("--spm_model",
                        type=str,
                        default="",
                        help="spm model path")
    parser.add_argument("--batch_size",
                        type=int,
                        default=64,
                        help="batch size")
    parser.add_argument("--max_words", type=int, default=100, help="max words")
    parser.add_argument("--cuda", type=str, default="True", help="use cuda")
    parser.add_argument("--output", type=str, default="", help="output file")
    args = parser.parse_args()

    # Reload a pretrained model
    reloaded = torch.load(args.model)
    params = AttrDict(reloaded['params'])

    # Reload the SPM model
    spm_model = spm.SentencePieceProcessor()
    spm_model.Load(args.spm_model)

    # cuda
    assert args.cuda in ["True", "False"]
    args.cuda = eval(args.cuda)

    # build dictionary / update parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build model / reload weights
    model = TransformerModel(params, dico, True, True)
    reloaded['model'] = OrderedDict({
        key.replace('module.', ''): reloaded['model'][key]
        for key in reloaded['model']
    })
    model.load_state_dict(reloaded['model'])
    model.eval()

    if args.cuda:
        model.cuda()

    # load sentences
    sentences = []
    with open(args.input) as f:
        for line in f:
            line = spm_model.EncodeAsPieces(line.rstrip())
            line = line[:args.max_words - 1]
            sentences.append(line)

    # encode sentences
    embs = []
    for i in range(0, len(sentences), args.batch_size):
        batch = sentences[i:i + args.batch_size]
        lengths = torch.LongTensor([len(s) + 1 for s in batch])
        bs, slen = len(batch), lengths.max().item()
        assert slen <= args.max_words

        x = torch.LongTensor(slen, bs).fill_(params.pad_index)
        for k in range(bs):
            sent = torch.LongTensor([params.eos_index] +
                                    [dico.index(w) for w in batch[k]])
            x[:len(sent), k] = sent

        if args.cuda:
            x = x.cuda()
            lengths = lengths.cuda()

        with torch.no_grad():
            embedding = model('fwd',
                              x=x,
                              lengths=lengths,
                              langs=None,
                              causal=False).contiguous()[0].cpu()

        embs.append(embedding)

    # save embeddings
    torch.save(torch.cat(embs, dim=0).squeeze(0), args.output)