Esempio n. 1
0
def run(args):
    writer = SummaryWriter()
    src, tgt, _, _ = build_dataset(args)

    print('Loading test data split.')
    _, _, test_gen = datasets.Multi30k.splits(
        exts=(build_file_extension(args.src_language),
              build_file_extension(args.tgt_language)),
        fields=(('src', src), ('tgt', tgt)),
        filter_pred=lambda x: len(vars(x)['src']) <= args.max_seq_length and
        len(vars(x)['tgt']) <= args.max_seq_length)
    print('Finished loading test data split.')

    src_vocab_size = len(src.vocab.itos)
    tgt_vocab_size = len(tgt.vocab.itos)

    _, _, test_iterator = data.Iterator.splits(
        (_, _, test_gen),
        sort_key=lambda x: len(x.src),
        batch_sizes=(args.batch_size, args.batch_size, args.batch_size))

    print('Instantiating model...')
    device = args.device
    model = Transformer(src_vocab_size,
                        tgt_vocab_size,
                        device,
                        p_dropout=args.dropout)
    model = model.to(device)
    model.load_state_dict(torch.load(args.model))
    print('Model instantiated!')

    print('Starting testing...')
    test(model, test_iterator, src.vocab, tgt.vocab, args, writer)
    print('Finished testing.')
Esempio n. 2
0
def do_predict():
    train_iterator, valid_iterator, test_iterator, SRC, TGT = prepare_data_multi30k()
    src_pad_idx = SRC.vocab.stoi[SRC.pad_token]
    tgt_pad_idx = TGT.vocab.stoi[TGT.pad_token]
    src_vocab_size = len(SRC.vocab)
    tgt_vocab_size = len(TGT.vocab)

    model = Transformer(n_src_vocab=src_vocab_size,
                        n_trg_vocab=tgt_vocab_size,
                        src_pad_idx=src_pad_idx,
                        trg_pad_idx=tgt_pad_idx,
                        d_word_vec=256,
                        d_model=256,
                        d_inner=512,
                        n_layer=3,
                        n_head=8,
                        dropout=0.1,
                        n_position=200)
    model.cuda()

    model_dir  = "./checkpoint/transformer"
    model_path = os.path.join(model_dir, "model_9.pt")
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)

    model.eval()
    pre_sents = []
    gth_sents = []
    for idx, batch in enumerate(test_iterator):
        if idx % 10 == 0: print("[TIME] --- time: {} --- [TIME]".format(time.ctime(time.time())))
        # src_seq: [seq_len, batch_size]
        # tgt_seq: [seq_len, batch_size]
        src_seq, src_len = batch.src
        tgt_seq, tgt_len = batch.trg

        batch_size = src_seq.size(0)
        pre_tokens = []
        with torch.no_grad():
            for idx in range(batch_size):
                tokens = translate_tokens(src_seq[idx], SRC, TGT, model, max_len=32)
                pre_tokens.append(tokens)

        # tgt: [batch_size, seq_len]
        gth_tokens = tgt_seq.cpu().detach().numpy().tolist()
        for tokens, gth_ids in zip(pre_tokens, gth_tokens):
            gth = [TGT.vocab.itos[idx] for idx in gth_ids]
            pre_sents.append(" ".join(tokens))
            gth_sents.append(" ".join(gth))

    pre_path = os.path.join(model_dir, "pre.json")
    gth_path = os.path.join(model_dir, "gth.json")
    with open(pre_path, "w", encoding="utf-8") as writer:
        json.dump(pre_sents, writer, ensure_ascii=False, indent=4)
    with open(gth_path, "w", encoding="utf-8") as writer:
        json.dump(gth_sents, writer, ensure_ascii=False, indent=4)
    def __init__(self, src_vocab, tgt_vocab, src_vocab_size, tgt_vocab_size,
                 args):
        self.max_seq_length = args.max_seq_length
        self.device = args.device
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.beam_size = args.beam_size

        model = Transformer(src_vocab_size, tgt_vocab_size, args.device)
        model.load_state_dict(torch.load(args.model))
        model = model.to(args.device)
        self.model = model
        self.model.eval()
def run(args):
    writer = SummaryWriter()
    src, tgt, train_iterator, val_iterator = build_dataset(args)

    src_vocab_size = len(src.vocab.itos)
    tgt_vocab_size = len(tgt.vocab.itos)

    print('Instantiating model...')
    device = args.device
    model = Transformer(src_vocab_size,
                        tgt_vocab_size,
                        device,
                        p_dropout=args.dropout)
    model = model.to(device)

    if args.checkpoint is not None:
        model.load_state_dict(torch.load(args.checkpoint))
    else:
        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    print('Model instantiated!')

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.9, 0.98),
                           eps=1e-9)

    print('Starting training...')
    for epoch in range(args.epochs):
        acc = train(model, epoch + 1, train_iterator, optimizer, src.vocab,
                    tgt.vocab, args, writer)
        model_file = 'models/model_' + str(epoch) + '_' + str(acc) + '.pth'
        torch.save(model.state_dict(), model_file)
        print('Saved model to ' + model_file)
        validate(model, epoch + 1, val_iterator, src.vocab, tgt.vocab, args,
                 writer)
    print('Finished training.')
Esempio n. 5
0
"""
import time
import torch
from transformer.transformer import Transformer


if __name__ == '__main__':
    '''
    从tar中提取模型 整理成pt文件
    '''
    checkpoint = 'BEST_Model.tar'
    print('loading {}...'.format(checkpoint))
    start = time.time()
    checkpoint = torch.load(checkpoint)
    print('elapsed {} sec'.format(time.time() - start))
    model = checkpoint['model']
    print(type(model))

    filename = 'reading_comprehension.pt'
    print('saving {}...'.format(filename))
    start = time.time()
    torch.save(model.state_dict(), filename)
    print('elapsed {} sec'.format(time.time() - start))

    print('loading {}...'.format(filename))
    start = time.time()
    model = Transformer()
    model.load_state_dict(torch.load(filename))
    print('elapsed {} sec'.format(time.time() - start))