def test_ellipsoid_clustering(): import pylab as plt from jax import disable_jit, jit points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)), 1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))], axis=0) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) mask = jnp.ones(points.shape[0], jnp.bool_) mu, C = bounding_ellipsoid(points, mask) radii, rotation = ellipsoid_params(C) # plt.plot(y[0, :], y[1, :]) log_VS = log_ellipsoid_volume(radii) - jnp.log(5) with disable_jit(): cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS) )(random.PRNGKey(0), points, log_VS) mu, radii, rotation = ellipsoid_parameters print(mu, radii, rotation, jnp.bincount(cluster_id, minlength=0, length=4)) for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == i plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters))) plt.show()
def test_sample_multi_ellipsoid(): import pylab as plt from jax import disable_jit, jit, vmap points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)), 1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))], axis=0) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) mask = jnp.ones(points.shape[0], jnp.bool_) mu, C = bounding_ellipsoid(points, mask) radii, rotation = ellipsoid_params(C) y = mu[:, None] + rotation @ jnp.diag(radii) @ x # plt.plot(y[0, :], y[1, :]) log_VS = log_ellipsoid_volume(radii) - jnp.log(5) with disable_jit(): cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS) )(random.PRNGKey(0), points, log_VS) mu, radii, rotation = ellipsoid_parameters # print(mu, radii, rotation) u = vmap(lambda key: sample_multi_ellipsoid(key, mu, radii, rotation, unit_cube_constraint=True)[1])(random.split(random.PRNGKey(0),1000)) plt.scatter(u[:, 0], u[:, 1], marker='+') for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == i # plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters))) plt.show()
def test_cluster_split(): import pylab as plt from jax import disable_jit points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)), 1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))], axis=0) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) mask = jnp.zeros(points.shape[0], jnp.bool_) mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0], jnp.bool_)) radii, rotation = ellipsoid_params(C) y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) log_VS = log_ellipsoid_volume(radii) - jnp.log(5) with disable_jit(): cluster_id, log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split = \ cluster_split(random.PRNGKey(0), points, mask, log_VS, log_ellipsoid_volume(radii), kmeans_init=True) print(jnp.logaddexp(log_ellipsoid_volume(radii1), log_ellipsoid_volume(radii2)), log_ellipsoid_volume(radii)) print(log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split) print(cluster_id) y = mu1[:, None] + rotation1 @ jnp.diag(radii1) @ x plt.plot(y[0, :], y[1, :]) y = mu2[:, None] + rotation2 @ jnp.diag(radii2) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == 0 plt.scatter(points[mask, 0], points[mask, 1]) mask = cluster_id == 1 plt.scatter(points[mask, 0], points[mask, 1]) plt.show()
def test_bounding_ellipsoid(): points = random.normal(random.PRNGKey(0), shape=(10, 2)) mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0])) radii, rotation = ellipsoid_params(C) print(mu, C, radii, rotation) theta = jnp.linspace(0., jnp.pi * 2, 100) x = mu[:, None] + rotation @ jnp.diag(radii) @ jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) import pylab as plt plt.scatter(points[:, 0], points[:, 1]) plt.plot(x[0, :], x[1, :]) plt.show()
def test_kmeans(): points = jnp.concatenate([random.normal(random.PRNGKey(0), shape=(30, 2)), 3. + random.normal(random.PRNGKey(0), shape=(10, 2))], axis=0) cluster_id, centers = kmeans(random.PRNGKey(0), points, jnp.ones(points.shape[0], dtype=jnp.bool_), K=2) mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0])) radii, rotation = ellipsoid_params(C) theta = jnp.linspace(0., jnp.pi * 2, 100) x = mu[:, None] + rotation @ jnp.diag(radii) @ jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) import pylab as plt mask = cluster_id == 0 plt.scatter(points[mask, 0], points[mask, 1]) mask = cluster_id == 1 plt.scatter(points[mask, 0], points[mask, 1]) plt.plot(x[0, :], x[1, :]) plt.show()
def test_generic_kmeans(): from jaxns.prior_transforms import PriorChain, UniformPrior from jax import vmap, disable_jit, jit import pylab as plt data = 'shells' if data == 'eggbox': def log_likelihood(theta, **kwargs): return (2. + jnp.prod(jnp.cos(0.5 * theta))) ** 5 prior_chain = PriorChain() \ .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2))) U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 1000)) theta = vmap(lambda u: prior_chain(u))(U) lik = vmap(lambda theta: log_likelihood(**theta))(theta) select = lik > 100. if data == 'shells': def log_likelihood(theta, **kwargs): def log_circ(theta, c, r, w): return -0.5*(jnp.linalg.norm(theta - c) - r)**2/w**2 - jnp.log(jnp.sqrt(2*jnp.pi*w**2)) w1=w2=jnp.array(0.1) r1=r2=jnp.array(2.) c1 = jnp.array([0., -4.]) c2 = jnp.array([0., 4.]) return jnp.logaddexp(log_circ(theta, c1,r1,w1) , log_circ(theta,c2,r2,w2)) prior_chain = PriorChain() \ .push(UniformPrior('theta', low=-12.*jnp.ones(2), high=12.*jnp.ones(2))) U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 40000)) theta = vmap(lambda u: prior_chain(u))(U) lik = vmap(lambda theta: log_likelihood(**theta))(theta) select = lik > 1. print("Selecting", jnp.sum(select)) log_VS = jnp.log(jnp.sum(select)/select.size) print("V(S)",jnp.exp(log_VS)) points = U[select, :] sc = plt.scatter(U[:,0], U[:,1],c=jnp.exp(lik)) plt.colorbar(sc) plt.show() mask = jnp.ones(points.shape[0], dtype=jnp.bool_) K = 18 with disable_jit(): # state = generic_kmeans(random.PRNGKey(0), points, mask, method='ellipsoid',K=K,meta=dict(log_VS=log_VS)) # state = generic_kmeans(random.PRNGKey(0), points, mask, method='mahalanobis',K=K) # state = generic_kmeans(random.PRNGKey(0), points, mask, method='euclidean',K=K) # cluster_id, log_cluster_VS = hierarchical_clustering(random.PRNGKey(0), points, 7, log_VS) cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS) )(random.PRNGKey(0), points, log_VS) # mu, radii, rotation = ellipsoid_parameters K = int(jnp.max(cluster_id)+1) mu, C = vmap(lambda k: bounding_ellipsoid(points, cluster_id == k))(jnp.arange(K)) radii, rotation = vmap(ellipsoid_params)(C) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :], c=plt.cm.jet(i / K)) mask = cluster_id == i plt.scatter(points[mask, 0], points[mask, 1], c=jnp.atleast_2d(plt.cm.jet(i / K))) plt.xlim(-1,2) plt.ylim(-1,2) plt.show()