def test_detach_hidden_RNN(): # Create hidden vector (in tuple) X = torch.ones(2, 3, 4) model = nn.RNN(4, 1) _, hidden = model(X) # Function to test hidden_ = _detach_hidden(hidden) assert hidden_.grad_fn is None # properly detached assert (hidden == hidden_).all().item() == 1 # Equal values
def test_detach_hidden_LSTM(): # Create hidden vector (in tuple) X = torch.ones(2, 3, 4) model = nn.LSTM(4, 1) _, hidden = model(X) # Function to test hidden_ = _detach_hidden(hidden) for h, h_ in zip(hidden, hidden_): assert h_.grad_fn is None # properly detached assert (h == h_).all().item() == 1 # Equal values
def test_detach_hidden_raise(): with pytest.raises(TypeError): _detach_hidden(0)