예제 #1
0
 def lax_fun(dataset, points, w):
   w = jax.numpy.abs(w) if weights else None
   kde = lsp_stats.gaussian_kde(dataset, bw_method=method, weights=w)
   if func is None:
     result = kde(points)
   else:
     result = getattr(kde, func)(points)
   return result
예제 #2
0
  def testKdeResample1d(self, shape, dtype):
    rng = jtu.rand_default(self.rng())
    dataset = rng(shape, dtype)
    weights = jax.numpy.abs(rng(shape[-1:], dtype))
    kde = lsp_stats.gaussian_kde(dataset, weights=weights)
    samples = jax.numpy.squeeze(kde.resample(jax.random.PRNGKey(5), shape=(1000,)))

    def cdf(x):
      result = jax.vmap(partial(kde.integrate_box_1d, -np.inf))(x)
      # Manually casting to numpy in order to avoid type promotion error
      return np.array(result)

    self.assertGreater(osp_stats.kstest(samples, cdf).pvalue, 0.01)
예제 #3
0
  def testKdePyTree(self):
    @jax.jit
    def evaluate_kde(kde, x):
      return kde.evaluate(x)

    dtype = np.float32
    rng = jtu.rand_default(self.rng())
    dataset = rng((3, 15), dtype)
    x = rng((3, 12), dtype)
    kde = lsp_stats.gaussian_kde(dataset)
    leaves, treedef = tree_util.tree_flatten(kde)
    kde2 = tree_util.tree_unflatten(treedef, leaves)
    tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2)
    self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
예제 #4
0
 def resample(key, dataset, weights, *, shape):
   kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
   return kde.resample(key, shape=shape)
예제 #5
0
 def lax_fun(dataset, weights):
   kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
   other = lsp_stats.gaussian_kde(
     dataset[..., :-3] + 0.1, weights=jax.numpy.abs(weights[:-3]))
   return kde.integrate_kde(other)
예제 #6
0
 def lax_fun(dataset, weights):
   kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
   return kde.integrate_box_1d(-0.5, 1.5)
예제 #7
0
 def lax_fun(dataset, weights):
   kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
   return kde.integrate_gaussian(mean, covariance)