def filter_forward_predictive_distribution(m, V, Q, C, d=None): """ The predictive distirbution p(yt | y_{1:t-1}), obtained by marginalising the state prediction distribution in the measurement model. """ Vpy = symmetrize(matmul(C, matmul(V, C.transpose(-1, -2))) + Q) mpy = matvec(C, m) if d is not None: mpy += d return mpy, Vpy
def filter_forward_prediction_step(m, V, R, A, b=None): """ Single prediction step in forward filtering cycle (prediction --> measurement) """ mp = matvec(A, m) if A is not None else m if b is not None: mp = mp + b # do not 'mp += b', is inplace and wont work with autograd Vp = symmetrize(( matmul(A, matmul(V, A.transpose(-1, -2))) if A is not None else V) + R) return mp, Vp
def smooth_backward_step(m_sm, V_sm, m_fw, V_fw, A, b, R): # filter one-step predictive variance P = (matmul(A, matmul(V_fw, A.transpose(-1, -2))) if A is not None else V_fw) + R Pinv = batch_cholesky_inverse(cholesky(P)) # m and V share J when in the conditioning operation (joint to posterior bw) J = matmul(V_fw, matmul(A.transpose(-1, -2), Pinv)) m_sm_t = m_fw + matvec( J, m_sm - (matvec(A, m_fw) if A is not None else m_fw) - b) V_sm_t = symmetrize(V_fw + matmul(J, matmul(V_sm - P, J.transpose(-1, -2)))) Cov_sm_t = matmul(J, V_sm) return m_sm_t, V_sm_t, Cov_sm_t
def filter_forward_measurement_step(y, m, V, Q, C, d=None, return_loss_components=False): """ Single measurement/update step in forward filtering cycle (prediction --> measurement) """ mpy, Vpy = filter_forward_predictive_distribution(m=m, V=V, Q=Q, C=C, d=d) LVpyinv = torch.inverse(cholesky(Vpy)) S = matmul(LVpyinv, matmul( C, V)) # CV is cov_T. cov VC.T could be output of predictive dist. dobs = y - mpy dobs_norm = matvec(LVpyinv, dobs) mt = m + matvec(S.transpose(-1, -2), dobs_norm) Vt = symmetrize(V - matmul(S.transpose(-1, -2), S)) if return_loss_components: # I hate to do that! but it is convenient... return mt, Vt, dobs_norm, LVpyinv else: return mt, Vt
def smooth_global( dims: TensorDims, A: torch.Tensor, B: Optional[torch.Tensor], C: torch.Tensor, D: Optional[torch.Tensor], LQinv_tril: torch.Tensor, LQinv_logdiag: torch.Tensor, LRinv_tril: torch.Tensor, LRinv_logdiag: torch.Tensor, LV0inv_tril: torch.Tensor, LV0inv_logdiag: torch.Tensor, m0: torch.Tensor, y: torch.Tensor, u_state: Optional[torch.Tensor] = None, u_obs: Optional[torch.Tensor] = None, ): """ compute posterior by direct inversion of unrolled model """ device, dtype = A.device, A.dtype R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag) Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag) V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag) Q_field = torch.zeros( (dims.batch, dims.timesteps * dims.state, dims.timesteps * dims.state), device=device, dtype=dtype, ) h_field = torch.zeros((dims.batch, dims.timesteps * dims.state), device=device, dtype=dtype) # pre-compute biases b = matvec(B, u_state) if u_state is not None else 0 d = matvec(D, u_obs) if u_obs is not None else 0 Rinv = symmetrize(torch.cholesky_inverse(cholesky(R))) Qinv = symmetrize(torch.cholesky_inverse(cholesky(Q))) V0inv = symmetrize(torch.cholesky_inverse(cholesky(V0))) CtQinvymd = matvec(matmul(C.transpose(-1, -2), Qinv), y - d) h_obs = CtQinvymd.transpose(1, 0).reshape(( dims.batch, dims.timesteps * dims.state, )) Q_obs = kron( torch.eye(dims.timesteps, dtype=dtype, device=device), matmul(C.transpose(-1, -2), matmul(Qinv, C)), ) AtRinvA = matmul(A.transpose(-1, -2), matmul(Rinv, A)) RinvA = matmul(Rinv, A) h_field[:, :dims.state] = matmul(V0inv, m0).repeat((dims.batch, ) + (1, ) * (h_field.ndim - 1)) Q_field[:, :dims.state, :dims.state] += V0inv.repeat((dims.batch, ) + (1, ) * (Q_field.ndim - 1)) for t in range(dims.timesteps - 1): idx = t * dims.state h_field[:, idx:idx + dims.state] += -matvec(RinvA.transpose(-1, -2), b[t]) h_field[:, idx + dims.state:idx + 2 * dims.state] += matvec(Rinv, b[t]) Q_field[:, idx:idx + dims.state, idx:idx + dims.state] += AtRinvA Q_field[:, idx:idx + dims.state, idx + dims.state:idx + 2 * dims.state] += -RinvA.transpose(-1, -2) Q_field[:, idx + dims.state:idx + 2 * dims.state, idx:idx + dims.state] += -RinvA Q_field[:, idx + dims.state:idx + 2 * dims.state, idx + dims.state:idx + 2 * dims.state, ] += Rinv L_all_inv = torch.inverse(cholesky(Q_field + Q_obs)) V_all = matmul(L_all_inv.transpose(-1, -2), L_all_inv) m_all = matvec(V_all, h_obs + h_field) # Pytorch has no Fortran style reading of indices. m = m_all.reshape((dims.batch, dims.timesteps, dims.state)).transpose(0, 1) V, Cov = [], [] for t in range(0, dims.timesteps): idx = t * dims.state V.append(V_all[:, idx:idx + dims.state, idx:idx + dims.state]) if t < (dims.timesteps - 1): Cov.append(V_all[:, idx:idx + dims.state, idx + dims.state:idx + 2 * dims.state, ]) else: Cov.append( torch.zeros( (dims.batch, dims.state, dims.state), device=device, dtype=dtype, )) V = torch.stack(V, dim=0) Cov = torch.stack(Cov, dim=0) return m, V, Cov
def smooth_global( dims: TensorDims, A: torch.Tensor, B: Optional[torch.Tensor], C: torch.Tensor, D: Optional[torch.Tensor], LQinv_tril: torch.Tensor, LQinv_logdiag: torch.Tensor, LRinv_tril: torch.Tensor, LRinv_logdiag: torch.Tensor, LV0inv_tril: torch.Tensor, LV0inv_logdiag: torch.Tensor, m0: torch.Tensor, y: torch.Tensor, u_state: Optional[torch.Tensor] = None, u_obs: Optional[torch.Tensor] = None, ): """ compute posterior by direct inversion of unrolled model """ # This implementation works only if all mats have time and batch dimension. # Otherwise does not broadcast correctly. assert A.ndim == 4 assert B.ndim == 4 assert C.ndim == 4 assert D.ndim == 4 assert LQinv_tril.ndim == 4 assert LQinv_logdiag.ndim == 3 assert LRinv_tril.ndim == 4 assert LRinv_logdiag.ndim == 3 # raise NotImplementedError("Not yet implemented") device, dtype = A.device, A.dtype R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag) Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag) V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag) Q_field = torch.zeros( (dims.batch, dims.timesteps * dims.state, dims.timesteps * dims.state), device=device, dtype=dtype, ) h_field = torch.zeros((dims.batch, dims.timesteps * dims.state), device=device, dtype=dtype) b = matvec(B, u_state[:-1]) if u_state is not None else 0 d = matvec(D, u_obs) if u_obs is not None else 0 Rinv = symmetrize(batch_cholesky_inverse(cholesky(R))) Qinv = symmetrize(batch_cholesky_inverse(cholesky(Q))) V0inv = symmetrize(batch_cholesky_inverse(cholesky(V0))) CtQinvymd = matvec(matmul(C.transpose(-1, -2), Qinv), y - d) h_obs = CtQinvymd.transpose(1, 0).reshape(( dims.batch, dims.timesteps * dims.state, )) CtQinvC = matmul(C.transpose(-1, -2), matmul(Qinv, C)) assert len(CtQinvC) == dims.timesteps Q_obs = make_block_diagonal(mats=tuple(mat_t for mat_t in CtQinvC)) AtRinvA = matmul(A.transpose(-1, -2), matmul(Rinv, A)) RinvA = matmul(Rinv, A) h_field[:, :dims.state] = matmul(V0inv, m0).repeat((dims.batch, ) + (1, ) * (h_field.ndim - 1)) Q_field[:, :dims.state, :dims.state] += V0inv.repeat((dims.batch, ) + (1, ) * (Q_field.ndim - 1)) for t in range(dims.timesteps - 1): id = t * dims.state h_field[:, id:id + dims.state] += -matvec(RinvA[t].transpose(-1, -2), b[t]) h_field[:, id + dims.state:id + 2 * dims.state] += matvec(Rinv[t], b[t]) Q_field[:, id:id + dims.state, id:id + dims.state] += AtRinvA[t] Q_field[:, id:id + dims.state, id + dims.state:id + 2 * dims.state] += -RinvA[t].transpose(-1, -2) Q_field[:, id + dims.state:id + 2 * dims.state, id:id + dims.state] += -RinvA[t] Q_field[:, id + dims.state:id + 2 * dims.state, id + dims.state:id + 2 * dims.state, ] += Rinv[t] L_all_inv = torch.inverse(cholesky(Q_field + Q_obs)) V_all = matmul(L_all_inv.transpose(-1, -2), L_all_inv) m_all = matvec(V_all, h_obs + h_field) # Pytorch has no Fortran style reading of indices. m = m_all.reshape((dims.batch, dims.timesteps, dims.state)).transpose(0, 1) V, Cov = [], [] for t in range(0, dims.timesteps): id = t * dims.state V.append(V_all[:, id:id + dims.state, id:id + dims.state]) if t < (dims.timesteps - 1): Cov.append(V_all[:, id:id + dims.state, id + dims.state:id + 2 * dims.state, ]) else: Cov.append( torch.zeros( (dims.batch, dims.state, dims.state), device=device, dtype=dtype, )) V = torch.stack(V, dim=0) Cov = torch.stack(Cov, dim=0) return m, V, Cov