Example #1
0
    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        u = random.uniform(key, sample_shape + self.batch_shape)

        # NB: we use a more numerically stable formula for a symmetric base distribution
        #   A = icdf(cdf(low) + (cdf(high) - cdf(low)) * u) = icdf[(1 - u) * cdf(low) + u * cdf(high)]
        # will suffer by precision issues when low is large;
        # If low < loc:
        #   A = icdf[(1 - u) * cdf(low) + u * cdf(high)]
        # Else
        #   A = 2 * loc - icdf[(1 - u) * cdf(2*loc-low)) + u * cdf(2*loc - high)]
        loc = self.base_dist.loc
        sign = jnp.where(loc >= self.low, 1.0, -1.0)
        return (1 - sign) * loc + sign * self.base_dist.icdf(
            (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high
        )
Example #2
0
 def step(key, state, init_key=None):
     transition_key, accept_key = random.split(key)
     next_state = st.init(inner_step)(init_key, transition_key,
                                      state)(transition_key, state)
     # TODO(sharadmv): add log probabilities to the state to avoid recalculation.
     state_log_prob = unnormalized_log_prob(state)
     next_state_log_prob = unnormalized_log_prob(next_state)
     log_unclipped_accept_prob = next_state_log_prob - state_log_prob
     accept_prob = harvest.sow(np.clip(np.exp(log_unclipped_accept_prob),
                                       0., 1.),
                               tag=MCMC_METRICS,
                               name='accept_prob')
     u = primitive.tie_in(accept_prob, random.uniform(accept_key))
     accept = np.log(u) < log_unclipped_accept_prob
     return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s),
                                    next_state, state)
Example #3
0
    def test_graph_network_shape_dtype(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        R = random.uniform(key, (32, spatial_dimension), dtype=dtype)

        d, _ = space.free()

        cutoff = 0.2

        init_fn, energy_fn = energy.graph_network(d, cutoff)
        params = init_fn(key, R)

        E_out = energy_fn(params, R)

        assert E_out.shape == ()
        assert E_out.dtype == dtype
Example #4
0
    def test_pair_no_species_vector(self, spatial_dimension, dtype):
        square = lambda dr: np.sum(dr**2, axis=2)
        disp, _ = space.free()

        mapped_square = smap.pair(square, disp)

        disp = space.map_product(disp)
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            mapped_ref = np.array(0.5 * np.sum(square(disp(R, R))),
                                  dtype=dtype)
            self.assertAllClose(mapped_square(R), mapped_ref)
Example #5
0
    def test_soft_sphere_neighbor_list_energy(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        exact_energy_fn = energy.soft_sphere_pair(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(
            displacement, box_size)

        nbrs = neighbor_fn(R)

        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R, nbrs))
Example #6
0
    def test_morse_small_neighbor_list_energy(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(5.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_energy_fn = energy.morse_pair(displacement)

        R = box_size * random.uniform(key, (10, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.morse_neighbor_list(
            displacement, box_size)

        nbrs = neighbor_fn(R)
        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R, nbrs))
Example #7
0
    def test_lennard_jones_cell_list_force(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_force_fn = quantity.force(
            energy.lennard_jones_pair(displacement))

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        force_fn = quantity.force(
            energy.lennard_jones_cell_list(displacement, box_size, R))

        self.assertAllClose(np.array(exact_force_fn(R), dtype=dtype),
                            force_fn(R), True)
Example #8
0
    def _body_fn(kXVU):
        def _next_kxv(kxv):
            k = kxv[0]
            x = random.normal(k, ())
            k, = random.split(k, 1)
            v = 1.0 + c * x
            return k, x, v

        key = kXVU[0]
        key, x, v = lax.while_loop(lambda kxv: kxv[2] <= 0.0, _next_kxv,
                                   (key, 0.0, -1.0))
        X = x * x
        V = v * v * v
        U = random.uniform(key, ())
        key, = random.split(key, 1)
        return key, X, V, U
Example #9
0
    def test_pairwise_grid_force_incommensurate(self, spatial_dimension,
                                                dtype):
        key = random.PRNGKey(1)

        box_size = f32(12.1)
        cell_size = f32(3.0)
        displacement, _ = space.periodic(box_size)
        energy_fn = energy.soft_sphere_pairwise(displacement, quantity.Dynamic)
        force_fn = quantity.force(energy_fn)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        grid_force_fn = jit(smap.grid(force_fn, box_size, cell_size, R))
        species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64)
        self.assertAllClose(np.array(force_fn(R, species, 1), dtype=dtype),
                            grid_force_fn(R), True)
Example #10
0
    def test_mixed_list_assignment_in_setup(self):
        class Test(nn.Module):
            def setup(self):
                self.layers = [nn.Dense(10), nn.relu, nn.Dense(10)]

            def __call__(self, x):
                for lyr in self.layers:
                    x = lyr(x)
                return x

        x = random.uniform(random.PRNGKey(0), (5, 5))
        variables = Test().init(random.PRNGKey(0), jnp.ones((5, 5)))
        y = Test().apply(variables, x)
        m0 = variables['params']['layers_0']['kernel']
        m1 = variables['params']['layers_2']['kernel']
        self.assertTrue(jnp.all(y == jnp.dot(nn.relu(jnp.dot(x, m0)), m1)))
Example #11
0
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= random.uniform(keys) * p
            return [res, params]
Example #12
0
    def sample(self, state, model_args, model_kwargs):
        i, x, x_pe, x_grad, _, mean_accept_prob, adapt_state, rng_key = state
        x_flat, unravel_fn = ravel_pytree(x)
        x_grad_flat, _ = ravel_pytree(x_grad)
        shape = jnp.shape(x_flat)
        rng_key, key_normal, key_bernoulli, key_accept = random.split(rng_key, 4)

        mass_sqrt_inv = adapt_state.mass_matrix_sqrt_inv

        x_grad_flat_scaled = mass_sqrt_inv @ x_grad_flat if self._dense_mass else mass_sqrt_inv * x_grad_flat

        # Generate proposal y.
        z = adapt_state.step_size * random.normal(key_normal, shape)

        p = expit(-z * x_grad_flat_scaled)
        b = jnp.where(random.uniform(key_bernoulli, shape) < p, 1., -1.)

        dx_flat = b * z
        dx_flat_scaled = mass_sqrt_inv.T @ dx_flat if self._dense_mass else mass_sqrt_inv * dx_flat

        y_flat = x_flat + dx_flat_scaled

        y = unravel_fn(y_flat)
        y_pe, y_grad = jax.value_and_grad(self._potential_fn)(y)
        y_grad_flat, _ = ravel_pytree(y_grad)
        y_grad_flat_scaled = mass_sqrt_inv @ y_grad_flat if self._dense_mass else mass_sqrt_inv * y_grad_flat

        log_accept_ratio = x_pe - y_pe + jnp.sum(softplus(dx_flat * x_grad_flat_scaled) -
                                                 softplus(-dx_flat * y_grad_flat_scaled))
        accept_prob = jnp.clip(jnp.exp(log_accept_ratio), a_max=1.)

        x, x_flat, pe, x_grad = jax.lax.cond(random.bernoulli(key_accept, accept_prob),
                                             (y, y_flat, y_pe, y_grad), identity,
                                             (x, x_flat, x_pe, x_grad), identity)

        # do not update adapt_state after warmup phase
        adapt_state = jax.lax.cond(i < self._num_warmup,
                                   (i, accept_prob, (x,), adapt_state),
                                   lambda args: self._wa_update(*args),
                                   adapt_state,
                                   identity)

        itr = i + 1
        n = jnp.where(i < self._num_warmup, itr, itr - self._num_warmup)
        mean_accept_prob = mean_accept_prob + (accept_prob - mean_accept_prob) / n

        return BarkerMHState(itr, x, pe, x_grad, accept_prob, mean_accept_prob, adapt_state, rng_key)
Example #13
0
def sample_along_rays(
    key,
    origins,
    directions,
    num_coarse_samples,
    near,
    far,
    use_stratified_sampling,
    use_linear_disparity,
):
    """Stratified sampling along the rays.

    Args:
      key: jnp.ndarray, random generator key.
      origins: ray origins.
      directions: ray directions.
      num_coarse_samples: int.
      near: float, near clip.
      far: float, far clip.
      use_stratified_sampling: use stratified sampling.
      use_linear_disparity: sampling linearly in disparity rather than depth.

    Returns:
      z_vals: jnp.ndarray, [batch_size, num_coarse_samples], sampled z values.
      points: jnp.ndarray, [batch_size, num_coarse_samples, 3], sampled points.
    """
    batch_size = origins.shape[0]

    t_vals = jnp.linspace(0.0, 1.0, num_coarse_samples)
    if not use_linear_disparity:
        z_vals = near * (1.0 - t_vals) + far * t_vals
    else:
        z_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals)
    if use_stratified_sampling:
        mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
        upper = jnp.concatenate([mids, z_vals[..., -1:]], -1)
        lower = jnp.concatenate([z_vals[..., :1], mids], -1)
        t_rand = random.uniform(key, [batch_size, num_coarse_samples])
        z_vals = lower + (upper - lower) * t_rand
    else:
        # Broadcast z_vals to make the returned shape consistent.
        z_vals = jnp.broadcast_to(z_vals[None, ...], [batch_size, num_coarse_samples])

    return (
        z_vals,
        (origins[..., None, :] + z_vals[..., :, None] * directions[..., None, :]),
    )
  def test_scale_invariance_weight_quantization(self, prec):
    # Scaling weights by power of 2, should scale the output by the same scale.
    weights = random.uniform(random.PRNGKey(0), (10, 1))
    weight_scale = 16
    scaled_weights = weights * weight_scale

    weights = QuantOps.create_weights_fake_quant(
        w=weights,
        weight_params=QuantOps.WeightParams(
            prec=prec, axis=None, half_shift=False))

    scaled_weights = QuantOps.create_weights_fake_quant(
        w=scaled_weights,
        weight_params=QuantOps.WeightParams(
            prec=prec, axis=None, half_shift=False))

    onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)
Example #15
0
def test_logaddexp():
    a = jnp.log(1.)
    b = jnp.log(1.)
    assert logaddexp(a,b) == jnp.log(2.)
    a = jnp.log(1.)
    b = jnp.log(-2.+0j)
    assert jnp.isclose(jnp.exp(logaddexp(a, b)).real, -1.)

    a = jnp.log(-1.+0j)
    b = jnp.log(2. + 0j)
    assert jnp.isclose(jnp.exp(logaddexp(a, b)).real, 1.)

    for i in range(100):
        u = random.uniform(random.PRNGKey(i),shape=(2,))*20. - 10.
        a = jnp.log(u[0] + 0j)
        b = jnp.log(u[1] + 0j)
        assert jnp.isclose(jnp.exp(logaddexp(a,b)).real, u[0] + u[1])
Example #16
0
    def test_lennard_jones_cell_neighbor_list_energy(self, spatial_dimension,
                                                     dtype):
        key = random.PRNGKey(1)

        box_size = f32(15)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_energy_fn = energy.lennard_jones_pair(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
            displacement, box_size, R)

        idx = neighbor_fn(R)
        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R, idx), True)
def _uniform_jax(shape,
                 minval=0,
                 maxval=None,
                 dtype=tf.float32,
                 seed=None,
                 name=None):  # pylint: disable=unused-argument
    import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
    if seed is None:
        raise ValueError('Must provide PRNGKey to sample in JAX.')
    dtype = utils.common_dtype([minval, maxval], dtype_hint=dtype)
    maxval = 1 if maxval is None else maxval
    shape = _shape([], shape)
    return jaxrand.uniform(key=seed,
                           shape=shape,
                           dtype=dtype,
                           minval=minval,
                           maxval=maxval)
def mala_kernel(key, paramCurrent, paramGradCurrent, log_post, logpostCurrent, dt):
    key, subkey1, subkey2 = random.split(key, 3)
    paramProp = paramCurrent + dt*paramGradCurrent + jnp.sqrt(2*dt)*random.normal(key=subkey1, shape=paramCurrent.shape)
    new_log_post, new_grad = log_post(paramProp)

    term1 = paramProp - paramCurrent - dt*paramGradCurrent
    term2 = paramCurrent - paramProp - dt*new_grad
    q_new = -0.25*(1/dt)*jnp.dot(term1, term1)
    q_current = -0.25*(1/dt)*jnp.dot(term2, term2)

    log_ratio = new_log_post - logpostCurrent + q_current - q_new
    acceptBool = jnp.log(random.uniform(key=subkey2)) < log_ratio
    paramCurrent = jnp.where(acceptBool, paramProp, paramCurrent)
    current_grad = jnp.where(acceptBool, new_grad, paramGradCurrent)
    current_log_post = jnp.where(acceptBool, new_log_post, logpostCurrent)
    accepts_add = jnp.where(acceptBool, 1,0)
    return key, paramCurrent, current_grad, current_log_post, accepts_add
    def test_affine_coupling(self):
        def transform(rng, input_dim, output_dim, hidden_dim=64, act=stax.Relu):
            init_fun, apply_fun = stax.serial(
                stax.Dense(hidden_dim), act, stax.Dense(hidden_dim), act, stax.Dense(output_dim),
            )
            _, params = init_fun(rng, (input_dim,))
            return params, apply_fun

        inputs = random.uniform(random.PRNGKey(0), (20, 5), minval=-10.0, maxval=10.0)

        init_fun = flows.AffineCoupling(transform)
        for test in (returns_correct_shape, is_bijective):
            test(self, init_fun, inputs)

        init_fun = flows.AffineCouplingSplit(transform, transform)
        for test in (returns_correct_shape, is_bijective):
            test(self, init_fun, inputs)
Example #20
0
  def test_cell_list_direct_force_jit(self, spatial_dimension, dtype):
    key = random.PRNGKey(1)

    box_size = f32(9.0)
    cell_size = f32(1.0)
    displacement, _ = space.periodic(box_size)
    energy_fn = energy.soft_sphere_pair(displacement)
    force_fn = quantity.force(energy_fn)

    R = box_size * random.uniform(
      key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
    grid_energy_fn = smap.cartesian_product(
      energy.soft_sphere, space.metric(displacement))
    grid_force_fn = quantity.force(grid_energy_fn)
    grid_force_fn = jit(smap.cell_list(grid_force_fn, box_size, cell_size, R))
    self.assertAllClose(
      np.array(force_fn(R), dtype=dtype), grid_force_fn(R), True)
Example #21
0
 def body(state):
     (i, _, key, done, _) = state
     key, accept_key, sample_key, select_key = random.split(key, 4)
     k = random.categorical(select_key, log_p)
     mu_k = mu[k, :]
     radii_k = radii[k, :]
     rotation_k = rotation[k, :, :]
     u_test = sample_ellipsoid(sample_key,
                               mu_k,
                               radii_k,
                               rotation_k,
                               unit_cube_constraint=unit_cube_constraint)
     inside = vmap(lambda mu, radii, rotation: point_in_ellipsoid(
         u_test, mu, radii, rotation))(mu, radii, rotation)
     n_intersect = jnp.sum(inside)
     done = (random.uniform(accept_key) < jnp.reciprocal(n_intersect))
     return (i + 1, k, key, done, u_test)
Example #22
0
def rejection_stitching(ssm_scenario: StateSpaceModel,
                        x0_all: jnp.ndarray,
                        t: float,
                        x1_all: jnp.ndarray,
                        tplus1: float,
                        x1_log_weight: jnp.ndarray,
                        random_key: jnp.ndarray,
                        maximum_rejections: int,
                        init_bound_param: float,
                        bound_inflation: float) -> Tuple[jnp.ndarray, int]:
    rejection_initial_keys = random.split(random_key, 3)
    n = len(x1_all)

    # Prerun to initiate bound
    x1_initial_inds = random.categorical(rejection_initial_keys[0], x1_log_weight, shape=(n,))
    initial_cond_dens = jnp.exp(-vmap(ssm_scenario.transition_potential,
                                      (0, None, 0, None))(x0_all, t, x1_all[x1_initial_inds], tplus1))
    max_cond_dens = jnp.max(initial_cond_dens)
    initial_bound = jnp.where(max_cond_dens > init_bound_param, max_cond_dens * bound_inflation, init_bound_param)
    initial_not_yet_accepted_arr = random.uniform(rejection_initial_keys[1], (n,)) > initial_cond_dens / initial_bound

    out_tup = while_loop(lambda tup: jnp.logical_and(tup[0].sum() > 0, tup[-2] < maximum_rejections),
                         lambda tup: rejection_stitch_proposal_all(ssm_scenario, x0_all, t, x1_all, tplus1,
                                                                   x1_log_weight,
                                                                   bound_inflation, *tup),
                         (initial_not_yet_accepted_arr,
                          x1_initial_inds,
                          initial_bound,
                          random.split(rejection_initial_keys[2], n),
                          1,
                          n))
    not_yet_accepted_arr, x1_final_inds, final_bound, random_keys, rej_attempted, num_transition_evals = out_tup

    x1_final_inds = map(lambda i: full_stitch_single_cond(not_yet_accepted_arr[i],
                                                          x1_final_inds[i],
                                                          ssm_scenario,
                                                          x0_all[i],
                                                          t,
                                                          x1_all,
                                                          tplus1,
                                                          x1_log_weight,
                                                          random_keys[i]), jnp.arange(n))

    num_transition_evals = num_transition_evals + len(x1_all) * not_yet_accepted_arr.sum()

    return x1_final_inds, num_transition_evals
Example #23
0
    def test_cell_list_random_emplace_rect(self, dtype, dim):
        key = random.PRNGKey(1)

        box_size = np.array([9.0, 3.0, 7.25]) if dim == 3 else np.array(
            [9.0, 3.25])
        cell_size = f32(1.0)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, dim))

        cell_fn = partition.cell_list(box_size, cell_size)
        cell_list = cell_fn.allocate(R)

        id_flat = np.reshape(cell_list.id_buffer, (-1, ))
        R_flat = np.reshape(cell_list.position_buffer, (-1, dim))
        R_out = np.zeros((PARTICLE_COUNT + 1, dim))
        R_out = R_out.at[id_flat].set(R_flat)[:-1]
        self.assertAllClose(R_out, R)
def predictive_resample_single_loop(key,logcdf_conditionals,logpdf_joints,rho,n,T):
    d = jnp.shape(logcdf_conditionals)[0]

    #generate uniform random numbers
    key, subkey = random.split(key) #split key
    a_rand = random.uniform(subkey,shape = (T,d))

    #Append a_rand to empty vn (for correct array size)
    vT = jnp.concatenate((jnp.zeros((n,d)),a_rand),axis = 0)

    #run forward loop
    inputs = vT,logcdf_conditionals,logpdf_joints,rho
    rng = jnp.arange(n,n+T)
    outputs,rng = mvcd.update_ptest_single_scan(inputs,rng)
    vT,logcdf_conditionals,logpdf_joints,rho = outputs

    return logcdf_conditionals,logpdf_joints
Example #25
0
def conditional_sample(p, y, key):
    """
    Generate conditional Binary relaxed samples
    :param p: Binary relaxed params (interpreted as Bernoulli probabilities) (jax.numpy array)
    :param y: Conditioning parameters (jax.numpy array)
    :param key: PRNG key
    """
    tol = 1e-7
    p = np.clip(p, tol, 1 - tol)

    v = random.uniform(key, shape=y.shape)
    v_prime = (v * p + (1 - p)) * y + (v * (1 - p)) * (1 - y)
    v_prime = np.clip(v_prime, tol, 1 - tol)

    logit_v = logit(v_prime)
    logit_p = logit(p)
    return logit_p + logit_v
Example #26
0
    def test_bond_no_type_static(self, spatial_dimension, dtype):
        harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2)
        disp, _ = space.free()
        metric = space.metric(disp)

        mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32))

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2]))

            self.assertAllClose(mapped(R), dtype(accum))
Example #27
0
    def test_cell_list_random_emplace(self, dtype, dim):
        key = random.PRNGKey(1)

        box_size = f32(9.0)
        cell_size = f32(1.0)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, dim))

        cell_fn = partition.cell_list(box_size, cell_size, R)
        cell_list = cell_fn(R)

        id_flat = np.reshape(cell_list.id_buffer, (-1, ))
        R_flat = np.reshape(cell_list.R_buffer, (-1, dim))
        R_out = np.zeros((PARTICLE_COUNT + 1, dim))
        R_out = ops.index_update(R_out, id_flat, R_flat)[:-1]

        self.assertAllClose(R_out, R)
Example #28
0
    def test_morse_neighbor_list_force(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_force_fn = quantity.force(energy.morse_pair(displacement))

        r = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.morse_neighbor_list(
            displacement, box_size)
        force_fn = quantity.force(energy_fn)

        nbrs = neighbor_fn(r)
        self.assertAllClose(np.array(exact_force_fn(r), dtype=dtype),
                            force_fn(r, nbrs))
Example #29
0
def main(unused_argv):
  key = random.PRNGKey(0)

  # Setup some variables describing the system.
  N = 500
  dimension = 2
  box_size = f32(25.0)

  # Create helper functions to define a periodic box of some size.
  displacement, shift = space.periodic(box_size)

  metric = space.metric(displacement)

  # Use JAX's random number generator to generate random initial positions.
  key, split = random.split(key)
  R = random.uniform(
    split, (N, dimension), minval=0.0, maxval=box_size, dtype=f32)

  # The system ought to be a 50:50 mixture of two types of particles, one
  # large and one small.
  sigma = np.array([[1.0, 1.2], [1.2, 1.4]], dtype=f32)
  N_2 = int(N / 2)
  species = np.array([0] * N_2 + [1] * N_2, dtype=i32)

  # Create an energy function.
  energy_fn = energy.soft_sphere_pair(displacement, species, sigma)
  force_fn = quantity.force(energy_fn)

  # Create a minimizer.
  init_fn, apply_fn = minimize.fire_descent(energy_fn, shift)
  opt_state = init_fn(R)

  # Minimize the system.
  minimize_steps = 50
  print_every = 10

  print('Minimizing.')
  print('Step\tEnergy\tMax Force')
  print('-----------------------------------')
  for step in range(minimize_steps):
    opt_state = apply_fn(opt_state)

    if step % print_every == 0:
      R = opt_state.position
      print('{:.2f}\t{:.2f}\t{:.2f}'.format(
          step, energy_fn(R), np.max(force_fn(R))))
Example #30
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model, subkey),
                                      substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(
                *model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if v['type'] == 'sample' and not v['is_observed'] and not v[
                        'fn'].is_discrete:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            params = transform_fn(
                inv_transforms, {k: v
                                 for k, v in constrained_values.items()},
                invert=True)
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(subkey,
                                               jnp.shape(v),
                                               minval=-radius,
                                               maxval=radius)
                    key, subkey = random.split(key)

        potential_fn = partial(potential_energy,
                               model,
                               model_args,
                               model_kwargs,
                               enum=enum)
        if forward_mode_differentiation:
            pe = potential_fn(params)
            z_grad = jacfwd(potential_fn)(params)
        else:
            pe, z_grad = value_and_grad(potential_fn)(params)
        z_grad_flat = ravel_pytree(z_grad)[0]
        is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        return i + 1, key, (params, pe, z_grad), is_valid