def make_val_plots_univariate( model, data, idxs_ts, n_steps_forecast, savepath, marginalize_states=False, future_target_groundtruth=None, idx_particle=None, show=False, ): device = model.device data = {name: val.to(device) for name, val in data.items()} future_target = data.pop("future_target") future_target_plot = ( future_target_groundtruth if future_target_groundtruth is not None else future_target ) y_plot = torch.cat([data["past_target"], future_target_plot]) predictions_filtered, predictions_forecast = model( **data, n_steps_forecast=n_steps_forecast, marginalize_states=marginalize_states, ) predictions = predictions_filtered + predictions_forecast if isinstance(predictions[0].emissions, torch.distributions.Distribution): mpy_trajectory = torch.stack([p.emissions.mean for p in predictions]) # using distribution.variance (i.e. diagonal) for predictive variance. # We dont need covariances for plots. And this allows non-Gaussian. Vpy_trajectory = batch_diag_matrix( torch.stack([p.emissions.variance for p in predictions]) ) elif isinstance(predictions[0].emissions, torch.Tensor): mpy_trajectory = torch.stack([p.emissions for p in predictions]) Vpy_trajectory = batch_diag_matrix(torch.zeros_like(mpy_trajectory)) else: raise ValueError( f"Unexpected emission type: {type(predictions[0].emissions)}", ) log_weights = torch.stack([p.latents.log_weights for p in predictions]) norm_weights_trajectory = torch.exp(normalize_log_weights(log_weights)) for idx_ts in idxs_ts: fig, axs = plot_predictive_distribution( y=y_plot.detach(), mpy=mpy_trajectory.detach(), Vpy=Vpy_trajectory.detach(), norm_weights=norm_weights_trajectory.detach(), n_steps_forecast=n_steps_forecast, idx_timeseries=idx_ts, idx_particle=idx_particle, show=show, savepath=f"{savepath}_b{idx_ts}.pdf", ) plt.close(fig)
def forward(self, switch, controls: ControlInputsSGLSISSM) -> GLSParams: weights = self.link_transformers(switch=switch) # biases B = (torch.einsum("...k,koi->...oi", weights.B, self.B) if self.B is not None else None) D = (torch.einsum("...k,koi->...oi", weights.D, self.D) if self.D is not None else None) b = self.compute_bias( s=switch, u=controls.state, bias_fn=self.b_fn, bias_matrix=B, ) d = self.compute_bias( s=switch, u=controls.target, bias_fn=self.d_fn, bias_matrix=D, ) # transition and emission from ISSM _, C, R_diag_projector = self.issm( seasonal_indicators=controls.seasonal_indicators, ) A = None # instead of identity, we use None to reduce computation # covariance matrices if self.make_cov_from_cholesky_avg: Q_diag, LQ_diag = self.var_from_average_scales( weights=weights.Q, Linv_logdiag=self.LQinv_logdiag_limiter(self.LQinv_logdiag), ) R_diag, LR_diag = self.var_from_average_scales( weights=weights.R, Linv_logdiag=self.LRinv_logdiag_limiter(self.LRinv_logdiag), ) else: Q_diag, LQ_diag = self.var_from_average_variances( weights=weights.Q, Linv_logdiag=self.LQinv_logdiag_limiter(self.LQinv_logdiag), ) R_diag, LR_diag = self.var_from_average_variances( weights=weights.R, Linv_logdiag=self.LRinv_logdiag_limiter(self.LRinv_logdiag), ) Q = batch_diag_matrix(Q_diag) R = batch_diag_matrix(matvec(R_diag_projector, R_diag)) LQ = batch_diag_matrix(LQ_diag) LR = batch_diag_matrix(matvec(R_diag_projector, LR_diag)) return GLSParams(A=A, b=b, C=C, d=d, Q=Q, R=R, LR=LR, LQ=LQ)
def cov_from_average_log_scales( weights: torch.Tensor, Linv_logdiag: torch.Tensor, Linv_tril: (torch.Tensor, None), ): if Linv_tril is None: ( mat_diag_weighted, Lmat_diag_weighted, ) = GLSParameters.var_from_average_log_scales( weights=weights, Linv_logdiag=Linv_logdiag, ) mat_weighted = batch_diag_matrix(mat_diag_weighted) Lmat_weighted = batch_diag_matrix(Lmat_diag_weighted) else: raise Exception("No can do.") return mat_weighted, Lmat_weighted
def cov_from_average_variances( weights: torch.Tensor, Linv_logdiag: torch.Tensor, Linv_tril: (torch.Tensor, None), ): if Linv_tril is None: ( mat_diag_weighted, Lmat_diag_weighted, ) = GLSParameters.var_from_average_variances( weights=weights, Linv_logdiag=Linv_logdiag, ) mat_weighted = batch_diag_matrix(mat_diag_weighted) Lmat_weighted = batch_diag_matrix(Lmat_diag_weighted) else: mat, _ = cov_and_chol_from_invcholesky_param( Linv_tril=Linv_tril, Linv_logdiag=Linv_logdiag, ) mat_weighted = torch.einsum("...k,kq->...q", weights, mat) Lmat_weighted = torch.cholesky(mat_weighted) return mat_weighted, Lmat_weighted
def cov_from_average_scales( weights: torch.Tensor, Linv_logdiag: torch.Tensor, Linv_tril: (torch.Tensor, None), ): if Linv_tril is None: ( mat_diag_weighted, Lmat_diag_weighted, ) = GLSParameters.var_from_average_scales( weights=weights, Linv_logdiag=Linv_logdiag, ) mat_weighted = batch_diag_matrix(mat_diag_weighted) Lmat_weighted = batch_diag_matrix(Lmat_diag_weighted) else: Lmat = torch.inverse( torch.tril(Linv_tril, -1) + batch_diag_matrix(torch.exp(Linv_logdiag))) Lmat_weighted = torch.einsum("...k,koi->...oi", weights, Lmat) mat_weighted = matmul(Lmat_weighted, Lmat_weighted.transpose(-1, -2)) # LL^T return mat_weighted, Lmat_weighted
def precision_matrix(self): return batch_diag_matrix(self.base_dist.variance**-1)
def scale_tril(self): return batch_diag_matrix(self.base_dist.scale)
def covariance_matrix(self): return batch_diag_matrix(self.base_dist.variance)
def forward(self, diagonal): return batch_diag_matrix(diagonal=diagonal)
def R_diag_projector(self, seasonal_indicators): # R_diag_projector = torch.ones(seasonal_indicators.shape[:-1] + (self.n_state,)) R_diag_projector = batch_diag_matrix( torch.ones(seasonal_indicators.shape[:-1] + (self.n_state,)) ) return R_diag_projector.to(self.dtype).to(self.device)