def sample(self, key, sample_shape=()): assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) # NB: we use a more numerically stable formula for a symmetric base distribution # A = icdf(cdf(low) + (cdf(high) - cdf(low)) * u) = icdf[(1 - u) * cdf(low) + u * cdf(high)] # will suffer by precision issues when low is large; # If low < loc: # A = icdf[(1 - u) * cdf(low) + u * cdf(high)] # Else # A = 2 * loc - icdf[(1 - u) * cdf(2*loc-low)) + u * cdf(2*loc - high)] loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return (1 - sign) * loc + sign * self.base_dist.icdf( (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high )
def step(key, state, init_key=None): transition_key, accept_key = random.split(key) next_state = st.init(inner_step)(init_key, transition_key, state)(transition_key, state) # TODO(sharadmv): add log probabilities to the state to avoid recalculation. state_log_prob = unnormalized_log_prob(state) next_state_log_prob = unnormalized_log_prob(next_state) log_unclipped_accept_prob = next_state_log_prob - state_log_prob accept_prob = harvest.sow(np.clip(np.exp(log_unclipped_accept_prob), 0., 1.), tag=MCMC_METRICS, name='accept_prob') u = primitive.tie_in(accept_prob, random.uniform(accept_key)) accept = np.log(u) < log_unclipped_accept_prob return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s), next_state, state)
def test_graph_network_shape_dtype(self, spatial_dimension, dtype): key = random.PRNGKey(0) R = random.uniform(key, (32, spatial_dimension), dtype=dtype) d, _ = space.free() cutoff = 0.2 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(key, R) E_out = energy_fn(params, R) assert E_out.shape == () assert E_out.dtype == dtype
def test_pair_no_species_vector(self, spatial_dimension, dtype): square = lambda dr: np.sum(dr**2, axis=2) disp, _ = space.free() mapped_square = smap.pair(square, disp) disp = space.map_product(disp) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) mapped_ref = np.array(0.5 * np.sum(square(disp(R, R))), dtype=dtype) self.assertAllClose(mapped_square(R), mapped_ref)
def test_soft_sphere_neighbor_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) exact_energy_fn = energy.soft_sphere_pair(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list( displacement, box_size) nbrs = neighbor_fn(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, nbrs))
def test_morse_small_neighbor_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(5.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.morse_pair(displacement) R = box_size * random.uniform(key, (10, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.morse_neighbor_list( displacement, box_size) nbrs = neighbor_fn(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, nbrs))
def test_lennard_jones_cell_list_force(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force( energy.lennard_jones_pair(displacement)) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) force_fn = quantity.force( energy.lennard_jones_cell_list(displacement, box_size, R)) self.assertAllClose(np.array(exact_force_fn(R), dtype=dtype), force_fn(R), True)
def _body_fn(kXVU): def _next_kxv(kxv): k = kxv[0] x = random.normal(k, ()) k, = random.split(k, 1) v = 1.0 + c * x return k, x, v key = kXVU[0] key, x, v = lax.while_loop(lambda kxv: kxv[2] <= 0.0, _next_kxv, (key, 0.0, -1.0)) X = x * x V = v * v * v U = random.uniform(key, ()) key, = random.split(key, 1) return key, X, V, U
def test_pairwise_grid_force_incommensurate(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(12.1) cell_size = f32(3.0) displacement, _ = space.periodic(box_size) energy_fn = energy.soft_sphere_pairwise(displacement, quantity.Dynamic) force_fn = quantity.force(energy_fn) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) grid_force_fn = jit(smap.grid(force_fn, box_size, cell_size, R)) species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64) self.assertAllClose(np.array(force_fn(R, species, 1), dtype=dtype), grid_force_fn(R), True)
def test_mixed_list_assignment_in_setup(self): class Test(nn.Module): def setup(self): self.layers = [nn.Dense(10), nn.relu, nn.Dense(10)] def __call__(self, x): for lyr in self.layers: x = lyr(x) return x x = random.uniform(random.PRNGKey(0), (5, 5)) variables = Test().init(random.PRNGKey(0), jnp.ones((5, 5))) y = Test().apply(variables, x) m0 = variables['params']['layers_0']['kernel'] m1 = variables['params']['layers_2']['kernel'] self.assertTrue(jnp.all(y == jnp.dot(nn.relu(jnp.dot(x, m0)), m1)))
def kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.65): res = np.abs(np.matmul(x1, x2)) if do_square: res *= res if do_flip: res = -res res *= random.uniform(keys) * p return [res, params]
def sample(self, state, model_args, model_kwargs): i, x, x_pe, x_grad, _, mean_accept_prob, adapt_state, rng_key = state x_flat, unravel_fn = ravel_pytree(x) x_grad_flat, _ = ravel_pytree(x_grad) shape = jnp.shape(x_flat) rng_key, key_normal, key_bernoulli, key_accept = random.split(rng_key, 4) mass_sqrt_inv = adapt_state.mass_matrix_sqrt_inv x_grad_flat_scaled = mass_sqrt_inv @ x_grad_flat if self._dense_mass else mass_sqrt_inv * x_grad_flat # Generate proposal y. z = adapt_state.step_size * random.normal(key_normal, shape) p = expit(-z * x_grad_flat_scaled) b = jnp.where(random.uniform(key_bernoulli, shape) < p, 1., -1.) dx_flat = b * z dx_flat_scaled = mass_sqrt_inv.T @ dx_flat if self._dense_mass else mass_sqrt_inv * dx_flat y_flat = x_flat + dx_flat_scaled y = unravel_fn(y_flat) y_pe, y_grad = jax.value_and_grad(self._potential_fn)(y) y_grad_flat, _ = ravel_pytree(y_grad) y_grad_flat_scaled = mass_sqrt_inv @ y_grad_flat if self._dense_mass else mass_sqrt_inv * y_grad_flat log_accept_ratio = x_pe - y_pe + jnp.sum(softplus(dx_flat * x_grad_flat_scaled) - softplus(-dx_flat * y_grad_flat_scaled)) accept_prob = jnp.clip(jnp.exp(log_accept_ratio), a_max=1.) x, x_flat, pe, x_grad = jax.lax.cond(random.bernoulli(key_accept, accept_prob), (y, y_flat, y_pe, y_grad), identity, (x, x_flat, x_pe, x_grad), identity) # do not update adapt_state after warmup phase adapt_state = jax.lax.cond(i < self._num_warmup, (i, accept_prob, (x,), adapt_state), lambda args: self._wa_update(*args), adapt_state, identity) itr = i + 1 n = jnp.where(i < self._num_warmup, itr, itr - self._num_warmup) mean_accept_prob = mean_accept_prob + (accept_prob - mean_accept_prob) / n return BarkerMHState(itr, x, pe, x_grad, accept_prob, mean_accept_prob, adapt_state, rng_key)
def sample_along_rays( key, origins, directions, num_coarse_samples, near, far, use_stratified_sampling, use_linear_disparity, ): """Stratified sampling along the rays. Args: key: jnp.ndarray, random generator key. origins: ray origins. directions: ray directions. num_coarse_samples: int. near: float, near clip. far: float, far clip. use_stratified_sampling: use stratified sampling. use_linear_disparity: sampling linearly in disparity rather than depth. Returns: z_vals: jnp.ndarray, [batch_size, num_coarse_samples], sampled z values. points: jnp.ndarray, [batch_size, num_coarse_samples, 3], sampled points. """ batch_size = origins.shape[0] t_vals = jnp.linspace(0.0, 1.0, num_coarse_samples) if not use_linear_disparity: z_vals = near * (1.0 - t_vals) + far * t_vals else: z_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals) if use_stratified_sampling: mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) upper = jnp.concatenate([mids, z_vals[..., -1:]], -1) lower = jnp.concatenate([z_vals[..., :1], mids], -1) t_rand = random.uniform(key, [batch_size, num_coarse_samples]) z_vals = lower + (upper - lower) * t_rand else: # Broadcast z_vals to make the returned shape consistent. z_vals = jnp.broadcast_to(z_vals[None, ...], [batch_size, num_coarse_samples]) return ( z_vals, (origins[..., None, :] + z_vals[..., :, None] * directions[..., None, :]), )
def test_scale_invariance_weight_quantization(self, prec): # Scaling weights by power of 2, should scale the output by the same scale. weights = random.uniform(random.PRNGKey(0), (10, 1)) weight_scale = 16 scaled_weights = weights * weight_scale weights = QuantOps.create_weights_fake_quant( w=weights, weight_params=QuantOps.WeightParams( prec=prec, axis=None, half_shift=False)) scaled_weights = QuantOps.create_weights_fake_quant( w=scaled_weights, weight_params=QuantOps.WeightParams( prec=prec, axis=None, half_shift=False)) onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)
def test_logaddexp(): a = jnp.log(1.) b = jnp.log(1.) assert logaddexp(a,b) == jnp.log(2.) a = jnp.log(1.) b = jnp.log(-2.+0j) assert jnp.isclose(jnp.exp(logaddexp(a, b)).real, -1.) a = jnp.log(-1.+0j) b = jnp.log(2. + 0j) assert jnp.isclose(jnp.exp(logaddexp(a, b)).real, 1.) for i in range(100): u = random.uniform(random.PRNGKey(i),shape=(2,))*20. - 10. a = jnp.log(u[0] + 0j) b = jnp.log(u[1] + 0j) assert jnp.isclose(jnp.exp(logaddexp(a,b)).real, u[0] + u[1])
def test_lennard_jones_cell_neighbor_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.lennard_jones_pair(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, R) idx = neighbor_fn(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, idx), True)
def _uniform_jax(shape, minval=0, maxval=None, dtype=tf.float32, seed=None, name=None): # pylint: disable=unused-argument import jax.random as jaxrand # pylint: disable=g-import-not-at-top if seed is None: raise ValueError('Must provide PRNGKey to sample in JAX.') dtype = utils.common_dtype([minval, maxval], dtype_hint=dtype) maxval = 1 if maxval is None else maxval shape = _shape([], shape) return jaxrand.uniform(key=seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval)
def mala_kernel(key, paramCurrent, paramGradCurrent, log_post, logpostCurrent, dt): key, subkey1, subkey2 = random.split(key, 3) paramProp = paramCurrent + dt*paramGradCurrent + jnp.sqrt(2*dt)*random.normal(key=subkey1, shape=paramCurrent.shape) new_log_post, new_grad = log_post(paramProp) term1 = paramProp - paramCurrent - dt*paramGradCurrent term2 = paramCurrent - paramProp - dt*new_grad q_new = -0.25*(1/dt)*jnp.dot(term1, term1) q_current = -0.25*(1/dt)*jnp.dot(term2, term2) log_ratio = new_log_post - logpostCurrent + q_current - q_new acceptBool = jnp.log(random.uniform(key=subkey2)) < log_ratio paramCurrent = jnp.where(acceptBool, paramProp, paramCurrent) current_grad = jnp.where(acceptBool, new_grad, paramGradCurrent) current_log_post = jnp.where(acceptBool, new_log_post, logpostCurrent) accepts_add = jnp.where(acceptBool, 1,0) return key, paramCurrent, current_grad, current_log_post, accepts_add
def test_affine_coupling(self): def transform(rng, input_dim, output_dim, hidden_dim=64, act=stax.Relu): init_fun, apply_fun = stax.serial( stax.Dense(hidden_dim), act, stax.Dense(hidden_dim), act, stax.Dense(output_dim), ) _, params = init_fun(rng, (input_dim,)) return params, apply_fun inputs = random.uniform(random.PRNGKey(0), (20, 5), minval=-10.0, maxval=10.0) init_fun = flows.AffineCoupling(transform) for test in (returns_correct_shape, is_bijective): test(self, init_fun, inputs) init_fun = flows.AffineCouplingSplit(transform, transform) for test in (returns_correct_shape, is_bijective): test(self, init_fun, inputs)
def test_cell_list_direct_force_jit(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(9.0) cell_size = f32(1.0) displacement, _ = space.periodic(box_size) energy_fn = energy.soft_sphere_pair(displacement) force_fn = quantity.force(energy_fn) R = box_size * random.uniform( key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) grid_energy_fn = smap.cartesian_product( energy.soft_sphere, space.metric(displacement)) grid_force_fn = quantity.force(grid_energy_fn) grid_force_fn = jit(smap.cell_list(grid_force_fn, box_size, cell_size, R)) self.assertAllClose( np.array(force_fn(R), dtype=dtype), grid_force_fn(R), True)
def body(state): (i, _, key, done, _) = state key, accept_key, sample_key, select_key = random.split(key, 4) k = random.categorical(select_key, log_p) mu_k = mu[k, :] radii_k = radii[k, :] rotation_k = rotation[k, :, :] u_test = sample_ellipsoid(sample_key, mu_k, radii_k, rotation_k, unit_cube_constraint=unit_cube_constraint) inside = vmap(lambda mu, radii, rotation: point_in_ellipsoid( u_test, mu, radii, rotation))(mu, radii, rotation) n_intersect = jnp.sum(inside) done = (random.uniform(accept_key) < jnp.reciprocal(n_intersect)) return (i + 1, k, key, done, u_test)
def rejection_stitching(ssm_scenario: StateSpaceModel, x0_all: jnp.ndarray, t: float, x1_all: jnp.ndarray, tplus1: float, x1_log_weight: jnp.ndarray, random_key: jnp.ndarray, maximum_rejections: int, init_bound_param: float, bound_inflation: float) -> Tuple[jnp.ndarray, int]: rejection_initial_keys = random.split(random_key, 3) n = len(x1_all) # Prerun to initiate bound x1_initial_inds = random.categorical(rejection_initial_keys[0], x1_log_weight, shape=(n,)) initial_cond_dens = jnp.exp(-vmap(ssm_scenario.transition_potential, (0, None, 0, None))(x0_all, t, x1_all[x1_initial_inds], tplus1)) max_cond_dens = jnp.max(initial_cond_dens) initial_bound = jnp.where(max_cond_dens > init_bound_param, max_cond_dens * bound_inflation, init_bound_param) initial_not_yet_accepted_arr = random.uniform(rejection_initial_keys[1], (n,)) > initial_cond_dens / initial_bound out_tup = while_loop(lambda tup: jnp.logical_and(tup[0].sum() > 0, tup[-2] < maximum_rejections), lambda tup: rejection_stitch_proposal_all(ssm_scenario, x0_all, t, x1_all, tplus1, x1_log_weight, bound_inflation, *tup), (initial_not_yet_accepted_arr, x1_initial_inds, initial_bound, random.split(rejection_initial_keys[2], n), 1, n)) not_yet_accepted_arr, x1_final_inds, final_bound, random_keys, rej_attempted, num_transition_evals = out_tup x1_final_inds = map(lambda i: full_stitch_single_cond(not_yet_accepted_arr[i], x1_final_inds[i], ssm_scenario, x0_all[i], t, x1_all, tplus1, x1_log_weight, random_keys[i]), jnp.arange(n)) num_transition_evals = num_transition_evals + len(x1_all) * not_yet_accepted_arr.sum() return x1_final_inds, num_transition_evals
def test_cell_list_random_emplace_rect(self, dtype, dim): key = random.PRNGKey(1) box_size = np.array([9.0, 3.0, 7.25]) if dim == 3 else np.array( [9.0, 3.25]) cell_size = f32(1.0) R = box_size * random.uniform(key, (PARTICLE_COUNT, dim)) cell_fn = partition.cell_list(box_size, cell_size) cell_list = cell_fn.allocate(R) id_flat = np.reshape(cell_list.id_buffer, (-1, )) R_flat = np.reshape(cell_list.position_buffer, (-1, dim)) R_out = np.zeros((PARTICLE_COUNT + 1, dim)) R_out = R_out.at[id_flat].set(R_flat)[:-1] self.assertAllClose(R_out, R)
def predictive_resample_single_loop(key,logcdf_conditionals,logpdf_joints,rho,n,T): d = jnp.shape(logcdf_conditionals)[0] #generate uniform random numbers key, subkey = random.split(key) #split key a_rand = random.uniform(subkey,shape = (T,d)) #Append a_rand to empty vn (for correct array size) vT = jnp.concatenate((jnp.zeros((n,d)),a_rand),axis = 0) #run forward loop inputs = vT,logcdf_conditionals,logpdf_joints,rho rng = jnp.arange(n,n+T) outputs,rng = mvcd.update_ptest_single_scan(inputs,rng) vT,logcdf_conditionals,logpdf_joints,rho = outputs return logcdf_conditionals,logpdf_joints
def conditional_sample(p, y, key): """ Generate conditional Binary relaxed samples :param p: Binary relaxed params (interpreted as Bernoulli probabilities) (jax.numpy array) :param y: Conditioning parameters (jax.numpy array) :param key: PRNG key """ tol = 1e-7 p = np.clip(p, tol, 1 - tol) v = random.uniform(key, shape=y.shape) v_prime = (v * p + (1 - p)) * y + (v * (1 - p)) * (1 - y) v_prime = np.clip(v_prime, tol, 1 - tol) logit_v = logit(v_prime) logit_p = logit(p) return logit_p + logit_v
def test_bond_no_type_static(self, spatial_dimension, dtype): harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2) disp, _ = space.free() metric = space.metric(disp) mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32)) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2])) self.assertAllClose(mapped(R), dtype(accum))
def test_cell_list_random_emplace(self, dtype, dim): key = random.PRNGKey(1) box_size = f32(9.0) cell_size = f32(1.0) R = box_size * random.uniform(key, (PARTICLE_COUNT, dim)) cell_fn = partition.cell_list(box_size, cell_size, R) cell_list = cell_fn(R) id_flat = np.reshape(cell_list.id_buffer, (-1, )) R_flat = np.reshape(cell_list.R_buffer, (-1, dim)) R_out = np.zeros((PARTICLE_COUNT + 1, dim)) R_out = ops.index_update(R_out, id_flat, R_flat)[:-1] self.assertAllClose(R_out, R)
def test_morse_neighbor_list_force(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force(energy.morse_pair(displacement)) r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.morse_neighbor_list( displacement, box_size) force_fn = quantity.force(energy_fn) nbrs = neighbor_fn(r) self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype), force_fn(r, nbrs))
def main(unused_argv): key = random.PRNGKey(0) # Setup some variables describing the system. N = 500 dimension = 2 box_size = f32(25.0) # Create helper functions to define a periodic box of some size. displacement, shift = space.periodic(box_size) metric = space.metric(displacement) # Use JAX's random number generator to generate random initial positions. key, split = random.split(key) R = random.uniform( split, (N, dimension), minval=0.0, maxval=box_size, dtype=f32) # The system ought to be a 50:50 mixture of two types of particles, one # large and one small. sigma = np.array([[1.0, 1.2], [1.2, 1.4]], dtype=f32) N_2 = int(N / 2) species = np.array([0] * N_2 + [1] * N_2, dtype=i32) # Create an energy function. energy_fn = energy.soft_sphere_pair(displacement, species, sigma) force_fn = quantity.force(energy_fn) # Create a minimizer. init_fn, apply_fn = minimize.fire_descent(energy_fn, shift) opt_state = init_fn(R) # Minimize the system. minimize_steps = 50 print_every = 10 print('Minimizing.') print('Step\tEnergy\tMax Force') print('-----------------------------------') for step in range(minimize_steps): opt_state = apply_fn(opt_state) if step % print_every == 0: R = opt_state.position print('{:.2f}\t{:.2f}\t{:.2f}'.format( step, energy_fn(R), np.max(force_fn(R))))
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace( *model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed'] and not v[ 'fn'].is_discrete: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius) key, subkey = random.split(key) potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum) if forward_mode_differentiation: pe = potential_fn(params) z_grad = jacfwd(potential_fn)(params) else: pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) return i + 1, key, (params, pe, z_grad), is_valid