Ejemplo n.º 1
0
def test_detach_hidden_RNN():
    # Create hidden vector (in tuple)
    X = torch.ones(2, 3, 4)
    model = nn.RNN(4, 1)
    _, hidden = model(X)

    # Function to test
    hidden_ = _detach_hidden(hidden)

    assert hidden_.grad_fn is None  # properly detached
    assert (hidden == hidden_).all().item() == 1  # Equal values
Ejemplo n.º 2
0
def test_detach_hidden_LSTM():
    # Create hidden vector (in tuple)
    X = torch.ones(2, 3, 4)
    model = nn.LSTM(4, 1)
    _, hidden = model(X)

    # Function to test
    hidden_ = _detach_hidden(hidden)

    for h, h_ in zip(hidden, hidden_):
        assert h_.grad_fn is None  # properly detached
        assert (h == h_).all().item() == 1  # Equal values
Ejemplo n.º 3
0
def test_detach_hidden_raise():
    with pytest.raises(TypeError):
        _detach_hidden(0)