def test_forward_path( rnn_model: RNNModel, batch_prev_tkids: torch.Tensor, ): r"""Parameters used during forward must have gradients.""" # Make sure model has no gradients. rnn_model = rnn_model.train() rnn_model.zero_grad() rnn_model(batch_prev_tkids).sum().backward() assert hasattr(rnn_model.emb.weight.grad, 'grad') assert hasattr(rnn_model.pre_hid[0].weight.grad, 'grad') assert hasattr(rnn_model.hid.weight_ih_l0.grad, 'grad') assert hasattr(rnn_model.post_hid[-1].weight.grad, 'grad')
def test_back_propagation_path( rnn_model: RNNModel, batch_prev_tkids: torch.Tensor, batch_next_tkids: torch.Tensor, ): r"""Gradients with respect to loss must get back propagated.""" # Make sure model has no gradients. rnn_model = rnn_model.train() rnn_model.zero_grad() rnn_model.loss_fn( batch_prev_tkids=batch_prev_tkids, batch_next_tkids=batch_next_tkids, ).backward() assert hasattr(rnn_model.emb.weight.grad, 'grad') assert hasattr(rnn_model.pre_hid[0].weight.grad, 'grad') assert hasattr(rnn_model.hid.weight_ih_l0.grad, 'grad') assert hasattr(rnn_model.post_hid[-1].weight.grad, 'grad')