Esempio n. 1
0
def frank_wolfe_unsigned_projection(element: FiniteVec,
                                    solution: FiniteVec = None,
                                    num_samples: np.int32 = 100):
    assert (len(element) == 1)
    #key = PRNGKey(np.int32(time()))
    if solution is None:
        solution = FiniteVec(element.k,
                             element.inspace_points[:1, :],
                             np.zeros(1),
                             points_per_split=1)
    for k in range(num_samples):

        def cost(x):
            x = x.reshape((1, -1))
            return (solution(x) - element(x)).sum()

        g_cost = grad(cost)
        idx = randint(0, element.points_per_split - 1)
        cand = element.inspace_points[idx:idx + 1, :]
        #print(cand)
        #print(cost(cand), grad(cost)(cand))
        res = minimize(__casted_output(cost),
                       cand,
                       jac=__casted_output(g_cost))
        solution.inspace_points = np.vstack(
            [solution.inspace_points, res["x"]])
        gamma_k = 1. / (k + 1)
        solution.prefactors = np.hstack([(1 - gamma_k) * solution.prefactors,
                                         gamma_k])
        solution.points_per_split = solution.points_per_split + 1
    return solution
Esempio n. 2
0
def kmeans(key, points, mask, K=2):
    """
    Perform kmeans clustering with Euclidean metric.

    Args:
        key:
        points: [N, D]
        mask: [N] bool
        K: int

    Returns: cluster_id [N], centers [K, D]

    """
    N, D = points.shape

    def body(state):
        (i, done, old_cluster_id, centers) = state
        new_centers = vmap(lambda k: jnp.average(
            points, weights=(old_cluster_id == k) & mask, axis=0))(
                jnp.arange(K))
        dx = points - new_centers[:, None, :]  # K, N, D
        squared_norm = jnp.sum(jnp.square(dx), axis=-1)  # K, N
        new_cluster_id = jnp.argmin(squared_norm, axis=0)  # N
        done = jnp.all(new_cluster_id == old_cluster_id)
        # print("kmeans reassigns", jnp.sum(old_cluster_id!=new_cluster_id))
        return i + 1, done, new_cluster_id, new_centers

    do_kmeans = jnp.sum(mask) > K
    i, _, cluster_id, centers = while_loop(
        lambda state: ~state[1], body,
        (jnp.array(0), ~do_kmeans,
         random.randint(key, shape=(N, ), minval=0, maxval=2), jnp.zeros(
             (K, D))))
    return cluster_id, centers
Esempio n. 3
0
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - np.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
    elif isinstance(constraint, constraints._Real):
        return lax.full(size, np.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones((size[-1],)), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=np.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1)) + 1e-2
    elif isinstance(constraint, constraints._LowerCholesky):
        return random.uniform(key, size)
    elif isinstance(constraint, constraints._PositiveDefinite):
        return random.normal(key, size)
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
Esempio n. 4
0
def test_taylor_proxy_norm(subsample_size):
    data_key, tr_key, rng_key = random.split(random.PRNGKey(0), 3)
    ref_params = jnp.array([0.1, 0.5, -0.2])
    sigma = .1

    data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(
        data_key, (100, ))
    n, _ = data.shape

    def model(data, subsample_size):
        mean = numpyro.sample(
            'mean', dist.Normal(ref_params, jnp.ones_like(ref_params)))
        with numpyro.plate('data',
                           data.shape[0],
                           subsample_size=subsample_size,
                           dim=-2) as idx:
            numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx])

    def log_prob_fn(params):
        return vmap(dist.Normal(params, sigma).log_prob)(data).sum(-1)

    log_prob = log_prob_fn(ref_params)
    log_norm_jac = jacrev(log_prob_fn)(ref_params)
    log_norm_hessian = hessian(log_prob_fn)(ref_params)

    tr = numpyro.handlers.trace(numpyro.handlers.seed(model,
                                                      tr_key)).get_trace(
                                                          data, subsample_size)
    plate_sizes = {'data': (n, subsample_size)}

    proxy_constructor = HMCECS.taylor_proxy({'mean': ref_params})
    proxy_fn, gibbs_init, gibbs_update = proxy_constructor(
        tr, plate_sizes, model, (data, subsample_size), {})

    def taylor_expand_2nd_order(idx, pos):
        return log_prob[idx] + (
            log_norm_jac[idx] @ pos) + .5 * (pos @ log_norm_hessian[idx]) @ pos

    def taylor_expand_2nd_order_sum(pos):
        return log_prob.sum() + log_norm_jac.sum(
            0) @ pos + .5 * pos @ log_norm_hessian.sum(0) @ pos

    for _ in range(5):
        split_key, perturbe_key, rng_key = random.split(rng_key, 3)
        perturbe_params = ref_params + dist.Normal(.1, 0.1).sample(
            perturbe_key, ref_params.shape)
        subsample_idx = random.randint(rng_key, (subsample_size, ), 0, n)
        gibbs_site = {'data': subsample_idx}
        proxy_state = gibbs_init(None, gibbs_site)
        actual_proxy_sum, actual_proxy_sub = proxy_fn(
            {'data': perturbe_params}, ['data'], proxy_state)
        assert_allclose(actual_proxy_sub['data'],
                        taylor_expand_2nd_order(subsample_idx,
                                                perturbe_params - ref_params),
                        rtol=1e-5)
        assert_allclose(actual_proxy_sum['data'],
                        taylor_expand_2nd_order_sum(perturbe_params -
                                                    ref_params),
                        rtol=1e-5)
Esempio n. 5
0
 def run_update(batch_idx, opt_state):
     kl_warmup = kl_warmup_fun(batch_idx)
     didxs = random.randint(next(dkeyg), [lfads_hps['batch_size']], 0,
                            train_data.shape[0])
     x_bxt = train_data[didxs].astype(np.float32)
     opt_state = update_fun(batch_idx, opt_state, lfads_hps, lfads_opt_hps,
                            next(fkeyg), x_bxt, kl_warmup)
     return opt_state
Esempio n. 6
0
 def sample_batch(self, batch_size):
     """Sample past experience."""
     self.rng, rng_input = random.split(self.rng)
     indexes = random.randint(rng_input,
                              shape=(batch_size, ),
                              minval=0,
                              maxval=self.size)
     return self.data_points[indexes]
Esempio n. 7
0
def random_integer(shape, dtype, minval, maxval, seed):
  """Generates a sample from uniform distribution over [minval, maxval)."""
  return random.randint(
      shape=tuple(shape),
      dtype=dtype,
      minval=minval,
      maxval=maxval,
      key=make_tensor_seed(seed))
Esempio n. 8
0
 def get_action(self, x_t):
     """ return action """
     self.T += 1
     eta_t = 1 - 2 * random.randint(
         generate_key(), minval=0, maxval=2, shape=(self.m, ))
     self.eta.append(eta_t)
     self.x_history.append(np.squeeze(x_t, axis=1))
     return -self.K @ x_t + np.expand_dims(eta_t, axis=1)
Esempio n. 9
0
def train_op(net: Module, opt: Optimizer, x, y, key,
             hyperparams: ServerHyperParams):
    index = random.randint(key,
                           shape=(hyperparams.oracle_batch_size, ),
                           minval=0,
                           maxval=x.shape[0])
    v, g = vg(net, opt.target, x[index], y[index])
    return v, opt.apply_gradient(g)
Esempio n. 10
0
def consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None):
    """
    Merges subposteriors following consensus Monte Carlo algorithm.

    **References:**

    1. *Bayes and big data: The consensus Monte Carlo algorithm*,
       Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman,
       Edward I. George, Robert E. McCulloch

    :param list subposteriors: a list in which each element is a collection of samples.
    :param int num_draws: number of draws from the merged posterior.
    :param bool diagonal: whether to compute weights using variance or covariance, defaults to
        `False` (using covariance).
    :param jax.random.PRNGKey rng_key: source of the randomness, defaults to `jax.random.PRNGKey(0)`.
    :return: if `num_draws` is None, merges subposteriors without resampling; otherwise, returns
        a collection of `num_draws` samples with the same data structure as each subposterior.
    """
    # stack subposteriors
    joined_subposteriors = tree_multimap(lambda *args: jnp.stack(args), *subposteriors)
    # shape of joined_subposteriors: n_subs x n_samples x sample_shape
    joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(
        joined_subposteriors
    )

    if num_draws is not None:
        rng_key = random.PRNGKey(0) if rng_key is None else rng_key
        # randomly gets num_draws from subposteriors
        n_subs = len(subposteriors)
        n_samples = tree_flatten(subposteriors[0])[0][0].shape[0]
        # shape of draw_idxs: n_subs x num_draws x sample_shape
        draw_idxs = random.randint(
            rng_key, shape=(n_subs, num_draws), minval=0, maxval=n_samples
        )
        joined_subposteriors = vmap(lambda x, idx: x[idx])(
            joined_subposteriors, draw_idxs
        )

    if diagonal:
        # compute weights for each subposterior (ref: Section 3.1 of [1])
        weights = vmap(lambda x: 1 / jnp.var(x, ddof=1, axis=0))(joined_subposteriors)
        normalized_weights = weights / jnp.sum(weights, axis=0)
        # get weighted samples
        samples_flat = jnp.einsum(
            "ij,ikj->kj", normalized_weights, joined_subposteriors
        )
    else:
        weights = vmap(lambda x: jnp.linalg.inv(jnp.cov(x.T)))(joined_subposteriors)
        normalized_weights = jnp.matmul(
            jnp.linalg.inv(jnp.sum(weights, axis=0)), weights
        )
        samples_flat = jnp.einsum(
            "ijk,ilk->lj", normalized_weights, joined_subposteriors
        )

    # unravel_fn acts on 1 sample of a subposterior
    _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0]))
    return vmap(lambda x: unravel_fn(x))(samples_flat)
Esempio n. 11
0
def flip_spin(grid: np.DeviceArray, n_x: int, n_y: int,
              x_subkey: np.DeviceArray,
              y_subkey: np.DeviceArray) -> np.DeviceArray:
    """Flip the spin of a single element on the grid

    :param grid: grid with spins
    :param n_x: grid's first dimension
    :param n_y: grids's second dimension
    :param xflip_subkey: subkey for random x coordinate in flip
    :param yflip_subkey: subkey for random y coordinate in flip
    :return: grid with one flipped spin
    """

    x = random.randint(x_subkey, (1, ), 0, n_x)
    y = random.randint(y_subkey, (1, ), 0, n_y)
    mask = index_update(np.ones_like(grid), index[x, y], -1)
    flipped_grid = grid * mask
    return flipped_grid
Esempio n. 12
0
 def testRandomBroadcast(self):
   """Issue 4033"""
   # test for broadcast issue in https://github.com/google/jax/issues/4033
   key = random.PRNGKey(0)
   shape = (10, 2)
   x = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
   assert x.shape == shape
   x = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2]))
   assert x.shape == shape
Esempio n. 13
0
    def initializeMeans(self, X):
        n = X.shape[0]
        means = X[:self.n_clusters]
        idx = random.randint(self.key, (self.n_clusters, ), 0, n)

        #for i in range(idx.shape[0]):
        #    means[i] = X[i]

        return means
Esempio n. 14
0
def distort_image_with_randaugment(image,
                                   num_layers,
                                   magnitude,
                                   random_key,
                                   cutout_const=40,
                                   translate_const=50.0,
                                   default_replace_value=None,
                                   available_ops=DEFAULT_OPS,
                                   op_probs=DEFAULT_PROBS,
                                   join_transforms=False):
    """Applies the RandAugment policy to `image`.

    RandAugment is from the paper https://arxiv.org/abs/1909.13719,

    Args:
        image: `Tensor` of shape [height, width, 3] representing an image.
        num_layers: Integer, the number of augmentation transformations to apply
          sequentially to an image. Represented as (N) in the paper. Usually best
          values will be in the range [1, 3].
        magnitude: Integer, shared magnitude across all augmentation operations.
          Represented as (M) in the paper. Usually best values are in the range
          [5, 30].
        random_key: random key to do random stuff
        join_transforms: reduce multiple transforms to one. Much more efficient but simpler.
        cutout_const: max cutout size int
        translate_const: maximum translation amount int
        default_replace_value: default replacement value for pixels outside of the image
        available_ops: available operations
        op_probs: probabilities of operations
        join_transforms: apply transformations immediately or join them

    Returns:
        The augmented version of `image`.
    """

    geometric_transforms = jnp.identity(4)

    for_i_args = (image, geometric_transforms, random_key, available_ops, op_probs,
                  magnitude, cutout_const, translate_const, join_transforms, default_replace_value)

    if DEBUG:  # un-jitted
        for i in range(num_layers):
            for_i_args = _randaugment_inner_for_loop(i, for_i_args)
    else:  # jitted
        for_i_args = jax.lax.fori_loop(0, num_layers, _randaugment_inner_for_loop, for_i_args)

    image, geometric_transforms = for_i_args[0], for_i_args[1]

    if join_transforms:
        replace_value = default_replace_value or random.randint(random_key,
                                                                [image.shape[-1]],
                                                                minval=0,
                                                                maxval=256)
        image = transforms.apply_transform(image, geometric_transforms, mask_value=replace_value)

    return image
Esempio n. 15
0
    def sample(self, rng, batch_size):
        ind = random.randint(rng, (batch_size, ), 0, self.size)

        return BufferOutput(
            obs=jax.device_put(self.obs[ind]),
            action=jax.device_put(self.action[ind]),
            next_obs=jax.device_put(self.next_obs[ind]),
            reward=jax.device_put(self.reward[ind]),
            not_done=jax.device_put(self.not_done[ind]),
        )
Esempio n. 16
0
 def body_fn(val, idx):
     i_p1 = size - idx
     i = i_p1 - 1
     j = random.randint(rng_keys[idx], (), 0, i_p1)
     val = ops.index_update(
         val,
         ops.index[[i, j], ],
         val[ops.index[[j, i], ]],
     )
     return val, None
Esempio n. 17
0
 def test_sliced_distances(self):
     num_points = 100
     dim = 50
     key = random.PRNGKey(42)
     use_keys = random.split(key, num=3)
     indices1 = random.randint(use_keys[0],
                               shape=(int(1.5 * num_points), ),
                               minval=0,
                               maxval=num_points)
     indices2 = random.randint(use_keys[1],
                               shape=(int(1.5 * num_points), ),
                               minval=0,
                               maxval=num_points)
     inputs = random.normal(use_keys[2], shape=(num_points, dim))
     distance_fn = trimap.euclidean_dist
     dist_sliced = trimap.sliced_distances(indices1, indices2, inputs,
                                           distance_fn)
     dist_direct = distance_fn(inputs[indices1], inputs[indices2])
     npt.assert_equal(np.array(dist_sliced), np.array(dist_direct))
Esempio n. 18
0
    def sample(self, rng: PRNGSequence, batch_size: int):
        ind = random.randint(rng, (batch_size, ), 0, self.size)

        return (
            jax.device_put(self.state[ind]),
            jax.device_put(self.action[ind]),
            jax.device_put(self.next_state[ind]),
            jax.device_put(self.reward[ind]),
            jax.device_put(self.not_done[ind]),
        )
Esempio n. 19
0
def _update_block(rng_key, num_blocks, subsample_idx, plate_size):
    size, subsample_size = plate_size
    rng_key, subkey, block_key = random.split(rng_key, 3)
    block_size = (subsample_size - 1) // num_blocks + 1
    pad = block_size - (subsample_size - 1) % block_size - 1

    chosen_block = random.randint(block_key,
                                  shape=(),
                                  minval=0,
                                  maxval=num_blocks)
    new_idx = random.randint(subkey,
                             minval=0,
                             maxval=size,
                             shape=(block_size, ))
    subsample_idx_padded = jnp.pad(subsample_idx, (0, pad))
    start = chosen_block * block_size
    subsample_idx_padded = lax.dynamic_update_slice_in_dim(
        subsample_idx_padded, new_idx, start, 0)
    return rng_key, subsample_idx_padded[:subsample_size], pad, new_idx, start
Esempio n. 20
0
def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
    rng_key, rng_proposal = random.split(rng_key, 2)
    z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

    proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
    z_new_flat = ops.index_update(z_discrete_flat, idx, proposal)
    z_new = unravel_fn(z_new_flat)
    pe_new = potential_fn(z_new)
    log_accept_ratio = pe - pe_new
    return rng_key, z_new, pe_new, log_accept_ratio
Esempio n. 21
0
    def sample_batch(self, batch_size):
        self.key = random.split(self.key)[0]
        idxs = random.randint(self.key, (batch_size, ), 0, len(self))
        batch = [self[idx] for idx in idxs]

        # Each item to be its own tensor of len batch_size
        b = list(zip(*batch))
        buf_mean = 0
        buf_std = 1
        return [(jnp.asarray(t) - buf_mean) / buf_std for t in b]
Esempio n. 22
0
def create_grid(n_x: int, n_y: int, random_seed: int) -> np.DeviceArray:
    """Create a grid randomly filled with -1 and +1

    :param n_x: grid's first dimension
    :param n_y: grids's second dimension
    :param random_seed: seed for random functions
    :return: grid of size (n_x, n_y)
    """
    key = random.PRNGKey(random_seed)
    return random.randint(key, (n_x, n_y), 0, 2) * 2 - 1
    def data_loader(batch_shape,
                    key=None,
                    start=None,
                    split='train',
                    return_if_at_end=False):
        assert (key is None) ^ (start is None)
        at_end = False

        # We have the option to choose the index of the images
        if (key is None):
            data_indices = file_indices
            n_files = batch_shape[-1]
            batch_idx = start + jnp.arange(n_files)

            # Trim the batch indices so that we don't exceed the size of the dataset
            batch_idx = batch_idx[batch_idx < validation_indices.shape[0]]
            batch_shape = batch_shape[:-1] + (batch_idx.shape[0], )

            # Check if we're at the end of the dataset
            if (batch_idx.shape[0] < n_files):
                at_end = True

            batch_idx = np.broadcast_to(batch_idx, batch_shape)
        else:
            if (split == 'train'):
                data_indices = train_indices
            elif (split == 'test'):
                data_indices = test_indices
            elif (split == 'validation'):
                data_indices = validation_indices
            else:
                assert 0, 'Invalid split name.  Choose from \'train\', \'test\' or \'validation\''

            batch_idx = random.randint(key,
                                       batch_shape,
                                       minval=0,
                                       maxval=data_indices.shape[0])
            batch_idx = data_indices[batch_idx]

        batch_files = all_files[np.array(batch_idx)]

        images = np.zeros(batch_shape + x_shape,
                          dtype=np.int32).reshape((-1, ) + x_shape)
        for k, path in enumerate(batch_files.ravel()):
            im = plt.imread(path, format='jpg')
            im = im[::strides[0], ::strides[1]][crop[0][0]:crop[0][1],
                                                crop[1][0]:crop[1][1]]
            im = im // quantize_factor

            images[k] = im

        ret = images.reshape(batch_shape + x_shape)
        if (return_if_at_end):
            return ret, at_end
        return ret
Esempio n. 24
0
 def create_angle_vecs(angles, rkey):
     angle_rand_idx = random.randint(rkey, (ntrials, ), 0, np.size(angles))
     cos_angles = np.array(
         [np.cos(np.radians(angles[i])) for i in angle_rand_idx])
     sin_angles = np.array(
         [np.sin(np.radians(angles[i])) for i in angle_rand_idx])
     cos_angles = np.expand_dims(cos_angles, axis=0)
     cos_angles = np.expand_dims(cos_angles, axis=2)
     sin_angles = np.expand_dims(sin_angles, axis=0)
     sin_angles = np.expand_dims(sin_angles, axis=2)
     return cos_angles, sin_angles
Esempio n. 25
0
def _crop(key, image, hps):
  """Randomly shifts the window viewing the image."""
  pixpad = (hps.crop_num_pixels, hps.crop_num_pixels)
  zero = (0, 0)
  padded_image = jnp.pad(image, (pixpad, pixpad, zero),
                         mode='constant', constant_values=0.0)
  corner = random.randint(key, (2,), 0, 2 * hps.crop_num_pixels)
  corner = jnp.concatenate((corner, jnp.zeros((1,), jnp.int32)))
  cropped_image = lax.dynamic_slice(padded_image, corner, image.shape)

  return cropped_image
Esempio n. 26
0
    def _sample(self,key,n_samps,factors, bits_to_fix = -1, values_to_fix = -1):
        """generate samples from a distributions with a given set of factors
        
        Parameters
        ----------
        key : jax.random.PRNGKey
            jax random number generator 
        n_samps : int
            number of samples to generate
        factors : array_like
            factors of the distribution
        
        Returns
        -------
        array_like
            samples from the model
        """
        state = random.randint(key,minval=0,maxval=2, shape=(self.N,))
        unifs = random.uniform(key, shape=(n_samps*self.N,))
        all_states = np.zeros((n_samps,self.N))

        if bits_to_fix != -1:
            condition = True
            bits_to_keep = np.array([x for x in range(self.N) if x not in bits_to_fix])
            N = bits_to_keep.size
            values_to_fix = np.array(values_to_fix)
        else:
            condition = False
            bits_to_keep = np.arange(self.N)  
            N = self.N
        # @jit
        # def run_mh(j, loop_carry):
        #     state, all_states = loop_carry
        #     all_states = index_update(all_states,j//self.N,state)  # a bit wasteful
        #     state_flipped = index_update(state,j%self.N,1-state[j%self.N])
        #     dE = self.calc_e(factors,state_flipped)-self.calc_e(factors,state)
        #     accept = ((dE < 0) | (unifs[j] < np.exp(-dE)))
        #     state = np.where(accept, state_flipped, state)
        #     return state, all_states
        
        @jit
        def run_mh(j, loop_carry):
            state, all_states = loop_carry
            if condition:
                state = index_update(state, bits_to_fix, values_to_fix)
            all_states = index_update(all_states,j//N,state)  # a bit wasteful
            state_flipped = index_update(state,bits_to_keep[j%N],1-state[bits_to_keep[j%N]])
            dE = self.calc_e(factors,state_flipped)-self.calc_e(factors,state)
            accept = ((dE < 0) | (unifs[j] < np.exp(-dE)))
            state = np.where(accept, state_flipped, state)
            return state, all_states

        all_states = fori_loop(0, n_samps * N, run_mh, (state, all_states))
        return all_states[1]
Esempio n. 27
0
 def body(state, X):
     (key, i, center) = state
     key, new_point_key, t_key = random.split(key, 3)
     new_point = points[random.randint(new_point_key, (), 0, points.shape[0]), :]
     dx = points - center
     p = new_point - center
     p = p / jnp.linalg.norm(p)
     t_new = jnp.max(dx @ p, axis=0)
     new_point = center + random.uniform(t_key) * t_new * p
     center = (center * i + new_point) / (i + 1)
     return (key, i + 1, center), (new_point,)
Esempio n. 28
0
def test_simulate_polymer_brownian():

    key = random.PRNGKey(0)

    polymer = random.randint(key, minval=0, maxval=4, shape=(10, ))

    key, _ = random.split(key)

    positions, energy_fn = physics.simulate_polymer_brownian(key=key,
                                                             polymer=polymer,
                                                             box_size=6.8)
Esempio n. 29
0
 def body_fun(carry):
     key, samples, _ = carry
     key, use_key = random.split(key)
     new_samples = random.randint(use_key,
                                  shape=shape,
                                  minval=0,
                                  maxval=maxval)
     discard = jnp.logical_or(in1dvec(new_samples, samples),
                              in1dvec(new_samples, rejects))
     samples = jnp.where(discard, samples, new_samples)
     return key, samples, in1dvec(samples, rejects)
Esempio n. 30
0
 def initialize_input_sel(shape, dtype):
     dim = shape[-1]
     if self.method == "random":
         rng = hk.next_rng_key()
         input_sel = random.randint(rng,
                                    shape=(dim, ),
                                    minval=1,
                                    maxval=dim + 1)
     else:
         input_sel = jnp.arange(1, dim + 1)
     return input_sel