def test_forward_path( gru_model: GRUModel, batch_prev_tkids: torch.Tensor, ): r"""Parameters used during forward must have gradients.""" # Make sure model has no gradients. gru_model = gru_model.train() gru_model.zero_grad() gru_model(batch_prev_tkids).sum().backward() assert hasattr(gru_model.emb.weight.grad, 'grad') assert hasattr(gru_model.pre_hid[0].weight.grad, 'grad') assert hasattr(gru_model.hid.weight_ih_l0.grad, 'grad') assert hasattr(gru_model.post_hid[-1].weight.grad, 'grad')
def test_back_propagation_path( gru_model: GRUModel, batch_prev_tkids: torch.Tensor, batch_next_tkids: torch.Tensor, ): r"""Gradients with respect to loss must get back propagated.""" # Make sure model has no gradients. gru_model = gru_model.train() gru_model.zero_grad() gru_model.loss_fn( batch_prev_tkids=batch_prev_tkids, batch_next_tkids=batch_next_tkids, ).backward() assert hasattr(gru_model.emb.weight.grad, 'grad') assert hasattr(gru_model.pre_hid[0].weight.grad, 'grad') assert hasattr(gru_model.hid.weight_ih_l0.grad, 'grad') assert hasattr(gru_model.post_hid[-1].weight.grad, 'grad')