def _forecast(self, N_timesteps, filtering_state, include_observation_noise=True): """ Internal helper for forecasting. """ N_trans_matrix = repeated_matmul(self.trans_matrix, N_timesteps) N_trans_obs = torch.matmul(N_trans_matrix, self.obs_matrix) predicted_mean = torch.matmul(filtering_state.loc, N_trans_obs) # first compute the contribution from filtering_state.covariance_matrix predicted_covar1 = torch.matmul(N_trans_obs.transpose(-1, -2), torch.matmul( filtering_state.covariance_matrix, N_trans_obs)) # N O O # next compute the contribution from process noise that is injected at each timestep. # (we need to do a cumulative sum to integrate across time) process_covar = self._get_trans_dist().covariance_matrix N_trans_obs_shift = torch.cat( [self.obs_matrix.unsqueeze(0), N_trans_obs[:-1]]) predicted_covar2 = torch.matmul(N_trans_obs_shift.transpose( -1, -2), torch.matmul(process_covar, N_trans_obs_shift)) # N O O predicted_covar = predicted_covar1 + torch.cumsum(predicted_covar2, dim=0) if include_observation_noise: predicted_covar = predicted_covar + self.obs_noise_scale.pow( 2.0).diag_embed() return predicted_mean, predicted_covar
def _forecast(self, N_timesteps, filtering_state, include_observation_noise=True): """ Internal helper for forecasting. """ dts = torch.arange(N_timesteps, dtype=self.z_trans_matrix.dtype, device=self.z_trans_matrix.device) + 1.0 dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance( dt=dts) gp_trans_matrix = block_diag_embed(gp_trans_matrix) gp_process_covar = block_diag_embed(gp_process_covar[..., 0:1, 0:1]) N_trans_matrix = repeated_matmul(self.z_trans_matrix, N_timesteps) N_trans_obs = torch.matmul(N_trans_matrix, self.z_obs_matrix) # z-state contribution + gp contribution predicted_mean1 = torch.matmul( filtering_state.loc[-self.state_dim:].unsqueeze(-2), N_trans_obs).squeeze(-2) predicted_mean2 = torch.matmul( filtering_state.loc[:self.full_gp_state_dim].unsqueeze(-2), gp_trans_matrix[..., self.obs_selector]).squeeze(-2) predicted_mean = predicted_mean1 + predicted_mean2 # first compute the contributions from filtering_state.covariance_matrix: z-space and gp fs_cov = filtering_state.covariance_matrix predicted_covar1z = torch.matmul(N_trans_obs.transpose(-1, -2), torch.matmul( fs_cov[self.full_gp_state_dim:, self.full_gp_state_dim:], N_trans_obs)) # N O O gp_trans = gp_trans_matrix[..., self.obs_selector] predicted_covar1gp = torch.matmul( gp_trans.transpose(-1, -2), torch.matmul( fs_cov[:self.full_gp_state_dim:, :self.full_gp_state_dim], gp_trans)) # next compute the contribution from process noise that is injected at each timestep. # (we need to do a cumulative sum to integrate across time for the z-state contribution) z_process_covar = self.trans_noise_scale_sq.diag_embed() N_trans_obs_shift = torch.cat( [self.z_obs_matrix.unsqueeze(0), N_trans_obs[0:-1]]) predicted_covar2z = torch.matmul(N_trans_obs_shift.transpose( -1, -2), torch.matmul(z_process_covar, N_trans_obs_shift)) # N O O predicted_covar = predicted_covar1z + predicted_covar1gp + gp_process_covar + \ torch.cumsum(predicted_covar2z, dim=0) if include_observation_noise: predicted_covar = predicted_covar + self.obs_noise_scale.pow( 2.0).diag_embed() return predicted_mean, predicted_covar
def test_repeated_matmul(size, n): M = torch.randn(size) result = repeated_matmul(M, n) assert result.shape == ((n, ) + size) serial_result = M for i in range(n): assert_equal(result[i, ...], serial_result) serial_result = torch.matmul(serial_result, M)