def train_and_eval(model, corpus, optimizer, criterion, params, args, save_path): val_data = batchify(corpus.valid, args["eval_batch_size"], args) test_data = batchify(corpus.test, args["test_batch_size"], args) best_val_loss = [] stored_loss = 100000000 try: for epoch in range(1, args["epochs"] + 1): epoch_start_time = time.time() train(model, corpus, optimizer, criterion, params, epoch, args) if 't0' in optimizer.param_groups[0]: tmp = {} for prm in model.parameters(): tmp[prm] = prm.data.clone() prm.data = optimizer.state[prm]['ax'].clone() val_loss2 = evaluate(model, criterion, args, val_data) print('-' * 89) print( '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f} | valid bpc {:8.3f}'.format( epoch, (time.time() - epoch_start_time), val_loss2, math.exp(val_loss2), val_loss2 / math.log(2))) print('-' * 89) if val_loss2 < stored_loss: model_save(save_path, model, criterion, optimizer) print('Saving Averaged!') stored_loss = val_loss2 for prm in model.parameters(): prm.data = tmp[prm].clone() else: val_loss = evaluate(model, criterion, args, val_data, args["eval_batch_size"]) print('-' * 89) print( '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f} | valid bpc {:8.3f}'.format( epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss), val_loss / math.log(2))) print('-' * 89) if val_loss < stored_loss: model_save(save_path, model, criterion, optimizer) print('Saving model (new best validation)') stored_loss = val_loss if args["optimizer"] == 'sgd' and 't0' not in optimizer.param_groups[ 0] and (len(best_val_loss) > args["nonmono"] and val_loss > min( best_val_loss[:-args["nonmono"]])): print('Switching to ASGD') optimizer = torch.optim.ASGD(model.parameters(), lr=args["lr"], t0=0, lambd=0., weight_decay=args["wdecay"]) if "when" in args and epoch in args["when"]: print('Saving model before learning rate decreased') # model_save('{}.e{}'.format(save_path, epoch)) model_save('{}.e{}'.format(save_path, epoch), model, criterion, optimizer) print('Dividing learning rate by 10') optimizer.param_groups[0]['lr'] /= 10. best_val_loss.append(val_loss) except KeyboardInterrupt: print('-' * 89) print('Exiting from training early') # Load the best saved model. if os.path.exists(save_path): model_state_dict, criterion, params = model_load(save_path) model.load_state_dict(model_state_dict) # Run on test data. test_loss = evaluate(model, criterion, args, test_data, args["test_batch_size"]) print('=' * 89) print( '| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}' .format(test_loss, math.exp(test_loss), test_loss / math.log(2))) print('=' * 89)
import torch from data_utils import load_hierarchical_corpus from main import args, data_hierarchical_path, model_type_save_path, model_type_to_subtype_save_path, model_subtype_to_word_save_path, model_hierarchical_save_path from rnn_model.build_model import get_model as get_rnn_model from double_input_rnn_model.build_model import get_model as get_double_input_rnn_model from hierarchical_model.build_model import get_model as get_hierarchical_model from rnn_model.utils import model_load from beam_search import beam_search, idx2word device = "cuda:0" if args["cuda"] else "cpu" with torch.no_grad(): corpus_hierarchical = load_hierarchical_corpus(data_hierarchical_path) model_type, _, _ = get_rnn_model(corpus_hierarchical, args) model_type_to_subtype, _, _ = get_double_input_rnn_model(corpus_hierarchical, args) model_subtype_to_word, _, _ = get_double_input_rnn_model(corpus_hierarchical, args) model_hierarchical, _, _ = get_hierarchical_model(corpus_hierarchical, model_subtype_to_word, model_type_to_subtype, model_type, args) model_hierarchical_state_dict, _, _ = model_load(model_hierarchical_save_path, device) model_hierarchical.load_state_dict(model_hierarchical_state_dict) initial_sentence = 'season chicken with salt and pepper .' print() print("initial_sentence : ", initial_sentence.replace('_', ' ')) print() initial_sentence = initial_sentence.split() hidden_hierarchical = model_hierarchical.init_hidden(1) search = ' '.join([idx2word(word, corpus_hierarchical) for word in beam_search(model_hierarchical, corpus_hierarchical, hidden_hierarchical, initial_sentence)[-1]]) search = search.replace('_', ' ') print(search)