def run_evaluation(corpus_dir, save_dir, datafile, config_file):
    config = Config.from_json_file(config_file)
    vocab = Vocabulary("words")

    # set checkpoint to load from; set to None if starting from scratch
    load_filename = os.path.join(
        save_dir, config.model_name, config.corpus_name,
        '{}-{}_{}'.format(config.encoder_n_layers, config.decoder_n_layers,
                          config.hidden_size), 'last_checkpoint.tar')

    # if loading on the same machine the model trained on
    checkpoint = torch.load(load_filename)
    # if loading a model trained on gpu to cpu
    # checkpoint = torch.load(load_filename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint["en"]
    decoder_sd = checkpoint["de"]
    encoder_optimizer_sd = checkpoint["en_opt"]
    decoder_optimizer_sd = checkpoint["de_opt"]
    embedding_sd = checkpoint["embedding"]
    vocab.__dict__ = checkpoint["voc_dict"]

    print("Building encoder and decoder ...")
    # initialize word embeddings
    embedding = nn.Embedding(vocab.num_words, config.hidden_size)
    embedding.load_state_dict(embedding_sd)

    # initialize encoder and decoder models
    encoder = EncoderRNN(config.hidden_size, embedding,
                         config.encoder_n_layers, config.dropout)
    decoder = LuongAttnDecoderRNN(config.attn_model, embedding,
                                  config.hidden_size, vocab.num_words,
                                  config.decoder_n_layers, config.dropout)

    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)

    # Set dropout layers to eval mode

    encoder.eval()
    decoder.eval()

    # Initialize search module
    searcher = GreedySearchDecoder(encoder, decoder)

    # Begin chatting (uncomment and run the following line to begin)
    evaluate_input(encoder, decoder, searcher, vocab)
Exemplo n.º 2
0
    i = i.strip('_checkpoint.tar')
    if int(i) > int(recent):
        recent = i

LOAD_MODEL_PATH = LOAD_MODEL_PATH + recent + '_checkpoint.tar'

#checkpoint = torch.load(LOAD_MODEL_PATH)
checkpoint = torch.load(LOAD_MODEL_PATH, map_location=torch.device('cpu'))

encoder_sd = checkpoint['en']
decoder_sd = checkpoint['de']
encoder_optimizer_sd = checkpoint['en_opt']
decoder_optimizer_sd = checkpoint['de_opt']
embedding_sd = checkpoint['embedding']
voc = Vocabulary()
voc.__dict__ = checkpoint['voc_dict']

#do not edit these parameters
attn_model = 'dot'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1

embedding = nn.Embedding(voc.num_of_words, hidden_size)
embedding.load_state_dict(embedding_sd)
encoder = Encoder(hidden_size, embedding, encoder_n_layers, dropout)
encoder.load_state_dict(encoder_sd)
decoder = Attention_Decoder(attn_model, embedding, hidden_size, hidden_size,
                            voc.num_of_words, decoder_n_layers, dropout)
decoder.load_state_dict(decoder_sd)