Esempio n. 1
0
def _sw_angle_interaction(dR12,
                          dR13,
                          gamma=1.2,
                          sigma=2.0951,
                          cutoff=1.8 * 2.0951):
    """The angular interaction for the Stillinger-Weber potential.
  This function is defined only for interaction with a pair of
  neighbors. We then vmap this function three times below to make
  it work on the whole system of atoms.
  Args:
    dR12: A d-dimensional vector that specifies the displacement
    of the first neighbor. This potential is usually used in three
    dimensions.
    dR13: A d-dimensional vector that specifies the displacement
    of the second neighbor.
    gamma: A scalar used to fit the angle interaction.
    sigma: A scalar that sets the distance scale between neighbors.
    cutoff: The cutoff beyond which the interactions are not
    considered. The default value should not be changed for the 
    default SW potential.
  Returns:
    Angular interaction energy for one pair of neighbors.
    """
    a = cutoff / sigma
    dr12 = space.distance(dR12)
    dr13 = space.distance(dR13)
    dr12 = np.where(dr12 < cutoff, dr12, 0)
    dr13 = np.where(dr13 < cutoff, dr13, 0)
    term1 = np.exp(gamma / (dr12 / sigma - a) + gamma / (dr13 / sigma - a))
    cos_angle = quantity.angle_between_two_vectors(dR12, dR13)
    term2 = (cos_angle + 1. / 3)**2
    within_cutoff = (dr12 > 0) & (dr13 > 0) & (np.linalg.norm(dR12 - dR13) >
                                               1e-5)
    return np.where(within_cutoff, term1 * term2, 0)
Esempio n. 2
0
    def test_periodic_displacement(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)

            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            dR = space.map_product(space.pairwise_displacement)(R, R)

            dR_wrapped = space.periodic_displacement(f32(1.0), dR)

            dR_direct = dR
            dr_direct = space.distance(dR)
            dr_direct = np.reshape(dr_direct, dr_direct.shape + (1, ))

            if spatial_dimension == 2:
                for i in range(-1, 2):
                    for j in range(-1, 2):
                        dR_shifted = dR + np.array([i, j], dtype=R.dtype)

                        dr_shifted = space.distance(dR_shifted)
                        dr_shifted = np.reshape(dr_shifted,
                                                dr_shifted.shape + (1, ))

                        dR_direct = np.where(dr_shifted < dr_direct,
                                             dR_shifted, dR_direct)
                        dr_direct = np.where(dr_shifted < dr_direct,
                                             dr_shifted, dr_direct)
            elif spatial_dimension == 3:
                for i in range(-1, 2):
                    for j in range(-1, 2):
                        for k in range(-1, 2):
                            dR_shifted = dR + np.array([i, j, k],
                                                       dtype=R.dtype)

                            dr_shifted = space.distance(dR_shifted)
                            dr_shifted = np.reshape(dr_shifted,
                                                    dr_shifted.shape + (1, ))

                            dR_direct = np.where(dr_shifted < dr_direct,
                                                 dR_shifted, dR_direct)
                            dr_direct = np.where(dr_shifted < dr_direct,
                                                 dr_shifted, dr_direct)

            dR_direct = np.array(dR_direct, dtype=dR.dtype)
            assert dR_wrapped.dtype == dtype
            self.assertAllClose(dR_wrapped, dR_direct, True)
Esempio n. 3
0
 def compute_fn(R, **kwargs):
     d = partial(displacement, **kwargs)
     dR = space.map_product(d)(R, R)
     dr = space.distance(dR)
     first_term = np.sum(_sw_radial_interaction(dr)) / 2.0 * A
     second_term = lam * np.sum(sw_three_body_term(dR, dR)) / 2.0
     return epsilon * (first_term + three_body_strength * second_term)
Esempio n. 4
0
 def compute_fn(R):
     dR = space.map_product(displacement)(R, R)
     dr = space.distance(dR)
     first_term = A * np.sum(_gupta_term1(dr, p, r_0n, cutoff), axis=1)
     second_term = np.sqrt(np.sum(_gupta_term2(dr, q, r_0n, cutoff),
                                  axis=1))
     return U_n / 2.0 * np.sum(first_term - second_term)
Esempio n. 5
0
 def energy(R, **kwargs):
     dr = space.distance(displacement(R, R, **kwargs))
     total_charge = smap._high_precision_sum(charge_fn(dr), axis=1)
     embedding_energy = embedding_fn(total_charge)
     pairwise_energy = smap._high_precision_sum(
         smap._diagonal_mask(pairwise_fn(dr)), axis=1) / f32(2.0)
     return smap._high_precision_sum(embedding_energy + pairwise_energy,
                                     axis=axis)
Esempio n. 6
0
 def exact_stress(R):
   dR = space.map_product(displacement_fn)(R, R)
   dr = space.distance(dR)
   g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
   V = quantity.volume(dim, box)
   dUdr = 0.5 * g(dr)[:, :, None, None]
   dr = (dr + jnp.eye(N))[:, :, None, None]
   return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr),
                   axis=(0, 1))
Esempio n. 7
0
def single_pair_angular_symmetry_function(dR12, dR13, eta, lam, zeta,
                                          cutoff_distance):
    """Computes the angular symmetry function due to one pair of neighbors."""

    dR23 = dR12 - dR13
    dr12_2 = space.square_distance(dR12)
    dr13_2 = space.square_distance(dR13)
    dr23_2 = space.square_distance(dR23)
    dr12 = space.distance(dR12)
    dr13 = space.distance(dR13)
    dr23 = space.distance(dR23)
    triplet_squared_distances = dr12_2 + dr13_2 + dr23_2
    triplet_cutoff = reduce(
        lambda x, y: x * _behler_parrinello_cutoff_fn(y, cutoff_distance),
        [dr12, dr13, dr23], 1.0)
    result = 2.0 ** (1.0 - zeta) * (
        1.0 + lam * quantity.angle_between_two_vectors(dR12, dR13)) ** zeta * \
        np.exp(-eta * triplet_squared_distances) * triplet_cutoff
    return result
Esempio n. 8
0
 def compute_fn(R):
     dR = space.map_product(displacement)(R, R)
     dr = space.distance(dR)
     first_term = A * np.sum(_gupta_term1(dr, p, r_0n, cutoff), axis=1)
     # Safe sqrt used in order to ensure that force calculations are not nan
     # when the particles are too widely separated at initialization
     # (corresponding to the case where the attractive term is 0.).
     attractive_term = np.sum(_gupta_term2(dr, q, r_0n, cutoff), axis=1)
     second_term = util.safe_mask(attractive_term > 0, np.sqrt,
                                  attractive_term)
     return U_n / 2.0 * np.sum(first_term - second_term)
Esempio n. 9
0
def get_dir_cos(dist_vec):
    """ Calculates directional cosines from distance vectors.
    Calculate directional cosines with respect to the standard cartesian
    axes and avoid division by zero
    Args:
        dist_vec: distance vector between particles
    Returns: dir_cos, array of directional cosines of distances between particles
    """
    norm = distance(dist_vec)
    dir_cos = dist_vec * jnp.repeat(jnp.expand_dims(jnp.where(
        jnp.linalg.norm(dist_vec, axis=-1) == 0, jnp.zeros(norm.shape), 1 / norm), axis=-1), 3, axis=-1)
    return dir_cos
Esempio n. 10
0
  def test_graph_network_learning(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    R_key, dr0_key, params_key = random.split(key, 3)

    d, _ = space.free()

    R = random.uniform(R_key, (6, 3, spatial_dimension), dtype=dtype)
    dr0 = random.uniform(dr0_key, (6, 3, 3), dtype=dtype)
    E_gt = vmap(
      lambda R, dr0: \
      np.sum((space.distance(space.map_product(d)(R, R)) - dr0) ** 2))

    cutoff = 0.2

    init_fn, energy_fn = energy.graph_network(d, cutoff)
    params = init_fn(params_key, R[0])

    @jit
    def loss(params, R):
      return np.mean((vmap(energy_fn, (None, 0))(params, R) - E_gt(R, dr0)) ** 2)

    opt = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-4))


    @jit
    def update(params, opt_state, R):
      updates, opt_state = opt.update(grad(loss)(params, R),
                                      opt_state)
      return optax.apply_updates(params, updates), opt_state

    opt_state = opt.init(params)

    l0 = loss(params, R)
    for i in range(4):
      params, opt_state = update(params, opt_state, R)

    assert loss(params, R) < l0 * 0.95
Esempio n. 11
0
def angle_between_two_vectors(dR_12, dR_13):
  dr_12 = space.distance(dR_12) + 1e-7
  dr_13 = space.distance(dR_13) + 1e-7
  cos_angle = np.dot(dR_12, dR_13) / dr_12 / dr_13
  return np.clip(cos_angle, -1.0, 1.0)
Esempio n. 12
0
 def truncated_square(dR, sigma):
     dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1, ))
     return np.where(dr < sigma, dR**2, f32(0.))
Esempio n. 13
0
def angle_between_two_vectors(dR_12: Array, dR_13: Array) -> Array:
    dr_12 = space.distance(dR_12) + 1e-7
    dr_13 = space.distance(dR_13) + 1e-7
    cos_angle = jnp.dot(dR_12, dR_13) / dr_12 / dr_13
    return jnp.clip(cos_angle, -1.0, 1.0)
Esempio n. 14
0
 def truncated_square(dR, sigma):
   dr = space.distance(dR)
   return np.where(dr < sigma, dr ** 2, f32(0.))
Esempio n. 15
0
    def create_hamiltonian_wo_k(positions, species, shifts, kwargs, kwargs_diag, kwargs_overlap):
        """
        Args:
            positions: particle position matrix 2D
            species: array of species
            shifts: uses 2D shifts matrix of comupte_shifts function
            kwargs: Dict of 2D matrix of slyer-koster parameters
            kwargs_diag: Dict of 2D matrix of one-site parameters
            kwargs_overlap: Dict of 2D matrix of off-site overlap parameters
        Returns:
            hamiltonian wo k as matrix
        """
        n_orbitals = 9  # 4 for sp, 9 for spd
        create_bondmatrix_mask = bondmatrix_masking(cutoff)
        shifted_positions = vmap(shift_fn, (0, None))(positions, shifts)
        shifted_pair_distance_vectors = (vmap(vmap(vmap(dist_vec, (0, None), 0), (1, None), 1), (None, 0), 0)
                                         (shifted_positions, positions))

        # expand species shape to be the same as the shifted coordinates
        shifted_species = jnp.repeat(jnp.expand_dims(species, axis=0), shifts.shape[0], axis=0).T
        # flatten first dimension for cartesian product
        shifted_species = shifted_species.reshape((shifted_species.shape[0] * shifted_species.shape[1],))
        shifted_species = cartesian_prod(shifted_species, species).T
        # separate into two vectors for particle pairs a,b and reshape to (particle number, particle_number, N_images)
        species_a = shifted_species[0].reshape(shifted_pair_distance_vectors.shape[0:-1])
        species_b = shifted_species[1].reshape(shifted_pair_distance_vectors.shape[0:-1])
        dir_cos = get_dir_cos(shifted_pair_distance_vectors)  # (particle number, particle number, N_images, 3)

        pair_distances = distance(shifted_pair_distance_vectors)  # (particle number, particle number, N_images, dim)
        bondmatrix = create_bondmatrix_mask(pair_distances)

        # off-site
        param_vec = get_params(pair_distances, species_a, species_b, kwargs)
        param_vec *= jnp.expand_dims(bondmatrix, axis=-1)
        time_start = time.time()
        hamiltonian = vmap(vmap(vmap(get_hop_int, 0), 0), 0)(jnp.concatenate([param_vec, dir_cos], axis=-1))
        time_end = time.time()
        print("Time get hop int", time_end-time_start)
        # print("hamiltonian", hamiltonian.shape)
        
        # onsite
        param_diag = get_params_diag(pair_distances, species_a, kwargs_diag)
        hamiltonian += param_diag

        # overlap matrix
        overlap_vec = get_params_overlap(pair_distances, species_a, species_b, kwargs_overlap)
        overlap_vec *= jnp.expand_dims(bondmatrix, axis=-1)
        overlap_matrix = vmap(vmap(vmap(get_hop_int, 0), 0), 0)(jnp.concatenate([overlap_vec, dir_cos], axis=-1))

        # reshape hamiltonian \ overlap to (particle number*N_orbitals, particle number*N_orbitals, N_images)
        hamiltonian = jnp.reshape(jnp.transpose(hamiltonian, (0, 3, 1, 4, 2)),
                                  (species.shape[0] * n_orbitals, species.shape[0] * n_orbitals, shifts.shape[0]))
        overlap_matrix = jnp.reshape(jnp.transpose(overlap_matrix, (0, 3, 1, 4, 2)),
                                  (species.shape[0] * n_orbitals, species.shape[0] * n_orbitals, shifts.shape[0]))
        # print("hamiltonian", hamiltonian.shape)

        # onsite overlap
        # overlap_diag = jnp.expand_dims(jnp.diag(jnp.ones(species.shape[0] * n_orbitals)), -1)
        # print("overlap", overlap_diag.shape, overlap_diag[:, :, 0])
        # overlap_matrix += overlap_diag
        # print("overlap", overlap_matrix.shape, overlap_matrix)
        # np.save("overlap_wo_k_jax.npy", overlap_matrix)

        return hamiltonian, overlap_matrix