def forward(self, u, s, x=None, m=None, V=None): """ m and V are from the Gaussian state x_{t-1}. Forward marginalises out the Gaussian if m and V given, or transforms a sample x if that is given instead. """ # assert len({(x is None), (m is None and V is None)}) == 2 assert (x is None) or (m is None and V is None) # h = torch.cat((u, s), dim=-1) if u is not None else s switch_to_switch_dist = self.conditional_dist_transform(h) if self.F is not None: if m is not None: # marginalise mp, Vp = filter_forward_prediction_step( m=m, V=V, A=self.F, R=switch_to_switch_dist.covariance_matrix, b=switch_to_switch_dist.loc, ) else: # single sample fwd mp = matvec(self.F, x) + switch_to_switch_dist.loc Vp = switch_to_switch_dist.covariance_matrix return MultivariateNormal(loc=mp, scale_tril=cholesky(Vp)) else: return switch_to_switch_dist
def _make_switch_transition_dist( self, lat_vars_tm1: GLSVariablesSGLS, ctrl_t: ControlInputsSGLS, ) -> torch.distributions.MultivariateNormal: """ Compute p(s_t | s_{t-1}) = \int p(x_{t-1}, s_t | s_{t-1}) dx_{t-1} = \int p(s_t | s_{t-1}, x_{t-1}) N(x_{t-1} | s_{t-1}) dx_{t-1}. We use an additive structure, resulting in a convolution of PDFs, i.e. i) the conditional from the switch-to-switch transition and ii) the marginal from the state-switch-transition (state marginalised). The Gaussian is a 'stable distribution' -> The sum of Gaussian variables (not PDFs!) is Gaussian, for which locs and *covariances* are summed. (In case of weighted sum, means and *scales* are weighted.) """ if len({lat_vars_tm1.x is None, lat_vars_tm1.m is None}) != 2: raise Exception( "Provide samples XOR dist params (-> marginalize).") marginalize_states = lat_vars_tm1.m is not None # i) switch-to-switch conditional switch_to_switch_dist = super()._make_switch_transition_dist( lat_vars_tm1=lat_vars_tm1, ctrl_t=ctrl_t, ) # ii) state-to-switch rec_base_params = self.recurrent_base_parameters( switch=lat_vars_tm1.switch) if marginalize_states: m, V = filter_forward_prediction_step( m=lat_vars_tm1.m, V=lat_vars_tm1.V, A=rec_base_params.F, R=rec_base_params.S, b=None, ) else: m = matvec(rec_base_params.F, lat_vars_tm1.x) V = rec_base_params.S state_to_switch_dist = MultivariateNormal( loc=m, scale_tril=cholesky(V), ) # combine i) & ii): sum variables (=convolve PDFs). switch_model_dist = gaussian_linear_combination({ state_to_switch_dist: 0.5, switch_to_switch_dist: 0.5 }) return switch_model_dist
def filter_forward( dims: TensorDims, A: torch.Tensor, B: Optional[torch.Tensor], C: torch.Tensor, D: Optional[torch.Tensor], LQinv_tril: torch.Tensor, LQinv_logdiag: torch.Tensor, LRinv_tril: torch.Tensor, LRinv_logdiag: torch.Tensor, LV0inv_tril: torch.Tensor, LV0inv_logdiag: torch.Tensor, m0: torch.Tensor, y: torch.Tensor, u_state: Optional[torch.Tensor] = None, u_obs: Optional[torch.Tensor] = None, ): device, dtype = A.device, A.dtype R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag) Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag) V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag) m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims) # pre-compute biases b = matvec(B, u_state) if u_state is not None else 0 d = matvec(D, u_obs) if u_obs is not None else 0 m_fw = torch.zeros((dims.timesteps, dims.batch, dims.state), device=device, dtype=dtype) V_fw = torch.zeros( (dims.timesteps, dims.batch, dims.state, dims.state), device=device, dtype=dtype, ) for t in range(0, dims.timesteps): (mp, Vp) = (filter_forward_prediction_step( m=m_fw[t - 1], V=V_fw[t - 1], R=R, A=A, b=b[t - 1], ) if t > 0 else (m0, V0)) m_fw[t], V_fw[t] = filter_forward_measurement_step(y=y[t], m=mp, V=Vp, Q=Q, C=C, d=d[t]) return m_fw, V_fw
def loss_forward( dims: TensorDims, A: torch.Tensor, B: Optional[torch.Tensor], C: torch.Tensor, D: Optional[torch.Tensor], LQinv_tril: torch.Tensor, LQinv_logdiag: torch.Tensor, LRinv_tril: torch.Tensor, LRinv_logdiag: torch.Tensor, LV0inv_tril: torch.Tensor, LV0inv_logdiag: torch.Tensor, m0: torch.Tensor, y: torch.Tensor, u_state: Optional[torch.Tensor] = None, u_obs: Optional[torch.Tensor] = None, ): device, dtype = A.device, A.dtype R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag) Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag) V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag) m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims) # pre-compute biases b = matvec(B, u_state) if u_state is not None else 0 d = matvec(D, u_obs) if u_obs is not None else 0 # Note: We can not use (more readable) in-place operations due to backprop problems. m_fw = [None] * dims.timesteps V_fw = [None] * dims.timesteps loss = torch.zeros((dims.batch, ), device=device, dtype=dtype) for t in range(0, dims.timesteps): (mp, Vp) = (filter_forward_prediction_step( m=m_fw[t - 1], V=V_fw[t - 1], R=R, A=A, b=b[t - 1], ) if t > 0 else (m0, V0)) m_fw[t], V_fw[t], dobs_norm, LVpyinv = filter_forward_measurement_step( y=y[t], m=mp, V=Vp, Q=Q, C=C, d=d[t], return_loss_components=True) loss += ( 0.5 * torch.sum(dobs_norm**2, dim=-1) - 0.5 * 2 * torch.sum(torch.log(batch_diag(LVpyinv)), dim=(-1, )) + 0.5 * dims.target * LOG_2PI) return loss
def filter_forward( dims: TensorDims, A: torch.Tensor, B: Optional[torch.Tensor], C: torch.Tensor, D: Optional[torch.Tensor], LQinv_tril: torch.Tensor, LQinv_logdiag: torch.Tensor, LRinv_tril: torch.Tensor, LRinv_logdiag: torch.Tensor, LV0inv_tril: torch.Tensor, LV0inv_logdiag: torch.Tensor, m0: torch.Tensor, y: torch.Tensor, u_state: Optional[torch.Tensor] = None, u_obs: Optional[torch.Tensor] = None, ): # TODO: assumes B != None, D != None. u_state and u_obs can be None. # Better to work with vectors b, d. matmul Bu should be done outside! m_fw = create_zeros_state_vec(dims=dims, device=A.device, dtype=A.dtype) V_fw = create_zeros_state_mat(dims=dims, device=A.device, dtype=A.dtype) for t in range(0, dims.timesteps): if t == 0: V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag) mp, Vp = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims) else: R_tm1 = cov_from_invcholesky_param(LRinv_tril[t - 1], LRinv_logdiag[t - 1]) b_tm1 = (matvec(B[t - 1], u_state[t - 1]) if u_state is not None else 0) mp, Vp = filter_forward_prediction_step(m=m_fw[t - 1], V=V_fw[t - 1], R=R_tm1, A=A[t - 1], b=b_tm1) Q_t = cov_from_invcholesky_param(LQinv_tril[t], LQinv_logdiag[t]) d_t = matvec(D[t], u_obs[t]) if u_obs is not None else 0 m_fw[t], V_fw[t] = filter_forward_measurement_step(y=y[t], m=mp, V=Vp, Q=Q_t, C=C[t], d=d_t) return torch.stack(m_fw, dim=0), torch.stack(V_fw, dim=0)
def marginal_step( self, lats_tm1: LatentsSGLS, ctrl_t: ControlInputsSGLS, deterministic: bool = False, ) -> Prediction: # TODO: duplication with sample_step. Requires refactoring. n_batch = lats_tm1.variables.m.shape[1] if lats_tm1.variables.switch is None: switch_model_dist_t = self._make_switch_prior_dist( n_particle=self.n_particle, n_batch=n_batch, lat_vars_tm1=lats_tm1.variables, ctrl_t=ctrl_t, ) else: switch_model_dist_t = self._make_switch_transition_dist( lat_vars_tm1=lats_tm1.variables, ctrl_t=ctrl_t, ) s_t = (switch_model_dist_t.mean if deterministic else switch_model_dist_t.sample()) gls_params_t = self.gls_base_parameters( switch=s_t, controls=ctrl_t, ) x_t = None m_t, V_t = filter_forward_prediction_step( m=lats_tm1.variables.m, V=lats_tm1.variables.V, R=gls_params_t.R, A=gls_params_t.A, b=gls_params_t.b, ) mpy_t, Vpy_t = filter_forward_predictive_distribution( m=m_t, V=V_t, Q=gls_params_t.Q, C=gls_params_t.C, d=gls_params_t.d, ) emission_dist_t = torch.distributions.MultivariateNormal( loc=mpy_t, scale_tril=cholesky(Vpy_t), ) # NOTE: Should compute Cov if need forecast joint distribution. lats_t = LatentsSGLS( log_weights=lats_tm1.log_weights, # does not change w/o evidence. gls_params=None, # not used outside this function variables=GLSVariablesSGLS( x=x_t, m=m_t, V=V_t, Cov=None, switch=s_t, ), ) return Prediction(latents=lats_t, emissions=emission_dist_t)
def filter_step( self, lats_tm1: (LatentsSGLS, None), tar_t: torch.Tensor, ctrl_t: ControlInputsSGLS, tar_is_obs_t: Optional[torch.Tensor] = None, ): if tar_is_obs_t is not None: raise NotImplementedError("cannot handle missing data atm.") is_initial_step = lats_tm1 is None if is_initial_step: n_particle, n_batch = self.n_particle, len(tar_t) state_prior = self.state_prior_model( None, batch_shape_to_prepend=(n_particle, n_batch), ) log_norm_weights = normalize_log_weights( log_weights=torch.zeros_like(state_prior.loc[..., 0]), ) lats_tm1 = LatentsSGLS( log_weights=None, # Not used. We use log_norm_weights instead. gls_params=None, # First (previous) step no gls_params variables=GLSVariablesSGLS( m=state_prior.loc, V=state_prior.covariance_matrix, Cov=None, x=None, switch=None, ), ) switch_model_dist = self._make_switch_prior_dist( lat_vars_tm1=lats_tm1.variables, ctrl_t=ctrl_t, n_particle=n_particle, n_batch=n_batch, ) else: log_norm_weights = normalize_log_weights( log_weights=lats_tm1.log_weights, ) log_norm_weights, resampled_tensors = resample( n_particle=self.n_particle, log_norm_weights=log_norm_weights, tensors_to_resample={ key: val for key, val in lats_tm1.variables.__dict__.items() if key not in ("x", "Cov") # below set to None explicitly }, resampling_indices_fn=self.resampling_indices_fn, criterion_fn=self.resampling_criterion_fn, ) # We dont use gls_params anymore. # If needed for e.g. evaluation, remember to re-sample all params! lats_tm1 = LatentsSGLS( log_weights=None, # Not used. We use log_norm_weights instead. gls_params=None, # not used outside this function. Read above. variables=GLSVariablesSGLS( **resampled_tensors, x=None, Cov=None, ), ) switch_model_dist = self._make_switch_transition_dist( lat_vars_tm1=lats_tm1.variables, ctrl_t=ctrl_t, ) switch_proposal_dist = self._make_switch_proposal_dist( switch_model_dist=switch_model_dist, switch_encoder_dist=self._make_encoder_dists( tar_t=tar_t, ctrl_t=ctrl_t, ).switch, ) s_t = switch_proposal_dist.rsample() gls_params_t = self.gls_base_parameters( switch=s_t, controls=ctrl_t, ) mp, Vp = filter_forward_prediction_step( m=lats_tm1.variables.m, V=lats_tm1.variables.V, R=gls_params_t.R, A=gls_params_t.A, b=gls_params_t.b, ) m_t, V_t = filter_forward_measurement_step( y=tar_t, m=mp, V=Vp, Q=gls_params_t.Q, C=gls_params_t.C, d=gls_params_t.d, ) mpy_t, Vpy_t = filter_forward_predictive_distribution( m=mp, V=Vp, Q=gls_params_t.Q, C=gls_params_t.C, d=gls_params_t.d, ) measurement_dist = MultivariateNormal( loc=mpy_t, scale_tril=cholesky(Vpy_t), ) log_update = (measurement_dist.log_prob(tar_t) + switch_model_dist.log_prob(s_t) - switch_proposal_dist.log_prob(s_t)) log_weights_t = log_norm_weights + log_update return LatentsSGLS( log_weights=log_weights_t, gls_params=None, # not used outside this function variables=GLSVariablesSGLS( m=m_t, V=V_t, x=None, Cov=None, switch=s_t, ), )
def filter_step( self, lats_tm1: (LatentsKVAE, None), tar_t: torch.Tensor, ctrl_t: ControlInputs, tar_is_obs_t: Optional[torch.Tensor] = None, ) -> LatentsKVAE: is_initial_step = lats_tm1 is None if tar_is_obs_t is None: tar_is_obs_t = torch.ones( tar_t.shape[:-1], dtype=tar_t.dtype, device=tar_t.device, ) # 1) Initial step must prepare previous latents with prior and learnt z. if is_initial_step: n_particle, n_batch = self.n_particle, len(tar_t) state_prior = self.state_prior_model( None, batch_shape_to_prepend=(n_particle, n_batch), ) z_init = self.z_initial[None, None].repeat(n_particle, n_batch, 1) lats_tm1 = LatentsKVAE( variables=GLSVariablesKVAE( m=state_prior.loc, V=state_prior.covariance_matrix, Cov=None, x=None, auxiliary=z_init, rnn_state=None, m_auxiliary_variational=None, V_auxiliary_variational=None, ), gls_params=None, ) # 2) Compute GLS params rnn_state_t, rnn_output_t = self.compute_deterministic_switch_step( rnn_input=lats_tm1.variables.auxiliary, rnn_prev_state=lats_tm1.variables.rnn_state, ) gls_params_t = self.gls_base_parameters( switch=rnn_output_t, controls=ctrl_t, ) # Perform filter step: # 3) Prediction Step: Only for t > 0 and using previous GLS params. # (In KVAE, they do first update then prediction step.) if is_initial_step: mp, Vp, = lats_tm1.variables.m, lats_tm1.variables.V else: mp, Vp = filter_forward_prediction_step( m=lats_tm1.variables.m, V=lats_tm1.variables.V, R=lats_tm1.gls_params.R, A=lats_tm1.gls_params.A, b=lats_tm1.gls_params.b, ) # 4) Update step # 4a) Observed data: Infer pseudo-obs by encoding obs && Bayes update auxiliary_variational_dist_t = self.encoder(tar_t) z_infer_t = auxiliary_variational_dist_t.rsample([self.n_particle]) m_infer_t, V_infer_t = filter_forward_measurement_step( y=z_infer_t, m=mp, V=Vp, Q=gls_params_t.Q, C=gls_params_t.C, d=gls_params_t.d, ) # 4b) Choice: inferred / predicted m, V for observed / missing data. is_filtered = tar_is_obs_t[None, :].repeat(self.n_particle, 1).byte() replace_m_fw = is_filtered[:, :, None].repeat(1, 1, mp.shape[2]) replace_V_fw = is_filtered[:, :, None, None].repeat( 1, 1, Vp.shape[2], Vp.shape[3], ) assert replace_m_fw.shape == m_infer_t.shape == mp.shape assert replace_V_fw.shape == V_infer_t.shape == Vp.shape m_t = torch.where(replace_m_fw, m_infer_t, mp) V_t = torch.where(replace_V_fw, V_infer_t, Vp) # 4c) Missing Data: Predict pseudo-observations && No Bayes update mpz_t, Vpz_t = filter_forward_predictive_distribution( m=m_t, # posterior predictive or one-step-predictive (if missing) V=V_t, Q=gls_params_t.Q, C=gls_params_t.C, d=gls_params_t.d, ) auxiliary_predictive_dist_t = MultivariateNormal( loc=mpz_t, covariance_matrix=Vpz_t, ) z_gen_t = auxiliary_predictive_dist_t.rsample() # 4d) Choice: inferred / predicted z for observed / missing data. # One-step predictive if missing and inferred from encoder otherwise. replace_z = is_filtered[:, :, None].repeat(1, 1, z_gen_t.shape[2]) z_t = torch.where(replace_z, z_infer_t, z_gen_t) # 5) Put result in Latents object, used in next iteration lats_t = LatentsKVAE( variables=GLSVariablesKVAE( m=m_t, V=V_t, Cov=None, x=None, auxiliary=z_t, rnn_state=rnn_state_t, m_auxiliary_variational=auxiliary_variational_dist_t.loc, V_auxiliary_variational=auxiliary_variational_dist_t. covariance_matrix, ), gls_params=gls_params_t, ) return lats_t
def loss_forward( dims: TensorDims, A: torch.Tensor, B: Optional[torch.Tensor], C: torch.Tensor, D: Optional[torch.Tensor], LQinv_tril: torch.Tensor, LQinv_logdiag: torch.Tensor, LRinv_tril: torch.Tensor, LRinv_logdiag: torch.Tensor, LV0inv_tril: torch.Tensor, LV0inv_logdiag: torch.Tensor, m0: torch.Tensor, y: torch.Tensor, u_state: Optional[torch.Tensor] = None, u_obs: Optional[torch.Tensor] = None, ): """ Computes the sample-wise (e.g. (particle, batch)) forward / filter loss. If particle and batch dims are used, the ordering is TPBF (time, particle, batch, feature). Note: it would be better numerically (amount of computation and precision) to sum/avg over these dimensions in all loss terms. However, we must compute it sample-wise for many algorithms, i.e. Rao-Blackwellised Particle Filters. """ device, dtype = A.device, A.dtype V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag) m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims) # Note: We can not use (more readable) in-place operations due to backprop problems. m_fw = [None] * dims.timesteps V_fw = [None] * dims.timesteps dim_particle = ((dims.particle, ) if dims.particle is not None and dims.particle != 0 else tuple()) loss = torch.zeros(dim_particle + (dims.batch, ), device=device, dtype=dtype) for t in range(0, dims.timesteps): (mp, Vp) = (filter_forward_prediction_step( m=m_fw[t - 1], V=V_fw[t - 1], R=cov_from_invcholesky_param(LRinv_tril[t - 1], LRinv_logdiag[t - 1]), A=A[t - 1], b=matvec(B[t - 1], u_state[t - 1]) if u_state is not None else 0, ) if t > 0 else (m0, V0)) m_fw[t], V_fw[t], dobs_norm, LVpyinv = filter_forward_measurement_step( y=y[t], m=mp, V=Vp, Q=cov_from_invcholesky_param(LQinv_tril[t], LQinv_logdiag[t]), C=C[t], d=matvec(D[t], u_obs[t]) if u_obs is not None else 0, return_loss_components=True, ) loss += 0.5 * (torch.sum(dobs_norm**2, dim=-1) - 2 * torch.sum(torch.log(batch_diag(LVpyinv)), dim=(-1, )) + dims.target * LOG_2PI) return loss