def test_diag_part(): params = torch.range(1, 6).view(1, 6) dim = 3 L = nn.fill_triangular(params, dim) diag_L = torch.diagonal(L, dim1=-2, dim2=-1) assert diag_L.shape == (1, 3) assert np.allclose(diag_L.numpy().squeeze(), np.array([4, 5, 1]))
def test_filltriangular(): params = torch.range(0, 5).view(1, 6) dim = 3 L = nn.fill_triangular(params, dim) # assert L.shape == () assert L.shape == (1, 3, 3) assert np.allclose(L, np.array([[3, 0, 0], [5, 4, 0], [2, 1, 0]]))
def test_set_diag(): with torch.enable_grad(): params = torch.range(1, 6, requires_grad=True).view(1, 6) dim = 3 L = nn.fill_triangular(params, dim) diag_L = torch.ones((1, 3), dtype=params.dtype, requires_grad=True) * 6 M = nn.set_diag(L, diag_L) loss = M.sum() loss.backward() assert M.grad_fn # is not None assert L.grad_fn # is not None assert np.allclose(nn.diag_part(M).detach(), diag_L.detach())