def logpdf(self, x): """Compute the log-pdf. Args: x (input): Values to compute the log-pdf of. Returns: list[tensor]: Log-pdf for every input in `x`. If it can be determined that the list contains only a single log-pdf, then the list is flattened to a scalar. """ x = B.uprank(x) # Handle missing data. We don't handle missing data for batched computation. if B.rank(x) == 2 and B.shape(x, 1) == 1: available = B.jit_to_numpy(~B.isnan(x[:, 0])) if not B.all(available): # Take the elements of the mean, variance, and inputs corresponding to # the available data. available_mean = B.take(self.mean, available) available_var = B.submatrix(self.var, available) available_x = B.take(x, available) return Normal(available_mean, available_var).logpdf(available_x) logpdfs = ( -( B.logdet(self.var)[..., None] # Correctly line up with `iqf_diag`. + B.cast(self.dtype, self.dim) * B.cast(self.dtype, B.log_2_pi) + B.iqf_diag(self.var, B.subtract(x, self.mean))) / 2) return logpdfs[..., 0] if B.shape(logpdfs, -1) == 1 else logpdfs
def test_predict_noisy(construct_ilmm, x): ilmm = construct_ilmm(noise_amplification=1000) y = ilmm.sample(x) ilmm = ilmm.condition(x, y) means, lowers, uppers = ilmm.predict(x, latent=False) # Test that predictions have high uncertainty. assert B.all(uppers - lowers > 10)
def test_predict_noiseless(construct_ilmm, x): ilmm = construct_ilmm(noise_amplification=1e-10) y = ilmm.sample(x) ilmm = ilmm.condition(x, y) means, lowers, uppers = ilmm.predict(x, latent=True) # Test that predictions match sample and have low uncertainty. approx(means, y, atol=1e-3) assert B.all(uppers - lowers < 1e-4)
def tuple_equal(x, y): """Check tuples for equality. Args: x (tuple): First tuple. y (tuple): Second tuple. Returns: bool: `x` and `y` are equal. """ return len(x) == len(y) and \ all([_shape(xi) == _shape(yi) and B.all(xi == yi) for xi, yi in zip(x, y)])
def test_mo_batched(): x = B.randn(16, 10, 1) with Measure(): p = cross(GP(1, 2 * EQ().stretch(0.5)), GP(2, 2 * EQ().stretch(0.5))) y = p(x).sample() logpdf = p(x, 0.1).logpdf(y) assert B.shape(logpdf) == (16, ) assert B.shape(y) == (16, 20, 1) p = p | (p(x), y) y2 = p(x).sample() logpdf2 = p(x, 0.1).logpdf(y) assert B.shape(y2) == (16, 20, 1) assert B.shape(logpdf2) == (16, ) assert B.all(logpdf2 > logpdf) approx(y, y2, atol=1e-5)
def test_batched(): x1 = B.randn(16, 10, 1) x2 = B.randn(16, 5, 1) p = GP(1, 2 * EQ().stretch(0.5)) y1, y2 = p.measure.sample(p(x1), p(x2)) logpdf = p.measure.logpdf((p(x1, 0.1), y1), (p(x2, 0.1), y2)) assert B.shape(y1) == (16, 10, 1) assert B.shape(y2) == (16, 5, 1) assert B.shape(logpdf) == (16, ) p = p | ((p(x1), y1), (p(x2), y2)) y1_2, y2_2 = p.measure.sample(p(x1), p(x2)) logpdf2 = p.measure.logpdf((p(x1, 0.1), y1), (p(x2, 0.1), y2)) assert B.shape(y1_2) == (16, 10, 1) assert B.shape(y2_2) == (16, 5, 1) approx(y1, y1_2, atol=1e-5) approx(y2, y2_2, atol=1e-5) assert B.shape(logpdf2) == (16, ) assert B.all(logpdf2 > logpdf)
def __eq__(self, other): return B.all(self.noises == other.noises)
def __eq__(self, other): return B.all(self.alpha == other.alpha)
def __eq__(self, other): return self[0] == other[0] and B.all(self.period == other.period)
def __eq__(self, other): return B.all(self.alpha == other.alpha) and \ B.all(self.beta == other.beta)
def __eq__(self, other): return B.all(self.scale == other.scale) and self[0] == other[0]