def _sw_angle_interaction(dR12, dR13, gamma=1.2, sigma=2.0951, cutoff=1.8 * 2.0951): """The angular interaction for the Stillinger-Weber potential. This function is defined only for interaction with a pair of neighbors. We then vmap this function three times below to make it work on the whole system of atoms. Args: dR12: A d-dimensional vector that specifies the displacement of the first neighbor. This potential is usually used in three dimensions. dR13: A d-dimensional vector that specifies the displacement of the second neighbor. gamma: A scalar used to fit the angle interaction. sigma: A scalar that sets the distance scale between neighbors. cutoff: The cutoff beyond which the interactions are not considered. The default value should not be changed for the default SW potential. Returns: Angular interaction energy for one pair of neighbors. """ a = cutoff / sigma dr12 = space.distance(dR12) dr13 = space.distance(dR13) dr12 = np.where(dr12 < cutoff, dr12, 0) dr13 = np.where(dr13 < cutoff, dr13, 0) term1 = np.exp(gamma / (dr12 / sigma - a) + gamma / (dr13 / sigma - a)) cos_angle = quantity.angle_between_two_vectors(dR12, dR13) term2 = (cos_angle + 1. / 3)**2 within_cutoff = (dr12 > 0) & (dr13 > 0) & (np.linalg.norm(dR12 - dR13) > 1e-5) return np.where(within_cutoff, term1 * term2, 0)
def test_periodic_displacement(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) dR = space.map_product(space.pairwise_displacement)(R, R) dR_wrapped = space.periodic_displacement(f32(1.0), dR) dR_direct = dR dr_direct = space.distance(dR) dr_direct = np.reshape(dr_direct, dr_direct.shape + (1, )) if spatial_dimension == 2: for i in range(-1, 2): for j in range(-1, 2): dR_shifted = dR + np.array([i, j], dtype=R.dtype) dr_shifted = space.distance(dR_shifted) dr_shifted = np.reshape(dr_shifted, dr_shifted.shape + (1, )) dR_direct = np.where(dr_shifted < dr_direct, dR_shifted, dR_direct) dr_direct = np.where(dr_shifted < dr_direct, dr_shifted, dr_direct) elif spatial_dimension == 3: for i in range(-1, 2): for j in range(-1, 2): for k in range(-1, 2): dR_shifted = dR + np.array([i, j, k], dtype=R.dtype) dr_shifted = space.distance(dR_shifted) dr_shifted = np.reshape(dr_shifted, dr_shifted.shape + (1, )) dR_direct = np.where(dr_shifted < dr_direct, dR_shifted, dR_direct) dr_direct = np.where(dr_shifted < dr_direct, dr_shifted, dr_direct) dR_direct = np.array(dR_direct, dtype=dR.dtype) assert dR_wrapped.dtype == dtype self.assertAllClose(dR_wrapped, dR_direct, True)
def compute_fn(R, **kwargs): d = partial(displacement, **kwargs) dR = space.map_product(d)(R, R) dr = space.distance(dR) first_term = np.sum(_sw_radial_interaction(dr)) / 2.0 * A second_term = lam * np.sum(sw_three_body_term(dR, dR)) / 2.0 return epsilon * (first_term + three_body_strength * second_term)
def compute_fn(R): dR = space.map_product(displacement)(R, R) dr = space.distance(dR) first_term = A * np.sum(_gupta_term1(dr, p, r_0n, cutoff), axis=1) second_term = np.sqrt(np.sum(_gupta_term2(dr, q, r_0n, cutoff), axis=1)) return U_n / 2.0 * np.sum(first_term - second_term)
def energy(R, **kwargs): dr = space.distance(displacement(R, R, **kwargs)) total_charge = smap._high_precision_sum(charge_fn(dr), axis=1) embedding_energy = embedding_fn(total_charge) pairwise_energy = smap._high_precision_sum( smap._diagonal_mask(pairwise_fn(dr)), axis=1) / f32(2.0) return smap._high_precision_sum(embedding_energy + pairwise_energy, axis=axis)
def exact_stress(R): dR = space.map_product(displacement_fn)(R, R) dr = space.distance(dR) g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()') V = quantity.volume(dim, box) dUdr = 0.5 * g(dr)[:, :, None, None] dr = (dr + jnp.eye(N))[:, :, None, None] return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr), axis=(0, 1))
def single_pair_angular_symmetry_function(dR12, dR13, eta, lam, zeta, cutoff_distance): """Computes the angular symmetry function due to one pair of neighbors.""" dR23 = dR12 - dR13 dr12_2 = space.square_distance(dR12) dr13_2 = space.square_distance(dR13) dr23_2 = space.square_distance(dR23) dr12 = space.distance(dR12) dr13 = space.distance(dR13) dr23 = space.distance(dR23) triplet_squared_distances = dr12_2 + dr13_2 + dr23_2 triplet_cutoff = reduce( lambda x, y: x * _behler_parrinello_cutoff_fn(y, cutoff_distance), [dr12, dr13, dr23], 1.0) result = 2.0 ** (1.0 - zeta) * ( 1.0 + lam * quantity.angle_between_two_vectors(dR12, dR13)) ** zeta * \ np.exp(-eta * triplet_squared_distances) * triplet_cutoff return result
def compute_fn(R): dR = space.map_product(displacement)(R, R) dr = space.distance(dR) first_term = A * np.sum(_gupta_term1(dr, p, r_0n, cutoff), axis=1) # Safe sqrt used in order to ensure that force calculations are not nan # when the particles are too widely separated at initialization # (corresponding to the case where the attractive term is 0.). attractive_term = np.sum(_gupta_term2(dr, q, r_0n, cutoff), axis=1) second_term = util.safe_mask(attractive_term > 0, np.sqrt, attractive_term) return U_n / 2.0 * np.sum(first_term - second_term)
def get_dir_cos(dist_vec): """ Calculates directional cosines from distance vectors. Calculate directional cosines with respect to the standard cartesian axes and avoid division by zero Args: dist_vec: distance vector between particles Returns: dir_cos, array of directional cosines of distances between particles """ norm = distance(dist_vec) dir_cos = dist_vec * jnp.repeat(jnp.expand_dims(jnp.where( jnp.linalg.norm(dist_vec, axis=-1) == 0, jnp.zeros(norm.shape), 1 / norm), axis=-1), 3, axis=-1) return dir_cos
def test_graph_network_learning(self, spatial_dimension, dtype): key = random.PRNGKey(0) R_key, dr0_key, params_key = random.split(key, 3) d, _ = space.free() R = random.uniform(R_key, (6, 3, spatial_dimension), dtype=dtype) dr0 = random.uniform(dr0_key, (6, 3, 3), dtype=dtype) E_gt = vmap( lambda R, dr0: \ np.sum((space.distance(space.map_product(d)(R, R)) - dr0) ** 2)) cutoff = 0.2 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(params_key, R[0]) @jit def loss(params, R): return np.mean((vmap(energy_fn, (None, 0))(params, R) - E_gt(R, dr0)) ** 2) opt = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-4)) @jit def update(params, opt_state, R): updates, opt_state = opt.update(grad(loss)(params, R), opt_state) return optax.apply_updates(params, updates), opt_state opt_state = opt.init(params) l0 = loss(params, R) for i in range(4): params, opt_state = update(params, opt_state, R) assert loss(params, R) < l0 * 0.95
def angle_between_two_vectors(dR_12, dR_13): dr_12 = space.distance(dR_12) + 1e-7 dr_13 = space.distance(dR_13) + 1e-7 cos_angle = np.dot(dR_12, dR_13) / dr_12 / dr_13 return np.clip(cos_angle, -1.0, 1.0)
def truncated_square(dR, sigma): dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1, )) return np.where(dr < sigma, dR**2, f32(0.))
def angle_between_two_vectors(dR_12: Array, dR_13: Array) -> Array: dr_12 = space.distance(dR_12) + 1e-7 dr_13 = space.distance(dR_13) + 1e-7 cos_angle = jnp.dot(dR_12, dR_13) / dr_12 / dr_13 return jnp.clip(cos_angle, -1.0, 1.0)
def truncated_square(dR, sigma): dr = space.distance(dR) return np.where(dr < sigma, dr ** 2, f32(0.))
def create_hamiltonian_wo_k(positions, species, shifts, kwargs, kwargs_diag, kwargs_overlap): """ Args: positions: particle position matrix 2D species: array of species shifts: uses 2D shifts matrix of comupte_shifts function kwargs: Dict of 2D matrix of slyer-koster parameters kwargs_diag: Dict of 2D matrix of one-site parameters kwargs_overlap: Dict of 2D matrix of off-site overlap parameters Returns: hamiltonian wo k as matrix """ n_orbitals = 9 # 4 for sp, 9 for spd create_bondmatrix_mask = bondmatrix_masking(cutoff) shifted_positions = vmap(shift_fn, (0, None))(positions, shifts) shifted_pair_distance_vectors = (vmap(vmap(vmap(dist_vec, (0, None), 0), (1, None), 1), (None, 0), 0) (shifted_positions, positions)) # expand species shape to be the same as the shifted coordinates shifted_species = jnp.repeat(jnp.expand_dims(species, axis=0), shifts.shape[0], axis=0).T # flatten first dimension for cartesian product shifted_species = shifted_species.reshape((shifted_species.shape[0] * shifted_species.shape[1],)) shifted_species = cartesian_prod(shifted_species, species).T # separate into two vectors for particle pairs a,b and reshape to (particle number, particle_number, N_images) species_a = shifted_species[0].reshape(shifted_pair_distance_vectors.shape[0:-1]) species_b = shifted_species[1].reshape(shifted_pair_distance_vectors.shape[0:-1]) dir_cos = get_dir_cos(shifted_pair_distance_vectors) # (particle number, particle number, N_images, 3) pair_distances = distance(shifted_pair_distance_vectors) # (particle number, particle number, N_images, dim) bondmatrix = create_bondmatrix_mask(pair_distances) # off-site param_vec = get_params(pair_distances, species_a, species_b, kwargs) param_vec *= jnp.expand_dims(bondmatrix, axis=-1) time_start = time.time() hamiltonian = vmap(vmap(vmap(get_hop_int, 0), 0), 0)(jnp.concatenate([param_vec, dir_cos], axis=-1)) time_end = time.time() print("Time get hop int", time_end-time_start) # print("hamiltonian", hamiltonian.shape) # onsite param_diag = get_params_diag(pair_distances, species_a, kwargs_diag) hamiltonian += param_diag # overlap matrix overlap_vec = get_params_overlap(pair_distances, species_a, species_b, kwargs_overlap) overlap_vec *= jnp.expand_dims(bondmatrix, axis=-1) overlap_matrix = vmap(vmap(vmap(get_hop_int, 0), 0), 0)(jnp.concatenate([overlap_vec, dir_cos], axis=-1)) # reshape hamiltonian \ overlap to (particle number*N_orbitals, particle number*N_orbitals, N_images) hamiltonian = jnp.reshape(jnp.transpose(hamiltonian, (0, 3, 1, 4, 2)), (species.shape[0] * n_orbitals, species.shape[0] * n_orbitals, shifts.shape[0])) overlap_matrix = jnp.reshape(jnp.transpose(overlap_matrix, (0, 3, 1, 4, 2)), (species.shape[0] * n_orbitals, species.shape[0] * n_orbitals, shifts.shape[0])) # print("hamiltonian", hamiltonian.shape) # onsite overlap # overlap_diag = jnp.expand_dims(jnp.diag(jnp.ones(species.shape[0] * n_orbitals)), -1) # print("overlap", overlap_diag.shape, overlap_diag[:, :, 0]) # overlap_matrix += overlap_diag # print("overlap", overlap_matrix.shape, overlap_matrix) # np.save("overlap_wo_k_jax.npy", overlap_matrix) return hamiltonian, overlap_matrix