X, y = trainData.sample(batch_size, i) X = X.permute(1, 0, 2) # torch.set_printoptions(profile="full") classifier.reset() y_pred = classifier.forward(X) # print("Y_PRED", y_pred) loss = criterion.forward(y_pred, y) # print("CUR", loss.item()) totloss += loss.item() tot2loss += loss.item() gradLoss = criterion.backward(y_pred, y) classifier.backward(gradLoss) layer.gradWhh = torch.clamp(layer.gradWhh, -5, +5) layer.gradWxh = torch.clamp(layer.gradWxh, -5, +5) layer.gradWhy = torch.clamp(layer.gradWhy, -5, +5) layer.gradBy = torch.clamp(layer.gradBy, -5, +5) layer.gradBh = torch.clamp(layer.gradBh, -5, +5) layer.Whh -= alpha * layer.gradWhh layer.Wxh -= alpha * layer.gradWxh layer.Why -= alpha * layer.gradWhy layer.Bh -= alpha * layer.gradBh layer.By -= alpha * layer.gradBy label = torch.argmax(y_pred, dim=1) correct += torch.sum(label == y.long()).item() count += len(y) print('Epoch', epoch, 'complete')