def test_broadcast_kde_pdf_shape(n_samples): bw = 0.1 precision = 10 x = objax.random.normal((n_samples,), generator=generator) lb, ub = get_domain_extension(x, 10) support = np.linspace(lb, ub, precision) pdf_support = broadcast_kde_pdf(support, x, bw) # checks chex.assert_shape(support, (precision,)) chex.assert_shape(pdf_support, (precision,))
def test_broadcast_kde_cdf_shape(n_samples): bw = 0.1 precision = 10 x = objax.random.normal((n_samples,), generator=generator) lb, ub = get_domain_extension(x, 10) support = np.linspace(lb, ub, precision) factor = normalization_factor(x, bw) quantiles = broadcast_kde_cdf(support, x, factor) # checks chex.assert_shape(quantiles, (precision,))
transform=transform, bijector=bijector, transform_and_bijector=transform_and_bijector, transform_gradient_bijector=transform_gradient_bijector, ) def init_kde_params( X: jnp.ndarray, bw: float = 0.1, support_extension: Union[int, float] = 10, precision: int = 1_000, return_params: bool = True, ): # generate support points lb, ub = get_domain_extension(X, support_extension) support = jnp.linspace(lb, ub, precision) # calculate the pdf for gaussian pdf pdf_support = broadcast_kde_pdf(support, X, bw) # calculate the cdf for support points factor = normalization_factor(X, bw) quantiles = broadcast_kde_cdf(support, X, factor) return UniKDEParams( support=support, quantiles=quantiles, support_pdf=support, empirical_pdf=pdf_support,