Ejemplo n.º 1
0
def test_forward_path(
    gru_model: GRUModel,
    batch_prev_tkids: torch.Tensor,
):
    r"""Parameters used during forward must have gradients."""
    # Make sure model has no gradients.
    gru_model = gru_model.train()
    gru_model.zero_grad()

    gru_model(batch_prev_tkids).sum().backward()

    assert hasattr(gru_model.emb.weight.grad, 'grad')
    assert hasattr(gru_model.pre_hid[0].weight.grad, 'grad')
    assert hasattr(gru_model.hid.weight_ih_l0.grad, 'grad')
    assert hasattr(gru_model.post_hid[-1].weight.grad, 'grad')
def test_back_propagation_path(
    gru_model: GRUModel,
    batch_prev_tkids: torch.Tensor,
    batch_next_tkids: torch.Tensor,
):
    r"""Gradients with respect to loss must get back propagated."""
    # Make sure model has no gradients.
    gru_model = gru_model.train()
    gru_model.zero_grad()

    gru_model.loss_fn(
        batch_prev_tkids=batch_prev_tkids,
        batch_next_tkids=batch_next_tkids,
    ).backward()

    assert hasattr(gru_model.emb.weight.grad, 'grad')
    assert hasattr(gru_model.pre_hid[0].weight.grad, 'grad')
    assert hasattr(gru_model.hid.weight_ih_l0.grad, 'grad')
    assert hasattr(gru_model.post_hid[-1].weight.grad, 'grad')