def test_forward_path(
    lstm_model: LSTMModel,
    batch_prev_tkids: torch.Tensor,
):
    r"""Parameters used during forward must have gradients."""
    # Make sure model has no gradients.
    lstm_model = lstm_model.train()
    lstm_model.zero_grad()

    lstm_model(batch_prev_tkids).sum().backward()

    assert hasattr(lstm_model.emb.weight.grad, 'grad')
    assert hasattr(lstm_model.pre_hid[0].weight.grad, 'grad')
    assert hasattr(lstm_model.hid.weight_ih_l0.grad, 'grad')
    assert hasattr(lstm_model.post_hid[-1].weight.grad, 'grad')
def test_back_propagation_path(
    lstm_model: LSTMModel,
    batch_prev_tkids: torch.Tensor,
    batch_next_tkids: torch.Tensor,
):
    r"""Gradients with respect to loss must get back propagated."""
    # Make sure model has no gradients.
    lstm_model = lstm_model.train()
    lstm_model.zero_grad()

    lstm_model.loss_fn(
        batch_prev_tkids=batch_prev_tkids,
        batch_next_tkids=batch_next_tkids,
    ).backward()

    assert hasattr(lstm_model.emb.weight.grad, 'grad')
    assert hasattr(lstm_model.pre_hid[0].weight.grad, 'grad')
    assert hasattr(lstm_model.hid.weight_ih_l0.grad, 'grad')
    assert hasattr(lstm_model.post_hid[-1].weight.grad, 'grad')