def lax_fun(dataset, points, w): w = jax.numpy.abs(w) if weights else None kde = lsp_stats.gaussian_kde(dataset, bw_method=method, weights=w) if func is None: result = kde(points) else: result = getattr(kde, func)(points) return result
def testKdeResample1d(self, shape, dtype): rng = jtu.rand_default(self.rng()) dataset = rng(shape, dtype) weights = jax.numpy.abs(rng(shape[-1:], dtype)) kde = lsp_stats.gaussian_kde(dataset, weights=weights) samples = jax.numpy.squeeze(kde.resample(jax.random.PRNGKey(5), shape=(1000,))) def cdf(x): result = jax.vmap(partial(kde.integrate_box_1d, -np.inf))(x) # Manually casting to numpy in order to avoid type promotion error return np.array(result) self.assertGreater(osp_stats.kstest(samples, cdf).pvalue, 0.01)
def testKdePyTree(self): @jax.jit def evaluate_kde(kde, x): return kde.evaluate(x) dtype = np.float32 rng = jtu.rand_default(self.rng()) dataset = rng((3, 15), dtype) x = rng((3, 12), dtype) kde = lsp_stats.gaussian_kde(dataset) leaves, treedef = tree_util.tree_flatten(kde) kde2 = tree_util.tree_unflatten(treedef, leaves) tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2) self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
def resample(key, dataset, weights, *, shape): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) return kde.resample(key, shape=shape)
def lax_fun(dataset, weights): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) other = lsp_stats.gaussian_kde( dataset[..., :-3] + 0.1, weights=jax.numpy.abs(weights[:-3])) return kde.integrate_kde(other)
def lax_fun(dataset, weights): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) return kde.integrate_box_1d(-0.5, 1.5)
def lax_fun(dataset, weights): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) return kde.integrate_gaussian(mean, covariance)