def test_ellipsoid_clustering(): import pylab as plt from jax import disable_jit, jit points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)), 1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))], axis=0) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) mask = jnp.ones(points.shape[0], jnp.bool_) mu, C = bounding_ellipsoid(points, mask) radii, rotation = ellipsoid_params(C) # plt.plot(y[0, :], y[1, :]) log_VS = log_ellipsoid_volume(radii) - jnp.log(5) with disable_jit(): cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS) )(random.PRNGKey(0), points, log_VS) mu, radii, rotation = ellipsoid_parameters print(mu, radii, rotation, jnp.bincount(cluster_id, minlength=0, length=4)) for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == i plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters))) plt.show()
def test_sample_multi_ellipsoid(): import pylab as plt from jax import disable_jit, jit, vmap points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)), 1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))], axis=0) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) mask = jnp.ones(points.shape[0], jnp.bool_) mu, C = bounding_ellipsoid(points, mask) radii, rotation = ellipsoid_params(C) y = mu[:, None] + rotation @ jnp.diag(radii) @ x # plt.plot(y[0, :], y[1, :]) log_VS = log_ellipsoid_volume(radii) - jnp.log(5) with disable_jit(): cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS) )(random.PRNGKey(0), points, log_VS) mu, radii, rotation = ellipsoid_parameters # print(mu, radii, rotation) u = vmap(lambda key: sample_multi_ellipsoid(key, mu, radii, rotation, unit_cube_constraint=True)[1])(random.split(random.PRNGKey(0),1000)) plt.scatter(u[:, 0], u[:, 1], marker='+') for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == i # plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters))) plt.show()
def init_multi_ellipsoid_sampler_state(key, live_points_U, depth, log_X): cluster_id, (mu, radii, rotation) = ellipsoid_clustering(key, live_points_U, depth, log_X) num_k = jnp.bincount(cluster_id, minlength=0, length=mu.shape[0]) return MultiEllipsoidSamplerState(cluster_id=cluster_id, mu=mu, radii=radii, rotation=rotation, num_k=num_k, num_fev_ma=jnp.asarray(1.))
def init_slice_sampler_state(key, live_points_U, depth, log_X, num_slices): cluster_id, (mu, radii, rotation) = ellipsoid_clustering(key, live_points_U, depth, log_X) num_k = jnp.bincount(cluster_id, minlength=0, length=mu.shape[0]) return SliceSamplerState( cluster_id=cluster_id, mu=mu, radii=radii, rotation=rotation, num_k=num_k, num_fev_ma=jnp.asarray(num_slices * live_points_U.shape[1] + 2.))
def main(): def log_likelihood(theta, **kwargs): return (2. + jnp.prod(jnp.cos(0.5 * theta)))**5 prior_chain = PriorChain() \ .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2))) U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims, )))( random.split(random.PRNGKey(0), 700)) theta = vmap(lambda u: prior_chain(u))(U) lik = vmap(lambda theta: log_likelihood(**theta))(theta) select = lik > 150. print("Selecting", jnp.sum(select), "need", 18 * 3) log_VS = jnp.log(jnp.sum(select) / select.size) print("V(S)", jnp.exp(log_VS)) U = U[select, :] with disable_jit(): cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS) )(random.PRNGKey(0), U, log_VS) mu, radii, rotation = ellipsoid_parameters theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == i plt.scatter(U[mask, 0], U[mask, 1], c=jnp.atleast_2d(plt.cm.jet(i / len(ellipsoid_parameters)))) plt.show()
def test_generic_kmeans(): from jaxns.prior_transforms import PriorChain, UniformPrior from jax import vmap, disable_jit, jit import pylab as plt data = 'shells' if data == 'eggbox': def log_likelihood(theta, **kwargs): return (2. + jnp.prod(jnp.cos(0.5 * theta))) ** 5 prior_chain = PriorChain() \ .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2))) U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 1000)) theta = vmap(lambda u: prior_chain(u))(U) lik = vmap(lambda theta: log_likelihood(**theta))(theta) select = lik > 100. if data == 'shells': def log_likelihood(theta, **kwargs): def log_circ(theta, c, r, w): return -0.5*(jnp.linalg.norm(theta - c) - r)**2/w**2 - jnp.log(jnp.sqrt(2*jnp.pi*w**2)) w1=w2=jnp.array(0.1) r1=r2=jnp.array(2.) c1 = jnp.array([0., -4.]) c2 = jnp.array([0., 4.]) return jnp.logaddexp(log_circ(theta, c1,r1,w1) , log_circ(theta,c2,r2,w2)) prior_chain = PriorChain() \ .push(UniformPrior('theta', low=-12.*jnp.ones(2), high=12.*jnp.ones(2))) U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 40000)) theta = vmap(lambda u: prior_chain(u))(U) lik = vmap(lambda theta: log_likelihood(**theta))(theta) select = lik > 1. print("Selecting", jnp.sum(select)) log_VS = jnp.log(jnp.sum(select)/select.size) print("V(S)",jnp.exp(log_VS)) points = U[select, :] sc = plt.scatter(U[:,0], U[:,1],c=jnp.exp(lik)) plt.colorbar(sc) plt.show() mask = jnp.ones(points.shape[0], dtype=jnp.bool_) K = 18 with disable_jit(): # state = generic_kmeans(random.PRNGKey(0), points, mask, method='ellipsoid',K=K,meta=dict(log_VS=log_VS)) # state = generic_kmeans(random.PRNGKey(0), points, mask, method='mahalanobis',K=K) # state = generic_kmeans(random.PRNGKey(0), points, mask, method='euclidean',K=K) # cluster_id, log_cluster_VS = hierarchical_clustering(random.PRNGKey(0), points, 7, log_VS) cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS) )(random.PRNGKey(0), points, log_VS) # mu, radii, rotation = ellipsoid_parameters K = int(jnp.max(cluster_id)+1) mu, C = vmap(lambda k: bounding_ellipsoid(points, cluster_id == k))(jnp.arange(K)) radii, rotation = vmap(ellipsoid_params)(C) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :], c=plt.cm.jet(i / K)) mask = cluster_id == i plt.scatter(points[mask, 0], points[mask, 1], c=jnp.atleast_2d(plt.cm.jet(i / K))) plt.xlim(-1,2) plt.ylim(-1,2) plt.show()