def _get_dist(self): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to :class:`GenericLGSSMWithGPNoiseModel`. """ gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance( dt=self.dt) trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) trans_covar[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_process_covar) trans_covar[ self.full_gp_state_dim:, self.full_gp_state_dim:] = self.trans_noise_scale_sq.diag_embed() trans_dist = MultivariateNormal( trans_covar.new_zeros(self.full_state_dim), trans_covar) full_trans_mat = trans_covar.new_zeros(self.full_state_dim, self.full_state_dim) full_trans_mat[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_trans_matrix) full_trans_mat[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.z_trans_matrix return dist.GaussianHMM(self._get_init_dist(), full_trans_mat, trans_dist, self._get_obs_matrix(), self._get_obs_dist())
def _get_dist(self): """ Get the `GaussianHMM` distribution that corresponds to `GenericLGSSMWithGPNoiseModel`. """ gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance( dt=self.dt) trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) trans_covar[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_process_covar) eye = torch.eye(self.state_dim, device=trans_covar.device, dtype=trans_covar.dtype) trans_covar[ self.full_gp_state_dim:, self. full_gp_state_dim:] = self.log_trans_noise_scale_sq.exp() * eye trans_dist = MultivariateNormal( trans_covar.new_zeros(self.full_state_dim), trans_covar) full_trans_mat = trans_covar.new_zeros(self.full_state_dim, self.full_state_dim) full_trans_mat[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_trans_matrix) full_trans_mat[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.z_trans_matrix return dist.GaussianHMM(self._get_init_dist(), full_trans_mat, trans_dist, self._get_obs_matrix(), self._get_obs_dist())
def get_dist(self, duration=None): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to :class:`GenericLGSSMWithGPNoiseModel`. :param int duration: Optional size of the time axis ``event_shape[0]``. This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance( dt=self.dt) trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) trans_covar[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_process_covar) trans_covar[ self.full_gp_state_dim:, self.full_gp_state_dim:] = self.trans_noise_scale_sq.diag_embed() trans_dist = MultivariateNormal( trans_covar.new_zeros(self.full_state_dim), trans_covar) full_trans_mat = trans_covar.new_zeros(self.full_state_dim, self.full_state_dim) full_trans_mat[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed(gp_trans_matrix) full_trans_mat[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.z_trans_matrix return dist.GaussianHMM(self._get_init_dist(), full_trans_mat, trans_dist, self._get_obs_matrix(), self._get_obs_dist(), duration=duration)
def _forecast(self, dts, filtering_state, include_observation_noise=True, full_covar=True): """ Internal helper for forecasting. """ assert dts.dim() == 1 dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) trans_mat, process_covar = self.kernel.transition_matrix_and_covariance(dt=dts) trans_mat = block_diag_embed(trans_mat) # S x full_state_dim x full_state_dim process_covar = block_diag_embed(process_covar) # S x full_state_dim x full_state_dim obs_matrix = self._get_obs_matrix() # full_state_dim x obs_dim trans_obs = torch.matmul(trans_mat, obs_matrix) # S x full_state_dim x obs_dim predicted_mean = torch.matmul(filtering_state.loc.unsqueeze(-2), trans_obs).squeeze(-2) predicted_function_covar = torch.matmul(trans_obs.transpose(-1, -2), torch.matmul(filtering_state.covariance_matrix, trans_obs)) predicted_function_covar = predicted_function_covar + \ torch.matmul(obs_matrix.transpose(-1, -2), torch.matmul(process_covar, obs_matrix)) if include_observation_noise: obs_noise = self.obs_noise_scale.pow(2.0).diag_embed() predicted_function_covar = predicted_function_covar + obs_noise if not full_covar: predicted_function_covar = predicted_function_covar.diagonal(dim1=-1, dim2=-2) return predicted_mean, predicted_function_covar
def _forecast(self, N_timesteps, filtering_state, include_observation_noise=True): """ Internal helper for forecasting. """ dts = torch.arange(N_timesteps, dtype=self.z_trans_matrix.dtype, device=self.z_trans_matrix.device) + 1.0 dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance( dt=dts) gp_trans_matrix = block_diag_embed(gp_trans_matrix) gp_process_covar = block_diag_embed(gp_process_covar[..., 0:1, 0:1]) N_trans_matrix = repeated_matmul(self.z_trans_matrix, N_timesteps) N_trans_obs = torch.matmul(N_trans_matrix, self.z_obs_matrix) # z-state contribution + gp contribution predicted_mean1 = torch.matmul( filtering_state.loc[-self.state_dim:].unsqueeze(-2), N_trans_obs).squeeze(-2) predicted_mean2 = torch.matmul( filtering_state.loc[:self.full_gp_state_dim].unsqueeze(-2), gp_trans_matrix[..., self.obs_selector]).squeeze(-2) predicted_mean = predicted_mean1 + predicted_mean2 # first compute the contributions from filtering_state.covariance_matrix: z-space and gp fs_cov = filtering_state.covariance_matrix predicted_covar1z = torch.matmul(N_trans_obs.transpose(-1, -2), torch.matmul( fs_cov[self.full_gp_state_dim:, self.full_gp_state_dim:], N_trans_obs)) # N O O gp_trans = gp_trans_matrix[..., self.obs_selector] predicted_covar1gp = torch.matmul( gp_trans.transpose(-1, -2), torch.matmul( fs_cov[:self.full_gp_state_dim:, :self.full_gp_state_dim], gp_trans)) # next compute the contribution from process noise that is injected at each timestep. # (we need to do a cumulative sum to integrate across time for the z-state contribution) z_process_covar = self.trans_noise_scale_sq.diag_embed() N_trans_obs_shift = torch.cat( [self.z_obs_matrix.unsqueeze(0), N_trans_obs[0:-1]]) predicted_covar2z = torch.matmul(N_trans_obs_shift.transpose( -1, -2), torch.matmul(z_process_covar, N_trans_obs_shift)) # N O O predicted_covar = predicted_covar1z + predicted_covar1gp + gp_process_covar + \ torch.cumsum(predicted_covar2z, dim=0) if include_observation_noise: predicted_covar = predicted_covar + self.obs_noise_scale.pow( 2.0).diag_embed() return predicted_mean, predicted_covar
def _get_dist(self): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to a :class:`LinearlyCoupledMaternGP`. """ trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt) trans_matrix = block_diag_embed(trans_matrix) process_covar = block_diag_embed(process_covar) loc = self.A.new_zeros(self.full_state_dim) trans_dist = MultivariateNormal(loc, process_covar) return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist, self._get_obs_matrix(), self._get_obs_dist())
def test_dependent_matern_gp(obs_dim): dt = 0.5 + torch.rand(1).item() gp = DependentMaternGP(nu=1.5, obs_dim=obs_dim, dt=dt, length_scale_init=0.5 + torch.rand(obs_dim)) # make sure stationary covariance matrix satisfies the relevant # matrix riccati equation lengthscale = gp.kernel.length_scale.unsqueeze(-1).unsqueeze(-1) F = torch.tensor([[0.0, 1.0], [0.0, 0.0]]) mask1 = torch.tensor([[0.0, 0.0], [-3.0, 0.0]]) mask2 = torch.tensor([[0.0, 0.0], [0.0, -math.sqrt(12.0)]]) F = block_diag_embed(F + mask1 / lengthscale.pow(2.0) + mask2 / lengthscale) stat_cov = gp._stationary_covariance() wiener_cov = gp._get_wiener_cov() wiener_cov *= torch.tensor([[0.0, 0.0], [0.0, 1.0]]).repeat(obs_dim, obs_dim) expected_zero = (torch.matmul(F, stat_cov) + torch.matmul(stat_cov, F.transpose(-1, -2)) + wiener_cov) assert_equal(expected_zero, torch.zeros(gp.full_state_dim, gp.full_state_dim))
def get_dist(self, duration=None): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to a :class:`LinearlyCoupledMaternGP`. :param int duration: Optional size of the time axis ``event_shape[0]``. This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt) trans_matrix = block_diag_embed(trans_matrix) process_covar = block_diag_embed(process_covar) loc = self.A.new_zeros(self.full_state_dim) trans_dist = MultivariateNormal(loc, process_covar) return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist, self._get_obs_matrix(), self._get_obs_dist(), duration=duration)
def _get_init_dist(self): loc = self.z_trans_matrix.new_zeros(self.full_state_dim) covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) covar[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed( self.kernel.stationary_covariance()) covar[self.full_gp_state_dim:, self.full_gp_state_dim:] = self.init_noise_scale_sq.diag_embed() return MultivariateNormal(loc, covar)
def _get_init_dist(self): loc = self.z_trans_matrix.new_zeros(self.full_state_dim) covar = self.z_trans_matrix.new_zeros(self.full_state_dim, self.full_state_dim) covar[:self.full_gp_state_dim, :self. full_gp_state_dim] = block_diag_embed( self.kernel.stationary_covariance()) eye = torch.eye(self.state_dim, device=loc.device, dtype=loc.dtype) covar[ self.full_gp_state_dim:, self.full_gp_state_dim:] = self.log_init_noise_scale_sq.exp() * eye return MultivariateNormal(loc, covar)
def test_block_diag_embed(batch_size, block_size): m = torch.randn(block_size).unsqueeze(0).expand((batch_size,) + block_size) b = block_diag_embed(m) assert b.shape == (batch_size * block_size[0], batch_size * block_size[1]) assert_equal(b.sum(), m.sum()) for k in range(batch_size): bottom, top = k * block_size[0], (k + 1) * block_size[0] left, right = k * block_size[1], (k + 1) * block_size[1] assert_equal(b[bottom:top, left:right], m[k])
def _trans_matrix_distribution_stat_covar(self, dts): stationary_covariance = self._stationary_covariance() trans_matrix = self.kernel.transition_matrix(dt=dts) trans_matrix = block_diag_embed(trans_matrix) trans_dist = self._get_trans_dist(trans_matrix, stationary_covariance) return trans_matrix, trans_dist, stationary_covariance
def _stationary_covariance(self): return block_diag_embed(self.kernel.stationary_covariance())
def test_block_diag(batch_shape, mat_size, block_size): mat = torch.randn(batch_shape + (block_size, ) + mat_size) mat_embed = block_diag_embed(mat) mat_embed_diag = block_diagonal(mat_embed, block_size) assert_equal(mat_embed_diag, mat)