Ejemplo n.º 1
0
 def _compute_cholesky_gp(
     self,
     kernel_matrix: Tensor,
     num_data_points: Optional[int] = None,
     noise: bool = True,
 ) -> Tensor:
     r"""
     Parameters
     --------------------
     kernel_matrix
         Kernel matrix of shape (batch_size, num_data_points, num_data_points).
     num_data_points
         Number of rows in the kernel_matrix.
     noise
         Boolean to determine whether to add :math:`\sigma^2I` to the kernel matrix.
         This is used in the predictive step if you would like to sample the predictive
         covariance matrix without noise.  It is set to True in every other case.
     Returns
     --------------------
     Tensor
         Cholesky factor :math:`L` of the kernel matrix with added noise :math:`LL^T = K + \sigma^2 I`
         of shape (batch_size, num_data_points, num_data_points).
     """
     if noise:  # Add sigma
         kernel_matrix = self.F.broadcast_plus(
             kernel_matrix,
             self.F.broadcast_mul(
                 self.sigma ** 2,
                 self.F.eye(
                     num_data_points, ctx=self.ctx, dtype=self.float_type
                 ),
             ),
         )
     # Warning: This method is more expensive than the iterative jitter
     # but it works for mx.sym
     if self.jitter_method == "eig":
         return jitter_cholesky_eig(
             self.F,
             kernel_matrix,
             num_data_points,
             self.ctx,
             self.float_type,
             self.diag_weight,
         )
     elif self.jitter_method == "iter" and self.F is mx.nd:
         return jitter_cholesky(
             self.F,
             kernel_matrix,
             num_data_points,
             self.ctx,
             self.float_type,
             self.max_iter_jitter,
             self.neg_tol,
             self.diag_weight,
             self.increase_jitter,
         )
     else:
         return self.F.linalg.potrf(kernel_matrix)
Ejemplo n.º 2
0
def test_jitter_unit(jitter_method, float_type, ctx) -> None:
    # TODO: Enable GPU tests on Jenkins
    if ctx == mx.Context("gpu") and not check_gpu_support():
        return
    matrix = nd.array([[[1, 2], [3, 4]], [[10, 100], [-21.5, 41]]],
                      ctx=ctx,
                      dtype=float_type)
    F = mx.nd
    num_data_points = matrix.shape[1]
    if jitter_method == "eig":
        L = jitter_cholesky_eig(F, matrix, num_data_points, ctx, float_type)
    elif jitter_method == "iter":
        L = jitter_cholesky(F, matrix, num_data_points, ctx, float_type)
    assert np.sum(np.isnan(L.asnumpy())) == 0, "NaNs in Cholesky factor!"
Ejemplo n.º 3
0
    def sample(self,
               num_samples: Optional[int] = None,
               scale: Optional[Tensor] = None) -> Tensor:
        r"""
        Generates samples from the LDS: p(z_1, z_2, \ldots, z_{`seq_length`}).

        Parameters
        ----------
        num_samples
            Number of samples to generate
        scale
            Scale of each sequence in x, shape (batch_size, output_dim)

        Returns
        -------
        Tensor
            Samples, shape (num_samples, batch_size, seq_length, output_dim)
        """
        F = self.F

        # Note on shapes: here we work with tensors of the following shape
        # in each time step t: (num_samples, batch_size, dim, dim),
        # where dim can be obs_dim or latent_dim or a constant 1 to facilitate
        # generalized matrix multiplication (gemm2)

        # Sample observation noise for all time steps
        # noise_std: (batch_size, seq_length, obs_dim, 1)
        noise_std = F.stack(*self.noise_std, axis=1).expand_dims(axis=-1)

        # samples_eps_obs[t]: (num_samples, batch_size, obs_dim, 1)
        samples_eps_obs = _safe_split(
            Gaussian(noise_std.zeros_like(), noise_std).sample(num_samples),
            axis=-3,
            num_outputs=self.seq_length,
            squeeze_axis=True,
        )

        # Sample standard normal for all time steps
        # samples_eps_std_normal[t]: (num_samples, batch_size, obs_dim, 1)
        samples_std_normal = _safe_split(
            Gaussian(noise_std.zeros_like(),
                     noise_std.ones_like()).sample(num_samples),
            axis=-3,
            num_outputs=self.seq_length,
            squeeze_axis=True,
        )

        # Sample the prior state.
        # samples_lat_state: (num_samples, batch_size, latent_dim, 1)
        # The prior covariance is observed to be slightly negative definite whenever there is
        # excessive zero padding at the beginning of the time series.
        # We add positive tolerance to the diagonal to avoid numerical issues.
        # Note that `jitter_cholesky` adds positive tolerance only if the decomposition without jitter fails.
        state = MultivariateGaussian(
            self.prior_mean,
            jitter_cholesky(F,
                            self.prior_cov,
                            self.latent_dim,
                            float_type=np.float32),
        )
        samples_lat_state = state.sample(num_samples).expand_dims(axis=-1)

        samples_seq = []
        for t in range(self.seq_length):
            # Expand all coefficients to include samples in axis 0
            # emission_coeff_t: (num_samples, batch_size, obs_dim, latent_dim)
            # transition_coeff_t:
            #   (num_samples, batch_size, latent_dim, latent_dim)
            # innovation_coeff_t: (num_samples, batch_size, 1, latent_dim)
            emission_coeff_t, transition_coeff_t, innovation_coeff_t = [
                _broadcast_param(coeff, axes=[0], sizes=[num_samples])
                if num_samples is not None else coeff for coeff in [
                    self.emission_coeff[t],
                    self.transition_coeff[t],
                    self.innovation_coeff[t],
                ]
            ]

            # Expand residuals as well
            # residual_t: (num_samples, batch_size, obs_dim, 1)
            residual_t = (_broadcast_param(
                self.residuals[t].expand_dims(axis=-1),
                axes=[0],
                sizes=[num_samples],
            ) if num_samples is not None else self.residuals[t].expand_dims(
                axis=-1))

            # (num_samples, batch_size, 1, obs_dim)
            samples_t = (F.linalg_gemm2(emission_coeff_t, samples_lat_state) +
                         residual_t + samples_eps_obs[t])
            samples_t = (samples_t.swapaxes(dim1=2, dim2=3) if num_samples
                         is not None else samples_t.swapaxes(dim1=1, dim2=2))
            samples_seq.append(samples_t)

            # sample next state: (num_samples, batch_size, latent_dim, 1)
            samples_lat_state = F.linalg_gemm2(
                transition_coeff_t, samples_lat_state) + F.linalg_gemm2(
                    innovation_coeff_t,
                    samples_std_normal[t],
                    transpose_a=True)

        # (num_samples, batch_size, seq_length, obs_dim)
        samples = F.concat(*samples_seq, dim=-2)
        return (samples if scale is None else F.broadcast_mul(
            samples,
            scale.expand_dims(axis=1).expand_dims(
                axis=0) if num_samples is not None else scale.expand_dims(
                    axis=1),
        ))