Exemple #1
0
    def _update_group(self,
                      obs: Tensor,
                      group_idx: Union[slice, Sequence[int]],
                      which_valid: Union[slice, Sequence[int]],
                      lower: Optional[Tensor] = None,
                      upper: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        # indices:
        idx_2d = bmat_idx(group_idx, which_valid)
        idx_3d = bmat_idx(group_idx, which_valid, which_valid)

        # observed values, censoring limits
        obs = obs[idx_2d]
        if lower is not None:
            lower = lower[idx_2d]
        elif torch.isnan(lower).any():
            raise ValueError("NaNs not allowed in `lower`")
        if upper is not None:
            upper = upper[idx_2d]
        elif torch.isnan(upper).any():
            raise ValueError("NaNs not allowed in `upper`")

        if (lower == upper).any():
            raise RuntimeError("lower cannot == upper")

        # subset belief / design-mats:
        means = self.means[group_idx]
        covs = self.covs[group_idx]
        R = self.R[idx_3d]
        H = self.H[idx_2d]
        measured_means = H.matmul(means.unsqueeze(-1)).squeeze(-1)

        # calculate censoring fx:
        prob_lo, prob_up = tobit_probs(mean=measured_means,
                                       cov=R,
                                       lower=lower,
                                       upper=upper)
        prob_obs = torch.diag_embed(1 - prob_up - prob_lo)

        mm_adj, R_adj = tobit_adjustment(mean=measured_means,
                                         cov=R,
                                         lower=lower,
                                         upper=upper,
                                         probs=(prob_lo, prob_up))

        # kalman gain:
        K = self.kalman_gain(covariance=covs,
                             H=H,
                             R_adjusted=R_adj,
                             prob_obs=prob_obs)

        # update
        means_new = self.mean_update(mean=means, K=K, residuals=obs - mm_adj)
        covs_new = self.covariance_update(covariance=covs,
                                          K=K,
                                          H=H,
                                          prob_obs=prob_obs)
        return means_new, covs_new
Exemple #2
0
    def _log_prob_with_subsetting(self,
                                  obs: Tensor,
                                  group_idx: Selector,
                                  time_idx: Selector,
                                  measure_idx: Selector,
                                  **kwargs) -> Tensor:
        self._check_lp_sub_input(group_idx, time_idx)

        idx_3d = bmat_idx(group_idx, time_idx, measure_idx)
        idx_4d = bmat_idx(group_idx, time_idx, measure_idx, measure_idx)
        dist = MultivariateNormal(self.predictions[idx_3d], self.prediction_uncertainty[idx_4d])
        return dist.log_prob(obs[idx_3d])
Exemple #3
0
    def _log_prob_with_subsetting(self,
                                  obs: Tensor,
                                  group_idx: Selector,
                                  time_idx: Selector,
                                  measure_idx: Selector,
                                  method: str = 'independent',
                                  lower: Optional[Tensor] = None,
                                  upper: Optional[Tensor] = None) -> Tensor:
        self._check_lp_sub_input(group_idx, time_idx)

        idx_3d = bmat_idx(group_idx, time_idx, measure_idx)
        idx_4d = bmat_idx(group_idx, time_idx, measure_idx, measure_idx)

        # subset obs, lower, upper:
        if upper is None:
            upper = torch.full_like(obs, float('inf'))
        if lower is None:
            lower = torch.full_like(obs, -float('inf'))
        obs, lower, upper = obs[idx_3d], lower[idx_3d], upper[idx_3d]

        #
        pred_mean = self.predictions[idx_3d]
        pred_cov = self.prediction_uncertainty[idx_4d]

        #
        cens_up = torch.isclose(obs, upper)
        cens_lo = torch.isclose(obs, lower)

        #
        loglik_uncens = torch.zeros_like(obs)
        loglik_cens_up = torch.zeros_like(obs)
        loglik_cens_lo = torch.zeros_like(obs)
        for m in range(pred_mean.shape[-1]):
            std = pred_cov[..., m, m].sqrt()
            z = (pred_mean[..., m] - obs[..., m]) / std

            # pdf is well behaved at tails:
            loglik_uncens[..., m] = std_normal.log_prob(z) - std.log()

            # but cdf is not, clamp:
            z = torch.clamp(z, -5., 5.)
            loglik_cens_up[..., m] = std_normal.cdf(z).log()
            loglik_cens_lo[..., m] = (1. - std_normal.cdf(z)).log()

        loglik = torch.zeros_like(obs)
        loglik[cens_up] = loglik_cens_up[cens_up]
        loglik[cens_lo] = loglik_cens_lo[cens_lo]
        loglik[~(cens_up | cens_lo)] = loglik_uncens[~(cens_up | cens_lo)]

        # take the product of the dimension probs (i.e., assume independence)
        return torch.sum(loglik, -1)
Exemple #4
0
 def _update_group(self,
                   obs: Tensor,
                   group_idx: Union[slice, Sequence[int]],
                   which_valid: Union[slice, Sequence[int]]) -> Tuple[Tensor, Tensor]:
     idx_2d = bmat_idx(group_idx, which_valid)
     idx_3d = bmat_idx(group_idx, which_valid, which_valid)
     group_obs = obs[idx_2d]
     group_means = self.means[group_idx]
     group_covs = self.covs[group_idx]
     group_H = self.H[idx_2d]
     group_R = self.R[idx_3d]
     group_measured_means = group_H.matmul(group_means.unsqueeze(2)).squeeze(2)
     group_system_covs = self.system_uncertainty(covs=group_covs, H=group_H, R=group_R)
     group_K = self.kalman_gain(system_covariance=group_system_covs, covariance=group_covs, H=group_H)
     means_new = self.mean_update(mean=group_means, K=group_K, residuals=group_obs - group_measured_means)
     covs_new = self.covariance_update(covariance=group_covs, K=group_K, H=group_H, R=group_R)
     return means_new, covs_new
Exemple #5
0
    def _log_prob_with_subsetting(self,
                                  obs: Tensor,
                                  group_idx: Selector,
                                  time_idx: Selector,
                                  measure_idx: Selector,
                                  method: str = 'independent',
                                  lower: Optional[Tensor] = None,
                                  upper: Optional[Tensor] = None) -> Tensor:
        self._check_lp_sub_input(group_idx, time_idx)

        idx_no_measure = bmat_idx(group_idx, time_idx)
        idx_3d = bmat_idx(group_idx, time_idx, measure_idx)
        idx_4d = bmat_idx(group_idx, time_idx, measure_idx, measure_idx)

        # subset obs, lower, upper:
        obs, lower, upper = obs[idx_3d], lower[idx_3d], upper[idx_3d]

        if method.lower() == 'update':
            means = self.means[idx_no_measure]
            covs = self.covs[idx_no_measure]
            H = self.H[idx_3d]
            R = self.R[idx_4d]
            measured_means = H.matmul(means.unsqueeze(-1)).squeeze(-1)

            # calculate prob-obs:
            prob_lo, prob_up = tobit_probs(mean=measured_means,
                                           cov=R,
                                           lower=lower,
                                           upper=upper)
            prob_obs = torch.diag_embed(1 - prob_up - prob_lo)

            # calculate adjusted measure mean and cov:
            mm_adj, R_adj = tobit_adjustment(mean=measured_means,
                                             cov=R,
                                             lower=lower,
                                             upper=upper,
                                             probs=(prob_lo, prob_up))

            # system uncertainty:
            Ht = H.permute(0, 1, 3, 2)
            system_uncertainty = prob_obs.matmul(H).matmul(covs).matmul(
                Ht).matmul(prob_obs) + R_adj

            # log prob:
            dist = torch.distributions.MultivariateNormal(
                mm_adj, system_uncertainty)
            return dist.log_prob(obs)
        elif method.lower() == 'independent':
            #
            pred_mean = self.predictions[idx_3d]
            pred_cov = self.prediction_uncertainty[idx_4d]

            #
            cens_up = torch.isclose(obs, upper)
            cens_lo = torch.isclose(obs, lower)

            #
            loglik_uncens = torch.zeros_like(obs)
            loglik_cens_up = torch.zeros_like(obs)
            loglik_cens_lo = torch.zeros_like(obs)
            for m in range(pred_mean.shape[-1]):
                std = pred_cov[..., m, m].sqrt()
                z = (pred_mean[..., m] - obs[..., m]) / std

                # pdf is well behaved at tails:
                loglik_uncens[..., m] = std_normal.log_prob(z) - std.log()

                # but cdf is not, clamp:
                z = torch.clamp(z, -5., 5.)
                loglik_cens_up[..., m] = std_normal.cdf(z).log()
                loglik_cens_lo[..., m] = (1. - std_normal.cdf(z)).log()

            loglik = torch.zeros_like(obs)
            loglik[cens_up] = loglik_cens_up[cens_up]
            loglik[cens_lo] = loglik_cens_lo[cens_lo]
            loglik[~(cens_up | cens_lo)] = loglik_uncens[~(cens_up | cens_lo)]

            # take the product of the dimension probs (i.e., assume independence)
            return torch.sum(loglik, -1)
        else:
            raise RuntimeError("Expected method to be one of: {}.".format(
                {'update', 'independent'}))
    def update(self, obs: Tensor) -> 'Gaussian':
        isnan = (obs != obs)

        means_new = self.means.data.clone()
        covs_new = self.covs.data.clone()

        # need to do a different update depending on which (if any) dimensions are missing:
        update_groups = defaultdict(list)
        anynan_by_group = (torch.sum(isnan, 1) > 0)

        # groups with nan:
        nan_group_idx = anynan_by_group.nonzero().squeeze(-1).tolist()
        for i in nan_group_idx:
            if isnan[i].all():
                continue  # if all nan, then simply skip update
            which_valid = (~isnan[i]).nonzero().squeeze(-1).tolist()
            update_groups[tuple(which_valid)].append(i)

        update_groups = list(update_groups.items())

        # groups without nan:
        if isnan.any():
            nonan_group_idx = (~anynan_by_group).nonzero().squeeze(-1).tolist()
            if len(nonan_group_idx):
                update_groups.append((slice(None), nonan_group_idx))
        else:
            # if no nans at all, then faster to use slices:
            update_groups.append((slice(None), slice(None)))

        measured_means, system_covs = self.measurement

        # updates:
        for which_valid, group_idx in update_groups:
            idx_2d = bmat_idx(group_idx, which_valid)
            idx_3d = bmat_idx(group_idx, which_valid, which_valid)
            group_obs = obs[idx_2d]
            group_means = self.means[group_idx]
            group_covs = self.covs[group_idx]
            group_measured_means = measured_means[idx_2d]
            group_system_covs = system_covs[idx_3d]
            group_H = self.H[idx_2d]
            group_R = self.R[idx_3d]
            group_K = self.kalman_gain(system_covariance=group_system_covs,
                                       covariance=group_covs,
                                       H=group_H)
            means_new[group_idx] = self.mean_update(mean=group_means,
                                                    K=group_K,
                                                    residuals=group_obs -
                                                    group_measured_means)
            covs_new[group_idx] = self.covariance_update(covariance=group_covs,
                                                         K=group_K,
                                                         H=group_H,
                                                         R=group_R)

        # calculate last-measured:
        any_measured_group_idx = (torch.sum(~isnan, 1) >
                                  0).nonzero().squeeze(-1)
        last_measured = self.last_measured.clone()
        last_measured[any_measured_group_idx] = 0
        return self.__class__(means=means_new,
                              covs=covs_new,
                              last_measured=last_measured)