Ejemplo n.º 1
0
def test_vec_to_tril_and_back_exceptions(check_lazy_shapes):
    # Check rank checks.
    for x in Tensor().forms():
        with pytest.raises(ValueError):
            B.vec_to_tril(x)
        with pytest.raises(ValueError):
            B.tril_to_vec(x)
    for x in Tensor(3).forms():
        with pytest.raises(ValueError):
            B.tril_to_vec(x)

    # Check square checks.
    for x in Tensor(3, 4).forms():
        with pytest.raises(ValueError):
            B.tril_to_vec(x)
    for x in Tensor(3, 4, 5).forms():
        with pytest.raises(ValueError):
            B.tril_to_vec(x)
Ejemplo n.º 2
0
def test_vec_to_tril(offset, batch_shape, check_lazy_shapes):
    n = B.length(B.tril_to_vec(B.ones(7, 7), offset=offset))
    check_function(B.vec_to_tril, (Tensor(*batch_shape, n),), {"offset": Value(offset)})
Ejemplo n.º 3
0
def test_vec_to_tril_and_back_correctness(offset, batch_shape, check_lazy_shapes):
    n = B.length(B.tril_to_vec(B.ones(7, 7), offset=offset))
    for vec in Tensor(*batch_shape, n).forms():
        mat = B.vec_to_tril(vec, offset=offset)
        approx(B.tril_to_vec(mat, offset=offset), vec)
Ejemplo n.º 4
0
 def generate_init(shape, dtype):
     mat = B.randn(dtype, *shape)
     return transform(B.tril_to_vec(mat, offset=-1))
Ejemplo n.º 5
0
 def inverse_transform(x):
     eye = B.eye(x)
     skew = B.solve(eye + x, eye - x)
     return B.tril_to_vec(skew, offset=-1)
Ejemplo n.º 6
0
 def inverse_transform(x):
     return B.tril_to_vec(B.logm(x), offset=-1)
Ejemplo n.º 7
0
 def inverse_transform(x):
     chol = B.cholesky(B.reg(x))
     return B.concat(B.log(B.diag(chol)), B.tril_to_vec(chol,
                                                        offset=-1))
Ejemplo n.º 8
0
 def inverse_transform(x):
     return B.tril_to_vec(x)