def compute_lds( self, F, feat_static_cat: Tensor, seasonal_indicators: Tensor, time_feat: Tensor, length: int, prior_mean: Optional[Tensor] = None, prior_cov: Optional[Tensor] = None, lstm_begin_state: Optional[List[Tensor]] = None, ): # embed categorical features and expand along time axis embedded_cat = self.embedder(feat_static_cat) repeated_static_features = embedded_cat.expand_dims(axis=1).repeat( axis=1, repeats=length) # construct big features tensor (context) features = F.concat(time_feat, repeated_static_features, dim=2) output, lstm_final_state = self.lstm.unroll( inputs=features, begin_state=lstm_begin_state, length=length, merge_outputs=True, ) if prior_mean is None: prior_input = F.slice_axis(output, axis=1, begin=0, end=1).squeeze(axis=1) prior_mean = self.prior_mean_model(prior_input) prior_cov_diag = ( self.prior_cov_diag_model(prior_input) * (self.prior_cov_bounds.upper - self.prior_cov_bounds.lower) + self.prior_cov_bounds.lower) prior_cov = make_nd_diag(F, prior_cov_diag, self.issm.latent_dim()) ( emission_coeff, transition_coeff, innovation_coeff, ) = self.issm.get_issm_coeff(seasonal_indicators) noise_std, innovation, residuals = self.lds_proj(output) lds = LDS( emission_coeff=emission_coeff, transition_coeff=transition_coeff, innovation_coeff=F.broadcast_mul(innovation, innovation_coeff), noise_std=noise_std, residuals=residuals, prior_mean=prior_mean, prior_cov=prior_cov, latent_dim=self.issm.latent_dim(), output_dim=self.issm.output_dim(), seq_length=length, ) return lds, lstm_final_state
def kalman_filter_step( F, target: Tensor, prior_mean: Tensor, prior_cov: Tensor, emission_coeff: Tensor, residual: Tensor, noise_std: Tensor, latent_dim: int, output_dim: int, ): """ One step of the Kalman filter. This function computes the filtered state (mean and covariance) given the linear system coefficients the prior state (mean and variance), as well as observations. Parameters ---------- F target Observations of the system output, shape (batch_size, output_dim) prior_mean Prior mean of the latent state, shape (batch_size, latent_dim) prior_cov Prior covariance of the latent state, shape (batch_size, latent_dim, latent_dim) emission_coeff Emission coefficient, shape (batch_size, output_dim, latent_dim) residual Residual component, shape (batch_size, output_dim) noise_std Standard deviation of the output noise, shape (batch_size, output_dim) latent_dim Dimension of the latent state vector Returns ------- Tensor Filtered_mean, shape (batch_size, latent_dim) Tensor Filtered_covariance, shape (batch_size, latent_dim, latent_dim) Tensor Log probability, shape (batch_size, ) """ # output_mean: mean of the target (batch_size, obs_dim) output_mean = F.linalg_gemm2( emission_coeff, prior_mean.expand_dims(axis=-1)).squeeze(axis=-1) # noise covariance noise_cov = make_nd_diag(F=F, x=noise_std * noise_std, d=output_dim) S_hh_x_A_tr = F.linalg_gemm2(prior_cov, emission_coeff, transpose_b=True) # covariance of the target output_cov = F.linalg_gemm2(emission_coeff, S_hh_x_A_tr) + noise_cov # compute the Cholesky decomposition output_cov = LL^T L_output_cov = F.linalg_potrf(output_cov) # Compute Kalman gain matrix K: # K = S_hh X with X = A^T output_cov^{-1} # We have X = A^T output_cov^{-1} => X output_cov = A^T => X LL^T = A^T # We can thus obtain X by solving two linear systems involving L kalman_gain = F.linalg_trsm( L_output_cov, F.linalg_trsm(L_output_cov, S_hh_x_A_tr, rightside=True, transpose=True), rightside=True, ) # compute the error target_minus_residual = target - residual delta = target_minus_residual - output_mean # filtered estimates filtered_mean = prior_mean.expand_dims(axis=-1) + F.linalg_gemm2( kalman_gain, delta.expand_dims(axis=-1)) filtered_mean = filtered_mean.squeeze(axis=-1) # Joseph's symmetrized update for covariance: ImKA = F.broadcast_sub(F.eye(latent_dim), F.linalg_gemm2(kalman_gain, emission_coeff)) filtered_cov = F.linalg_gemm2( ImKA, F.linalg_gemm2( prior_cov, ImKA, transpose_b=True)) + F.linalg_gemm2( kalman_gain, F.linalg_gemm2(noise_cov, kalman_gain, transpose_b=True)) # likelihood term: (batch_size,) log_p = MultivariateGaussian(output_mean, L_output_cov).log_prob(target_minus_residual) return filtered_mean, filtered_cov, log_p
def sample_marginals(self, num_samples: Optional[int] = None, scale: Optional[Tensor] = None) -> Tensor: r""" Generates samples from the marginals p(z_t), t = 1, \ldots, `seq_length`. Parameters ---------- num_samples Number of samples to generate scale Scale of each sequence in x, shape (batch_size, output_dim) Returns ------- Tensor Samples, shape (num_samples, batch_size, seq_length, output_dim) """ F = self.F state_mean = self.prior_mean.expand_dims(axis=-1) state_cov = self.prior_cov output_mean_seq = [] output_cov_seq = [] for t in range(self.seq_length): # compute and store observation mean at time t output_mean = F.linalg_gemm2( self.emission_coeff[t], state_mean) + self.residuals[t].expand_dims(axis=-1) output_mean_seq.append(output_mean) # compute and store observation cov at time t output_cov = F.linalg_gemm2( self.emission_coeff[t], F.linalg_gemm2( state_cov, self.emission_coeff[t], transpose_b=True), ) + make_nd_diag(F=F, x=self.noise_std[t] * self.noise_std[t], d=self.output_dim) output_cov_seq.append(output_cov.expand_dims(axis=1)) state_mean = F.linalg_gemm2(self.transition_coeff[t], state_mean) state_cov = F.linalg_gemm2( self.transition_coeff[t], F.linalg_gemm2( state_cov, self.transition_coeff[t], transpose_b=True), ) + F.linalg_gemm2( self.innovation_coeff[t], self.innovation_coeff[t], transpose_a=True, ) output_mean = F.concat(*output_mean_seq, dim=1) output_cov = F.concat(*output_cov_seq, dim=1) L = F.linalg_potrf(output_cov) output_distribution = MultivariateGaussian(output_mean, L) samples = output_distribution.sample(num_samples=num_samples) return (samples if scale is None else F.broadcast_mul( samples, scale.expand_dims(axis=1)))
(3, 4, 5), (), ), ( StudentT( mu=mx.nd.zeros(shape=(3, 4, 5)), sigma=mx.nd.ones(shape=(3, 4, 5)), nu=mx.nd.ones(shape=(3, 4, 5)), ), (3, 4, 5), (), ), ( MultivariateGaussian( mu=mx.nd.zeros(shape=(3, 4, 5)), L=make_nd_diag(F=mx.nd, x=mx.nd.ones(shape=(3, 4, 5)), d=5), ), (3, 4), (5, ), ), (Dirichlet(alpha=mx.nd.ones(shape=(3, 4, 5))), (3, 4), (5, )), ( DirichletMultinomial( dim=5, n_trials=9, alpha=mx.nd.ones(shape=(3, 4, 5))), (3, 4), (5, ), ), ( Laplace(mu=mx.nd.zeros(shape=(3, 4, 5)), b=mx.nd.ones(shape=(3, 4, 5))), (3, 4, 5),