예제 #1
0
def load_model(model_filename):
    with open(model_filename, 'rb') as f:
        checkpoint = torch.load(f)

    n_hidden, n_layers, state_dict, chars = checkpoint['n_hidden'], checkpoint['n_layers'], \
                                            checkpoint['state_dict'], checkpoint['chars']

    model = CharRNN(chars=chars, n_hidden=n_hidden, n_layers=n_layers)
    model.load_state_dict(state_dict=state_dict)
    return model
예제 #2
0
def start_text(start_text, n_words=250):
    # Here we have loaded in a model that trained over 20 epochs `rnn_20_epoch.net`
    with open('saved_model/rnn_20_epoch.net', 'rb') as f:
        checkpoint = torch.load(f)

    loaded = CharRNN(checkpoint['tokens'],
                     n_hidden=checkpoint['n_hidden'],
                     n_layers=checkpoint['n_layers'])
    loaded.load_state_dict(checkpoint['state_dict'])

    generated_text = sample(loaded,
                            n_words,
                            top_k=5,
                            prime='{} '.format(start_text))

    generated_text = generated_text.replace('\n', ' ')
    generated_text = '{}.'.format(generated_text.split('.')[0])

    return generated_text
예제 #3
0
from multiprocessing import Value

import numpy as np
import torch
from flask import Flask, request, jsonify, render_template
from torch.nn import functional as F

from model import CharRNN
from settings import *
from utils import load_dict, create_tune_header

app = Flask(__name__)
print("Environment:", app.config["ENV"])
# Create and load model
model = CharRNN(n_char)
model.load_state_dict(torch.load(default_model_path, map_location='cpu'))
model.eval()

# Load necessary dictionaries
int2char = load_dict(int2char_path)
char2int = load_dict(char2int_path)

counter = Value("i", 0)

error_message = "We created some tunes, but it seems like we can't create music from these melodies.Please try again!"

print("Ready!")


@app.route("/")
def generate_song():
예제 #4
0
quotes = read_file(file)
tokens = list(set(''.join(quotes)))
token_to_id = {token: idx for idx, token in enumerate(tokens)}
id_to_token = {idx: token for token, idx in token_to_id.items()}
num_tokens = len(tokens)
# encode the text
encoded = np.array([token_to_id[ch] for ch in quotes])


# Here we have loaded in a model that trained over 2 epochs `rnn_20_epoch.net`
with open('rnn_50_epoch.net', 'rb') as f:
    checkpoint = torch.load(f)

loaded = CharRNN(num_tokens, n_hidden=checkpoint['n_hidden'], n_layers=checkpoint['n_layers'])
loaded.load_state_dict(checkpoint['state_dict'])

def predict(net, char, h=None, top_k=None):
    ''' Given a character, predict the next character.
        Returns the predicted character and the hidden state.
    '''

    # tensor inputs
    x = np.array([[token_to_id[char]]])
    inputs = torch.from_numpy(x)

    if (train_on_gpu):
        inputs = inputs.cuda()

    # detach hidden state from history
    h = tuple([each.data for each in h])
        o, h = model(inp, h)
        out_dist = o.view(-1).div(temperature).view(-1).exp().cpu()
        top_i = torch.multinomial(out_dist, 1)[0]
        result.append(top_i)

    return seed_characters + int2str(result, all_chars)


if __name__ == '__main__':
    all_characters = string.printable
    input_size = len(all_characters)

    pre = torch.load("./model_50_RNN.pth")
    models_RNN = CharRNN(input_size, 512, input_size, 4).cuda()

    models_RNN.load_state_dict(pre)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    temp = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    rnns = []
    for i in temp:
        result = generate(models_RNN, "A", i, device)
        rnns.append(result)

    with open("RNN_Result_50.txt", "w") as f:
        for item in rnns:
            f.write("%s\n" % item)
    f.close()

    pre = torch.load("./model_50_LSTM.pth")
예제 #6
0
파일: sample.py 프로젝트: unexge/gdpn
            pred += vocab.idx_to_char(topi.item())

            input = torch.from_numpy(
                np.array([convert_idx_to_one_hot(vocab,
                                                 topi.item())])).unsqueeze(0)
            input = input.float().to(config.device)

        return pred


def load_checkpoint(path: str):
    checkpoint = torch.load(path)

    vocab = load_vocab_from_chars(checkpoint['vocab_chars'])
    model_state_dict = checkpoint['model_state_dict']

    return model_state_dict, vocab


if __name__ == '__main__':
    state_dict, vocab = load_checkpoint(sys.argv[1])

    # TODO: load hyperparameters from checkpoint?
    model = CharRNN(len(vocab), 256, 2)
    model.load_state_dict(state_dict)
    model.to(config.device)

    for _ in range(25):
        print(sample(model, vocab))