Exemple #1
0
class MaternKernel(PyroModule):
    """
    Provides the building blocks for representing univariate Gaussian Processes (GPs)
    with Matern kernels as state space models.

    :param float nu: The order of the Matern kernel (one of 0.5, 1.5 or 2.5)
    :param int num_gps: the number of GPs
    :param torch.Tensor length_scale_init: optional `num_gps`-dimensional vector of initializers
        for the length scale
    :param torch.Tensor kernel_scale_init: optional `num_gps`-dimensional vector of initializers
        for the kernel scale

    **References**

    [1] `Kalman Filtering and Smoothing Solutions to Temporal Gaussian Process Regression Models`,
        Jouni Hartikainen and Simo Sarkka.
    [2] `Stochastic Differential Equation Methods for Spatio-Temporal Gaussian Process Regression`,
        Arno Solin.
    """
    def __init__(self,
                 nu=1.5,
                 num_gps=1,
                 length_scale_init=None,
                 kernel_scale_init=None):
        if nu not in [0.5, 1.5, 2.5]:
            raise NotImplementedError(
                "The only supported values of nu are 0.5, 1.5 and 2.5")
        self.nu = nu
        self.state_dim = {0.5: 1, 1.5: 2, 2.5: 3}[nu]
        self.num_gps = num_gps

        if length_scale_init is None:
            length_scale_init = torch.ones(num_gps)
        assert length_scale_init.shape == (num_gps, )

        if kernel_scale_init is None:
            kernel_scale_init = torch.ones(num_gps)
        assert kernel_scale_init.shape == (num_gps, )

        super().__init__()

        self.length_scale = PyroParam(length_scale_init,
                                      constraint=constraints.positive)
        self.kernel_scale = PyroParam(kernel_scale_init,
                                      constraint=constraints.positive)

        if self.state_dim > 1:
            for x in range(self.state_dim):
                for y in range(self.state_dim):
                    mask = torch.zeros(self.state_dim, self.state_dim)
                    mask[x, y] = 1.0
                    self.register_buffer("mask{}{}".format(x, y), mask)

    @pyro_method
    def transition_matrix(self, dt):
        """
        Compute the (exponentiated) transition matrix of the GP latent space.
        The resulting matrix has layout (num_gps, old_state, new_state), i.e. this
        matrix multiplies states from the right.

        See section 5 in reference [1] for details.

        :param float dt: the time interval over which the GP latent space evolves.
        :returns torch.Tensor: a 3-dimensional tensor of transition matrices of shape
            (num_gps, state_dim, state_dim).
        """
        if self.nu == 0.5:
            rho = self.length_scale.unsqueeze(-1).unsqueeze(-1)
            return torch.exp(-dt / rho)
        elif self.nu == 1.5:
            rho = self.length_scale.unsqueeze(-1).unsqueeze(-1)
            dt_rho = dt / rho
            trans = (1.0 + root_three * dt_rho) * self.mask00 + \
                (-3.0 * dt_rho / rho) * self.mask01 + \
                dt * self.mask10 + \
                (1.0 - root_three * dt_rho) * self.mask11
            return torch.exp(-root_three * dt_rho) * trans
        elif self.nu == 2.5:
            rho = self.length_scale.unsqueeze(-1).unsqueeze(-1)
            dt_rho = root_five * dt / rho
            dt_rho_sq = dt_rho.pow(2.0)
            dt_rho_cu = dt_rho.pow(3.0)
            dt_rho_qu = dt_rho.pow(4.0)
            dt_sq = dt**2.0
            trans = (1.0 + dt_rho + 0.5 * dt_rho_sq) * self.mask00 + \
                (-0.5 * dt_rho_cu / dt) * self.mask01 + \
                ((0.5 * dt_rho_qu - dt_rho_cu) / dt_sq) * self.mask02 + \
                ((dt_rho + 1.0) * dt) * self.mask10 + \
                (1.0 + dt_rho - dt_rho_sq) * self.mask11 + \
                ((dt_rho_cu - 3.0 * dt_rho_sq) / dt) * self.mask12 + \
                (0.5 * dt_sq) * self.mask20 + \
                ((1.0 - 0.5 * dt_rho) * dt) * self.mask21 + \
                (1.0 - 2.0 * dt_rho + 0.5 * dt_rho_sq) * self.mask22
            return torch.exp(-dt_rho) * trans

    @pyro_method
    def stationary_covariance(self):
        """
        Compute the stationary state covariance. See Eqn. 3.26 in reference [2].

        :returns torch.Tensor: a 3-dimensional tensor of covariance matrices of shape
            (num_gps, state_dim, state_dim).
        """
        if self.nu == 0.5:
            sigmasq = self.kernel_scale.pow(2).unsqueeze(-1).unsqueeze(-1)
            return sigmasq
        elif self.nu == 1.5:
            sigmasq = self.kernel_scale.pow(2).unsqueeze(-1).unsqueeze(-1)
            rhosq = self.length_scale.pow(2).unsqueeze(-1).unsqueeze(-1)
            p_infinity = self.mask00 + (3.0 / rhosq) * self.mask11
            return sigmasq * p_infinity
        elif self.nu == 2.5:
            sigmasq = self.kernel_scale.pow(2).unsqueeze(-1).unsqueeze(-1)
            rhosq = self.length_scale.pow(2).unsqueeze(-1).unsqueeze(-1)
            p_infinity = 0.0
            p_infinity = self.mask00 + \
                (five_thirds / rhosq) * (self.mask11 - self.mask02 - self.mask20) + \
                (25.0 / rhosq.pow(2.0)) * self.mask22
            return sigmasq * p_infinity

    @pyro_method
    def process_covariance(self, A):
        """
        Given a transition matrix `A` computed with `transition_matrix` compute the
        the process covariance as described in Eqn. 3.11 in reference [2].

        :returns torch.Tensor: a batched covariance matrix of shape (num_gps, state_dim, state_dim)
        """
        assert A.shape[-3:] == (self.num_gps, self.state_dim, self.state_dim)
        p = self.stationary_covariance()
        q = p - torch.matmul(A.transpose(-1, -2), torch.matmul(p, A))
        return q

    @pyro_method
    def transition_matrix_and_covariance(self, dt):
        """
        Get the transition matrix and process covariance corresponding to a time interval `dt`.

        :param float dt: the time interval over which the GP latent space evolves.
        :returns tuple: (`transition_matrix`, `process_covariance`) both 3-dimensional tensors of
            shape (num_gps, state_dim, state_dim)
        """
        trans_matrix = self.transition_matrix(dt)
        process_covar = self.process_covariance(trans_matrix)
        return trans_matrix, process_covar
Exemple #2
0
class IndependentMaternGP(TimeSeriesModel):
    """
    A time series model in which each output dimension is modeled independently
    with a univariate Gaussian Process with a Matern kernel. The targets are assumed
    to be evenly spaced in time. Training and inference are logarithmic in the length
    of the time series T.

    :param float nu: The order of the Matern kernel; one of 0.5, 1.5 or 2.5.
    :param float dt: The time spacing between neighboring observations of the time series.
    :param int obs_dim: The dimension of the targets at each time step.
    :param torch.Tensor length_scale_init: optional initial values for the kernel length scale
        given as a ``obs_dim``-dimensional tensor
    :param torch.Tensor kernel_scale_init: optional initial values for the kernel scale
        given as a ``obs_dim``-dimensional tensor
    :param torch.Tensor obs_noise_scale_init: optional initial values for the observation noise scale
        given as a ``obs_dim``-dimensional tensor
    """
    def __init__(self, nu=1.5, dt=1.0, obs_dim=1,
                 length_scale_init=None, kernel_scale_init=None,
                 obs_noise_scale_init=None):
        self.nu = nu
        self.dt = dt
        self.obs_dim = obs_dim

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim,)

        super().__init__()

        self.kernel = MaternKernel(nu=nu, num_gps=obs_dim,
                                   length_scale_init=length_scale_init,
                                   kernel_scale_init=kernel_scale_init)

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)

        obs_matrix = [1.0] + [0.0] * (self.kernel.state_dim - 1)
        self.register_buffer("obs_matrix", torch.tensor(obs_matrix).unsqueeze(-1))

    def _get_init_dist(self):
        return torch.distributions.MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, self.kernel.state_dim),
                                                      self.kernel.stationary_covariance().squeeze(-3))

    def _get_obs_dist(self):
        return dist.Normal(self.obs_matrix.new_zeros(self.obs_dim, 1, 1),
                           self.obs_noise_scale.unsqueeze(-1).unsqueeze(-1)).to_event(1)

    def get_dist(self, duration=None):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
        to ``obs_dim``-many independent Matern GPs.

        :param int duration: Optional size of the time axis ``event_shape[0]``.
            This is required when sampling from homogeneous HMMs whose parameters
            are not expanded along the time axis.
        """
        trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt)
        trans_dist = MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, 1, self.kernel.state_dim),
                                        process_covar.unsqueeze(-3))
        trans_matrix = trans_matrix.unsqueeze(-3)
        return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist,
                                self.obs_matrix, self._get_obs_dist(), duration=duration)

    @pyro_method
    def log_prob(self, targets):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued ``targets`` at each time step
        :returns torch.Tensor: A 1-dimensional tensor of log probabilities of shape ``(obs_dim,)``
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self.get_dist().log_prob(targets.t().unsqueeze(-1))

    @torch.no_grad()
    def _filter(self, targets):
        """
        Return the filtering state for the associated state space model.
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self.get_dist().filter(targets.t().unsqueeze(-1))

    @torch.no_grad()
    def _forecast(self, dts, filtering_state, include_observation_noise=True):
        """
        Internal helper for forecasting.
        """
        assert dts.dim() == 1
        dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=dts)
        trans_matrix = trans_matrix[..., 0:1]
        predicted_mean = torch.matmul(filtering_state.loc.unsqueeze(-2), trans_matrix).squeeze(-2)[..., 0]
        predicted_function_covar = torch.matmul(trans_matrix.transpose(-1, -2), torch.matmul(
                                                filtering_state.covariance_matrix, trans_matrix))[..., 0, 0] + \
            process_covar[..., 0, 0]

        if include_observation_noise:
            predicted_function_covar = predicted_function_covar + self.obs_noise_scale.pow(2.0)
        return predicted_mean, predicted_function_covar

    @pyro_method
    def forecast(self, targets, dts):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued targets at each time step. These
            represent the training data that are conditioned on for the purpose of making
            forecasts.
        :param torch.Tensor dts: A 1-dimensional tensor of times to forecast into the future,
            with zero corresponding to the time of the final target ``targets[-1]``.
        :returns torch.distributions.Normal: Returns a predictive Normal distribution with batch shape ``(S,)`` and
            event shape ``(obs_dim,)``, where ``S`` is the size of ``dts``.
        """
        filtering_state = self._filter(targets)
        predicted_mean, predicted_covar = self._forecast(dts, filtering_state)
        return torch.distributions.Normal(predicted_mean, predicted_covar.sqrt())