Пример #1
0
def test_diag_part():
    params = torch.range(1, 6).view(1, 6)
    dim = 3
    L = nn.fill_triangular(params, dim)
    diag_L = torch.diagonal(L, dim1=-2, dim2=-1)
    assert diag_L.shape == (1, 3)
    assert np.allclose(diag_L.numpy().squeeze(), np.array([4, 5, 1]))
Пример #2
0
def test_filltriangular():
    params = torch.range(0, 5).view(1, 6)
    dim = 3
    L = nn.fill_triangular(params, dim)
    # assert L.shape == ()
    assert L.shape == (1, 3, 3)
    assert np.allclose(L, np.array([[3, 0, 0], [5, 4, 0], [2, 1, 0]]))
Пример #3
0
def test_set_diag():
    with torch.enable_grad():
        params = torch.range(1, 6, requires_grad=True).view(1, 6)
        dim = 3
        L = nn.fill_triangular(params, dim)
        diag_L = torch.ones((1, 3), dtype=params.dtype, requires_grad=True) * 6
        M = nn.set_diag(L, diag_L)
        loss = M.sum()
        loss.backward()
    assert M.grad_fn  # is not None
    assert L.grad_fn  # is not None
    assert np.allclose(nn.diag_part(M).detach(), diag_L.detach())