def test_ellipsoid_clustering():
    import pylab as plt
    from jax import disable_jit, jit
    points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)),
                              1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))],
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    mask = jnp.ones(points.shape[0], jnp.bool_)
    mu, C = bounding_ellipsoid(points, mask)
    radii, rotation = ellipsoid_params(C)
    # plt.plot(y[0, :], y[1, :])
    log_VS = log_ellipsoid_volume(radii) - jnp.log(5)

    with disable_jit():
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS)
                )(random.PRNGKey(0), points, log_VS)
        mu, radii, rotation = ellipsoid_parameters
        print(mu, radii, rotation, jnp.bincount(cluster_id, minlength=0, length=4))

    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :])
        mask = cluster_id == i
        plt.scatter(points[mask, 0], points[mask, 1], / len(ellipsoid_parameters)))
def test_sample_multi_ellipsoid():
    import pylab as plt
    from jax import disable_jit, jit, vmap
    points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)),
                              1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))],
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    mask = jnp.ones(points.shape[0], jnp.bool_)
    mu, C = bounding_ellipsoid(points, mask)
    radii, rotation = ellipsoid_params(C)
    y = mu[:, None] + rotation @ jnp.diag(radii) @ x
    # plt.plot(y[0, :], y[1, :])
    log_VS = log_ellipsoid_volume(radii) - jnp.log(5)

    with disable_jit():
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS)
                )(random.PRNGKey(0), points, log_VS)

        mu, radii, rotation = ellipsoid_parameters
        # print(mu, radii, rotation)
        u = vmap(lambda key: sample_multi_ellipsoid(key, mu, radii, rotation, unit_cube_constraint=True)[1])(random.split(random.PRNGKey(0),1000))
    plt.scatter(u[:, 0], u[:, 1], marker='+')
    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :])
        mask = cluster_id == i
        # plt.scatter(points[mask, 0], points[mask, 1], / len(ellipsoid_parameters)))
def test_cluster_split():
    import pylab as plt
    from jax import disable_jit
    points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)),
                              1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))],
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    mask = jnp.zeros(points.shape[0], jnp.bool_)
    mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0], jnp.bool_))
    radii, rotation = ellipsoid_params(C)
    y = mu[:, None] + rotation @ jnp.diag(radii) @ x
    plt.plot(y[0, :], y[1, :])
    log_VS = log_ellipsoid_volume(radii) - jnp.log(5)
    with disable_jit():
        cluster_id, log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split = \
            cluster_split(random.PRNGKey(0), points, mask, log_VS, log_ellipsoid_volume(radii), kmeans_init=True)
        print(jnp.logaddexp(log_ellipsoid_volume(radii1), log_ellipsoid_volume(radii2)), log_ellipsoid_volume(radii))
        print(log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split)

    y = mu1[:, None] + rotation1 @ jnp.diag(radii1) @ x
    plt.plot(y[0, :], y[1, :])

    y = mu2[:, None] + rotation2 @ jnp.diag(radii2) @ x
    plt.plot(y[0, :], y[1, :])

    mask = cluster_id == 0
    plt.scatter(points[mask, 0], points[mask, 1])
    mask = cluster_id == 1
    plt.scatter(points[mask, 0], points[mask, 1])
def test_bounding_ellipsoid():
    points = random.normal(random.PRNGKey(0), shape=(10, 2))
    mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0]))
    radii, rotation = ellipsoid_params(C)
    print(mu, C, radii, rotation)
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = mu[:, None] + rotation @ jnp.diag(radii) @ jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    import pylab as plt
    plt.scatter(points[:, 0], points[:, 1])
    plt.plot(x[0, :], x[1, :])
def test_kmeans():
    points = jnp.concatenate([random.normal(random.PRNGKey(0), shape=(30, 2)),
                              3. + random.normal(random.PRNGKey(0), shape=(10, 2))],

    cluster_id, centers = kmeans(random.PRNGKey(0), points, jnp.ones(points.shape[0], dtype=jnp.bool_), K=2)

    mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0]))
    radii, rotation = ellipsoid_params(C)
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = mu[:, None] + rotation @ jnp.diag(radii) @ jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    import pylab as plt
    mask = cluster_id == 0
    plt.scatter(points[mask, 0], points[mask, 1])
    mask = cluster_id == 1
    plt.scatter(points[mask, 0], points[mask, 1])
    plt.plot(x[0, :], x[1, :])
def test_generic_kmeans():
    from jaxns.prior_transforms import PriorChain, UniformPrior
    from jax import vmap, disable_jit, jit
    import pylab as plt

    data = 'shells'
    if data == 'eggbox':
        def log_likelihood(theta, **kwargs):
            return (2. + * theta))) ** 5

        prior_chain = PriorChain() \
            .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2)))

        U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 1000))
        theta = vmap(lambda u: prior_chain(u))(U)
        lik = vmap(lambda theta: log_likelihood(**theta))(theta)
        select = lik > 100.

    if data == 'shells':

        def log_likelihood(theta, **kwargs):
            def log_circ(theta, c, r, w):
                return -0.5*(jnp.linalg.norm(theta - c) - r)**2/w**2 - jnp.log(jnp.sqrt(2*jnp.pi*w**2))
            c1 = jnp.array([0., -4.])
            c2 = jnp.array([0., 4.])
            return jnp.logaddexp(log_circ(theta, c1,r1,w1) , log_circ(theta,c2,r2,w2))

        prior_chain = PriorChain() \
            .push(UniformPrior('theta', low=-12.*jnp.ones(2), high=12.*jnp.ones(2)))

        U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 40000))
        theta = vmap(lambda u: prior_chain(u))(U)
        lik = vmap(lambda theta: log_likelihood(**theta))(theta)
        select = lik > 1.

    print("Selecting", jnp.sum(select))
    log_VS = jnp.log(jnp.sum(select)/select.size)

    points = U[select, :]
    sc = plt.scatter(U[:,0], U[:,1],c=jnp.exp(lik))
    mask = jnp.ones(points.shape[0], dtype=jnp.bool_)
    K = 18
    with disable_jit():
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='ellipsoid',K=K,meta=dict(log_VS=log_VS))
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='mahalanobis',K=K)
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='euclidean',K=K)
        # cluster_id, log_cluster_VS = hierarchical_clustering(random.PRNGKey(0), points, 7, log_VS)
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS)
                )(random.PRNGKey(0), points, log_VS)
        # mu, radii, rotation = ellipsoid_parameters
        K = int(jnp.max(cluster_id)+1)

    mu, C = vmap(lambda k: bounding_ellipsoid(points, cluster_id == k))(jnp.arange(K))
    radii, rotation = vmap(ellipsoid_params)(C)

    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)

    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :], / K))
        mask = cluster_id == i
        plt.scatter(points[mask, 0], points[mask, 1], c=jnp.atleast_2d( / K)))