def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, char_to_idx, num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, prefixes): loss = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) model.to(device) state = None for epoch in range(num_epochs): l_sum, n, start = 0.0, 0, time.time() data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device) # 相邻采样 for X, Y in data_iter: if state is not None: # 使用detach函数从计算图分离隐藏状态, 这是为了 # 使模型参数的梯度计算只依赖一次迭代读取的小批量序列(防止梯度计算开销太大) if isinstance(state, tuple): # LSTM, state:(h, c) state = (state[0].detach(), state[1].detach()) else: state = state.detach() (output, state) = model( X, state) # output: 形状为(num_steps * batch_size, vocab_size) # Y的形状是(batch_size, num_steps),转置后再变成长度为 # batch * num_steps 的向量,这样跟输出的行一一对应 y = torch.transpose(Y, 0, 1).contiguous().view(-1) l = loss(output, y.long()) optimizer.zero_grad() l.backward() # 梯度裁剪 d2l.grad_clipping(model.parameters(), clipping_theta, device) optimizer.step() l_sum += l.item() * y.shape[0] n += y.shape[0] try: perplexity = math.exp(l_sum / n) except OverflowError: perplexity = float('inf') if (epoch + 1) % pred_period == 0: print('epoch %d, perplexity %f, time %.2f sec' % (epoch + 1, perplexity, time.time() - start)) for prefix in prefixes: print( ' -', predict_rnn_pytorch(prefix, pred_len, model, vocab_size, device, idx_to_char, char_to_idx))
def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, char_to_idx, num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, prefixes): #pred_period每多少周期打印一次 loss = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) model.to(device) state = None for epoch in range(num_epochs): l_sum, n, start = 0.0, 0, time.time() data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device) for X, Y in data_iter: if state is not None: if isinstance(state, tuple): state = (state[0].detach(), state[1].detach()) else: state = state.detach() (output, state) = model(X, state) ##我总感觉是transpose(Y,1,0),试验了困惑度还小一些,想法正确!!! y = torch.transpose(Y, 0, 1).contiguous().view(-1) l = loss(output, y.long()) optimizer.zero_grad() l.backward() d2l.grad_clipping(model.parameters(), clipping_theta, device) optimizer.step() l_sum += l.item() * y.shape[0] n += y.shape[0] try: perplexity = math.exp(l_sum / n) except OverflowError: perplexity = float('inf') if (epoch + 1) % pred_period == 0: print('epoch %d,perplexity %f,time %.2f sec' % (epoch + 1, perplexity, time.time() - start)) for prefix in prefixes: print( '-', predict_rnn_pytorch(prefix, pred_len, model, vocab_size, device, idx_to_char, char_to_idx))