Ejemplo n.º 1
0
def run_minimization_while(energy_fn, 
                           R_init, 
                           shift, 
                           max_grad_thresh = 1e-12, 
                           max_num_steps=1000000, 
                           **kwargs):
  init,apply=minimize.fire_descent(jit(energy_fn), shift, **kwargs)
  apply = jit(apply)

  @jit
  def get_maxgrad(state):
    return jnp.amax(jnp.abs(state.force))

  @jit
  def cond_fn(val):
    state, i = val
    return jnp.logical_and(get_maxgrad(state) > max_grad_thresh, 
                           i<max_num_steps)

  @jit
  def body_fn(val):
    state, i = val
    return apply(state), i+1

  state = init(R_init)
  state, num_iterations = lax.while_loop(cond_fn, body_fn, (state, 0))

  return state.position, get_maxgrad(state), num_iterations
Ejemplo n.º 2
0
    def test_fire_descent(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split, split0 = random.split(key, 3)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            R0 = random.uniform(split0, (PARTICLE_COUNT, spatial_dimension),
                                dtype=dtype)

            energy = lambda R, **kwargs: np.sum((R - R0)**2)
            _, shift_fn = space.free()

            opt_init, opt_apply = minimize.fire_descent(energy, shift_fn)

            opt_state = opt_init(R)
            E_current = energy(R)
            dr_current = np.sum((R - R0)**2)

            @jit
            def three_steps(state):
                return opt_apply(opt_apply(opt_apply(state)))

            for _ in range(OPTIMIZATION_STEPS):
                opt_state = three_steps(opt_state)
                R = opt_state.position
                E_new = energy(R)
                dr_new = np.sum((R - R0)**2)
                assert E_new < E_current
                assert E_new.dtype == dtype
                assert dr_new < dr_current
                assert dr_new.dtype == dtype
                E_current = E_new
                dr_current = dr_new
Ejemplo n.º 3
0
def main(unused_argv):
    key = random.PRNGKey(0)

    # Setup some variables describing the system.
    N = 500
    dimension = 2
    box_size = f32(25.0)

    # Create helper functions to define a periodic box of some size.
    displacement, shift = space.periodic(box_size)

    metric = space.metric(displacement)

    # Use JAX's random number generator to generate random initial positions.
    key, split = random.split(key)
    R = random.uniform(split, (N, dimension),
                       minval=0.0,
                       maxval=box_size,
                       dtype=f32)

    # The system ought to be a 50:50 mixture of two types of particles, one
    # large and one small.
    sigma = np.array([[1.0, 1.2], [1.2, 1.4]], dtype=f32)
    N_2 = int(N / 2)
    species = np.array([0] * N_2 + [1] * N_2, dtype=i32)

    # Create an energy function.
    energy_fn = energy.soft_sphere_pair(displacement, species, sigma)
    force_fn = quantity.force(energy_fn)

    # Create a minimizer.
    init_fn, apply_fn = minimize.fire_descent(energy_fn, shift)
    opt_state = init_fn(R)

    # Minimize the system.
    minimize_steps = 50
    print_every = 10

    print('Minimizing.')
    print('Step\tEnergy\tMax Force')
    print('-----------------------------------')
    for step in range(minimize_steps):
        opt_state = apply_fn(opt_state)

        if step % print_every == 0:
            R = opt_state.position
            print('{:.2f}\t{:.2f}\t{:.2f}'.format(step, energy_fn(R),
                                                  np.max(force_fn(R))))
Ejemplo n.º 4
0
    def test_fire_descent(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split, split0 = random.split(key, 3)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            R0 = random.uniform(split0, (PARTICLE_COUNT, spatial_dimension),
                                dtype=dtype)

            energy = lambda R, **kwargs: np.sum((R - R0)**2)
            _, shift_fn = space.free()

            opt_init, opt_apply = minimize.fire_descent(energy, shift_fn)

            opt_state = opt_init(R)
            E_current = energy(R)
            dr_current = np.sum((R - R0)**2)

            # NOTE(schsam): We add this to test to make sure we can jit through the
            # creation of FireDescentState.
            step_fn = lambda i, state: opt_apply(state)

            @jit
            def three_steps(state):
                return lax.fori_loop(0, 3, step_fn, state)

            for _ in range(OPTIMIZATION_STEPS):
                opt_state = three_steps(opt_state)
                R = opt_state.position
                E_new = energy(R)
                dr_new = np.sum((R - R0)**2)
                assert E_new < E_current
                assert E_new.dtype == dtype
                assert dr_new < dr_current
                assert dr_new.dtype == dtype
                E_current = E_new
                dr_current = dr_new