def run(args): writer = SummaryWriter() src, tgt, _, _ = build_dataset(args) print('Loading test data split.') _, _, test_gen = datasets.Multi30k.splits( exts=(build_file_extension(args.src_language), build_file_extension(args.tgt_language)), fields=(('src', src), ('tgt', tgt)), filter_pred=lambda x: len(vars(x)['src']) <= args.max_seq_length and len(vars(x)['tgt']) <= args.max_seq_length) print('Finished loading test data split.') src_vocab_size = len(src.vocab.itos) tgt_vocab_size = len(tgt.vocab.itos) _, _, test_iterator = data.Iterator.splits( (_, _, test_gen), sort_key=lambda x: len(x.src), batch_sizes=(args.batch_size, args.batch_size, args.batch_size)) print('Instantiating model...') device = args.device model = Transformer(src_vocab_size, tgt_vocab_size, device, p_dropout=args.dropout) model = model.to(device) model.load_state_dict(torch.load(args.model)) print('Model instantiated!') print('Starting testing...') test(model, test_iterator, src.vocab, tgt.vocab, args, writer) print('Finished testing.')
def do_predict(): train_iterator, valid_iterator, test_iterator, SRC, TGT = prepare_data_multi30k() src_pad_idx = SRC.vocab.stoi[SRC.pad_token] tgt_pad_idx = TGT.vocab.stoi[TGT.pad_token] src_vocab_size = len(SRC.vocab) tgt_vocab_size = len(TGT.vocab) model = Transformer(n_src_vocab=src_vocab_size, n_trg_vocab=tgt_vocab_size, src_pad_idx=src_pad_idx, trg_pad_idx=tgt_pad_idx, d_word_vec=256, d_model=256, d_inner=512, n_layer=3, n_head=8, dropout=0.1, n_position=200) model.cuda() model_dir = "./checkpoint/transformer" model_path = os.path.join(model_dir, "model_9.pt") state_dict = torch.load(model_path) model.load_state_dict(state_dict) model.eval() pre_sents = [] gth_sents = [] for idx, batch in enumerate(test_iterator): if idx % 10 == 0: print("[TIME] --- time: {} --- [TIME]".format(time.ctime(time.time()))) # src_seq: [seq_len, batch_size] # tgt_seq: [seq_len, batch_size] src_seq, src_len = batch.src tgt_seq, tgt_len = batch.trg batch_size = src_seq.size(0) pre_tokens = [] with torch.no_grad(): for idx in range(batch_size): tokens = translate_tokens(src_seq[idx], SRC, TGT, model, max_len=32) pre_tokens.append(tokens) # tgt: [batch_size, seq_len] gth_tokens = tgt_seq.cpu().detach().numpy().tolist() for tokens, gth_ids in zip(pre_tokens, gth_tokens): gth = [TGT.vocab.itos[idx] for idx in gth_ids] pre_sents.append(" ".join(tokens)) gth_sents.append(" ".join(gth)) pre_path = os.path.join(model_dir, "pre.json") gth_path = os.path.join(model_dir, "gth.json") with open(pre_path, "w", encoding="utf-8") as writer: json.dump(pre_sents, writer, ensure_ascii=False, indent=4) with open(gth_path, "w", encoding="utf-8") as writer: json.dump(gth_sents, writer, ensure_ascii=False, indent=4)
def __init__(self, src_vocab, tgt_vocab, src_vocab_size, tgt_vocab_size, args): self.max_seq_length = args.max_seq_length self.device = args.device self.src_vocab = src_vocab self.tgt_vocab = tgt_vocab self.beam_size = args.beam_size model = Transformer(src_vocab_size, tgt_vocab_size, args.device) model.load_state_dict(torch.load(args.model)) model = model.to(args.device) self.model = model self.model.eval()
def run(args): writer = SummaryWriter() src, tgt, train_iterator, val_iterator = build_dataset(args) src_vocab_size = len(src.vocab.itos) tgt_vocab_size = len(tgt.vocab.itos) print('Instantiating model...') device = args.device model = Transformer(src_vocab_size, tgt_vocab_size, device, p_dropout=args.dropout) model = model.to(device) if args.checkpoint is not None: model.load_state_dict(torch.load(args.checkpoint)) else: for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) print('Model instantiated!') optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9) print('Starting training...') for epoch in range(args.epochs): acc = train(model, epoch + 1, train_iterator, optimizer, src.vocab, tgt.vocab, args, writer) model_file = 'models/model_' + str(epoch) + '_' + str(acc) + '.pth' torch.save(model.state_dict(), model_file) print('Saved model to ' + model_file) validate(model, epoch + 1, val_iterator, src.vocab, tgt.vocab, args, writer) print('Finished training.')
""" import time import torch from transformer.transformer import Transformer if __name__ == '__main__': ''' 从tar中提取模型 整理成pt文件 ''' checkpoint = 'BEST_Model.tar' print('loading {}...'.format(checkpoint)) start = time.time() checkpoint = torch.load(checkpoint) print('elapsed {} sec'.format(time.time() - start)) model = checkpoint['model'] print(type(model)) filename = 'reading_comprehension.pt' print('saving {}...'.format(filename)) start = time.time() torch.save(model.state_dict(), filename) print('elapsed {} sec'.format(time.time() - start)) print('loading {}...'.format(filename)) start = time.time() model = Transformer() model.load_state_dict(torch.load(filename)) print('elapsed {} sec'.format(time.time() - start))