def init_training(args): from functools import partial import pickle pickle.load = partial(pickle.load, encoding="latin1") pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") # model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle) vocab = torch.load(args.vocab, map_location=lambda storage, loc: storage, pickle_module=pickle) model = NMT(args, vocab) model.train() if args.uniform_init: print('uniformly initialize parameters [-%f, +%f]' % (args.uniform_init, args.uniform_init), file=sys.stderr) for p in model.parameters(): p.data.uniform_(-args.uniform_init, args.uniform_init) vocab_mask = torch.ones(len(vocab.tgt)) vocab_mask[vocab.tgt['<pad>']] = 0 nll_loss = nn.NLLLoss(weight=vocab_mask, reduction='sum') cross_entropy_loss = nn.CrossEntropyLoss(weight=vocab_mask, reduction='sum') if args.cuda: # model = nn.DataParallel(model).cuda() model = model.cuda() nll_loss = nll_loss.cuda() cross_entropy_loss = cross_entropy_loss.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) return vocab, model, optimizer, nll_loss, cross_entropy_loss
def test(args): test_data_src = read_corpus(args.test_src, source='src') test_data_tgt = read_corpus(args.test_tgt, source='tgt') test_data = list(zip(test_data_src, test_data_tgt)) if args.load_model: print('load model from [%s]' % args.load_model) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) vocab = params['vocab'] saved_args = params['args'] state_dict = params['state_dict'] model = NMT(saved_args, vocab) model.load_state_dict(state_dict) else: vocab = torch.load(args.vocab) model = NMT(args, vocab) model.eval() if args.cuda: # model = nn.DataParallel(model).cuda() model = model.cuda() hypotheses = decode(model, test_data) top_hypotheses = [hyps[0] for hyps in hypotheses] bleu_score = get_bleu([tgt for src, tgt in test_data], top_hypotheses) word_acc = get_acc([tgt for src, tgt in test_data], top_hypotheses, 'word_acc') sent_acc = get_acc([tgt for src, tgt in test_data], top_hypotheses, 'sent_acc') print('Corpus Level BLEU: %f, word level acc: %f, sentence level acc: %f' % (bleu_score, word_acc, sent_acc), file=sys.stderr) if args.save_to_file: print('save decoding results to %s' % args.save_to_file) with open(args.save_to_file, 'w') as f: for hyps in hypotheses: f.write(' '.join(hyps[0][1:-1]) + '\n') if args.save_nbest: nbest_file = args.save_to_file + '.nbest' print('save nbest decoding results to %s' % nbest_file) with open(nbest_file, 'w') as f: for src_sent, tgt_sent, hyps in zip(test_data_src, test_data_tgt, hypotheses): print('Source: %s' % ' '.join(src_sent), file=f) print('Target: %s' % ' '.join(tgt_sent), file=f) print('Hypotheses:', file=f) for i, hyp in enumerate(hyps, 1): print('[%d] %s' % (i, ' '.join(hyp)), file=f) print('*' * 30, file=f)
def sample(args): train_data_src = read_corpus(args.train_src, source='src') train_data_tgt = read_corpus(args.train_tgt, source='tgt') train_data = zip(train_data_src, train_data_tgt) if args.load_model: print('load model from [%s]' % args.load_model) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) vocab = params['vocab'] opt = params['args'] state_dict = params['state_dict'] model = NMT(opt, vocab) model.load_state_dict(state_dict) else: vocab = torch.load(args.vocab) model = NMT(args, vocab) model.eval() if args.cuda: # model = nn.DataParallel(model).cuda() model = model.cuda() print('begin sampling') check_every = 10 train_iter = cum_samples = 0 train_time = time.time() for src_sents, tgt_sents in data_iter(train_data, batch_size=args.batch_size): train_iter += 1 samples = model.sample(src_sents, sample_size=args.sample_size, to_word=True) cum_samples += sum(len(sample) for sample in samples) if train_iter % check_every == 0: elapsed = time.time() - train_time print('sampling speed: %d/s' % (cum_samples / elapsed)) cum_samples = 0 train_time = time.time() for i, tgt_sent in enumerate(tgt_sents): print('*' * 80) print('target:' + ' '.join(tgt_sent)) tgt_samples = samples[i] print('samples:') for sid, sample in enumerate(tgt_samples, 1): print('[%d] %s' % (sid, ' '.join(sample[1:-1]))) print('*' * 80)
def interactive(args): assert args.load_model, 'You have to specify a pre-trained model' print('load model from [%s]' % args.load_model) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) vocab = params['vocab'] saved_args = params['args'] state_dict = params['state_dict'] model = NMT(saved_args, vocab) model.load_state_dict(state_dict) model.eval() if args.cuda: # model = nn.DataParallel(model).cuda() model = model.cuda() while True: src_sent = input('Source Sentence:') src_sent = src_sent.strip().split(' ') hyps = model.translate(src_sent) for i, hyp in enumerate(hyps, 1): print('Hypothesis #%d: %s' % (i, ' '.join(hyp)))
def train(): text = Text(config.src_corpus, config.tar_corpus) train_data = Data(config.train_path_src, config.train_path_tar) dev_data = Data(config.dev_path_src, config.dev_path_tar) train_loader = DataLoader(dataset=train_data, batch_size=config.batch_size, shuffle=True, collate_fn=utils.get_batch) dev_loader = DataLoader(dataset=dev_data, batch_size=config.dev_batch_size, shuffle=True, collate_fn=utils.get_batch) parser = OptionParser() parser.add_option("--embed_size", dest="embed_size", default=config.embed_size) parser.add_option("--hidden_size", dest="hidden_size", default=config.hidden_size) parser.add_option("--window_size_d", dest="window_size_d", default=config.window_size_d) parser.add_option("--encoder_layer", dest="encoder_layer", default=config.encoder_layer) parser.add_option("--decoder_layers", dest="decoder_layers", default=config.decoder_layers) parser.add_option("--dropout_rate", dest="dropout_rate", default=config.dropout_rate) (options, args) = parser.parse_args() device = torch.device("cuda:0" if config.cuda else "cpu") #model_path = "/home/wangshuhe/shuhelearn/ShuHeLearning/NMT_attention/result/01.31_drop0.3_54_21.46508598886769_checkpoint.pth" #print(f"load model from {model_path}", file=sys.stderr) #model = NMT.load(model_path) model = NMT(text, options, device) #model = model.cuda() #model_path = "/home/wangshuhe/shuhelearn/ShuHeLearning/NMT_attention/result/140_164.29781984744628_checkpoint.pth" #print(f"load model from {model_path}", file=sys.stderr) #model = NMT.load(model_path) #model = torch.nn.DataParallel(model) model = model.to(device) model = model.cuda() model.train() optimizer = Optim(torch.optim.Adam(model.parameters())) #optimizer = Optim(torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), config.hidden_size, config.warm_up_step) #print(optimizer.lr) epoch = 0 valid_num = 1 hist_valid_ppl = [] print("begin training!") while (True): epoch += 1 max_iter = int(math.ceil(len(train_data) / config.batch_size)) with tqdm(total=max_iter, desc="train") as pbar: for src_sents, tar_sents, tar_words_num_to_predict in train_loader: optimizer.zero_grad() batch_size = len(src_sents) now_loss = -model(src_sents, tar_sents) now_loss = now_loss.sum() loss = now_loss / batch_size loss.backward() _ = torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad) #optimizer.updata_lr() optimizer.step_and_updata_lr() pbar.set_postfix({ "epwwoch": epoch, "avg_loss": loss.item(), "ppl": math.exp(now_loss.item() / tar_words_num_to_predict), "lr": optimizer.lr }) #pbar.set_postfix({"epoch": epoch, "avg_loss": loss.item(), "ppl": math.exp(now_loss.item()/tar_words_num_to_predict)}) pbar.update(1) #print(optimizer.lr) if (epoch % config.valid_iter == 0): #if (epoch >= config.valid_iter//2): if (valid_num % 5 == 0): valid_num = 0 optimizer.updata_lr() valid_num += 1 print("now begin validation ...", file=sys.stderr) eav_ppl = evaluate_ppl(model, dev_data, dev_loader) print("validation ppl %.2f" % (eav_ppl), file=sys.stderr) flag = len(hist_valid_ppl) == 0 or eav_ppl < min(hist_valid_ppl) if (flag): print("current model is the best!, save to [%s]" % (config.model_save_path), file=sys.stderr) hist_valid_ppl.append(eav_ppl) model.save( os.path.join( config.model_save_path, f"02.08_window35drop0.2_{epoch}_{eav_ppl}_checkpoint.pth" )) torch.save( optimizer.optimizer.state_dict(), os.path.join( config.model_save_path, f"02.08_window35drop0.2_{epoch}_{eav_ppl}_optimizer.optim" )) if (epoch == config.max_epoch): print("reach the maximum number of epochs!", file=sys.stderr) return
def compute_lm_prob(args): """ given source-target sentence pairs, compute ppl and log-likelihood """ test_data_src = read_corpus(args.test_src, source='src') test_data_tgt = read_corpus(args.test_tgt, source='tgt') test_data = zip(test_data_src, test_data_tgt) if args.load_model: print('load model from [%s]' % args.load_model) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) vocab = params['vocab'] saved_args = params['args'] state_dict = params['state_dict'] model = NMT(saved_args, vocab) model.load_state_dict(state_dict) else: vocab = torch.load(args.vocab) model = NMT(args, vocab) model.eval() if args.cuda: # model = nn.DataParallel(model).cuda() model = model.cuda() f = open(args.save_to_file, 'w') for src_sent, tgt_sent in test_data: src_sents = [src_sent] tgt_sents = [tgt_sent] batch_size = len(src_sents) src_sents_len = [len(s) for s in src_sents] pred_tgt_word_nums = [len(s[1:]) for s in tgt_sents] # omitting leading `<s>` # (sent_len, batch_size) src_sents_var = to_input_variable(src_sents, model.vocab.src, cuda=args.cuda, is_test=True) tgt_sents_var = to_input_variable(tgt_sents, model.vocab.tgt, cuda=args.cuda, is_test=True) # (tgt_sent_len, batch_size, tgt_vocab_size) scores = model(src_sents_var, src_sents_len, tgt_sents_var[:-1]) # (tgt_sent_len * batch_size, tgt_vocab_size) log_scores = F.log_softmax(scores.view(-1, scores.size(2))) # remove leading <s> in tgt sent, which is not used as the target # (batch_size * tgt_sent_len) flattened_tgt_sents = tgt_sents_var[1:].view(-1) # (batch_size * tgt_sent_len) tgt_log_scores = torch.gather( log_scores, 1, flattened_tgt_sents.unsqueeze(1)).squeeze(1) # 0-index is the <pad> symbol tgt_log_scores = tgt_log_scores * ( 1. - torch.eq(flattened_tgt_sents, 0).float()) # (tgt_sent_len, batch_size) tgt_log_scores = tgt_log_scores.view(-1, batch_size) # .permute(1, 0) # (batch_size) tgt_sent_scores = tgt_log_scores.sum(dim=0).squeeze() tgt_sent_word_scores = [ tgt_sent_scores[i].item() / pred_tgt_word_nums[i] for i in range(batch_size) ] for src_sent, tgt_sent, score in zip(src_sents, tgt_sents, tgt_sent_word_scores): f.write('%s ||| %s ||| %f\n' % (' '.join(src_sent), ' '.join(tgt_sent), score)) f.close()