Ejemplo n.º 1
0
    def _forecast(self,
                  N_timesteps,
                  filtering_state,
                  include_observation_noise=True):
        """
        Internal helper for forecasting.
        """
        N_trans_matrix = repeated_matmul(self.trans_matrix, N_timesteps)
        N_trans_obs = torch.matmul(N_trans_matrix, self.obs_matrix)
        predicted_mean = torch.matmul(filtering_state.loc, N_trans_obs)

        # first compute the contribution from filtering_state.covariance_matrix
        predicted_covar1 = torch.matmul(N_trans_obs.transpose(-1, -2),
                                        torch.matmul(
                                            filtering_state.covariance_matrix,
                                            N_trans_obs))  # N O O

        # next compute the contribution from process noise that is injected at each timestep.
        # (we need to do a cumulative sum to integrate across time)
        process_covar = self._get_trans_dist().covariance_matrix
        N_trans_obs_shift = torch.cat(
            [self.obs_matrix.unsqueeze(0), N_trans_obs[:-1]])
        predicted_covar2 = torch.matmul(N_trans_obs_shift.transpose(
            -1, -2), torch.matmul(process_covar, N_trans_obs_shift))  # N O O

        predicted_covar = predicted_covar1 + torch.cumsum(predicted_covar2,
                                                          dim=0)

        if include_observation_noise:
            predicted_covar = predicted_covar + self.obs_noise_scale.pow(
                2.0).diag_embed()

        return predicted_mean, predicted_covar
Ejemplo n.º 2
0
    def _forecast(self,
                  N_timesteps,
                  filtering_state,
                  include_observation_noise=True):
        """
        Internal helper for forecasting.
        """
        dts = torch.arange(N_timesteps,
                           dtype=self.z_trans_matrix.dtype,
                           device=self.z_trans_matrix.device) + 1.0
        dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance(
            dt=dts)
        gp_trans_matrix = block_diag_embed(gp_trans_matrix)
        gp_process_covar = block_diag_embed(gp_process_covar[..., 0:1, 0:1])

        N_trans_matrix = repeated_matmul(self.z_trans_matrix, N_timesteps)
        N_trans_obs = torch.matmul(N_trans_matrix, self.z_obs_matrix)

        # z-state contribution + gp contribution
        predicted_mean1 = torch.matmul(
            filtering_state.loc[-self.state_dim:].unsqueeze(-2),
            N_trans_obs).squeeze(-2)
        predicted_mean2 = torch.matmul(
            filtering_state.loc[:self.full_gp_state_dim].unsqueeze(-2),
            gp_trans_matrix[..., self.obs_selector]).squeeze(-2)
        predicted_mean = predicted_mean1 + predicted_mean2

        # first compute the contributions from filtering_state.covariance_matrix: z-space and gp
        fs_cov = filtering_state.covariance_matrix
        predicted_covar1z = torch.matmul(N_trans_obs.transpose(-1, -2),
                                         torch.matmul(
                                             fs_cov[self.full_gp_state_dim:,
                                                    self.full_gp_state_dim:],
                                             N_trans_obs))  # N O O
        gp_trans = gp_trans_matrix[..., self.obs_selector]
        predicted_covar1gp = torch.matmul(
            gp_trans.transpose(-1, -2),
            torch.matmul(
                fs_cov[:self.full_gp_state_dim:, :self.full_gp_state_dim],
                gp_trans))

        # next compute the contribution from process noise that is injected at each timestep.
        # (we need to do a cumulative sum to integrate across time for the z-state contribution)
        z_process_covar = self.trans_noise_scale_sq.diag_embed()
        N_trans_obs_shift = torch.cat(
            [self.z_obs_matrix.unsqueeze(0), N_trans_obs[0:-1]])
        predicted_covar2z = torch.matmul(N_trans_obs_shift.transpose(
            -1, -2), torch.matmul(z_process_covar, N_trans_obs_shift))  # N O O

        predicted_covar = predicted_covar1z + predicted_covar1gp + gp_process_covar + \
            torch.cumsum(predicted_covar2z, dim=0)

        if include_observation_noise:
            predicted_covar = predicted_covar + self.obs_noise_scale.pow(
                2.0).diag_embed()

        return predicted_mean, predicted_covar
Ejemplo n.º 3
0
def test_repeated_matmul(size, n):
    M = torch.randn(size)
    result = repeated_matmul(M, n)
    assert result.shape == ((n, ) + size)

    serial_result = M
    for i in range(n):
        assert_equal(result[i, ...], serial_result)
        serial_result = torch.matmul(serial_result, M)