Example #1
0
    def compute_lds(
        self,
        F,
        feat_static_cat: Tensor,
        seasonal_indicators: Tensor,
        time_feat: Tensor,
        length: int,
        prior_mean: Optional[Tensor] = None,
        prior_cov: Optional[Tensor] = None,
        lstm_begin_state: Optional[List[Tensor]] = None,
    ):
        # embed categorical features and expand along time axis
        embedded_cat = self.embedder(feat_static_cat)
        repeated_static_features = embedded_cat.expand_dims(axis=1).repeat(
            axis=1, repeats=length)

        # construct big features tensor (context)
        features = F.concat(time_feat, repeated_static_features, dim=2)

        output, lstm_final_state = self.lstm.unroll(
            inputs=features,
            begin_state=lstm_begin_state,
            length=length,
            merge_outputs=True,
        )

        if prior_mean is None:
            prior_input = F.slice_axis(output, axis=1, begin=0,
                                       end=1).squeeze(axis=1)

            prior_mean = self.prior_mean_model(prior_input)
            prior_cov_diag = (
                self.prior_cov_diag_model(prior_input) *
                (self.prior_cov_bounds.upper - self.prior_cov_bounds.lower) +
                self.prior_cov_bounds.lower)
            prior_cov = make_nd_diag(F, prior_cov_diag, self.issm.latent_dim())

        (
            emission_coeff,
            transition_coeff,
            innovation_coeff,
        ) = self.issm.get_issm_coeff(seasonal_indicators)

        noise_std, innovation, residuals = self.lds_proj(output)

        lds = LDS(
            emission_coeff=emission_coeff,
            transition_coeff=transition_coeff,
            innovation_coeff=F.broadcast_mul(innovation, innovation_coeff),
            noise_std=noise_std,
            residuals=residuals,
            prior_mean=prior_mean,
            prior_cov=prior_cov,
            latent_dim=self.issm.latent_dim(),
            output_dim=self.issm.output_dim(),
            seq_length=length,
        )

        return lds, lstm_final_state
Example #2
0
def kalman_filter_step(
    F,
    target: Tensor,
    prior_mean: Tensor,
    prior_cov: Tensor,
    emission_coeff: Tensor,
    residual: Tensor,
    noise_std: Tensor,
    latent_dim: int,
    output_dim: int,
):
    """
    One step of the Kalman filter.

    This function computes the filtered state (mean and covariance) given the
    linear system coefficients the prior state (mean and variance),
    as well as observations.

    Parameters
    ----------
    F
    target
        Observations of the system output, shape (batch_size, output_dim)
    prior_mean
        Prior mean of the latent state, shape (batch_size, latent_dim)
    prior_cov
        Prior covariance of the latent state, shape
        (batch_size, latent_dim, latent_dim)
    emission_coeff
        Emission coefficient, shape (batch_size, output_dim, latent_dim)
    residual
        Residual component, shape (batch_size, output_dim)
    noise_std
        Standard deviation of the output noise, shape (batch_size, output_dim)
    latent_dim
        Dimension of the latent state vector
    Returns
    -------
    Tensor
        Filtered_mean, shape (batch_size, latent_dim)
    Tensor
        Filtered_covariance, shape (batch_size, latent_dim, latent_dim)
    Tensor
        Log probability, shape (batch_size, )
    """
    # output_mean: mean of the target (batch_size, obs_dim)
    output_mean = F.linalg_gemm2(
        emission_coeff, prior_mean.expand_dims(axis=-1)).squeeze(axis=-1)

    # noise covariance
    noise_cov = make_nd_diag(F=F, x=noise_std * noise_std, d=output_dim)

    S_hh_x_A_tr = F.linalg_gemm2(prior_cov, emission_coeff, transpose_b=True)

    # covariance of the target
    output_cov = F.linalg_gemm2(emission_coeff, S_hh_x_A_tr) + noise_cov

    # compute the Cholesky decomposition output_cov = LL^T
    L_output_cov = F.linalg_potrf(output_cov)

    # Compute Kalman gain matrix K:
    # K = S_hh X with X = A^T output_cov^{-1}
    # We have X = A^T output_cov^{-1} => X output_cov = A^T => X LL^T = A^T
    # We can thus obtain X by solving two linear systems involving L
    kalman_gain = F.linalg_trsm(
        L_output_cov,
        F.linalg_trsm(L_output_cov,
                      S_hh_x_A_tr,
                      rightside=True,
                      transpose=True),
        rightside=True,
    )

    # compute the error
    target_minus_residual = target - residual
    delta = target_minus_residual - output_mean

    # filtered estimates
    filtered_mean = prior_mean.expand_dims(axis=-1) + F.linalg_gemm2(
        kalman_gain, delta.expand_dims(axis=-1))
    filtered_mean = filtered_mean.squeeze(axis=-1)

    # Joseph's symmetrized update for covariance:
    ImKA = F.broadcast_sub(F.eye(latent_dim),
                           F.linalg_gemm2(kalman_gain, emission_coeff))

    filtered_cov = F.linalg_gemm2(
        ImKA, F.linalg_gemm2(
            prior_cov, ImKA, transpose_b=True)) + F.linalg_gemm2(
                kalman_gain,
                F.linalg_gemm2(noise_cov, kalman_gain, transpose_b=True))

    # likelihood term: (batch_size,)
    log_p = MultivariateGaussian(output_mean,
                                 L_output_cov).log_prob(target_minus_residual)

    return filtered_mean, filtered_cov, log_p
Example #3
0
    def sample_marginals(self,
                         num_samples: Optional[int] = None,
                         scale: Optional[Tensor] = None) -> Tensor:
        r"""
        Generates samples from the marginals p(z_t),
        t = 1, \ldots, `seq_length`.

        Parameters
        ----------
        num_samples
            Number of samples to generate
        scale
            Scale of each sequence in x, shape (batch_size, output_dim)

        Returns
        -------
        Tensor
            Samples, shape (num_samples, batch_size, seq_length, output_dim)
        """
        F = self.F

        state_mean = self.prior_mean.expand_dims(axis=-1)
        state_cov = self.prior_cov

        output_mean_seq = []
        output_cov_seq = []

        for t in range(self.seq_length):
            # compute and store observation mean at time t
            output_mean = F.linalg_gemm2(
                self.emission_coeff[t],
                state_mean) + self.residuals[t].expand_dims(axis=-1)

            output_mean_seq.append(output_mean)

            # compute and store observation cov at time t
            output_cov = F.linalg_gemm2(
                self.emission_coeff[t],
                F.linalg_gemm2(
                    state_cov, self.emission_coeff[t], transpose_b=True),
            ) + make_nd_diag(F=F,
                             x=self.noise_std[t] * self.noise_std[t],
                             d=self.output_dim)

            output_cov_seq.append(output_cov.expand_dims(axis=1))

            state_mean = F.linalg_gemm2(self.transition_coeff[t], state_mean)

            state_cov = F.linalg_gemm2(
                self.transition_coeff[t],
                F.linalg_gemm2(
                    state_cov, self.transition_coeff[t], transpose_b=True),
            ) + F.linalg_gemm2(
                self.innovation_coeff[t],
                self.innovation_coeff[t],
                transpose_a=True,
            )

        output_mean = F.concat(*output_mean_seq, dim=1)
        output_cov = F.concat(*output_cov_seq, dim=1)

        L = F.linalg_potrf(output_cov)

        output_distribution = MultivariateGaussian(output_mean, L)

        samples = output_distribution.sample(num_samples=num_samples)

        return (samples if scale is None else F.broadcast_mul(
            samples, scale.expand_dims(axis=1)))
Example #4
0
     (3, 4, 5),
     (),
 ),
 (
     StudentT(
         mu=mx.nd.zeros(shape=(3, 4, 5)),
         sigma=mx.nd.ones(shape=(3, 4, 5)),
         nu=mx.nd.ones(shape=(3, 4, 5)),
     ),
     (3, 4, 5),
     (),
 ),
 (
     MultivariateGaussian(
         mu=mx.nd.zeros(shape=(3, 4, 5)),
         L=make_nd_diag(F=mx.nd, x=mx.nd.ones(shape=(3, 4, 5)), d=5),
     ),
     (3, 4),
     (5, ),
 ),
 (Dirichlet(alpha=mx.nd.ones(shape=(3, 4, 5))), (3, 4), (5, )),
 (
     DirichletMultinomial(
         dim=5, n_trials=9, alpha=mx.nd.ones(shape=(3, 4, 5))),
     (3, 4),
     (5, ),
 ),
 (
     Laplace(mu=mx.nd.zeros(shape=(3, 4, 5)),
             b=mx.nd.ones(shape=(3, 4, 5))),
     (3, 4, 5),