예제 #1
0
def train_and_eval(model, corpus, optimizer, criterion, params, args,
                   save_path):
    val_data = batchify(corpus.valid, args["eval_batch_size"], args)
    test_data = batchify(corpus.test, args["test_batch_size"], args)

    best_val_loss = []
    stored_loss = 100000000

    try:
        for epoch in range(1, args["epochs"] + 1):
            epoch_start_time = time.time()
            train(model, corpus, optimizer, criterion, params, epoch, args)
            if 't0' in optimizer.param_groups[0]:
                tmp = {}
                for prm in model.parameters():
                    tmp[prm] = prm.data.clone()
                    prm.data = optimizer.state[prm]['ax'].clone()

                val_loss2 = evaluate(model, criterion, args, val_data)
                print('-' * 89)
                print(
                    '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                    'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
                        epoch, (time.time() - epoch_start_time), val_loss2,
                        math.exp(val_loss2), val_loss2 / math.log(2)))
                print('-' * 89)

                if val_loss2 < stored_loss:
                    model_save(save_path, model, criterion, optimizer)
                    print('Saving Averaged!')
                    stored_loss = val_loss2

                for prm in model.parameters():
                    prm.data = tmp[prm].clone()

            else:
                val_loss = evaluate(model, criterion, args, val_data,
                                    args["eval_batch_size"])
                print('-' * 89)
                print(
                    '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                    'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
                        epoch, (time.time() - epoch_start_time), val_loss,
                        math.exp(val_loss), val_loss / math.log(2)))
                print('-' * 89)

                if val_loss < stored_loss:
                    model_save(save_path, model, criterion, optimizer)
                    print('Saving model (new best validation)')
                    stored_loss = val_loss

                if args["optimizer"] == 'sgd' and 't0' not in optimizer.param_groups[
                        0] and (len(best_val_loss) > args["nonmono"]
                                and val_loss > min(
                                    best_val_loss[:-args["nonmono"]])):
                    print('Switching to ASGD')
                    optimizer = torch.optim.ASGD(model.parameters(),
                                                 lr=args["lr"],
                                                 t0=0,
                                                 lambd=0.,
                                                 weight_decay=args["wdecay"])

                if "when" in args and epoch in args["when"]:
                    print('Saving model before learning rate decreased')
                    #  model_save('{}.e{}'.format(save_path, epoch))
                    model_save('{}.e{}'.format(save_path, epoch), model,
                               criterion, optimizer)
                    print('Dividing learning rate by 10')
                    optimizer.param_groups[0]['lr'] /= 10.

                best_val_loss.append(val_loss)

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')

    # Load the best saved model.
    if os.path.exists(save_path):
        model_state_dict, criterion, params = model_load(save_path)
        model.load_state_dict(model_state_dict)
        # Run on test data.
        test_loss = evaluate(model, criterion, args, test_data,
                             args["test_batch_size"])
        print('=' * 89)
        print(
            '| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'
            .format(test_loss, math.exp(test_loss), test_loss / math.log(2)))
        print('=' * 89)
예제 #2
0
import torch
from data_utils import load_hierarchical_corpus
from main import args, data_hierarchical_path, model_type_save_path, model_type_to_subtype_save_path, model_subtype_to_word_save_path, model_hierarchical_save_path
from rnn_model.build_model import get_model as get_rnn_model
from double_input_rnn_model.build_model import get_model as get_double_input_rnn_model
from hierarchical_model.build_model import get_model as get_hierarchical_model
from rnn_model.utils import model_load
from beam_search import beam_search, idx2word

device = "cuda:0" if args["cuda"] else "cpu"

with torch.no_grad():
    corpus_hierarchical = load_hierarchical_corpus(data_hierarchical_path)
    model_type, _, _ = get_rnn_model(corpus_hierarchical, args)
    model_type_to_subtype, _, _ = get_double_input_rnn_model(corpus_hierarchical, args)
    model_subtype_to_word, _, _ = get_double_input_rnn_model(corpus_hierarchical, args)

    model_hierarchical, _, _ = get_hierarchical_model(corpus_hierarchical, model_subtype_to_word, model_type_to_subtype, model_type, args)
    model_hierarchical_state_dict, _, _ = model_load(model_hierarchical_save_path, device)
    model_hierarchical.load_state_dict(model_hierarchical_state_dict)

    initial_sentence = 'season chicken with salt and pepper .'
    print()
    print("initial_sentence : ", initial_sentence.replace('_', ' '))
    print()
    initial_sentence = initial_sentence.split()
    hidden_hierarchical = model_hierarchical.init_hidden(1)
    search = ' '.join([idx2word(word, corpus_hierarchical) for word in beam_search(model_hierarchical, corpus_hierarchical, hidden_hierarchical, initial_sentence)[-1]])
    search = search.replace('_', ' ')
    print(search)