Ejemplo n.º 1
0
# encoding: utf-8
"""
@author: J.Zhang
@contact: [email protected]
@site: https://github.com/kaitokuroba7
@software: PyCharm
@file: RNN_main.py
@time: 2021/3/19 11:29
"""
import torch
import Function.utils as d2l
import RNN_basic as Func
import RNN_Func as RNN_Func

(corpus_indices, char_to_idx, idx_to_char,
 vocab_size) = d2l.load_data_jay_lyrics()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
print('will use', device)

X = torch.arange(10).view(2, 5)
inputs = Func.to_onehot(X, vocab_size)
print(len(inputs), inputs[0].shape)

state = Func.init_rnn_state(X.shape[0], num_hiddens, device)
inputs = Func.to_onehot(X.to(device), vocab_size)
params = Func.get_params(num_inputs, num_hiddens, num_outputs, device)
output, state_new = Func.rnn(inputs, state, params)
print(len(output), output[0].shape, state_new[0].shape)
Ejemplo n.º 2
0
"""
@author: J.Zhang
@contact: [email protected]
@site: https://github.com/kaitokuroba7
@software: PyCharm
@file: RNN_main.py
@time: 2021/3/22 17:01
"""
import torch
import torch.nn as nn
import Function.utils as d2l
import RNN_model as Model
import RNN_Func as Func

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()
num_steps = 35
num_hiddens = 256
# rnn_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens) # 已测试
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)
model = Model.RNNModel(rnn_layer=rnn_layer, vocab_size=vocab_size).to(device)

res = Func.predict_rnn_pytorch('分开', 10, model, vocab_size=vocab_size, device=device,
                               idx_to_char=idx_to_char, char_to_idx=char_to_idx)

print(res)

if __name__ == "__main__":
    num_epochs, batch_size, lr, clipping_theta = 250, 32, 1e-3, 1e-2  # 注意这里的学习率设置
    pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']
    Func.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,