Beispiel #1
0
 def prior_sample(self, num_samps, t=None):
     """
     Sample from the model prior f~N(0,K) multiple times using a nested loop.
     :param num_samps: the number of samples to draw [scalar]
     :param t: the input locations at which to sample (defaults to train+test set) [N_samp, 1]
     :return:
         f_sample: the prior samples [S, N_samp]
     """
     self.update_model(softplus_list(self.prior.hyp))
     if t is None:
         t = self.t_all
     else:
         x_ind = np.argsort(t[:, 0])
         t = t[x_ind]
     dt = np.concatenate([np.array([0.0]), np.diff(t[:, 0])])
     N = dt.shape[0]
     with loops.Scope() as s:
         s.f_sample = np.zeros([N, self.func_dim, num_samps])
         s.m = np.linalg.cholesky(self.Pinf) @ random.normal(random.PRNGKey(99), shape=[self.state_dim, 1])
         for i in s.range(num_samps):
             s.m = np.linalg.cholesky(self.Pinf) @ random.normal(random.PRNGKey(i), shape=[self.state_dim, 1])
             for k in s.range(N):
                 A = self.prior.state_transition(dt[k], self.prior.hyp)  # transition and noise process matrices
                 Q = self.Pinf - A @ self.Pinf @ A.T
                 C = np.linalg.cholesky(Q + 1e-6 * np.eye(self.state_dim))  # <--- can be a bit unstable
                 # we need to provide a different PRNG seed every time:
                 s.m = A @ s.m + C @ random.normal(random.PRNGKey(i*k+k), shape=[self.state_dim, 1])
                 H = self.prior.measurement_model(t[k, 1:], softplus_list(self.prior.hyp))
                 f = (H @ s.m).T
                 s.f_sample = index_add(s.f_sample, index[k, ..., i], np.squeeze(f))
     return s.f_sample
Beispiel #2
0
 def predict(self, y=None, dt=None, mask=None, site_params=None, sampling=False,
             r=None, return_full=False, compute_nlpd=True):
     """
     Calculate posterior predictive distribution p(f*|f,y) by filtering and smoothing across the
     training & test locations.
     This function is also used during posterior sampling to smooth the auxillary data sampled from the prior.
     The output shapes depend on return_full
     :param y: observations (nans at test locations) [M, 1]
     :param dt: step sizes Δtₙ = tₙ - tₙ₋₁ [M, 1]
     :param mask: a boolean array signifying which elements are observed and which are nan [M, 1]
     :param site_params: the sites computed during a previous inference proceedure [2, M, obs_dim]
     :param sampling: notify whether we are doing posterior sampling
     :param r: spatial locations [M, R]
     :param return_full: flag to notify if we are handling the case where spatial test locations are a different
                         size to training locations
     :param compute_nlpd: flag to notify whether to compute the negative log predictive density of the test data
     :return:
         posterior_mean: the posterior predictive mean [M, state_dim] or [M, obs_dim]
         posterior_cov: the posterior predictive (co)variance [M, M, state_dim] or [M, obs_dim]
         site_params: the site parameters. If none are provided then new sites are computed [2, M, obs_dim]
     """
     y = self.y_all if y is None else y
     r = self.r_all if r is None else r
     dt = self.dt_all if dt is None else dt
     mask = self.mask if mask is None else mask
     params = [self.prior.hyp.copy(), self.likelihood.hyp.copy()]
     site_params = self.sites.site_params if site_params is None else site_params
     if site_params is not None and not sampling:
         # construct a vector of site parameters that is the full size of the test data
         # test site parameters are 𝓝(0,∞), and will not be used
         site_mean = np.zeros([dt.shape[0], self.func_dim, 1])
         site_cov = 1e5 * np.tile(np.eye(self.func_dim), (dt.shape[0], 1, 1))
         # replace parameters at training locations with the supplied sites
         site_mean = index_add(site_mean, index[self.train_id], site_params[0])
         site_cov = index_update(site_cov, index[self.train_id], site_params[1])
         site_params = (site_mean, site_cov)
     _, (filter_mean, filter_cov, site_params) = self.kalman_filter(y, dt, params, True, mask, site_params, r)
     _, posterior_mean, posterior_cov = self.rauch_tung_striebel_smoother(params, filter_mean, filter_cov, dt,
                                                                          True, return_full, None, None, r)
     if compute_nlpd:
         nlpd_test = self.negative_log_predictive_density(self.t_all[self.test_id], self.y_all[self.test_id],
                                                          posterior_mean[self.test_id],
                                                          posterior_cov[self.test_id],
                                                          softplus_list(params[0]), softplus(params[1]),
                                                          return_full)
     else:
         nlpd_test = np.nan
     # in the spatial model, the train and test points may be of different size. This deals with that situation:
     if return_full:
         measure_func = vmap(
             self.compute_measurement, (0, 0, 0, None)
         )
         posterior_mean, posterior_cov = measure_func(self.r_test,
                                                      posterior_mean[self.test_id], posterior_cov[self.test_id],
                                                      softplus_list(self.prior.hyp))
     return posterior_mean, posterior_cov, site_params, nlpd_test
Beispiel #3
0
 def predict(self, t=None, r=None, site_params=None, sampling=False):
     """
     Calculate posterior predictive distribution p(f*|f,y) by filtering and smoothing across the
     training & test locations.
     This function is also used during posterior sampling to smooth the auxillary data sampled from the prior.
     The output shapes depend on return_full
     :param t: test time steps [M, 1]
     :param r: spatial test locations [M, R]
     :param site_params: the sites computed during a previous inference proceedure [2, M, obs_dim]
     :param sampling: notify whether we are doing posterior sampling
     :return:
         predict_mean: the posterior predictive mean [M, state_dim] or [M, obs_dim]
         predict_cov: the posterior predictive (co)variance [M, M, state_dim] or [M, obs_dim]
         site_params: the site parameters. If none are provided then new sites are computed [2, M, obs_dim]
     """
     if t is None:
         t, y, r = self.t, self.y, self.r
     (t, y, r, r_test, dt, train_id, test_id,
      mask) = test_input_admin(self.t, self.y, self.r, t, None, r)
     return_full = r_test.shape[1] != r.shape[
         1]  # are spatial test locations different size to training locations?
     posterior_mean, posterior_cov, site_params = self.predict_everywhere(
         y, r, dt, train_id, mask, site_params, sampling, return_full)
     # only return the posterior at the test locations
     predict_mean, predict_cov = posterior_mean[test_id], posterior_cov[
         test_id]
     # in the spatial model, the train and test points may be of different size. This deals with that situation:
     if return_full:
         measure_func = vmap(self.compute_measurement, (0, 0, 0, None))
         predict_mean, predict_cov = measure_func(
             r_test, predict_mean, predict_cov,
             softplus_list(self.prior.hyp))
     return np.squeeze(predict_mean), np.squeeze(predict_cov)
Beispiel #4
0
 def negative_log_predictive_density(self, t=None, y=None, r=None):
     """
     Compute the (normalised) negative log predictive density (NLPD) of the test data yₙ*:
         NLPD = - ∑ₙ log ∫ p(yₙ*|fₙ*) 𝓝(fₙ*|mₙ*,vₙ*) dfₙ*
     where fₙ* is the function value at the test location.
     The above can be computed using the EP moment matching method, which we vectorise using vmap.
     :param t: test time steps [M, 1]
     :param y: test observations [M, 1]
     :param r: test spatial locations [M, R]
     :return:
         NLPD: the negative log predictive density for the test data
     """
     if t is None:
         t, y, r = self.t, self.y, self.r
     (t, y, r, r_test, dt, train_id, test_id,
      mask) = test_input_admin(self.t, self.y, self.r, t, y, r)
     return_full = r_test.shape[1] != r.shape[
         1]  # are spatial test locations different size to training locations?
     # run the filter and smooth across both train and test points
     posterior_mean, posterior_cov, _ = self.predict_everywhere(
         y, r, dt, train_id, mask, sampling=False, return_full=return_full)
     test_mean, test_cov = posterior_mean[test_id], posterior_cov[test_id]
     hyp_prior, hyp_lik = softplus_list(self.prior.hyp), softplus(
         self.likelihood.hyp)
     if return_full:
         measure_func = vmap(self.compute_measurement, (0, 0, 0, None))
         test_mean, test_cov = measure_func(r_test, test_mean, test_cov,
                                            hyp_prior)
     # vectorise the EP moment matching method
     lpd_func = vmap(self.likelihood.moment_match,
                     (0, 0, 0, None, None, None))
     log_predictive_density, _, _ = lpd_func(y[test_id], test_mean,
                                             test_cov, hyp_lik, 1, None)
     return -np.mean(log_predictive_density)  # mean = normalised sum
Beispiel #5
0
def gradient_step(i, state, mod, plot_num_, mu_prev_):
    params = get_params(state)
    mod.prior.hyp = params[0]
    mod.likelihood.hyp = params[1]

    # grad(Filter) + Smoother:
    neg_log_marg_lik, gradients = mod.run()
    # neg_log_marg_lik, gradients = mod.run_two_stage()  # <-- less elegant but reduces compile time

    prior_params = softplus_list(params[0])
    print('iter %2d: var=%1.2f len=%1.2f, nlml=%2.2f' %
          (i, prior_params[0], prior_params[1], neg_log_marg_lik))

    return opt_update(i, gradients, state), plot_num_, mu_prev_
Beispiel #6
0
def gradient_step(i, state, mod):
    params = get_params(state)
    mod.prior.hyp = params[0]
    mod.likelihood.hyp = params[1]

    # grad(Filter) + Smoother:
    neg_log_marg_lik, gradients = mod.run()
    # neg_log_marg_lik, gradients = mod.run_two_stage()

    prior_params = softplus_list(params[0])
    print('iter %2d: var_f1=%1.2f len_f1=%1.2f var_f2=%1.2f len_f2=%1.2f, nlml=%2.2f' %
          (i, prior_params[0][0], prior_params[0][1], prior_params[1][0], prior_params[1][1], neg_log_marg_lik))

    if plot_intermediate:
        plot(mod, i)

    return opt_update(i, gradients, state)
Beispiel #7
0
def gradient_step(i, state, mod):
    params = get_params(state)
    mod.prior.hyp = params[0]
    mod.likelihood.hyp = params[1]

    # grad(Filter) + Smoother:
    neg_log_marg_lik, gradients = mod.run()
    # neg_log_marg_lik, gradients = mod.run_two_stage()  # <-- less elegant but reduces compile time

    prior_params, lik_param = softplus_list(params[0]), softplus(params[1])
    print('iter %2d: var_f=%1.2f len_f=%1.2f var_y=%1.2f, nlml=%2.2f' %
          (i, prior_params[0], prior_params[1], lik_param, neg_log_marg_lik))

    if plot_intermediate:
        plot(mod, i)

    return opt_update(i, gradients, state)
Beispiel #8
0
def gradient_step(i, state, mod, plot_num_, mu_prev_):
    params = get_params(state)
    mod.prior.hyp = params[0]
    mod.likelihood.hyp = params[1]

    # grad(Filter) + Smoother:
    neg_log_marg_lik, gradients = mod.run()
    # neg_log_marg_lik, gradients = mod.run_two_stage()

    prior_params = softplus_list(params[0])
    print('iter %2d: var=%1.2f len_time=%1.2f len_space=%1.2f, nlml=%2.2f' %
          (i, prior_params[0], prior_params[1], prior_params[2], neg_log_marg_lik))

    if plot_intermediate:
        plot_2d_classification(mod, i)
        # plot_num_, mu_prev_ = plot_2d_classification_filtering(mod, i, plot_num_, mu_prev_)

    return opt_update(i, gradients, state), plot_num_, mu_prev_
Beispiel #9
0
def gradient_step(i, state, mod):
    params = get_params(state)
    mod.prior.hyp = params[0]
    mod.likelihood.hyp = params[1]

    # grad(Filter) + Smoother:
    # neg_log_marg_lik, gradients = mod.run()
    neg_log_marg_lik, gradients = mod.run_two_stage()

    prior_params = softplus_list(params[0])
    # print('iter %2d: var1=%1.2f len1=%1.2f om1=%1.2f var2=%1.2f len2=%1.2f om2=%1.2f var3=%1.2f len3=%1.2f om3=%1.2f '
    #       'var4=%1.2f len4=%1.2f var5=%1.2f len5=%1.2f var6=%1.2f len6=%1.2f '
    #       'vary=%1.2f, nlml=%2.2f' %
    #       (i, prior_params[0][0], prior_params[0][1], prior_params[0][2],
    #        prior_params[1][0], prior_params[1][1], prior_params[1][2],
    #        prior_params[2][0], prior_params[2][1], prior_params[2][2],
    #        prior_params[3][0], prior_params[3][1],
    #        prior_params[4][0], prior_params[4][1],
    #        prior_params[5][0], prior_params[5][1],
    #        softplus(params[1]), neg_log_marg_lik))
    # print('iter %2d: len1=%1.2f om1=%1.2f len2=%1.2f om2=%1.2f len3=%1.2f om3=%1.2f '
    #       'var4=%1.2f len4=%1.2f var5=%1.2f len5=%1.2f var6=%1.2f len6=%1.2f '
    #       'vary=%1.2f, nlml=%2.2f' %
    #       (i, prior_params[0][0], prior_params[0][1],
    #        prior_params[1][0], prior_params[1][1],
    #        prior_params[2][0], prior_params[2][1],
    #        prior_params[3][0], prior_params[3][1],
    #        prior_params[4][0], prior_params[4][1],
    #        prior_params[5][0], prior_params[5][1],
    #        softplus(params[1]), neg_log_marg_lik))
    print(
        'iter %2d: len1=%1.2f om1=%1.2f len2=%1.2f om2=%1.2f len3=%1.2f om3=%1.2f '
        'len4=%1.2f len5=%1.2f len6=%1.2f '
        'vary=%1.2f, nlml=%2.2f' %
        (i, prior_params[0][0], prior_params[0][1], prior_params[1][0],
         prior_params[1][1], prior_params[2][0], prior_params[2][1],
         prior_params[3], prior_params[4], prior_params[5], softplus(
             params[1]), neg_log_marg_lik))

    if plot_intermediate:
        plot(mod, i)

    return opt_update(i, gradients, state)
Beispiel #10
0
 def rauch_tung_striebel_smoother(self,
                                  params,
                                  m_filtered,
                                  P_filtered,
                                  dt,
                                  store=False,
                                  return_full=False,
                                  y=None,
                                  site_params=None,
                                  r=None):
     """
     Run the RTS smoother to get p(fₙ|y₁,...,y_N),
     i.e. compute p(f)𝚷ₙsₙ(fₙ) where sₙ(fₙ) are the sites (approx. likelihoods).
     If sites are provided, then it is assumed they are to be updated, which is done by
     calling the site-specific update() method.
     :param params: the model parameters, i.e the hyperparameters of the prior & likelihood
     :param m_filtered: the intermediate distribution means computed during filtering [N, state_dim, 1]
     :param P_filtered: the intermediate distribution covariances computed during filtering [N, state_dim, state_dim]
     :param dt: step sizes Δtₙ = tₙ - tₙ₋₁ [N, 1]
     :param store: a flag determining whether to store and return state mean and covariance
     :param return_full: a flag determining whether to return the full state distribution or just the function(s)
     :param y: observed data [N, obs_dim]
     :param site_params: the Gaussian approximate likelihoods [2, N, obs_dim]
     :param r: spatial input locations
     :return:
         var_exp: the sum of the variational expectations [scalar]
         smoothed_mean: the posterior marginal means [N, obs_dim]
         smoothed_var: the posterior marginal variances [N, obs_dim]
         site_params: the updated sites [2, N, obs_dim]
     """
     theta_prior, theta_lik = softplus_list(params[0]), softplus(params[1])
     self.update_model(
         theta_prior
     )  # all model components that are not static must be computed inside the function
     N = dt.shape[0]
     dt = np.concatenate([dt[1:], np.array([0.0])], axis=0)
     with loops.Scope() as s:
         s.m, s.P = m_filtered[-1, ...], P_filtered[-1, ...]
         if return_full:
             s.smoothed_mean = np.zeros([N, self.state_dim, 1])
             s.smoothed_cov = np.zeros([N, self.state_dim, self.state_dim])
         else:
             s.smoothed_mean = np.zeros([N, self.func_dim, 1])
             s.smoothed_cov = np.zeros([N, self.func_dim, self.func_dim])
         if site_params is not None:
             s.site_mean = np.zeros([N, self.func_dim, 1])
             s.site_var = np.zeros([N, self.func_dim, self.func_dim])
         for n in s.range(N - 1, -1, -1):
             # --- First compute the smoothing distribution: ---
             A = self.prior.state_transition(
                 dt[n], theta_prior
             )  # closed form integration of transition matrix
             m_predicted = A @ m_filtered[n, ...]
             tmp_gain_cov = A @ P_filtered[n, ...]
             P_predicted = A @ (P_filtered[n, ...] -
                                self.Pinf) @ A.T + self.Pinf
             # backward Kalman gain:
             # G = F * A' * P^{-1}
             # since both F(iltered) and P(redictive) are cov matrices, thus self-adjoint, we can take the transpose:
             #   = (P^{-1} * A * F)'
             G_transpose = solve(P_predicted, tmp_gain_cov)  # (P^-1)AF
             s.m = m_filtered[n, ...] + G_transpose.T @ (s.m - m_predicted)
             s.P = P_filtered[
                 n, ...] + G_transpose.T @ (s.P - P_predicted) @ G_transpose
             H = self.prior.measurement_model(r[n], theta_prior)
             if store:
                 if return_full:
                     s.smoothed_mean = index_add(s.smoothed_mean,
                                                 index[n, ...], s.m)
                     s.smoothed_cov = index_add(s.smoothed_cov,
                                                index[n, ...], s.P)
                 else:
                     s.smoothed_mean = index_add(s.smoothed_mean,
                                                 index[n, ...], H @ s.m)
                     s.smoothed_cov = index_add(s.smoothed_cov, index[n,
                                                                      ...],
                                                H @ s.P @ H.T)
             # --- Now update the site parameters: ---
             if site_params is not None:
                 # extract mean and var from state:
                 post_mean, post_cov = H @ s.m, H @ s.P @ H.T
                 # calculate the new sites
                 _, site_mu, site_cov = self.sites.update(
                     self.likelihood, y[n][...,
                                           np.newaxis], post_mean, post_cov,
                     theta_lik, (site_params[0][n], site_params[1][n]))
                 s.site_mean = index_add(s.site_mean, index[n, ...],
                                         site_mu)
                 s.site_var = index_add(s.site_var, index[n, ...], site_cov)
     if site_params is not None:
         site_params = (s.site_mean, s.site_var)
     if store:
         return site_params, s.smoothed_mean, s.smoothed_cov
     return site_params
Beispiel #11
0
 def kalman_filter(self,
                   y,
                   dt,
                   params,
                   store=False,
                   mask=None,
                   site_params=None,
                   r=None):
     """
     Run the Kalman filter to get p(fₙ|y₁,...,yₙ).
     The Kalman update step invloves some control flow to work out whether we are
         i) initialising the sites
         ii) using supplied sites
         iii) performing a Gaussian update with fixed parameters (e.g. in posterior sampling or ELBO calc.)
     If store is True then we compute and return the intermediate filtering distributions
     p(fₙ|y₁,...,yₙ) and sites sₙ(fₙ), otherwise we do not store the intermediates and simply
     return the energy / negative log-marginal likelihood, -log p(y).
     :param y: observed data [N, obs_dim]
     :param dt: step sizes Δtₙ = tₙ - tₙ₋₁ [N, 1]
     :param params: the model parameters, i.e the hyperparameters of the prior & likelihood
     :param store: flag to notify whether to store the intermediates
     :param mask: boolean array signifying which elements of y are observed [N, obs_dim]
     :param site_params: the Gaussian approximate likelihoods [2, N, obs_dim]
     :param r: spatial input locations
     :return:
         if store is True:
             neg_log_marg_lik: the filter energy, i.e. negative log-marginal likelihood -log p(y),
                               used for hyperparameter optimisation (learning) [scalar]
             filtered_mean: intermediate filtering means [N, state_dim, 1]
             filtered_cov: intermediate filtering covariances [N, state_dim, state_dim]
             site_mean: mean of the approximate likelihood sₙ(fₙ) [N, obs_dim]
             site_cov: variance of the approximate likelihood sₙ(fₙ) [N, obs_dim]
         otherwise:
             neg_log_marg_lik: the filter energy, i.e. negative log-marginal likelihood -log p(y),
                               used for hyperparameter optimisation (learning) [scalar]
     """
     theta_prior, theta_lik = softplus_list(params[0]), softplus(params[1])
     self.update_model(
         theta_prior
     )  # all model components that are not static must be computed inside the function
     N = dt.shape[0]
     with loops.Scope() as s:
         s.neg_log_marg_lik = 0.0  # negative log-marginal likelihood
         s.m, s.P = self.minf, self.Pinf
         if store:
             s.filtered_mean = np.zeros([N, self.state_dim, 1])
             s.filtered_cov = np.zeros([N, self.state_dim, self.state_dim])
             s.site_mean = np.zeros([N, self.func_dim, 1])
             s.site_cov = np.zeros([N, self.func_dim, self.func_dim])
         for n in s.range(N):
             y_n = y[n][..., np.newaxis]
             # -- KALMAN PREDICT --
             #  mₙ⁻ = Aₙ mₙ₋₁
             #  Pₙ⁻ = Aₙ Pₙ₋₁ Aₙ' + Qₙ, where Qₙ = Pinf - Aₙ Pinf Aₙ'
             A = self.prior.state_transition(dt[n], theta_prior)
             m_ = A @ s.m
             P_ = A @ (s.P - self.Pinf) @ A.T + self.Pinf
             # --- KALMAN UPDATE ---
             # Given previous predicted mean mₙ⁻ and cov Pₙ⁻, incorporate yₙ to get filtered mean mₙ &
             # cov Pₙ and compute the marginal likelihood p(yₙ|y₁,...,yₙ₋₁)
             H = self.prior.measurement_model(r[n], theta_prior)
             predict_mean = H @ m_
             predict_cov = H @ P_ @ H.T
             if mask is not None:  # note: this is a bit redundant but may come in handy in multi-output problems
                 y_n = np.where(mask[n][..., np.newaxis],
                                predict_mean[:y_n.shape[0]],
                                y_n)  # fill in masked obs with expectation
             log_lik_n, site_mean, site_cov = self.sites.update(
                 self.likelihood, y_n, predict_mean, predict_cov, theta_lik,
                 None)
             if site_params is not None:  # use supplied site parameters to perform the update
                 site_mean, site_cov = site_params[0][n], site_params[1][n]
             # modified Kalman update (see Nickish et. al. ICML 2018 or Wilkinson et. al. ICML 2019):
             S = predict_cov + site_cov
             HP = H @ P_
             K = solve(S, HP).T  # PH'(S^-1)
             s.m = m_ + K @ (site_mean - predict_mean)
             s.P = P_ - K @ HP
             if mask is not None:  # note: this is a bit redundant but may come in handy in multi-output problems
                 s.m = np.where(np.any(mask[n]), m_, s.m)
                 s.P = np.where(np.any(mask[n]), P_, s.P)
                 log_lik_n = np.where(mask[n][..., 0],
                                      np.zeros_like(log_lik_n), log_lik_n)
             s.neg_log_marg_lik -= np.sum(log_lik_n)
             if store:
                 s.filtered_mean = index_add(s.filtered_mean, index[n, ...],
                                             s.m)
                 s.filtered_cov = index_add(s.filtered_cov, index[n, ...],
                                            s.P)
                 s.site_mean = index_add(s.site_mean, index[n, ...],
                                         site_mean)
                 s.site_cov = index_add(s.site_cov, index[n, ...], site_cov)
     if store:
         return s.neg_log_marg_lik, (s.filtered_mean, s.filtered_cov,
                                     (s.site_mean, s.site_cov))
     return s.neg_log_marg_lik