# encoding: utf-8 """ @author: J.Zhang @contact: [email protected] @site: https://github.com/kaitokuroba7 @software: PyCharm @file: RNN_main.py @time: 2021/3/19 11:29 """ import torch import Function.utils as d2l import RNN_basic as Func import RNN_Func as RNN_Func (corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size print('will use', device) X = torch.arange(10).view(2, 5) inputs = Func.to_onehot(X, vocab_size) print(len(inputs), inputs[0].shape) state = Func.init_rnn_state(X.shape[0], num_hiddens, device) inputs = Func.to_onehot(X.to(device), vocab_size) params = Func.get_params(num_inputs, num_hiddens, num_outputs, device) output, state_new = Func.rnn(inputs, state, params) print(len(output), output[0].shape, state_new[0].shape)
""" @author: J.Zhang @contact: [email protected] @site: https://github.com/kaitokuroba7 @software: PyCharm @file: RNN_main.py @time: 2021/3/22 17:01 """ import torch import torch.nn as nn import Function.utils as d2l import RNN_model as Model import RNN_Func as Func device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') (corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics() num_steps = 35 num_hiddens = 256 # rnn_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens) # 已测试 rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens) model = Model.RNNModel(rnn_layer=rnn_layer, vocab_size=vocab_size).to(device) res = Func.predict_rnn_pytorch('分开', 10, model, vocab_size=vocab_size, device=device, idx_to_char=idx_to_char, char_to_idx=char_to_idx) print(res) if __name__ == "__main__": num_epochs, batch_size, lr, clipping_theta = 250, 32, 1e-3, 1e-2 # 注意这里的学习率设置 pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开'] Func.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,