def simple_spring_bond( displacement_or_metric, bond, bond_type=None, length=1, epsilon=1, alpha=2): """Convenience wrapper to compute energy of particles bonded by springs.""" length = np.array(length, f32) epsilon = np.array(epsilon, f32) alpha = np.array(alpha, f32) return smap.bond( simple_spring, space.canonicalize_displacement_or_metric(displacement_or_metric), bond, bond_type, length=length, epsilon=epsilon, alpha=alpha)
def test_bond_no_type_static(self, spatial_dimension, dtype): harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2) disp, _ = space.free() metric = space.metric(disp) mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32)) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2])) self.assertAllClose(mapped(R), dtype(accum))
def simple_spring_bond(displacement_or_metric: DisplacementOrMetricFn, bond: Array, bond_type: Array = None, length: Array = 1, epsilon: Array = 1, alpha: Array = 2) -> Callable[[Array], Array]: """Convenience wrapper to compute energy of particles bonded by springs.""" length = np.array(length, f32) epsilon = np.array(epsilon, f32) alpha = np.array(alpha, f32) return smap.bond( simple_spring, space.canonicalize_displacement_or_metric(displacement_or_metric), bond, bond_type, length=length, epsilon=epsilon, alpha=alpha)
def bistable_spring_bond(displacement_or_metric, bond, bond_type=None, r0=1, a2=2, a4=5): """Convenience wrapper to compute energy of particles bonded by springs.""" r0 = jnp.array(r0, jnp.float32) a2 = jnp.array(a2, jnp.float32) a4 = jnp.array(a4, jnp.float32) return smap.bond( bistable_spring, space.canonicalize_displacement_or_metric(displacement_or_metric), bond, bond_type, r0=r0, a2=a2, a4=a4)
def simple_spring_bond(displacement_or_metric: DisplacementOrMetricFn, bond: Array, bond_type: Array = None, length: Array = 1, epsilon: Array = 1, alpha: Array = 2) -> Callable[[Array], Array]: """Convenience wrapper to compute energy of particles bonded by springs.""" length = maybe_downcast(length) epsilon = maybe_downcast(epsilon) alpha = maybe_downcast(alpha) return smap.bond( simple_spring, space.canonicalize_displacement_or_metric(displacement_or_metric), bond, bond_type, ignore_unused_parameters=True, length=length, epsilon=epsilon, alpha=alpha)
def test_bond_params_dynamic(self, spatial_dimension, dtype): harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2) disp, _ = space.free() metric = space.metric(disp) sigma = np.array([1.0, 2.0], f32) mapped = smap.bond(harmonic, metric, sigma=1.0) bonds = np.array([[0, 1], [0, 2]], i32) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2) self.assertAllClose(mapped(R, bonds, sigma=sigma), dtype(accum))