コード例 #1
0
def lds_sample(As, bs, Qi_sqrts, ms, Ri_sqrts, z=None):
    """
    Sample a linear dynamical system
    """
    T, D = ms.shape
    assert As.shape == (T-1, D, D)
    assert bs.shape == (T-1, D)
    assert Qi_sqrts.shape == (T-1, D, D)
    assert Ri_sqrts.shape == (T, D, D)

    # Convert to block form
    J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts)

    # Convert blocks to banded form so we can capitalize on Lapack code
    J_banded = A_banded = blocks_to_bands(J_diag, J_lower_diag, lower=True)
    L = cholesky_banded(J_banded, lower=True)
    U = transpose_banded((2*D-1, 0), L)

    # We have (U^T U)^{-1} = U^{-1} U^{-T} = AA^T = Sigma
    # where A = U^{-1}.  Samples are Az = U^{-1}z = x, or equivalently Ux = z.
    z = npr.randn(T*D,) if z is None else np.reshape(z, (T*D,))
    samples = np.reshape(solve_banded((0, 2*D-1), U, z), (T, D))

    # Get the mean mu = J^{-1} h
    mu = np.reshape(solveh_banded(J_banded, np.ravel(h), lower=True), (T, D))

    # Add the mean
    return samples + mu
コード例 #2
0
ファイル: primitives.py プロジェクト: zhanghonglishanzai/ssm
    def vjp(C_bar):
        b_bar = solveh_banded(A_banded, C_bar, lower=lower, **kwargs)
        A_bar = np.zeros_like(A_banded)
        K = b.shape[1] if b.ndim == 2 else 1
        _vjp_solveh_banded_A(A_bar, b_bar.reshape(-1, K), C_bar.reshape(-1, K),
                             C.reshape(-1, K), lower, A_banded)

        return A_bar
コード例 #3
0
def lds_mean(As, bs, Qi_sqrts, ms, Ri_sqrts):
    """
    Compute the posterior mean of the linear dynamical system
    """
    T, D = ms.shape
    assert As.shape == (T-1, D, D)
    assert bs.shape == (T-1, D)
    assert Qi_sqrts.shape == (T-1, D, D)
    assert Ri_sqrts.shape == (T, D, D)

    # Convert to block form
    J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts)

    # Convert blocks to banded form so we can capitalize on Lapack code
    J_banded = blocks_to_bands(J_diag, J_lower_diag, lower=True)

    return solveh_banded(J_banded, h.ravel(), lower=True).reshape((T, D))
    
コード例 #4
0
ファイル: primitives.py プロジェクト: hanyas/ssm
def block_tridiagonal_sample(J_diag, J_lower_diag, h, z=None):
    """
    Sample a Gaussian chain graph represented by a block
    tridiagonal precision matrix and a linear potential.
    """
    T, D = h.shape
    assert J_diag.shape == (T, D, D)
    assert J_lower_diag.shape == (T - 1, D, D)

    # Convert blocks to banded form so we can capitalize on Lapack code
    J_banded = A_banded = blocks_to_bands(J_diag, J_lower_diag, lower=True)
    L = cholesky_banded(J_banded, lower=True)
    U = transpose_banded((2 * D - 1, 0), L)

    # We have (U^T U)^{-1} = U^{-1} U^{-T} = AA^T = Sigma
    # where A = U^{-1}.  Samples are Az = U^{-1}z = x, or equivalently Ux = z.
    z = npr.randn(T * D, ) if z is None else np.reshape(z, (T * D, ))
    samples = np.reshape(solve_banded((0, 2 * D - 1), U, z), (T, D))

    # Get the mean mu = J^{-1} h
    mu = np.reshape(solveh_banded(J_banded, np.ravel(h), lower=True), (T, D))

    # Add the mean
    return samples + mu
コード例 #5
0
def solve_lds(As, bs, Qi_sqrts, ms, Ri_sqrts, v):
    J_diag, J_lower_diag, _ = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts)
    J_banded = blocks_to_bands(J_diag, J_lower_diag, lower=True)
    x_flat = solveh_banded(J_banded, np.ravel(v), lower=True)
    return np.reshape(x_flat, v.shape)
コード例 #6
0
 def vjp(C_bar):
     return solveh_banded(A_banded, C_bar, lower=lower, **kwargs)
コード例 #7
0
ファイル: primitives.py プロジェクト: hanyas/ssm
def block_tridiagonal_mean(J_diag, J_lower_diag, h, lower=True):
    # Convert blocks to banded form so we can capitalize on Lapack code
    return solveh_banded(blocks_to_bands(J_diag, J_lower_diag, lower=lower),
                         h.ravel(),
                         lower=lower)
コード例 #8
0
ファイル: primitives.py プロジェクト: hanyas/ssm
def solve_symm_block_tridiag(J_diag, J_lower_diag, v):
    J_banded = blocks_to_bands(J_diag, J_lower_diag, lower=True)
    x_flat = solveh_banded(J_banded, np.ravel(v), lower=True)
    return np.reshape(x_flat, v.shape)