Beispiel #1
0
    def from_network(cls,
                     net: Network,
                     prng_key: PRNGKey,
                     record_stride: int = 1,
                     props: dict = {}) -> "ProcessState":
        s = cls(**props)
        s["n"] = net.n
        s["m"] = net.m
        s["step"] = 0
        s["prng_key"] = prng_key

        s["edge.i"] = jnp.arange(s.m, dtype=jnp.int32)
        s["edge.src"] = net.edges[:, 0]
        s["edge.tgt"] = net.edges[:, 1]
        s["edge.weight"] = jnp.ones(net.m, dtype=jnp.float32)
        s["edge.active"] = jnp.ones(net.m, dtype=jnp.bool_)

        s["node.i"] = jnp.arange(s.n, dtype=jnp.int32)
        s["node.in_deg"] = jnp.bincount(s.edge["tgt"], length=net.n)
        s["node.out_deg"] = jnp.bincount(s.edge["src"], length=net.n)
        s["node.deg"] = s.node["in_deg"] + s.node["out_deg"]
        s["node.weight"] = jnp.ones(net.n, dtype=jnp.float32)
        s["node.active"] = jnp.ones(net.n, dtype=jnp.bool_)

        s._records = ProcessRecords(stride=record_stride)
        s._network = net
        s._assert_shapes()
        return s
Beispiel #2
0
def radial_profile(data):
    """
  Compute the radial profile of 2d image
  :param data: 2d image
  :return: radial profile
  """
    center = data.shape[0] / 2
    y, x = jnp.indices((data.shape))
    r = jnp.sqrt((x - center)**2 + (y - center)**2)
    r = r.astype('int32')

    tbin = jnp.bincount(r.ravel(), data.ravel())
    nr = jnp.bincount(r.ravel())
    radialprofile = tbin / nr
    return radialprofile
Beispiel #3
0
def test_ellipsoid_clustering():
    import pylab as plt
    from jax import disable_jit, jit
    points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)),
                              1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))],
                             axis=0)
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    mask = jnp.ones(points.shape[0], jnp.bool_)
    mu, C = bounding_ellipsoid(points, mask)
    radii, rotation = ellipsoid_params(C)
    # plt.plot(y[0, :], y[1, :])
    log_VS = log_ellipsoid_volume(radii) - jnp.log(5)

    with disable_jit():
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS)
                )(random.PRNGKey(0), points, log_VS)
        mu, radii, rotation = ellipsoid_parameters
        print(mu, radii, rotation, jnp.bincount(cluster_id, minlength=0, length=4))

    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :])
        mask = cluster_id == i
        plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters)))

    plt.show()
Beispiel #4
0
def main(args):
    annotators, annotations = get_data()
    model = NAME_TO_MODEL[args.model]
    data = ((annotations, ) if model in [multinomial, item_difficulty] else
            (annotators, annotations))

    mcmc = MCMC(
        NUTS(model),
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(random.PRNGKey(0), *data)
    mcmc.print_summary()

    posterior_samples = mcmc.get_samples()
    predictive = Predictive(model, posterior_samples, infer_discrete=True)
    discrete_samples = predictive(random.PRNGKey(1), *data)

    item_class = vmap(lambda x: jnp.bincount(x, length=4),
                      in_axes=1)(discrete_samples["c"].squeeze(-1))
    print("Histogram of the predicted class of each item:")
    row_format = "{:>10}" * 5
    print(row_format.format("", *["c={}".format(i) for i in range(4)]))
    for i, row in enumerate(item_class):
        print(row_format.format(f"item[{i}]", *row))
Beispiel #5
0
 def compute_entropy(max_indices):
     max_index_probabilities = jnp.bincount(
         max_indices, minlength=num_actions,
         length=num_actions) / len(max_indices)
     entropy = -jnp.sum((max_index_probabilities + LOG_EPSILON) *
                        jnp.log(max_index_probabilities + LOG_EPSILON))
     return entropy
Beispiel #6
0
def _bincount(
        arr,
        weights=None,
        minlength=None,
        maxlength=None,  # pylint: disable=unused-argument
        dtype=tf.int32,
        name=None):  # pylint: disable=unused-argument
    return np.bincount(arr, weights,
                       minlength).astype(utils.numpy_dtype(dtype))
Beispiel #7
0
def test_categorical(shape=(1000, ), num_classes=5):
    key = jr.PRNGKey(time.time_ns())
    diri = dists.Dirichlet(np.ones(num_classes))
    data = jr.choice(key, num_classes, shape=shape)
    cate = dists.Categorical.fit(
        data,
        prior=diri,
    )
    assert np.allclose(np.bincount(data, minlength=5) / len(data),
                       cate.probs_parameter(),
                       atol=1e-6)
Beispiel #8
0
    def update_metric_state(cluster_id):
        num_k = jnp.bincount(cluster_id, weights, minlength=0, length=K)
        if method == 'euclidean':

            def get_mu(k):
                weights = (cluster_id == k) & mask
                mu = jnp.average(points, weights=weights, axis=0)
                mu = jnp.where(num_k[k] == 0, 0., mu)
                return mu

            mu = vmap(get_mu)(jnp.arange(K))
            return MetricState(cluster_centers=mu,
                               num_k=num_k,
                               C=None,
                               radii=None)
        if method == 'mahalanobis':

            def get_mu_and_C(k):
                weights = (cluster_id == k) & mask
                mu = jnp.average(points, weights=weights, axis=0)
                dist = points - mu
                Cov = jnp.average(dist[:, :, None] * dist[:, None, :],
                                  weights=weights,
                                  axis=0)
                C = jnp.linalg.pinv(Cov)
                mu = jnp.where(num_k[k] == 0, 0., mu)
                C = jnp.where(num_k[k] < D + 1, 0., C)
                return mu, C

            mu, C = vmap(get_mu_and_C)(jnp.arange(K))
            return MetricState(cluster_centers=mu,
                               num_k=num_k,
                               C=C,
                               radii=None)
        if method == 'ellipsoid':

            def get_mu_and_C_radii(k):
                weights = (cluster_id == k) & mask
                mu = jnp.average(points, weights=weights, axis=0)
                dist = points - mu
                Cov = jnp.average(dist[:, :, None] * dist[:, None, :],
                                  weights=weights,
                                  axis=0)
                C = jnp.linalg.pinv(Cov)
                mu = jnp.where(num_k[k] == 0, 0., mu)
                C = jnp.where(num_k[k] < D + 1, 0., C)
                radii, rotation = ellipsoid_params(C)
                return mu, C, radii

            mu, C, radii = vmap(get_mu_and_C_radii)(jnp.arange(K))
            return MetricState(cluster_centers=mu,
                               num_k=num_k,
                               C=C,
                               radii=radii)
Beispiel #9
0
        def body(state):
            (done, i, centers, cluster_id) = state

            # [M, max_K]
            new_centers = vmap(lambda coords:
                               jnp.bincount(cluster_id, weights=coords, minlength=max_K, length=max_K))(points.T)
            # max_K, M
            new_centers = new_centers.T
            # max_K
            num_per_cluster = jnp.bincount(cluster_id, minlength=max_K, length=max_K)
            # max_K, M
            new_centers = jnp.where(num_per_cluster[:, None] == 0.,
                                    jnp.zeros_like(new_centers),
                                    new_centers / num_per_cluster[:, None])
            # N
            new_cluster_id = masked_cluster_id(points, new_centers, K)

            done = jnp.all(new_cluster_id == cluster_id)

            return (done, i + 1, new_centers, new_cluster_id)
Beispiel #10
0
def init_multi_ellipsoid_sampler_state(key, live_points_U, depth, log_X):
    cluster_id, (mu, radii,
                 rotation) = ellipsoid_clustering(key, live_points_U, depth,
                                                  log_X)
    num_k = jnp.bincount(cluster_id, minlength=0, length=mu.shape[0])
    return MultiEllipsoidSamplerState(cluster_id=cluster_id,
                                      mu=mu,
                                      radii=radii,
                                      rotation=rotation,
                                      num_k=num_k,
                                      num_fev_ma=jnp.asarray(1.))
Beispiel #11
0
def init_slice_sampler_state(key, live_points_U, depth, log_X, num_slices):
    cluster_id, (mu, radii,
                 rotation) = ellipsoid_clustering(key, live_points_U, depth,
                                                  log_X)
    num_k = jnp.bincount(cluster_id, minlength=0, length=mu.shape[0])
    return SliceSamplerState(
        cluster_id=cluster_id,
        mu=mu,
        radii=radii,
        rotation=rotation,
        num_k=num_k,
        num_fev_ma=jnp.asarray(num_slices * live_points_U.shape[1] + 2.))
Beispiel #12
0
def _csr_fromdense_impl(mat, *, nnz, index_dtype):
    mat = jnp.asarray(mat)
    assert mat.ndim == 2
    m = mat.shape[0]

    row, col = jnp.nonzero(mat, size=nnz)
    data = mat[row, col]

    true_nonzeros = jnp.arange(nnz) < (mat != 0).sum()
    data = jnp.where(true_nonzeros, data, 0)
    row = jnp.where(true_nonzeros, row, m)
    indices = col.astype(index_dtype)
    indptr = jnp.zeros(m + 1, dtype=index_dtype).at[1:].set(
        jnp.cumsum(jnp.bincount(row, length=m)))
    return data, indices, indptr
Beispiel #13
0
    def _p_x_min(self, y_fantasized):
        """Estimate a probablity mass function over the x-location of global min.

    Args:
      y_fantasized: (n, m) shaped array of m fantasized y values over a common
      set of n x-locations.

    Returns:
      Estimated (n,) shaped array of pmf of the global min over x-location where
      the domain of the pmf is the common set of previous x-locations.

    """
        counts = jnp.bincount(jnp.argmin(y_fantasized, axis=0),
                              length=y_fantasized.shape[0])
        return counts / jnp.sum(counts)
Beispiel #14
0
def kmeans_step(
    val: Tuple[Array, Array, float, Optional[float], int],
    n_splits: int,
    parallel_computation: bool = False,
) -> Tuple[Array, Array, float, float, int]:
  """Perform single K-means step.

  Standard K-means step. Assigns observations to nearest cluster, then updates
  cluster centroids as mean of assigned observations. Inputs are packed into
  'val' to facilitate while loop condition for K-means.

  Args:
    val: tuple of [n_clusters, dim] cluster centroids. [n_observations, dim]
      observations. prev_dist, mean distance between observations and closest
      cluster centroid in previous K-means step. prev_2_dist, distance from two
      steps prior. step, idx of current step.
    n_splits: number of splits for compute assignments
    parallel_computation: if true, assumes is run inside pmap with
      'observations' axis.

  Returns:
    new_centroids: [n_clusters, dim] new cluster centroids.
    observations: [n_observations, dim].
    mean_dist: mean distance between observations and closest cluster centroid.
    prev_dist: mean distance for previous K-means step.
    step: idx of next step.
  """
  centroids, observations, prev_dist, _, step = val
  assignments, min_dist = compute_assignments(centroids, observations, n_splits)

  mean_dist = jnp.mean(min_dist)

  # Compute new cluster centroids as average of observations in cluster
  cluster_sums = jnp.zeros(centroids.shape).at[assignments].add(observations)
  counts = jnp.bincount(assignments, length=centroids.shape[0])

  if parallel_computation:
    cluster_sums = jax.lax.psum(cluster_sums, axis_name='observations')
    counts = jax.lax.psum(counts, axis_name='observations')
    mean_dist = jax.lax.pmean(mean_dist, axis_name='observations')

  new_centroids = cluster_sums / counts[:, None].clip(a_min=1.)

  hcb.id_print(step, what='step', tap_with_device=True)
  hcb.id_print(mean_dist - prev_dist, what='dist_diff', tap_with_device=True)
  step += 1

  return new_centroids, observations, mean_dist, prev_dist, step
Beispiel #15
0
    def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'):
        if k > 0:
            diag_size = min(N, M - k)
        else:
            diag_size = min(N + k, M)

        if diag_size <= 0:
            # if k is out of range, return an empty matrix.
            return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype)

        k = jnp.asarray(k)
        data = jnp.ones(diag_size, dtype=dtype)
        idx = jnp.arange(diag_size, dtype=index_dtype)
        zero = _const(idx, 0)
        k = _const(idx, k)
        col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k))
        indices = col.astype(index_dtype)
        # TODO(jakevdp): this can be done more efficiently.
        row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k))
        indptr = jnp.zeros(N + 1, dtype=index_dtype).at[1:].set(
            jnp.cumsum(jnp.bincount(row, length=N).astype(index_dtype)))
        return cls((data, indices, indptr), shape=(N, M))
Beispiel #16
0
 def normalize_f(self, f):
     sums = np.bincount(self.filter_ind, weights=np.square(f))
     sums = np.repeat(sums, self.nPix)
     f = f / np.sqrt(sums)
     return f
Beispiel #17
0
def bincount(x, weights=None, minlength=None):
  if isinstance(x, JaxArray): x = x.value
  if isinstance(weights, JaxArray): weights = weights.value
  return JaxArray(jnp.bincount(x, weights=weights, minlength=minlength))
Beispiel #18
0
 def vec_mag_one_fun_core(self, x):
     return np.bincount(self.filter_ind, weights=np.square(x)) - 1
Beispiel #19
0
 def compute_log_policy(max_indices):
     max_index_probabilities = jnp.bincount(
         max_indices, minlength=num_actions,
         length=num_actions) / len(max_indices)
     log_policy = jnp.log(max_index_probabilities + LOG_EPSILON)
     return log_policy
Beispiel #20
0
def mode(arr, axis=0, max_value=250):
    return jnp.apply_along_axis(
        lambda x: jnp.bincount(x, length=max_value).argmax(),
        axis=axis,
        arr=arr.astype(jnp.int32))
Beispiel #21
0
def from_dense(dense: jnp.ndarray):
    data, row, col = coo.from_dense(dense)
    row_lengths = jnp.bincount(row, length=dense.shape[0])
    indptr = jnp.concatenate((jnp.zeros(
        (1, ), dtype=row_lengths.dtype), jnp.cumsum(row_lengths)))
    return data, indptr, col
Beispiel #22
0
def _coo_to_csr(row, nrows):
    indptr = jnp.zeros(nrows + 1, row.dtype)
    return indptr.at[1:].set(jnp.cumsum(jnp.bincount(row, length=nrows)))
Beispiel #23
0
SIZE = len(data)
print("Data Size", SIZE)
# Create tables
key = jax.random.PRNGKey(0)
data = np.array(data, dtype=np.int32)
docs = np.array(doc, dtype=np.int32)
topic_token = jax.random.randint(key, (SIZE, ), 0, TOPICS, dtype=np.int32)

topic_word = jax.ops.index_add(
    np.zeros((TOPICS, VOCAB), dtype=np.int32),
    jax.ops.index[topic_token.reshape(-1),
                  data.reshape(-1)], 1)
topic_document = jax.ops.index_add(
    np.zeros((DOCUMENTS, TOPICS), dtype=np.int32),
    jax.ops.index[docs, topic_token], 1)
tokens_doc = np.bincount(docs, length=DOCUMENTS)

# Main code
ALPHA, BETA = 0.1, 0.01


def token_loop(state, scanned):
    topic_word, topic_document, topic_count = state
    topic_token, data, doc, key = scanned

    local_tw = topic_word[:, data].at[topic_token].add(-1)
    local_td = topic_document[doc].at[topic_token].add(-1)
    local_tc = topic_count.at[topic_token].add(-1)

    # Resample
    dist = ((local_tw + BETA) / (local_tc + VOCAB * BETA)) \
Beispiel #24
0
def _csr_fromdense_impl(mat, *, nnz, index_dtype):
    m = mat.shape[0]
    data, row, col = _coo_fromdense_impl(mat, nnz=nnz, index_dtype=index_dtype)
    indptr = jnp.zeros(m + 1, dtype=index_dtype).at[1:].set(
        jnp.cumsum(jnp.bincount(row, length=m)))
    return data, col, indptr
Beispiel #25
0
def _nonzero_indices(x, N):
    """Find min(N, x.size) indices of nonzero elements of x."""
    return jnp.cumsum(jnp.bincount(jnp.cumsum(x != 0), length=min(N, x.size)))