def _update_group(self, obs: Tensor, group_idx: Union[slice, Sequence[int]], which_valid: Union[slice, Sequence[int]], lower: Optional[Tensor] = None, upper: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: # indices: idx_2d = bmat_idx(group_idx, which_valid) idx_3d = bmat_idx(group_idx, which_valid, which_valid) # observed values, censoring limits obs = obs[idx_2d] if lower is not None: lower = lower[idx_2d] elif torch.isnan(lower).any(): raise ValueError("NaNs not allowed in `lower`") if upper is not None: upper = upper[idx_2d] elif torch.isnan(upper).any(): raise ValueError("NaNs not allowed in `upper`") if (lower == upper).any(): raise RuntimeError("lower cannot == upper") # subset belief / design-mats: means = self.means[group_idx] covs = self.covs[group_idx] R = self.R[idx_3d] H = self.H[idx_2d] measured_means = H.matmul(means.unsqueeze(-1)).squeeze(-1) # calculate censoring fx: prob_lo, prob_up = tobit_probs(mean=measured_means, cov=R, lower=lower, upper=upper) prob_obs = torch.diag_embed(1 - prob_up - prob_lo) mm_adj, R_adj = tobit_adjustment(mean=measured_means, cov=R, lower=lower, upper=upper, probs=(prob_lo, prob_up)) # kalman gain: K = self.kalman_gain(covariance=covs, H=H, R_adjusted=R_adj, prob_obs=prob_obs) # update means_new = self.mean_update(mean=means, K=K, residuals=obs - mm_adj) covs_new = self.covariance_update(covariance=covs, K=K, H=H, prob_obs=prob_obs) return means_new, covs_new
def _log_prob_with_subsetting(self, obs: Tensor, group_idx: Selector, time_idx: Selector, measure_idx: Selector, **kwargs) -> Tensor: self._check_lp_sub_input(group_idx, time_idx) idx_3d = bmat_idx(group_idx, time_idx, measure_idx) idx_4d = bmat_idx(group_idx, time_idx, measure_idx, measure_idx) dist = MultivariateNormal(self.predictions[idx_3d], self.prediction_uncertainty[idx_4d]) return dist.log_prob(obs[idx_3d])
def _log_prob_with_subsetting(self, obs: Tensor, group_idx: Selector, time_idx: Selector, measure_idx: Selector, method: str = 'independent', lower: Optional[Tensor] = None, upper: Optional[Tensor] = None) -> Tensor: self._check_lp_sub_input(group_idx, time_idx) idx_3d = bmat_idx(group_idx, time_idx, measure_idx) idx_4d = bmat_idx(group_idx, time_idx, measure_idx, measure_idx) # subset obs, lower, upper: if upper is None: upper = torch.full_like(obs, float('inf')) if lower is None: lower = torch.full_like(obs, -float('inf')) obs, lower, upper = obs[idx_3d], lower[idx_3d], upper[idx_3d] # pred_mean = self.predictions[idx_3d] pred_cov = self.prediction_uncertainty[idx_4d] # cens_up = torch.isclose(obs, upper) cens_lo = torch.isclose(obs, lower) # loglik_uncens = torch.zeros_like(obs) loglik_cens_up = torch.zeros_like(obs) loglik_cens_lo = torch.zeros_like(obs) for m in range(pred_mean.shape[-1]): std = pred_cov[..., m, m].sqrt() z = (pred_mean[..., m] - obs[..., m]) / std # pdf is well behaved at tails: loglik_uncens[..., m] = std_normal.log_prob(z) - std.log() # but cdf is not, clamp: z = torch.clamp(z, -5., 5.) loglik_cens_up[..., m] = std_normal.cdf(z).log() loglik_cens_lo[..., m] = (1. - std_normal.cdf(z)).log() loglik = torch.zeros_like(obs) loglik[cens_up] = loglik_cens_up[cens_up] loglik[cens_lo] = loglik_cens_lo[cens_lo] loglik[~(cens_up | cens_lo)] = loglik_uncens[~(cens_up | cens_lo)] # take the product of the dimension probs (i.e., assume independence) return torch.sum(loglik, -1)
def _update_group(self, obs: Tensor, group_idx: Union[slice, Sequence[int]], which_valid: Union[slice, Sequence[int]]) -> Tuple[Tensor, Tensor]: idx_2d = bmat_idx(group_idx, which_valid) idx_3d = bmat_idx(group_idx, which_valid, which_valid) group_obs = obs[idx_2d] group_means = self.means[group_idx] group_covs = self.covs[group_idx] group_H = self.H[idx_2d] group_R = self.R[idx_3d] group_measured_means = group_H.matmul(group_means.unsqueeze(2)).squeeze(2) group_system_covs = self.system_uncertainty(covs=group_covs, H=group_H, R=group_R) group_K = self.kalman_gain(system_covariance=group_system_covs, covariance=group_covs, H=group_H) means_new = self.mean_update(mean=group_means, K=group_K, residuals=group_obs - group_measured_means) covs_new = self.covariance_update(covariance=group_covs, K=group_K, H=group_H, R=group_R) return means_new, covs_new
def _log_prob_with_subsetting(self, obs: Tensor, group_idx: Selector, time_idx: Selector, measure_idx: Selector, method: str = 'independent', lower: Optional[Tensor] = None, upper: Optional[Tensor] = None) -> Tensor: self._check_lp_sub_input(group_idx, time_idx) idx_no_measure = bmat_idx(group_idx, time_idx) idx_3d = bmat_idx(group_idx, time_idx, measure_idx) idx_4d = bmat_idx(group_idx, time_idx, measure_idx, measure_idx) # subset obs, lower, upper: obs, lower, upper = obs[idx_3d], lower[idx_3d], upper[idx_3d] if method.lower() == 'update': means = self.means[idx_no_measure] covs = self.covs[idx_no_measure] H = self.H[idx_3d] R = self.R[idx_4d] measured_means = H.matmul(means.unsqueeze(-1)).squeeze(-1) # calculate prob-obs: prob_lo, prob_up = tobit_probs(mean=measured_means, cov=R, lower=lower, upper=upper) prob_obs = torch.diag_embed(1 - prob_up - prob_lo) # calculate adjusted measure mean and cov: mm_adj, R_adj = tobit_adjustment(mean=measured_means, cov=R, lower=lower, upper=upper, probs=(prob_lo, prob_up)) # system uncertainty: Ht = H.permute(0, 1, 3, 2) system_uncertainty = prob_obs.matmul(H).matmul(covs).matmul( Ht).matmul(prob_obs) + R_adj # log prob: dist = torch.distributions.MultivariateNormal( mm_adj, system_uncertainty) return dist.log_prob(obs) elif method.lower() == 'independent': # pred_mean = self.predictions[idx_3d] pred_cov = self.prediction_uncertainty[idx_4d] # cens_up = torch.isclose(obs, upper) cens_lo = torch.isclose(obs, lower) # loglik_uncens = torch.zeros_like(obs) loglik_cens_up = torch.zeros_like(obs) loglik_cens_lo = torch.zeros_like(obs) for m in range(pred_mean.shape[-1]): std = pred_cov[..., m, m].sqrt() z = (pred_mean[..., m] - obs[..., m]) / std # pdf is well behaved at tails: loglik_uncens[..., m] = std_normal.log_prob(z) - std.log() # but cdf is not, clamp: z = torch.clamp(z, -5., 5.) loglik_cens_up[..., m] = std_normal.cdf(z).log() loglik_cens_lo[..., m] = (1. - std_normal.cdf(z)).log() loglik = torch.zeros_like(obs) loglik[cens_up] = loglik_cens_up[cens_up] loglik[cens_lo] = loglik_cens_lo[cens_lo] loglik[~(cens_up | cens_lo)] = loglik_uncens[~(cens_up | cens_lo)] # take the product of the dimension probs (i.e., assume independence) return torch.sum(loglik, -1) else: raise RuntimeError("Expected method to be one of: {}.".format( {'update', 'independent'}))
def update(self, obs: Tensor) -> 'Gaussian': isnan = (obs != obs) means_new = self.means.data.clone() covs_new = self.covs.data.clone() # need to do a different update depending on which (if any) dimensions are missing: update_groups = defaultdict(list) anynan_by_group = (torch.sum(isnan, 1) > 0) # groups with nan: nan_group_idx = anynan_by_group.nonzero().squeeze(-1).tolist() for i in nan_group_idx: if isnan[i].all(): continue # if all nan, then simply skip update which_valid = (~isnan[i]).nonzero().squeeze(-1).tolist() update_groups[tuple(which_valid)].append(i) update_groups = list(update_groups.items()) # groups without nan: if isnan.any(): nonan_group_idx = (~anynan_by_group).nonzero().squeeze(-1).tolist() if len(nonan_group_idx): update_groups.append((slice(None), nonan_group_idx)) else: # if no nans at all, then faster to use slices: update_groups.append((slice(None), slice(None))) measured_means, system_covs = self.measurement # updates: for which_valid, group_idx in update_groups: idx_2d = bmat_idx(group_idx, which_valid) idx_3d = bmat_idx(group_idx, which_valid, which_valid) group_obs = obs[idx_2d] group_means = self.means[group_idx] group_covs = self.covs[group_idx] group_measured_means = measured_means[idx_2d] group_system_covs = system_covs[idx_3d] group_H = self.H[idx_2d] group_R = self.R[idx_3d] group_K = self.kalman_gain(system_covariance=group_system_covs, covariance=group_covs, H=group_H) means_new[group_idx] = self.mean_update(mean=group_means, K=group_K, residuals=group_obs - group_measured_means) covs_new[group_idx] = self.covariance_update(covariance=group_covs, K=group_K, H=group_H, R=group_R) # calculate last-measured: any_measured_group_idx = (torch.sum(~isnan, 1) > 0).nonzero().squeeze(-1) last_measured = self.last_measured.clone() last_measured[any_measured_group_idx] = 0 return self.__class__(means=means_new, covs=covs_new, last_measured=last_measured)