def kernel_to_state_space(self, R=None): F, L, Qc, H, Pinf = self.kernel0.kernel_to_state_space(R) for i in range(1, self.num_kernels): kerneli = eval("self.kernel" + str(i)) F_, L_, Qc_, H_, Pinf_ = kerneli.kernel_to_state_space(R) F = block_diag(F, F_) L = block_diag(L, L_) Qc = block_diag(Qc, Qc_) H = np.block([H, H_]) Pinf = block_diag(Pinf, Pinf_) return F, L, Qc, H, Pinf
def stationary_covariance(self): Pinf = self.kernel0.stationary_covariance() for i in range(1, self.num_kernels): kerneli = eval("self.kernel" + str(i)) Pinf_ = kerneli.stationary_covariance() Pinf = block_diag(Pinf, Pinf_) return Pinf
def measurement_model(self): H = self.kernel0.measurement_model() for i in range(1, self.num_kernels): kerneli = eval("self.kernel" + str(i)) H_ = kerneli.measurement_model() H = block_diag(H, H_) return H
def get_meanfield_block_index(kernel): Pinf = kernel.stationary_covariance_meanfield() num_latents = Pinf.shape[0] sub_state_dim = Pinf.shape[1] state = np.ones([sub_state_dim, sub_state_dim]) for i in range(1, num_latents): state = block_diag(state, np.ones([sub_state_dim, sub_state_dim])) block_index = np.where(np.array(state, dtype=bool)) return block_index
def state_transition(self, dt): """ Calculation of the discrete-time state transition matrix A = expm(FΔt) for a sum of GPs :param dt: step size(s), Δt = tₙ - tₙ₋₁ [1] :return: state transition matrix A [D, D] """ A = self.kernel0.state_transition(dt) for i in range(1, self.num_kernels): kerneli = eval("self.kernel" + str(i)) A_ = kerneli.state_transition(dt) A = block_diag(A, A_) return A
def stationary_covariance(self): var_p = 1. ell_p = self.lengthscale_periodic a = self.b_fmK_2igrid * ell_p**(-2. * self.igrid) * np.exp( -1. / ell_p**2.) * var_p q2 = np.sum(a, axis=0) Pinf_m = np.array([[self.variance]]) Pinf = block_diag( np.kron(Pinf_m, q2[0] * np.eye(2)), np.kron(Pinf_m, q2[1] * np.eye(2)), np.kron(Pinf_m, q2[2] * np.eye(2)), np.kron(Pinf_m, q2[3] * np.eye(2)), np.kron(Pinf_m, q2[4] * np.eye(2)), np.kron(Pinf_m, q2[5] * np.eye(2)), np.kron(Pinf_m, q2[6] * np.eye(2)), ) return Pinf
def state_transition(self, dt): """ Calculation of the closed form discrete-time state transition matrix A = expm(FΔt) for the Quasi-Periodic Matern-3/2 prior :param dt: step size(s), Δt = tₙ - tₙ₋₁ [M+1, 1] :return: state transition matrix A [M+1, D, D] """ lam = np.sqrt(3.0) / self.lengthscale_matern # The angular frequency omega = 2 * np.pi / self.period harmonics = np.arange(self.order + 1) * omega R0 = self.subband_mat32(dt, lam, harmonics[0]) R1 = self.subband_mat32(dt, lam, harmonics[1]) R2 = self.subband_mat32(dt, lam, harmonics[2]) R3 = self.subband_mat32(dt, lam, harmonics[3]) R4 = self.subband_mat32(dt, lam, harmonics[4]) R5 = self.subband_mat32(dt, lam, harmonics[5]) R6 = self.subband_mat32(dt, lam, harmonics[6]) A = np.exp(-dt * lam) * block_diag(R0, R1, R2, R3, R4, R5, R6) return A
def kernel_to_state_space(self, R=None): var_p = 1. ell_p = self.lengthscale_periodic a = self.b_fmK_2igrid * ell_p**(-2. * self.igrid) * np.exp( -1. / ell_p**2.) * var_p q2 = np.sum(a, axis=0) # The angular frequency omega = 2 * np.pi / self.period # The model F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]])) L_p = np.eye(2 * (self.order + 1)) # Qc_p = np.zeros(2 * (self.N + 1)) Pinf_p = np.kron(np.diag(q2), np.eye(2)) H_p = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.])) lam = 3.0**0.5 / self.lengthscale_matern F_m = np.array([[0.0, 1.0], [-lam**2, -2 * lam]]) L_m = np.array([[0], [1]]) Qc_m = np.array( [[12.0 * 3.0**0.5 / self.lengthscale_matern**3.0 * self.variance]]) H_m = np.array([[1.0, 0.0]]) Pinf_m = np.array( [[self.variance, 0.0], [0.0, 3.0 * self.variance / self.lengthscale_matern**2.0]]) # F = np.kron(F_p, np.eye(2)) + np.kron(np.eye(14), F_m) F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(2), F_p) L = np.kron(L_m, L_p) Qc = np.kron(Qc_m, Pinf_p) H = np.kron(H_m, H_p) # Pinf = np.kron(Pinf_m, Pinf_p) Pinf = block_diag( np.kron(Pinf_m, q2[0] * np.eye(2)), np.kron(Pinf_m, q2[1] * np.eye(2)), np.kron(Pinf_m, q2[2] * np.eye(2)), np.kron(Pinf_m, q2[3] * np.eye(2)), np.kron(Pinf_m, q2[4] * np.eye(2)), np.kron(Pinf_m, q2[5] * np.eye(2)), np.kron(Pinf_m, q2[6] * np.eye(2)), ) return F, L, Qc, H, Pinf
def block_diag(values): return block_diag(*JaxBox.unbox_list(values))