def lds_sample(As, bs, Qi_sqrts, ms, Ri_sqrts, z=None): """ Sample a linear dynamical system """ T, D = ms.shape assert As.shape == (T-1, D, D) assert bs.shape == (T-1, D) assert Qi_sqrts.shape == (T-1, D, D) assert Ri_sqrts.shape == (T, D, D) # Convert to block form J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts) # Convert blocks to banded form so we can capitalize on Lapack code J_banded = A_banded = blocks_to_bands(J_diag, J_lower_diag, lower=True) L = cholesky_banded(J_banded, lower=True) U = transpose_banded((2*D-1, 0), L) # We have (U^T U)^{-1} = U^{-1} U^{-T} = AA^T = Sigma # where A = U^{-1}. Samples are Az = U^{-1}z = x, or equivalently Ux = z. z = npr.randn(T*D,) if z is None else np.reshape(z, (T*D,)) samples = np.reshape(solve_banded((0, 2*D-1), U, z), (T, D)) # Get the mean mu = J^{-1} h mu = np.reshape(solveh_banded(J_banded, np.ravel(h), lower=True), (T, D)) # Add the mean return samples + mu
def vjp(C_bar): b_bar = solveh_banded(A_banded, C_bar, lower=lower, **kwargs) A_bar = np.zeros_like(A_banded) K = b.shape[1] if b.ndim == 2 else 1 _vjp_solveh_banded_A(A_bar, b_bar.reshape(-1, K), C_bar.reshape(-1, K), C.reshape(-1, K), lower, A_banded) return A_bar
def lds_mean(As, bs, Qi_sqrts, ms, Ri_sqrts): """ Compute the posterior mean of the linear dynamical system """ T, D = ms.shape assert As.shape == (T-1, D, D) assert bs.shape == (T-1, D) assert Qi_sqrts.shape == (T-1, D, D) assert Ri_sqrts.shape == (T, D, D) # Convert to block form J_diag, J_lower_diag, h = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts) # Convert blocks to banded form so we can capitalize on Lapack code J_banded = blocks_to_bands(J_diag, J_lower_diag, lower=True) return solveh_banded(J_banded, h.ravel(), lower=True).reshape((T, D))
def block_tridiagonal_sample(J_diag, J_lower_diag, h, z=None): """ Sample a Gaussian chain graph represented by a block tridiagonal precision matrix and a linear potential. """ T, D = h.shape assert J_diag.shape == (T, D, D) assert J_lower_diag.shape == (T - 1, D, D) # Convert blocks to banded form so we can capitalize on Lapack code J_banded = A_banded = blocks_to_bands(J_diag, J_lower_diag, lower=True) L = cholesky_banded(J_banded, lower=True) U = transpose_banded((2 * D - 1, 0), L) # We have (U^T U)^{-1} = U^{-1} U^{-T} = AA^T = Sigma # where A = U^{-1}. Samples are Az = U^{-1}z = x, or equivalently Ux = z. z = npr.randn(T * D, ) if z is None else np.reshape(z, (T * D, )) samples = np.reshape(solve_banded((0, 2 * D - 1), U, z), (T, D)) # Get the mean mu = J^{-1} h mu = np.reshape(solveh_banded(J_banded, np.ravel(h), lower=True), (T, D)) # Add the mean return samples + mu
def solve_lds(As, bs, Qi_sqrts, ms, Ri_sqrts, v): J_diag, J_lower_diag, _ = convert_lds_to_block_tridiag(As, bs, Qi_sqrts, ms, Ri_sqrts) J_banded = blocks_to_bands(J_diag, J_lower_diag, lower=True) x_flat = solveh_banded(J_banded, np.ravel(v), lower=True) return np.reshape(x_flat, v.shape)
def vjp(C_bar): return solveh_banded(A_banded, C_bar, lower=lower, **kwargs)
def block_tridiagonal_mean(J_diag, J_lower_diag, h, lower=True): # Convert blocks to banded form so we can capitalize on Lapack code return solveh_banded(blocks_to_bands(J_diag, J_lower_diag, lower=lower), h.ravel(), lower=lower)
def solve_symm_block_tridiag(J_diag, J_lower_diag, v): J_banded = blocks_to_bands(J_diag, J_lower_diag, lower=True) x_flat = solveh_banded(J_banded, np.ravel(v), lower=True) return np.reshape(x_flat, v.shape)