def run_minimization_while(energy_fn, R_init, shift, max_grad_thresh = 1e-12, max_num_steps=1000000, **kwargs): init,apply=minimize.fire_descent(jit(energy_fn), shift, **kwargs) apply = jit(apply) @jit def get_maxgrad(state): return jnp.amax(jnp.abs(state.force)) @jit def cond_fn(val): state, i = val return jnp.logical_and(get_maxgrad(state) > max_grad_thresh, i<max_num_steps) @jit def body_fn(val): state, i = val return apply(state), i+1 state = init(R_init) state, num_iterations = lax.while_loop(cond_fn, body_fn, (state, 0)) return state.position, get_maxgrad(state), num_iterations
def test_fire_descent(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split, split0 = random.split(key, 3) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.uniform(split0, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) energy = lambda R, **kwargs: np.sum((R - R0)**2) _, shift_fn = space.free() opt_init, opt_apply = minimize.fire_descent(energy, shift_fn) opt_state = opt_init(R) E_current = energy(R) dr_current = np.sum((R - R0)**2) @jit def three_steps(state): return opt_apply(opt_apply(opt_apply(state))) for _ in range(OPTIMIZATION_STEPS): opt_state = three_steps(opt_state) R = opt_state.position E_new = energy(R) dr_new = np.sum((R - R0)**2) assert E_new < E_current assert E_new.dtype == dtype assert dr_new < dr_current assert dr_new.dtype == dtype E_current = E_new dr_current = dr_new
def main(unused_argv): key = random.PRNGKey(0) # Setup some variables describing the system. N = 500 dimension = 2 box_size = f32(25.0) # Create helper functions to define a periodic box of some size. displacement, shift = space.periodic(box_size) metric = space.metric(displacement) # Use JAX's random number generator to generate random initial positions. key, split = random.split(key) R = random.uniform(split, (N, dimension), minval=0.0, maxval=box_size, dtype=f32) # The system ought to be a 50:50 mixture of two types of particles, one # large and one small. sigma = np.array([[1.0, 1.2], [1.2, 1.4]], dtype=f32) N_2 = int(N / 2) species = np.array([0] * N_2 + [1] * N_2, dtype=i32) # Create an energy function. energy_fn = energy.soft_sphere_pair(displacement, species, sigma) force_fn = quantity.force(energy_fn) # Create a minimizer. init_fn, apply_fn = minimize.fire_descent(energy_fn, shift) opt_state = init_fn(R) # Minimize the system. minimize_steps = 50 print_every = 10 print('Minimizing.') print('Step\tEnergy\tMax Force') print('-----------------------------------') for step in range(minimize_steps): opt_state = apply_fn(opt_state) if step % print_every == 0: R = opt_state.position print('{:.2f}\t{:.2f}\t{:.2f}'.format(step, energy_fn(R), np.max(force_fn(R))))
def test_fire_descent(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split, split0 = random.split(key, 3) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.uniform(split0, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) energy = lambda R, **kwargs: np.sum((R - R0)**2) _, shift_fn = space.free() opt_init, opt_apply = minimize.fire_descent(energy, shift_fn) opt_state = opt_init(R) E_current = energy(R) dr_current = np.sum((R - R0)**2) # NOTE(schsam): We add this to test to make sure we can jit through the # creation of FireDescentState. step_fn = lambda i, state: opt_apply(state) @jit def three_steps(state): return lax.fori_loop(0, 3, step_fn, state) for _ in range(OPTIMIZATION_STEPS): opt_state = three_steps(opt_state) R = opt_state.position E_new = energy(R) dr_new = np.sum((R - R0)**2) assert E_new < E_current assert E_new.dtype == dtype assert dr_new < dr_current assert dr_new.dtype == dtype E_current = E_new dr_current = dr_new