def test_lsnn_cell(): cell = lsnn.LSNNCell(2, 2) state = cell.initial_state(5, "cpu") data = torch.ones(5, 2) z, state = cell(data, state) np.testing.assert_equal(z.numpy(), np.zeros((5, 2))) z, state = cell(data, state) np.testing.assert_raises( AssertionError, np.testing.assert_equal, state.i.detach().numpy(), np.zeros((5, 2)), )
def test_lsnn_cell(): cell = lsnn.LSNNCell(2, 10) state = cell.initial_state(5, "cpu") data = torch.zeros(5, 2) z, state = cell(data, state) np.testing.assert_equal(z.numpy(), np.zeros((5, 10)))
def test_lsnn_forward_shape_fail(): cell = lsnn.LSNNCell(2, 10) state = cell.initial_state(5, "cpu") data = torch.zeros(10) cell.forward(data, state)
def test_lsnn_state_fail(): # pylint: disable=E1120 cell = lsnn.LSNNCell(2, 10) cell.initial_state()
def test_lsnn_cell_param_fail(): # pylint: disable=E1120 _ = lsnn.LSNNCell()
def test_lsnn_cell(): cell = lsnn.LSNNCell(2, 10) state = cell.initial_state(5, "cpu") data = torch.zeros(5, 2) z, state = cell(data, state) assert torch.equal(z, torch.zeros(5, 10))
def test_lsnn_state_fail(): cell = lsnn.LSNNCell(2, 10) cell.initial_state()
def test_lsnn_cell_param_fail(): cell = lsnn.LSNNCell()