class MaternKernel(PyroModule): """ Provides the building blocks for representing univariate Gaussian Processes (GPs) with Matern kernels as state space models. :param float nu: The order of the Matern kernel (one of 0.5, 1.5 or 2.5) :param int num_gps: the number of GPs :param torch.Tensor length_scale_init: optional `num_gps`-dimensional vector of initializers for the length scale :param torch.Tensor kernel_scale_init: optional `num_gps`-dimensional vector of initializers for the kernel scale **References** [1] `Kalman Filtering and Smoothing Solutions to Temporal Gaussian Process Regression Models`, Jouni Hartikainen and Simo Sarkka. [2] `Stochastic Differential Equation Methods for Spatio-Temporal Gaussian Process Regression`, Arno Solin. """ def __init__(self, nu=1.5, num_gps=1, length_scale_init=None, kernel_scale_init=None): if nu not in [0.5, 1.5, 2.5]: raise NotImplementedError( "The only supported values of nu are 0.5, 1.5 and 2.5") self.nu = nu self.state_dim = {0.5: 1, 1.5: 2, 2.5: 3}[nu] self.num_gps = num_gps if length_scale_init is None: length_scale_init = torch.ones(num_gps) assert length_scale_init.shape == (num_gps, ) if kernel_scale_init is None: kernel_scale_init = torch.ones(num_gps) assert kernel_scale_init.shape == (num_gps, ) super().__init__() self.length_scale = PyroParam(length_scale_init, constraint=constraints.positive) self.kernel_scale = PyroParam(kernel_scale_init, constraint=constraints.positive) if self.state_dim > 1: for x in range(self.state_dim): for y in range(self.state_dim): mask = torch.zeros(self.state_dim, self.state_dim) mask[x, y] = 1.0 self.register_buffer("mask{}{}".format(x, y), mask) @pyro_method def transition_matrix(self, dt): """ Compute the (exponentiated) transition matrix of the GP latent space. The resulting matrix has layout (num_gps, old_state, new_state), i.e. this matrix multiplies states from the right. See section 5 in reference [1] for details. :param float dt: the time interval over which the GP latent space evolves. :returns torch.Tensor: a 3-dimensional tensor of transition matrices of shape (num_gps, state_dim, state_dim). """ if self.nu == 0.5: rho = self.length_scale.unsqueeze(-1).unsqueeze(-1) return torch.exp(-dt / rho) elif self.nu == 1.5: rho = self.length_scale.unsqueeze(-1).unsqueeze(-1) dt_rho = dt / rho trans = (1.0 + root_three * dt_rho) * self.mask00 + \ (-3.0 * dt_rho / rho) * self.mask01 + \ dt * self.mask10 + \ (1.0 - root_three * dt_rho) * self.mask11 return torch.exp(-root_three * dt_rho) * trans elif self.nu == 2.5: rho = self.length_scale.unsqueeze(-1).unsqueeze(-1) dt_rho = root_five * dt / rho dt_rho_sq = dt_rho.pow(2.0) dt_rho_cu = dt_rho.pow(3.0) dt_rho_qu = dt_rho.pow(4.0) dt_sq = dt**2.0 trans = (1.0 + dt_rho + 0.5 * dt_rho_sq) * self.mask00 + \ (-0.5 * dt_rho_cu / dt) * self.mask01 + \ ((0.5 * dt_rho_qu - dt_rho_cu) / dt_sq) * self.mask02 + \ ((dt_rho + 1.0) * dt) * self.mask10 + \ (1.0 + dt_rho - dt_rho_sq) * self.mask11 + \ ((dt_rho_cu - 3.0 * dt_rho_sq) / dt) * self.mask12 + \ (0.5 * dt_sq) * self.mask20 + \ ((1.0 - 0.5 * dt_rho) * dt) * self.mask21 + \ (1.0 - 2.0 * dt_rho + 0.5 * dt_rho_sq) * self.mask22 return torch.exp(-dt_rho) * trans @pyro_method def stationary_covariance(self): """ Compute the stationary state covariance. See Eqn. 3.26 in reference [2]. :returns torch.Tensor: a 3-dimensional tensor of covariance matrices of shape (num_gps, state_dim, state_dim). """ if self.nu == 0.5: sigmasq = self.kernel_scale.pow(2).unsqueeze(-1).unsqueeze(-1) return sigmasq elif self.nu == 1.5: sigmasq = self.kernel_scale.pow(2).unsqueeze(-1).unsqueeze(-1) rhosq = self.length_scale.pow(2).unsqueeze(-1).unsqueeze(-1) p_infinity = self.mask00 + (3.0 / rhosq) * self.mask11 return sigmasq * p_infinity elif self.nu == 2.5: sigmasq = self.kernel_scale.pow(2).unsqueeze(-1).unsqueeze(-1) rhosq = self.length_scale.pow(2).unsqueeze(-1).unsqueeze(-1) p_infinity = 0.0 p_infinity = self.mask00 + \ (five_thirds / rhosq) * (self.mask11 - self.mask02 - self.mask20) + \ (25.0 / rhosq.pow(2.0)) * self.mask22 return sigmasq * p_infinity @pyro_method def process_covariance(self, A): """ Given a transition matrix `A` computed with `transition_matrix` compute the the process covariance as described in Eqn. 3.11 in reference [2]. :returns torch.Tensor: a batched covariance matrix of shape (num_gps, state_dim, state_dim) """ assert A.shape[-3:] == (self.num_gps, self.state_dim, self.state_dim) p = self.stationary_covariance() q = p - torch.matmul(A.transpose(-1, -2), torch.matmul(p, A)) return q @pyro_method def transition_matrix_and_covariance(self, dt): """ Get the transition matrix and process covariance corresponding to a time interval `dt`. :param float dt: the time interval over which the GP latent space evolves. :returns tuple: (`transition_matrix`, `process_covariance`) both 3-dimensional tensors of shape (num_gps, state_dim, state_dim) """ trans_matrix = self.transition_matrix(dt) process_covar = self.process_covariance(trans_matrix) return trans_matrix, process_covar
class IndependentMaternGP(TimeSeriesModel): """ A time series model in which each output dimension is modeled independently with a univariate Gaussian Process with a Matern kernel. The targets are assumed to be evenly spaced in time. Training and inference are logarithmic in the length of the time series T. :param float nu: The order of the Matern kernel; one of 0.5, 1.5 or 2.5. :param float dt: The time spacing between neighboring observations of the time series. :param int obs_dim: The dimension of the targets at each time step. :param torch.Tensor length_scale_init: optional initial values for the kernel length scale given as a ``obs_dim``-dimensional tensor :param torch.Tensor kernel_scale_init: optional initial values for the kernel scale given as a ``obs_dim``-dimensional tensor :param torch.Tensor obs_noise_scale_init: optional initial values for the observation noise scale given as a ``obs_dim``-dimensional tensor """ def __init__(self, nu=1.5, dt=1.0, obs_dim=1, length_scale_init=None, kernel_scale_init=None, obs_noise_scale_init=None): self.nu = nu self.dt = dt self.obs_dim = obs_dim if obs_noise_scale_init is None: obs_noise_scale_init = 0.2 * torch.ones(obs_dim) assert obs_noise_scale_init.shape == (obs_dim,) super().__init__() self.kernel = MaternKernel(nu=nu, num_gps=obs_dim, length_scale_init=length_scale_init, kernel_scale_init=kernel_scale_init) self.obs_noise_scale = PyroParam(obs_noise_scale_init, constraint=constraints.positive) obs_matrix = [1.0] + [0.0] * (self.kernel.state_dim - 1) self.register_buffer("obs_matrix", torch.tensor(obs_matrix).unsqueeze(-1)) def _get_init_dist(self): return torch.distributions.MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, self.kernel.state_dim), self.kernel.stationary_covariance().squeeze(-3)) def _get_obs_dist(self): return dist.Normal(self.obs_matrix.new_zeros(self.obs_dim, 1, 1), self.obs_noise_scale.unsqueeze(-1).unsqueeze(-1)).to_event(1) def get_dist(self, duration=None): """ Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to ``obs_dim``-many independent Matern GPs. :param int duration: Optional size of the time axis ``event_shape[0]``. This is required when sampling from homogeneous HMMs whose parameters are not expanded along the time axis. """ trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt) trans_dist = MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, 1, self.kernel.state_dim), process_covar.unsqueeze(-3)) trans_matrix = trans_matrix.unsqueeze(-3) return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist, self.obs_matrix, self._get_obs_dist(), duration=duration) @pyro_method def log_prob(self, targets): """ :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim`` is the dimension of the real-valued ``targets`` at each time step :returns torch.Tensor: A 1-dimensional tensor of log probabilities of shape ``(obs_dim,)`` """ assert targets.dim() == 2 and targets.size(-1) == self.obs_dim return self.get_dist().log_prob(targets.t().unsqueeze(-1)) @torch.no_grad() def _filter(self, targets): """ Return the filtering state for the associated state space model. """ assert targets.dim() == 2 and targets.size(-1) == self.obs_dim return self.get_dist().filter(targets.t().unsqueeze(-1)) @torch.no_grad() def _forecast(self, dts, filtering_state, include_observation_noise=True): """ Internal helper for forecasting. """ assert dts.dim() == 1 dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=dts) trans_matrix = trans_matrix[..., 0:1] predicted_mean = torch.matmul(filtering_state.loc.unsqueeze(-2), trans_matrix).squeeze(-2)[..., 0] predicted_function_covar = torch.matmul(trans_matrix.transpose(-1, -2), torch.matmul( filtering_state.covariance_matrix, trans_matrix))[..., 0, 0] + \ process_covar[..., 0, 0] if include_observation_noise: predicted_function_covar = predicted_function_covar + self.obs_noise_scale.pow(2.0) return predicted_mean, predicted_function_covar @pyro_method def forecast(self, targets, dts): """ :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim`` is the dimension of the real-valued targets at each time step. These represent the training data that are conditioned on for the purpose of making forecasts. :param torch.Tensor dts: A 1-dimensional tensor of times to forecast into the future, with zero corresponding to the time of the final target ``targets[-1]``. :returns torch.distributions.Normal: Returns a predictive Normal distribution with batch shape ``(S,)`` and event shape ``(obs_dim,)``, where ``S`` is the size of ``dts``. """ filtering_state = self._filter(targets) predicted_mean, predicted_covar = self._forecast(dts, filtering_state) return torch.distributions.Normal(predicted_mean, predicted_covar.sqrt())