def log_Z_u(self, ts: SimpleNamespace, u: B.Numeric): """Compute the normalising constant term of the optimal :math:`q(z|u)`. Args: ts (:class:`types.SimpleNamespace`): Terms. u (tensor): Sample for :math:`u` to use. Returns: scalar: Normalising constant term. """ q_z = self.q_z_optimal(ts, u) quadratic = B.sum(ts.y**2) + B.sum(u * B.mm(ts.A_sum, u)) + ts.c_sum return 0.5 * (-ts.n * B.log(2 * B.pi * self.model.noise) - quadratic / self.model.noise + B.logdet(self.model.K_z) - B.logdet(q_z.prec) + B.squeeze(B.iqf(q_z.prec, q_z.lam)))
def _q_optimal(self, A, I, I_sum, K, x, x2=None): if x2 is not None: inner = B.mm(I, x2, I, tr_c=True) else: # This is _much_ more efficient! part = B.mm(I, x) inner = B.mm(part, part, tr_b=True) return NaturalNormal( B.mm(I_sum, x) / self.model.noise, K + (A + B.sum(inner, axis=0)) / self.model.noise, )
def elbo( self, state: B.RandomState, t: B.Numeric, y: B.Numeric, collapsed: Union[None, str] = None, ): """Compute the mean-field ELBO. Args: state (random state, optional): Random state. t (vector): Locations of observations. y (vector): Observations. collapsed (str, optional): Collapse over :math:`z` or :math:`u`. Returns: scalar: ELBO """ ts = self.construct_terms(t, y) if collapsed is None: q_u = self.q_u q_z = self.q_z elif collapsed == "z": q_u = self.q_u q_z = self.q_z_optimal_mean_field(ts, q_u) elif collapsed == "u": q_z = self.q_z q_u = self.q_u_optimal_mean_field(ts, q_z) else: raise ValueError(f'Invalid value "{collapsed}" for `collapsed`.') return state, ((-0.5 * ts.n * B.log(2 * B.pi * self.model.noise) + ( (-0.5 / self.model.noise) * (B.sum(ts.y**2) + B.sum(ts.A_sum * q_u.m2) + B.sum(ts.B_sum * q_z.m2) + B.sum(B.mm(ts.I_uz, q_z.m2, ts.I_uz, tr_c=True) * q_u.m2) + ts.c_sum - 2 * B.sum(q_u.mean * B.mm(ts.I_uz_sum, q_z.mean))))) - q_u.kl(self.p_u) - q_z.kl(self.p_z))
def _predict_moments(self, ts, u, u2, z, z2): # Compute first moment. m1 = B.flatten(B.mm(u, ts.I_uz, z, tr_a=True)) # Compute second moment. A = ts.I_ux - ts.K_z_squeezed B_ = ts.I_hz - ts.K_u_squeezed c = (ts.I_hx - B.sum(self.model.K_u_inv * ts.I_ux) - B.sum(self.model.K_z_inv * ts.I_hz, axis=(1, 2)) + B.sum(self.model.K_u_inv * ts.K_z_squeezed, axis=(1, 2))) m2 = (B.sum(A * u2, axis=(1, 2)) + B.sum(B_ * z2, axis=(1, 2)) + c + B.sum(u2 * B.mm(ts.I_uz, z2, ts.I_uz, tr_c=True), axis=(1, 2))) return m1, m2
def construct_terms(self, t, y=None): """Construct commonly required quantities. Args: t (vector): Locations of observations. y (vector, optional): Observations. """ ts = SimpleNamespace() ts.n = B.length(t) # Construct integrals. ts.I_hx = self.model.compute_i_hx(t, t) ts.I_ux = self.model.compute_I_ux() ts.I_hz = self.model.compute_I_hz(t) ts.I_uz = self.model.compute_I_uz(t) # Do some precomputations. ts.I_hx_sum = B.sum(ts.I_hx, axis=0) ts.I_hz_sum = B.sum(ts.I_hz, axis=0) ts.I_ux_sum = ts.n * ts.I_ux ts.K_u_squeezed = B.mm(ts.I_uz, self.model.K_u_inv, ts.I_uz, tr_a=True) ts.K_z_squeezed = B.mm(ts.I_uz, self.model.K_z_inv, ts.I_uz, tr_c=True) ts.A_sum = ts.I_ux_sum - B.sum(ts.K_z_squeezed, axis=0) ts.B_sum = ts.I_hz_sum - B.sum(ts.K_u_squeezed, axis=0) ts.c_sum = ( ts.I_hx_sum - B.sum(self.model.K_u_inv * ts.I_ux_sum) - B.sum(self.model.K_z_inv * ts.I_hz_sum) # It would be more efficient to first `B.sum(ts.K_z_squeezed, axis=0)`, but # for some reason that results in a segmentation fault when run on with the # JIT on the GPU. I'm not sure what's going on... + B.sum(self.model.K_u_inv * ts.K_z_squeezed)) if y is not None: ts.y = y ts.I_uz_sum = B.sum(y[:, None, None] * ts.I_uz, axis=0) # Weight by data. return ts