def frank_wolfe_unsigned_projection(element: FiniteVec, solution: FiniteVec = None, num_samples: np.int32 = 100): assert (len(element) == 1) #key = PRNGKey(np.int32(time())) if solution is None: solution = FiniteVec(element.k, element.inspace_points[:1, :], np.zeros(1), points_per_split=1) for k in range(num_samples): def cost(x): x = x.reshape((1, -1)) return (solution(x) - element(x)).sum() g_cost = grad(cost) idx = randint(0, element.points_per_split - 1) cand = element.inspace_points[idx:idx + 1, :] #print(cand) #print(cost(cand), grad(cost)(cand)) res = minimize(__casted_output(cost), cand, jac=__casted_output(g_cost)) solution.inspace_points = np.vstack( [solution.inspace_points, res["x"]]) gamma_k = 1. / (k + 1) solution.prefactors = np.hstack([(1 - gamma_k) * solution.prefactors, gamma_k]) solution.points_per_split = solution.points_per_split + 1 return solution
def kmeans(key, points, mask, K=2): """ Perform kmeans clustering with Euclidean metric. Args: key: points: [N, D] mask: [N] bool K: int Returns: cluster_id [N], centers [K, D] """ N, D = points.shape def body(state): (i, done, old_cluster_id, centers) = state new_centers = vmap(lambda k: jnp.average( points, weights=(old_cluster_id == k) & mask, axis=0))( jnp.arange(K)) dx = points - new_centers[:, None, :] # K, N, D squared_norm = jnp.sum(jnp.square(dx), axis=-1) # K, N new_cluster_id = jnp.argmin(squared_norm, axis=0) # N done = jnp.all(new_cluster_id == old_cluster_id) # print("kmeans reassigns", jnp.sum(old_cluster_id!=new_cluster_id)) return i + 1, done, new_cluster_id, new_centers do_kmeans = jnp.sum(mask) > K i, _, cluster_id, centers = while_loop( lambda state: ~state[1], body, (jnp.array(0), ~do_kmeans, random.randint(key, shape=(N, ), minval=0, maxval=2), jnp.zeros( (K, D)))) return cluster_id, centers
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) - 2 elif isinstance(constraint, constraints._GreaterThan): return constraint.lower_bound - np.exp(random.normal(key, size)) elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) return random.randint(key, size, lower_bound - 1, lower_bound) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound - poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.) elif isinstance(constraint, constraints._Real): return lax.full(size, np.nan) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones((size[-1],)), size=size[:-1]) + 1e-2 elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1 elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) + 1e-2 elif isinstance(constraint, constraints._LowerCholesky): return random.uniform(key, size) elif isinstance(constraint, constraints._PositiveDefinite): return random.normal(key, size) else: raise NotImplementedError('{} not implemented.'.format(constraint))
def test_taylor_proxy_norm(subsample_size): data_key, tr_key, rng_key = random.split(random.PRNGKey(0), 3) ref_params = jnp.array([0.1, 0.5, -0.2]) sigma = .1 data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample( data_key, (100, )) n, _ = data.shape def model(data, subsample_size): mean = numpyro.sample( 'mean', dist.Normal(ref_params, jnp.ones_like(ref_params))) with numpyro.plate('data', data.shape[0], subsample_size=subsample_size, dim=-2) as idx: numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) def log_prob_fn(params): return vmap(dist.Normal(params, sigma).log_prob)(data).sum(-1) log_prob = log_prob_fn(ref_params) log_norm_jac = jacrev(log_prob_fn)(ref_params) log_norm_hessian = hessian(log_prob_fn)(ref_params) tr = numpyro.handlers.trace(numpyro.handlers.seed(model, tr_key)).get_trace( data, subsample_size) plate_sizes = {'data': (n, subsample_size)} proxy_constructor = HMCECS.taylor_proxy({'mean': ref_params}) proxy_fn, gibbs_init, gibbs_update = proxy_constructor( tr, plate_sizes, model, (data, subsample_size), {}) def taylor_expand_2nd_order(idx, pos): return log_prob[idx] + ( log_norm_jac[idx] @ pos) + .5 * (pos @ log_norm_hessian[idx]) @ pos def taylor_expand_2nd_order_sum(pos): return log_prob.sum() + log_norm_jac.sum( 0) @ pos + .5 * pos @ log_norm_hessian.sum(0) @ pos for _ in range(5): split_key, perturbe_key, rng_key = random.split(rng_key, 3) perturbe_params = ref_params + dist.Normal(.1, 0.1).sample( perturbe_key, ref_params.shape) subsample_idx = random.randint(rng_key, (subsample_size, ), 0, n) gibbs_site = {'data': subsample_idx} proxy_state = gibbs_init(None, gibbs_site) actual_proxy_sum, actual_proxy_sub = proxy_fn( {'data': perturbe_params}, ['data'], proxy_state) assert_allclose(actual_proxy_sub['data'], taylor_expand_2nd_order(subsample_idx, perturbe_params - ref_params), rtol=1e-5) assert_allclose(actual_proxy_sum['data'], taylor_expand_2nd_order_sum(perturbe_params - ref_params), rtol=1e-5)
def run_update(batch_idx, opt_state): kl_warmup = kl_warmup_fun(batch_idx) didxs = random.randint(next(dkeyg), [lfads_hps['batch_size']], 0, train_data.shape[0]) x_bxt = train_data[didxs].astype(np.float32) opt_state = update_fun(batch_idx, opt_state, lfads_hps, lfads_opt_hps, next(fkeyg), x_bxt, kl_warmup) return opt_state
def sample_batch(self, batch_size): """Sample past experience.""" self.rng, rng_input = random.split(self.rng) indexes = random.randint(rng_input, shape=(batch_size, ), minval=0, maxval=self.size) return self.data_points[indexes]
def random_integer(shape, dtype, minval, maxval, seed): """Generates a sample from uniform distribution over [minval, maxval).""" return random.randint( shape=tuple(shape), dtype=dtype, minval=minval, maxval=maxval, key=make_tensor_seed(seed))
def get_action(self, x_t): """ return action """ self.T += 1 eta_t = 1 - 2 * random.randint( generate_key(), minval=0, maxval=2, shape=(self.m, )) self.eta.append(eta_t) self.x_history.append(np.squeeze(x_t, axis=1)) return -self.K @ x_t + np.expand_dims(eta_t, axis=1)
def train_op(net: Module, opt: Optimizer, x, y, key, hyperparams: ServerHyperParams): index = random.randint(key, shape=(hyperparams.oracle_batch_size, ), minval=0, maxval=x.shape[0]) v, g = vg(net, opt.target, x[index], y[index]) return v, opt.apply_gradient(g)
def consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None): """ Merges subposteriors following consensus Monte Carlo algorithm. **References:** 1. *Bayes and big data: The consensus Monte Carlo algorithm*, Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch :param list subposteriors: a list in which each element is a collection of samples. :param int num_draws: number of draws from the merged posterior. :param bool diagonal: whether to compute weights using variance or covariance, defaults to `False` (using covariance). :param jax.random.PRNGKey rng_key: source of the randomness, defaults to `jax.random.PRNGKey(0)`. :return: if `num_draws` is None, merges subposteriors without resampling; otherwise, returns a collection of `num_draws` samples with the same data structure as each subposterior. """ # stack subposteriors joined_subposteriors = tree_multimap(lambda *args: jnp.stack(args), *subposteriors) # shape of joined_subposteriors: n_subs x n_samples x sample_shape joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))( joined_subposteriors ) if num_draws is not None: rng_key = random.PRNGKey(0) if rng_key is None else rng_key # randomly gets num_draws from subposteriors n_subs = len(subposteriors) n_samples = tree_flatten(subposteriors[0])[0][0].shape[0] # shape of draw_idxs: n_subs x num_draws x sample_shape draw_idxs = random.randint( rng_key, shape=(n_subs, num_draws), minval=0, maxval=n_samples ) joined_subposteriors = vmap(lambda x, idx: x[idx])( joined_subposteriors, draw_idxs ) if diagonal: # compute weights for each subposterior (ref: Section 3.1 of [1]) weights = vmap(lambda x: 1 / jnp.var(x, ddof=1, axis=0))(joined_subposteriors) normalized_weights = weights / jnp.sum(weights, axis=0) # get weighted samples samples_flat = jnp.einsum( "ij,ikj->kj", normalized_weights, joined_subposteriors ) else: weights = vmap(lambda x: jnp.linalg.inv(jnp.cov(x.T)))(joined_subposteriors) normalized_weights = jnp.matmul( jnp.linalg.inv(jnp.sum(weights, axis=0)), weights ) samples_flat = jnp.einsum( "ijk,ilk->lj", normalized_weights, joined_subposteriors ) # unravel_fn acts on 1 sample of a subposterior _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0])) return vmap(lambda x: unravel_fn(x))(samples_flat)
def flip_spin(grid: np.DeviceArray, n_x: int, n_y: int, x_subkey: np.DeviceArray, y_subkey: np.DeviceArray) -> np.DeviceArray: """Flip the spin of a single element on the grid :param grid: grid with spins :param n_x: grid's first dimension :param n_y: grids's second dimension :param xflip_subkey: subkey for random x coordinate in flip :param yflip_subkey: subkey for random y coordinate in flip :return: grid with one flipped spin """ x = random.randint(x_subkey, (1, ), 0, n_x) y = random.randint(y_subkey, (1, ), 0, n_y) mask = index_update(np.ones_like(grid), index[x, y], -1) flipped_grid = grid * mask return flipped_grid
def testRandomBroadcast(self): """Issue 4033""" # test for broadcast issue in https://github.com/google/jax/issues/4033 key = random.PRNGKey(0) shape = (10, 2) x = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2)) assert x.shape == shape x = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2])) assert x.shape == shape
def initializeMeans(self, X): n = X.shape[0] means = X[:self.n_clusters] idx = random.randint(self.key, (self.n_clusters, ), 0, n) #for i in range(idx.shape[0]): # means[i] = X[i] return means
def distort_image_with_randaugment(image, num_layers, magnitude, random_key, cutout_const=40, translate_const=50.0, default_replace_value=None, available_ops=DEFAULT_OPS, op_probs=DEFAULT_PROBS, join_transforms=False): """Applies the RandAugment policy to `image`. RandAugment is from the paper https://arxiv.org/abs/1909.13719, Args: image: `Tensor` of shape [height, width, 3] representing an image. num_layers: Integer, the number of augmentation transformations to apply sequentially to an image. Represented as (N) in the paper. Usually best values will be in the range [1, 3]. magnitude: Integer, shared magnitude across all augmentation operations. Represented as (M) in the paper. Usually best values are in the range [5, 30]. random_key: random key to do random stuff join_transforms: reduce multiple transforms to one. Much more efficient but simpler. cutout_const: max cutout size int translate_const: maximum translation amount int default_replace_value: default replacement value for pixels outside of the image available_ops: available operations op_probs: probabilities of operations join_transforms: apply transformations immediately or join them Returns: The augmented version of `image`. """ geometric_transforms = jnp.identity(4) for_i_args = (image, geometric_transforms, random_key, available_ops, op_probs, magnitude, cutout_const, translate_const, join_transforms, default_replace_value) if DEBUG: # un-jitted for i in range(num_layers): for_i_args = _randaugment_inner_for_loop(i, for_i_args) else: # jitted for_i_args = jax.lax.fori_loop(0, num_layers, _randaugment_inner_for_loop, for_i_args) image, geometric_transforms = for_i_args[0], for_i_args[1] if join_transforms: replace_value = default_replace_value or random.randint(random_key, [image.shape[-1]], minval=0, maxval=256) image = transforms.apply_transform(image, geometric_transforms, mask_value=replace_value) return image
def sample(self, rng, batch_size): ind = random.randint(rng, (batch_size, ), 0, self.size) return BufferOutput( obs=jax.device_put(self.obs[ind]), action=jax.device_put(self.action[ind]), next_obs=jax.device_put(self.next_obs[ind]), reward=jax.device_put(self.reward[ind]), not_done=jax.device_put(self.not_done[ind]), )
def body_fn(val, idx): i_p1 = size - idx i = i_p1 - 1 j = random.randint(rng_keys[idx], (), 0, i_p1) val = ops.index_update( val, ops.index[[i, j], ], val[ops.index[[j, i], ]], ) return val, None
def test_sliced_distances(self): num_points = 100 dim = 50 key = random.PRNGKey(42) use_keys = random.split(key, num=3) indices1 = random.randint(use_keys[0], shape=(int(1.5 * num_points), ), minval=0, maxval=num_points) indices2 = random.randint(use_keys[1], shape=(int(1.5 * num_points), ), minval=0, maxval=num_points) inputs = random.normal(use_keys[2], shape=(num_points, dim)) distance_fn = trimap.euclidean_dist dist_sliced = trimap.sliced_distances(indices1, indices2, inputs, distance_fn) dist_direct = distance_fn(inputs[indices1], inputs[indices2]) npt.assert_equal(np.array(dist_sliced), np.array(dist_direct))
def sample(self, rng: PRNGSequence, batch_size: int): ind = random.randint(rng, (batch_size, ), 0, self.size) return ( jax.device_put(self.state[ind]), jax.device_put(self.action[ind]), jax.device_put(self.next_state[ind]), jax.device_put(self.reward[ind]), jax.device_put(self.not_done[ind]), )
def _update_block(rng_key, num_blocks, subsample_idx, plate_size): size, subsample_size = plate_size rng_key, subkey, block_key = random.split(rng_key, 3) block_size = (subsample_size - 1) // num_blocks + 1 pad = block_size - (subsample_size - 1) % block_size - 1 chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size, )) subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) start = chosen_block * block_size subsample_idx_padded = lax.dynamic_update_slice_in_dim( subsample_idx_padded, new_idx, start, 0) return rng_key, subsample_idx_padded[:subsample_size], pad, new_idx, start
def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size): rng_key, rng_proposal = random.split(rng_key, 2) z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size) z_new_flat = ops.index_update(z_discrete_flat, idx, proposal) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_accept_ratio = pe - pe_new return rng_key, z_new, pe_new, log_accept_ratio
def sample_batch(self, batch_size): self.key = random.split(self.key)[0] idxs = random.randint(self.key, (batch_size, ), 0, len(self)) batch = [self[idx] for idx in idxs] # Each item to be its own tensor of len batch_size b = list(zip(*batch)) buf_mean = 0 buf_std = 1 return [(jnp.asarray(t) - buf_mean) / buf_std for t in b]
def create_grid(n_x: int, n_y: int, random_seed: int) -> np.DeviceArray: """Create a grid randomly filled with -1 and +1 :param n_x: grid's first dimension :param n_y: grids's second dimension :param random_seed: seed for random functions :return: grid of size (n_x, n_y) """ key = random.PRNGKey(random_seed) return random.randint(key, (n_x, n_y), 0, 2) * 2 - 1
def data_loader(batch_shape, key=None, start=None, split='train', return_if_at_end=False): assert (key is None) ^ (start is None) at_end = False # We have the option to choose the index of the images if (key is None): data_indices = file_indices n_files = batch_shape[-1] batch_idx = start + jnp.arange(n_files) # Trim the batch indices so that we don't exceed the size of the dataset batch_idx = batch_idx[batch_idx < validation_indices.shape[0]] batch_shape = batch_shape[:-1] + (batch_idx.shape[0], ) # Check if we're at the end of the dataset if (batch_idx.shape[0] < n_files): at_end = True batch_idx = np.broadcast_to(batch_idx, batch_shape) else: if (split == 'train'): data_indices = train_indices elif (split == 'test'): data_indices = test_indices elif (split == 'validation'): data_indices = validation_indices else: assert 0, 'Invalid split name. Choose from \'train\', \'test\' or \'validation\'' batch_idx = random.randint(key, batch_shape, minval=0, maxval=data_indices.shape[0]) batch_idx = data_indices[batch_idx] batch_files = all_files[np.array(batch_idx)] images = np.zeros(batch_shape + x_shape, dtype=np.int32).reshape((-1, ) + x_shape) for k, path in enumerate(batch_files.ravel()): im = plt.imread(path, format='jpg') im = im[::strides[0], ::strides[1]][crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]] im = im // quantize_factor images[k] = im ret = images.reshape(batch_shape + x_shape) if (return_if_at_end): return ret, at_end return ret
def create_angle_vecs(angles, rkey): angle_rand_idx = random.randint(rkey, (ntrials, ), 0, np.size(angles)) cos_angles = np.array( [np.cos(np.radians(angles[i])) for i in angle_rand_idx]) sin_angles = np.array( [np.sin(np.radians(angles[i])) for i in angle_rand_idx]) cos_angles = np.expand_dims(cos_angles, axis=0) cos_angles = np.expand_dims(cos_angles, axis=2) sin_angles = np.expand_dims(sin_angles, axis=0) sin_angles = np.expand_dims(sin_angles, axis=2) return cos_angles, sin_angles
def _crop(key, image, hps): """Randomly shifts the window viewing the image.""" pixpad = (hps.crop_num_pixels, hps.crop_num_pixels) zero = (0, 0) padded_image = jnp.pad(image, (pixpad, pixpad, zero), mode='constant', constant_values=0.0) corner = random.randint(key, (2,), 0, 2 * hps.crop_num_pixels) corner = jnp.concatenate((corner, jnp.zeros((1,), jnp.int32))) cropped_image = lax.dynamic_slice(padded_image, corner, image.shape) return cropped_image
def _sample(self,key,n_samps,factors, bits_to_fix = -1, values_to_fix = -1): """generate samples from a distributions with a given set of factors Parameters ---------- key : jax.random.PRNGKey jax random number generator n_samps : int number of samples to generate factors : array_like factors of the distribution Returns ------- array_like samples from the model """ state = random.randint(key,minval=0,maxval=2, shape=(self.N,)) unifs = random.uniform(key, shape=(n_samps*self.N,)) all_states = np.zeros((n_samps,self.N)) if bits_to_fix != -1: condition = True bits_to_keep = np.array([x for x in range(self.N) if x not in bits_to_fix]) N = bits_to_keep.size values_to_fix = np.array(values_to_fix) else: condition = False bits_to_keep = np.arange(self.N) N = self.N # @jit # def run_mh(j, loop_carry): # state, all_states = loop_carry # all_states = index_update(all_states,j//self.N,state) # a bit wasteful # state_flipped = index_update(state,j%self.N,1-state[j%self.N]) # dE = self.calc_e(factors,state_flipped)-self.calc_e(factors,state) # accept = ((dE < 0) | (unifs[j] < np.exp(-dE))) # state = np.where(accept, state_flipped, state) # return state, all_states @jit def run_mh(j, loop_carry): state, all_states = loop_carry if condition: state = index_update(state, bits_to_fix, values_to_fix) all_states = index_update(all_states,j//N,state) # a bit wasteful state_flipped = index_update(state,bits_to_keep[j%N],1-state[bits_to_keep[j%N]]) dE = self.calc_e(factors,state_flipped)-self.calc_e(factors,state) accept = ((dE < 0) | (unifs[j] < np.exp(-dE))) state = np.where(accept, state_flipped, state) return state, all_states all_states = fori_loop(0, n_samps * N, run_mh, (state, all_states)) return all_states[1]
def body(state, X): (key, i, center) = state key, new_point_key, t_key = random.split(key, 3) new_point = points[random.randint(new_point_key, (), 0, points.shape[0]), :] dx = points - center p = new_point - center p = p / jnp.linalg.norm(p) t_new = jnp.max(dx @ p, axis=0) new_point = center + random.uniform(t_key) * t_new * p center = (center * i + new_point) / (i + 1) return (key, i + 1, center), (new_point,)
def test_simulate_polymer_brownian(): key = random.PRNGKey(0) polymer = random.randint(key, minval=0, maxval=4, shape=(10, )) key, _ = random.split(key) positions, energy_fn = physics.simulate_polymer_brownian(key=key, polymer=polymer, box_size=6.8)
def body_fun(carry): key, samples, _ = carry key, use_key = random.split(key) new_samples = random.randint(use_key, shape=shape, minval=0, maxval=maxval) discard = jnp.logical_or(in1dvec(new_samples, samples), in1dvec(new_samples, rejects)) samples = jnp.where(discard, samples, new_samples) return key, samples, in1dvec(samples, rejects)
def initialize_input_sel(shape, dtype): dim = shape[-1] if self.method == "random": rng = hk.next_rng_key() input_sel = random.randint(rng, shape=(dim, ), minval=1, maxval=dim + 1) else: input_sel = jnp.arange(1, dim + 1) return input_sel