コード例 #1
0
def test_ellipsoid_clustering():
    import pylab as plt
    from jax import disable_jit, jit
    points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)),
                              1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))],
                             axis=0)
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    mask = jnp.ones(points.shape[0], jnp.bool_)
    mu, C = bounding_ellipsoid(points, mask)
    radii, rotation = ellipsoid_params(C)
    # plt.plot(y[0, :], y[1, :])
    log_VS = log_ellipsoid_volume(radii) - jnp.log(5)

    with disable_jit():
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS)
                )(random.PRNGKey(0), points, log_VS)
        mu, radii, rotation = ellipsoid_parameters
        print(mu, radii, rotation, jnp.bincount(cluster_id, minlength=0, length=4))

    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :])
        mask = cluster_id == i
        plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters)))

    plt.show()
コード例 #2
0
def test_sample_multi_ellipsoid():
    import pylab as plt
    from jax import disable_jit, jit, vmap
    points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)),
                              1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))],
                             axis=0)
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    mask = jnp.ones(points.shape[0], jnp.bool_)
    mu, C = bounding_ellipsoid(points, mask)
    radii, rotation = ellipsoid_params(C)
    y = mu[:, None] + rotation @ jnp.diag(radii) @ x
    # plt.plot(y[0, :], y[1, :])
    log_VS = log_ellipsoid_volume(radii) - jnp.log(5)

    with disable_jit():
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS)
                )(random.PRNGKey(0), points, log_VS)

        mu, radii, rotation = ellipsoid_parameters
        # print(mu, radii, rotation)
        u = vmap(lambda key: sample_multi_ellipsoid(key, mu, radii, rotation, unit_cube_constraint=True)[1])(random.split(random.PRNGKey(0),1000))
    plt.scatter(u[:, 0], u[:, 1], marker='+')
    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :])
        mask = cluster_id == i
        # plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters)))
    plt.show()
コード例 #3
0
def init_multi_ellipsoid_sampler_state(key, live_points_U, depth, log_X):
    cluster_id, (mu, radii,
                 rotation) = ellipsoid_clustering(key, live_points_U, depth,
                                                  log_X)
    num_k = jnp.bincount(cluster_id, minlength=0, length=mu.shape[0])
    return MultiEllipsoidSamplerState(cluster_id=cluster_id,
                                      mu=mu,
                                      radii=radii,
                                      rotation=rotation,
                                      num_k=num_k,
                                      num_fev_ma=jnp.asarray(1.))
コード例 #4
0
ファイル: slice.py プロジェクト: fehiepsi/jaxns
def init_slice_sampler_state(key, live_points_U, depth, log_X, num_slices):
    cluster_id, (mu, radii,
                 rotation) = ellipsoid_clustering(key, live_points_U, depth,
                                                  log_X)
    num_k = jnp.bincount(cluster_id, minlength=0, length=mu.shape[0])
    return SliceSamplerState(
        cluster_id=cluster_id,
        mu=mu,
        radii=radii,
        rotation=rotation,
        num_k=num_k,
        num_fev_ma=jnp.asarray(num_slices * live_points_U.shape[1] + 2.))
コード例 #5
0
def main():
    def log_likelihood(theta, **kwargs):
        return (2. + jnp.prod(jnp.cos(0.5 * theta)))**5

    prior_chain = PriorChain() \
        .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2)))

    U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims, )))(
        random.split(random.PRNGKey(0), 700))
    theta = vmap(lambda u: prior_chain(u))(U)
    lik = vmap(lambda theta: log_likelihood(**theta))(theta)

    select = lik > 150.
    print("Selecting", jnp.sum(select), "need", 18 * 3)
    log_VS = jnp.log(jnp.sum(select) / select.size)
    print("V(S)", jnp.exp(log_VS))

    U = U[select, :]

    with disable_jit():
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS)
                )(random.PRNGKey(0), U, log_VS)
        mu, radii, rotation = ellipsoid_parameters

    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)

    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :])
        mask = cluster_id == i
        plt.scatter(U[mask, 0],
                    U[mask, 1],
                    c=jnp.atleast_2d(plt.cm.jet(i /
                                                len(ellipsoid_parameters))))

    plt.show()
コード例 #6
0
def test_generic_kmeans():
    from jaxns.prior_transforms import PriorChain, UniformPrior
    from jax import vmap, disable_jit, jit
    import pylab as plt

    data = 'shells'
    if data == 'eggbox':
        def log_likelihood(theta, **kwargs):
            return (2. + jnp.prod(jnp.cos(0.5 * theta))) ** 5

        prior_chain = PriorChain() \
            .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2)))

        U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 1000))
        theta = vmap(lambda u: prior_chain(u))(U)
        lik = vmap(lambda theta: log_likelihood(**theta))(theta)
        select = lik > 100.

    if data == 'shells':

        def log_likelihood(theta, **kwargs):
            def log_circ(theta, c, r, w):
                return -0.5*(jnp.linalg.norm(theta - c) - r)**2/w**2 - jnp.log(jnp.sqrt(2*jnp.pi*w**2))
            w1=w2=jnp.array(0.1)
            r1=r2=jnp.array(2.)
            c1 = jnp.array([0., -4.])
            c2 = jnp.array([0., 4.])
            return jnp.logaddexp(log_circ(theta, c1,r1,w1) , log_circ(theta,c2,r2,w2))


        prior_chain = PriorChain() \
            .push(UniformPrior('theta', low=-12.*jnp.ones(2), high=12.*jnp.ones(2)))

        U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 40000))
        theta = vmap(lambda u: prior_chain(u))(U)
        lik = vmap(lambda theta: log_likelihood(**theta))(theta)
        select = lik > 1.

    print("Selecting", jnp.sum(select))
    log_VS = jnp.log(jnp.sum(select)/select.size)
    print("V(S)",jnp.exp(log_VS))

    points = U[select, :]
    sc = plt.scatter(U[:,0], U[:,1],c=jnp.exp(lik))
    plt.colorbar(sc)
    plt.show()
    mask = jnp.ones(points.shape[0], dtype=jnp.bool_)
    K = 18
    with disable_jit():
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='ellipsoid',K=K,meta=dict(log_VS=log_VS))
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='mahalanobis',K=K)
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='euclidean',K=K)
        # cluster_id, log_cluster_VS = hierarchical_clustering(random.PRNGKey(0), points, 7, log_VS)
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS)
                )(random.PRNGKey(0), points, log_VS)
        # mu, radii, rotation = ellipsoid_parameters
        K = int(jnp.max(cluster_id)+1)

    mu, C = vmap(lambda k: bounding_ellipsoid(points, cluster_id == k))(jnp.arange(K))
    radii, rotation = vmap(ellipsoid_params)(C)

    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)

    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :], c=plt.cm.jet(i / K))
        mask = cluster_id == i
        plt.scatter(points[mask, 0], points[mask, 1], c=jnp.atleast_2d(plt.cm.jet(i / K)))
    plt.xlim(-1,2)
    plt.ylim(-1,2)
    plt.show()