Exemple #1
0
  def test_periodic_against_periodic_general_grad(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    tol = 1e-13
    if dtype is f32:
      tol = 1e-5

    for _ in range(STOCHASTIC_SAMPLES):
      key, split1, split2, split3 = random.split(key, 4)

      max_box_size = f32(10.0)
      box_size = max_box_size * random.uniform(
        split1, (spatial_dimension,), dtype=dtype)
      transform = jnp.diag(box_size)

      R = random.uniform(
        split2, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      R_scaled = R * box_size

      dR = random.normal(
        split3, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      disp_fn, shift_fn = space.periodic(box_size)
      general_disp_fn, general_shift_fn = space.periodic_general(transform)

      disp_fn = space.map_product(disp_fn)
      general_disp_fn = space.map_product(general_disp_fn)

      grad_fn = grad(lambda R: jnp.sum(disp_fn(R, R) ** 2))
      general_grad_fn = grad(lambda R: jnp.sum(general_disp_fn(R, R) ** 2))

      self.assertAllClose(grad_fn(R_scaled), general_grad_fn(R))
      assert general_grad_fn(R).dtype == dtype
Exemple #2
0
    def test_periodic_against_periodic_general(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split1, split2, split3 = random.split(key, 4)

            max_box_size = f16(10.0)
            box_size = max_box_size * random.uniform(split1,
                                                     (spatial_dimension, ),
                                                     dtype=dtype)
            transform = np.diag(box_size)

            R = random.uniform(split2, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            R_scaled = R * box_size

            dR = random.normal(split3, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            disp_fn, shift_fn = space.periodic(box_size)
            general_disp_fn, general_shift_fn = space.periodic_general(
                transform)

            disp_fn = space.map_product(disp_fn)
            general_disp_fn = space.map_product(general_disp_fn)

            self.assertAllClose(disp_fn(R_scaled, R_scaled),
                                general_disp_fn(R, R), True)
            assert disp_fn(R_scaled, R_scaled).dtype == dtype
            self.assertAllClose(shift_fn(R_scaled, dR),
                                general_shift_fn(R, dR) * box_size, True)
            assert shift_fn(R_scaled, dR).dtype == dtype
Exemple #3
0
  def test_triplet_static_species_scalar(self, spatial_dimension, dtype):
      key = random.PRNGKey(0)
      angle_fn = lambda dR1, dR2, param=5.0: param * np.sum(np.square(dR1))
      square = lambda dR, param: param * np.sum(np.square(dR))
      params = f32(np.array([[[1., 1.], [2., 0.]], [[0., 2.], [1., 1.]]]))

      count = PARTICLE_COUNT // 50
      key, split = random.split(key)
      species = random.randint(split, (count,), 0, 2)
      displacement, _ = space.free()
      metric = lambda Ra, Rb, **kwargs: \
        np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)
      triplet_square = smap.triplet(angle_fn,
                                    displacement,
                                    species=species,
                                    param=params,
                                    reduce_axis=None)

      metric = space.map_product(metric)
      for _ in range(STOCHASTIC_SAMPLES):
        key, split = random.split(key)
        R = random.uniform(
            split, (count, spatial_dimension), dtype=dtype)
        total = 0.
        for i in range(2):
          for j in range(2):
            R_1 = R[species == i]
            R_2 = R[species == j]
            total += 0.5 * np.sum(metric(R_1, R_2))
        self.assertAllClose(triplet_square(R) / count, np.array(total, dtype=dtype))
    def test_pair_static_species_vector(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * np.sum(dr**2, axis=2)
        params = np.array([[1.0, 2.0], [2.0, 3.0]], dtype=f32)

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        disp, _ = space.free()

        mapped_square = smap.pair(square, disp, species=species, param=params)

        disp = space.map_product(disp)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(square(disp(R_1, R_2), param))
            self.assertAllClose(mapped_square(R), np.array(total, dtype=dtype))
Exemple #5
0
 def compute_fn(R, **kwargs):
     d = partial(displacement, **kwargs)
     dR = space.map_product(d)(R, R)
     dr = space.distance(dR)
     first_term = np.sum(_sw_radial_interaction(dr)) / 2.0 * A
     second_term = lam * np.sum(sw_three_body_term(dR, dR)) / 2.0
     return epsilon * (first_term + three_body_strength * second_term)
    def test_pair_dynamic_species_scalar(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * dr**2
        params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        mapped_square = smap.pair(square,
                                  metric,
                                  species=quantity.Dynamic,
                                  param=params)

        metric = space.map_product(metric)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(
                        square(metric(R_1, R_2), param))
            self.assertAllClose(mapped_square(R, species, 2),
                                np.array(total, dtype=dtype))
Exemple #7
0
    def test_periodic_general_wrapped_vs_unwrapped(self, spatial_dimension,
                                                   dtype):
        key = random.PRNGKey(0)

        eye = np.eye(spatial_dimension, dtype=dtype)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split_R, split_T = random.split(key, 3)

            dT = random.normal(split_T, (spatial_dimension, spatial_dimension),
                               dtype=dtype)
            T = eye + dT + np.transpose(dT)

            R = random.uniform(split_R, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            R0 = R
            unwrapped_R = R

            displacement, shift = space.periodic_general(T)
            _, unwrapped_shift = space.periodic_general(T, wrapped=False)

            displacement = space.map_product(displacement)

            for _ in range(SHIFT_STEPS):
                key, split = random.split(key)
                dR = random.normal(split, (PARTICLE_COUNT, spatial_dimension),
                                   dtype=dtype)
                R = shift(R, dR)
                unwrapped_R = unwrapped_shift(unwrapped_R, dR)
                self.assertAllClose(displacement(R, R0),
                                    displacement(unwrapped_R, R0), True)
            assert not (np.all(unwrapped_R > 0) and np.all(unwrapped_R < 1))
Exemple #8
0
 def compute_fn(R):
     dR = space.map_product(displacement)(R, R)
     dr = space.distance(dR)
     first_term = A * np.sum(_gupta_term1(dr, p, r_0n, cutoff), axis=1)
     second_term = np.sqrt(np.sum(_gupta_term2(dr, q, r_0n, cutoff),
                                  axis=1))
     return U_n / 2.0 * np.sum(first_term - second_term)
Exemple #9
0
 def fn_mapped(R, **dynamic_kwargs):
     _metric = space.map_product(partial(metric, **dynamic_kwargs))
     _kwargs = merge_dicts(kwargs, dynamic_kwargs)
     _kwargs = _kwargs_to_parameters(species, **_kwargs)
     dr = _metric(R, R)
     return _high_precision_sum(_diagonal_mask(fn(dr, **_kwargs)),
                                axis=reduce_axis,
                                keepdims=keepdims) * f32(0.5)
Exemple #10
0
  def test_canonicalize_displacement_or_metric(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    displacement, _ = space.periodic_general(np.eye(spatial_dimension))
    metric = space.metric(displacement)
    test_metric = space.canonicalize_displacement_or_metric(displacement)

    metric = space.map_product(metric)
    test_metric = space.map_product(test_metric)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split1, split2 = random.split(key, 3)

      R = random.normal(
        split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      self.assertAllClose(metric(R, R), test_metric(R, R), True)
Exemple #11
0
 def exact_stress(R):
   dR = space.map_product(displacement_fn)(R, R)
   dr = space.distance(dR)
   g = jnp.vectorize(grad(energy.soft_sphere), signature='()->()')
   V = quantity.volume(dim, box)
   dUdr = 0.5 * g(dr)[:, :, None, None]
   dr = (dr + jnp.eye(N))[:, :, None, None]
   return -jnp.sum(dUdr * dR[:, :, None, :] * dR[:, :, :, None] / (V * dr),
                   axis=(0, 1))
Exemple #12
0
def pair_correlation(displacement_or_metric: Union[DisplacementFn, MetricFn],
                     radii: Array,
                     sigma: float,
                     species: Array = None):
    """Computes the pair correlation function at a mesh of distances.

  The pair correlation function measures the number of particles at a given
  distance from a central particle. The pair correlation function is defined
  by $g(r) = <\sum_{i\neq j}\delta(r - |r_i - r_j|)>.$ We make the
  approximation
  $\delta(r) \approx {1 \over \sqrt{2\pi\sigma^2}e^{-r / (2\sigma^2)}}$.

  Args:
    displacement_or_metric: A function that computes the displacement or
      distance between two points.
    radii: An array of radii at which we would like to compute g(r).
    sigima: A float specifying the width of the approximating Gaussian.
    species: An optional array specifying the species of each particle. If
      species is None then we compute a single g(r) for all particles,
      otherwise we compute one g(r) for each species.

  Returns:
    A function `g_fn` that computes the pair correlation function for a
    collection of particles.
  """
    d = space.canonicalize_displacement_or_metric(displacement_or_metric)
    d = space.map_product(d)

    def pairwise(dr, dim):
        return jnp.exp(-f32(0.5) *
                       (dr - radii)**2 / sigma**2) / radii**(dim - 1)

    pairwise = vmap(vmap(pairwise, (0, None)), (0, None))

    if species is None:

        def g_fn(R):
            dim = R.shape[-1]
            mask = 1 - jnp.eye(R.shape[0], dtype=R.dtype)
            return jnp.sum(mask[:, :, jnp.newaxis] * pairwise(d(R, R), dim),
                           axis=(1, ))
    else:
        if not (isinstance(species, jnp.ndarray) and is_integer(species)):
            raise TypeError('Malformed species; expecting array of integers.')
        species_types = jnp.unique(species)

        def g_fn(R):
            dim = R.shape[-1]
            g_R = []
            mask = 1 - jnp.eye(R.shape[0], dtype=R.dtype)
            for s in species_types:
                Rs = R[species == s]
                mask_s = mask[:, species == s, jnp.newaxis]
                g_R += [jnp.sum(mask_s * pairwise(d(Rs, R), dim), axis=(1, ))]
            return g_R

    return g_fn
Exemple #13
0
 def fn_mapped(R: Array, **dynamic_kwargs) -> Array:
   d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs))
   _kwargs = merge_dicts(kwargs, dynamic_kwargs)
   _kwargs = _kwargs_to_parameters(species, **_kwargs)
   dr = d(R, R)
   # NOTE(schsam): Currently we place a diagonal mask no matter what function
   # we are mapping. Should this be an option?
   return high_precision_sum(_diagonal_mask(fn(dr, **_kwargs)),
                             axis=reduce_axis, keepdims=keepdims) * f32(0.5)
Exemple #14
0
    def test_periodic_general_time_dependence(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        eye = np.eye(spatial_dimension)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split_T0_scale, split_T0_dT = random.split(key, 3)
            key, split_T1_scale, split_T1_dT = random.split(key, 3)
            key, split_t, split_R, split_dR = random.split(key, 4)

            size_0 = 10.0 * random.uniform(split_T0_scale, ())
            dtransform_0 = 0.5 * random.normal(
                split_T0_dT, (spatial_dimension, spatial_dimension))
            T_0 = np.array(size_0 * (eye + dtransform_0), dtype=dtype)

            size_1 = 10.0 * random.uniform(split_T1_scale, (), dtype=dtype)
            dtransform_1 = 0.5 * random.normal(
                split_T1_dT, (spatial_dimension, spatial_dimension),
                dtype=dtype)
            T_1 = np.array(size_1 * (eye + dtransform_1), dtype=dtype)

            T = lambda t: t * T_0 + (f32(1.0) - t) * T_1

            t_g = random.uniform(split_t, (), dtype=dtype)

            disp_fn, shift_fn = space.periodic_general(T)
            true_disp_fn, true_shift_fn = space.periodic_general(T(t_g))

            disp_fn = partial(disp_fn, t=t_g)

            disp_fn = space.map_product(disp_fn)
            true_disp_fn = space.map_product(true_disp_fn)

            R = random.uniform(split_R, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            dR = random.normal(split_dR, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            self.assertAllClose(disp_fn(R, R),
                                np.array(true_disp_fn(R, R), dtype=dtype),
                                True)
            self.assertAllClose(shift_fn(R, dR, t=t_g),
                                np.array(true_shift_fn(R, dR), dtype=dtype),
                                True)
Exemple #15
0
 def fn_mapped(R, **dynamic_kwargs) -> Array:
   d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs))
   _kwargs = merge_dicts(kwargs, dynamic_kwargs)
   _kwargs = _kwargs_to_parameters(species, **_kwargs)
   dR = d(R, R)
   compute_triplet = partial(fn, **_kwargs)
   output = vmap(vmap(vmap(compute_triplet, (None, 0)), (0, None)), 0)(dR, dR)
   return high_precision_sum(output,
                             axis=reduce_axis,
                             keepdims=keepdims) / 2.
Exemple #16
0
 def compute_fun(R, **dynamic_kwargs):
   _metric = partial(metric, **dynamic_kwargs)
   _metric = space.map_product(_metric)
   dr = _metric(R, R)
   dr = np.where(dr > f32(1e-7), dr, f32(1e7))
   dim = R.shape[1]
   exp = np.exp(-f32(0.5) * (dr[:, :, np.newaxis] - rs) ** 2 / sigma ** 2)
   e = np.exp(dr / sigma ** 2)
   gaussian_distances = exp / np.sqrt(2 * np.pi * sigma ** 2)
   return np.mean(gaussian_distances, axis=1) / rs ** (dim - 1)
Exemple #17
0
def eam(displacement: DisplacementFn,
        charge_fn: Callable[[Array], Array],
        embedding_fn: Callable[[Array], Array],
        pairwise_fn: Callable[[Array], Array],
        axis: Tuple[int, ...] = None) -> Callable[[Array], Array]:
    """Interatomic potential as approximated by embedded atom model (EAM).

  This code implements the EAM approximation to interactions between metallic
  atoms. In EAM, the potential energy of an atom is given by two terms: a
  pairwise energy and an embedding energy due to the interaction between the
  atom and background charge density. The EAM potential for a single atomic
  species is often
  determined by three functions:
    1) Charge density contribution of an atom as a function of distance.
    2) Energy of embedding an atom in the background charge density.
    3) Pairwise energy.
  These three functions are usually provided as spline fits, and we follow the
  implementation and spline fits given by [1]. Note that in current
  implementation, the three functions listed above can also be expressed by a
  any function with the correct signature, including neural networks.

  Args:
    displacement: A function that produces an ndarray of shape [n, m,
      spatial_dimension] of particle displacements from particle positions
      specified as an ndarray of shape [n, spatial_dimension] and [m,
      spatial_dimension] respectively.
    charge_fn: A function that takes an ndarray of shape [n, m] of distances
      between particles and returns a matrix of charge contributions.
    embedding_fn: Function that takes an ndarray of shape [n] of charges and
      returns an ndarray of shape [n] of the energy cost of embedding an atom
      into the charge.
    pairwise_fn: A function that takes an ndarray of shape [n, m] of distances
      and returns an ndarray of shape [n, m] of pairwise energies.
    axis: Specifies which axis the total energy should be summed over.

  Returns:
    A function that computes the EAM energy of a set of atoms with positions
    given by an [n, spatial_dimension] ndarray.

  [1] Y. Mishin, D. Farkas, M.J. Mehl, DA Papaconstantopoulos, "Interatomic
  potentials for monoatomic metals from experimental data and ab initio
  calculations." Physical Review B, 59 (1999)
  """
    metric = space.map_product(space.metric(displacement))

    def energy(R, **kwargs):
        dr = metric(R, R, **kwargs)
        total_charge = util.high_precision_sum(charge_fn(dr), axis=1)
        embedding_energy = embedding_fn(total_charge)
        pairwise_energy = util.high_precision_sum(
            smap._diagonal_mask(pairwise_fn(dr)), axis=1) / f32(2.0)
        return util.high_precision_sum(embedding_energy + pairwise_energy,
                                       axis=axis)

    return energy
Exemple #18
0
 def compute_fn(R):
     dR = space.map_product(displacement)(R, R)
     dr = space.distance(dR)
     first_term = A * np.sum(_gupta_term1(dr, p, r_0n, cutoff), axis=1)
     # Safe sqrt used in order to ensure that force calculations are not nan
     # when the particles are too widely separated at initialization
     # (corresponding to the case where the attractive term is 0.).
     attractive_term = np.sum(_gupta_term2(dr, q, r_0n, cutoff), axis=1)
     second_term = util.safe_mask(attractive_term > 0, np.sqrt,
                                  attractive_term)
     return U_n / 2.0 * np.sum(first_term - second_term)
Exemple #19
0
 def test_cosine_angles(self, dtype):
     displacement, _ = space.free()
     displacement = space.map_product(displacement)
     R = np.array([[0, 0], [0, 1], [1, 1]], dtype=dtype)
     dR = displacement(R, R)
     cangles = quantity.cosine_angles(dR)
     c45 = 1 / np.sqrt(2)
     true_cangles = np.array([[[0, 0, 0], [0, 1, c45], [0, c45, 1]],
                              [[1, 0, 0], [0, 0, 0], [0, 0, 1]],
                              [[1, c45, 0], [c45, 1, 0], [0, 0, 0]]],
                             dtype=dtype)
     self.assertAllClose(cangles, true_cangles)
Exemple #20
0
        def compute_fn(R: Array, **kwargs) -> Array:
            _metric = partial(metric, **kwargs)
            _metric = space.map_product(_metric)

            def return_radial(atom_type):
                """Returns the radial symmetry functions for neighbor type atom_type."""
                R_neigh = R[species == atom_type, :]
                dr = _metric(R, R_neigh)
                return util.high_precision_sum(radial_fn(etas, dr), axis=1).T

            return jnp.hstack([
                return_radial(atom_type) for atom_type in onp.unique(species)
            ])
Exemple #21
0
    def test_neighbor_list_build_time_dependent(self, dtype, dim):
        key = random.PRNGKey(1)

        if dim == 2:
            box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32)
        elif dim == 3:
            box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0],
                                         [0.0, 0.0, 7.25]])
        min_length = np.min(np.diag(box_fn(0.)))
        cutoff = f32(1.23)
        # TODO(schsam): Get cell-list working with anisotropic cell sizes.
        cell_size = cutoff / min_length

        displacement, _ = space.periodic_general(box_fn)
        metric = space.metric(displacement)

        R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
        N = R.shape[0]
        neighbor_list_fn = partition.neighbor_list(metric,
                                                   1.,
                                                   cutoff,
                                                   0.0,
                                                   1.1,
                                                   cell_size=cell_size,
                                                   t=np.array(0.))

        idx = neighbor_list_fn(R, t=np.array(0.25)).idx
        R_neigh = R[idx]
        mask = idx < N

        metric = partial(metric, t=f32(0.25))
        d = vmap(vmap(metric, (None, 0)))
        dR = d(R, R_neigh)

        d_exact = space.map_product(metric)
        dR_exact = d_exact(R, R)

        dR = np.where(dR < cutoff, dR, 0) * mask
        dR_exact = np.where(dR_exact < cutoff, dR_exact, 0)

        dR = np.sort(dR, axis=1)
        dR_exact = np.sort(dR_exact, axis=1)

        for i in range(dR.shape[0]):
            dR_row = dR[i]
            dR_row = dR_row[dR_row > 0.]

            dR_exact_row = dR_exact[i]
            dR_exact_row = dR_exact_row[dR_exact_row > 0.]

            self.assertAllClose(dR_row, dR_exact_row)
Exemple #22
0
    def test_periodic_general_dynamic(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        eye = jnp.eye(spatial_dimension)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split_T0_scale, split_T0_dT = random.split(key, 3)
            key, split_T1_scale, split_T1_dT = random.split(key, 3)
            key, split_t, split_R, split_dR = random.split(key, 4)

            size_0 = 10.0 * random.uniform(split_T0_scale, ())
            dtransform_0 = 0.5 * random.normal(
                split_T0_dT, (spatial_dimension, spatial_dimension))
            T_0 = jnp.array(size_0 * (eye + dtransform_0), dtype=dtype)

            size_1 = 10.0 * random.uniform(split_T1_scale, (), dtype=dtype)
            dtransform_1 = 0.5 * random.normal(
                split_T1_dT, (spatial_dimension, spatial_dimension),
                dtype=dtype)
            T_1 = jnp.array(size_1 * (eye + dtransform_1), dtype=dtype)

            disp_fn, shift_fn = space.periodic_general(T_0)
            true_disp_fn, true_shift_fn = space.periodic_general(T_1)

            disp_fn = partial(disp_fn, box=T_1)

            disp_fn = space.map_product(disp_fn)
            true_disp_fn = space.map_product(true_disp_fn)

            R = random.uniform(split_R, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            dR = random.normal(split_dR, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            self.assertAllClose(disp_fn(R, R),
                                jnp.array(true_disp_fn(R, R), dtype=dtype))
            self.assertAllClose(shift_fn(R, dR, box=T_1),
                                jnp.array(true_shift_fn(R, dR), dtype=dtype))
Exemple #23
0
 def compute_fn(R, **kwargs):
     atom_types = onp.unique(species)
     D_fn = partial(displacement, **kwargs)
     D_fn = space.map_product(D_fn)
     D_different_types = [D_fn(R[species == s, :], R) for s in atom_types]
     out = []
     for i in range(len(atom_types)):
         for j in range(i, len(atom_types)):
             out += [
                 jnp.sum(_all_pairs_angular(D_different_types[i],
                                            D_different_types[j]),
                         axis=[1, 2])
             ]
     return jnp.hstack(out)
Exemple #24
0
  def compute_fun(R, **kwargs):
    _metric = partial(metric, **kwargs)
    _metric = space.map_product(_metric)
    radial_fn = lambda eta, dr: (np.exp(-eta * dr**2) *
                _behler_parrinello_cutoff_fn(dr, cutoff_distance))
    def return_radial(atom_type):
      """Returns the radial symmetry functions for neighbor type atom_type."""
      R_neigh = R[species == atom_type, :]
      dr = _metric(R, R_neigh)
      
      radial = vmap(radial_fn, (0, None))(etas, dr)
      return np.sum(radial, axis=1).T

    return np.hstack([return_radial(atom_type) for 
                     atom_type in np.unique(species)])
Exemple #25
0
  def test_pair_no_species_vector(self, spatial_dimension, dtype):
    square = lambda dr: np.sum(dr ** 2, axis=2)
    disp, _ = space.free()

    mapped_square = smap.pair(square, disp)

    disp = space.map_product(disp)
    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      mapped_ref = np.array(0.5 * np.sum(square(disp(R, R))), dtype=dtype)
      self.assertAllClose(mapped_square(R), mapped_ref)
Exemple #26
0
    def test_periodic_displacement(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)

            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            dR = space.map_product(space.pairwise_displacement)(R, R)

            dR_wrapped = space.periodic_displacement(f32(1.0), dR)

            dR_direct = dR
            dr_direct = space.distance(dR)
            dr_direct = np.reshape(dr_direct, dr_direct.shape + (1, ))

            if spatial_dimension == 2:
                for i in range(-1, 2):
                    for j in range(-1, 2):
                        dR_shifted = dR + np.array([i, j], dtype=R.dtype)

                        dr_shifted = space.distance(dR_shifted)
                        dr_shifted = np.reshape(dr_shifted,
                                                dr_shifted.shape + (1, ))

                        dR_direct = np.where(dr_shifted < dr_direct,
                                             dR_shifted, dR_direct)
                        dr_direct = np.where(dr_shifted < dr_direct,
                                             dr_shifted, dr_direct)
            elif spatial_dimension == 3:
                for i in range(-1, 2):
                    for j in range(-1, 2):
                        for k in range(-1, 2):
                            dR_shifted = dR + np.array([i, j, k],
                                                       dtype=R.dtype)

                            dr_shifted = space.distance(dR_shifted)
                            dr_shifted = np.reshape(dr_shifted,
                                                    dr_shifted.shape + (1, ))

                            dR_direct = np.where(dr_shifted < dr_direct,
                                                 dR_shifted, dR_direct)
                            dr_direct = np.where(dr_shifted < dr_direct,
                                                 dr_shifted, dr_direct)

            dR_direct = np.array(dR_direct, dtype=dR.dtype)
            assert dR_wrapped.dtype == dtype
            self.assertAllClose(dR_wrapped, dR_direct, True)
Exemple #27
0
def pair_correlation(displacement_or_metric, rs, sigma):
    metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
    metric = space.map_product(metric)

    sigma = f32(sigma)
    rs = np.array(rs + 1e-7, f32)

    def compute_fun(R, **dynamic_kwargs):
        dr = metric(R, R, **dynamic_kwargs)
        dr = np.where(dr > f32(1e-7), dr, f32(1e7))
        dim = R.shape[1]
        exp = np.exp(-f32(0.5) * (dr[:, :, np.newaxis] - rs)**2 / sigma**2)
        gaussian_distances = exp / np.sqrt(2 * np.pi * sigma**2)
        return np.mean(gaussian_distances, axis=1) / rs**(dim - 1)

    return compute_fun
Exemple #28
0
 def fn_mapped(R, **dynamic_kwargs):
   U = f32(0.0)
   d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs))
   for i in range(species_count + 1):
     for j in range(i, species_count + 1):
       _kwargs = merge_dicts(kwargs, dynamic_kwargs)
       s_kwargs = _kwargs_to_parameters((i, j), **_kwargs)
       Ra = R[species == i]
       Rb = R[species == j]
       dr = d(Ra, Rb)
       if j == i:
         dU = high_precision_sum(_diagonal_mask(fn(dr, **s_kwargs)))
         U = U + f32(0.5) * dU
       else:
         dU = high_precision_sum(fn(dr, **s_kwargs))
         U = U + dU
   return U
Exemple #29
0
 def fn_mapped(R, species, **dynamic_kwargs):
   _check_species_dtype(species)
   U = f32(0.0)
   N = R.shape[0]
   d = space.map_product(partial(displacement_or_metric, **dynamic_kwargs))
   _kwargs = merge_dicts(kwargs, dynamic_kwargs)
   dr = d(R, R)
   for i in range(species_count):
     for j in range(species_count):
       s_kwargs = _kwargs_to_parameters((i, j), **_kwargs)
       mask_a = jnp.array(jnp.reshape(species == i, (N,)), dtype=R.dtype)
       mask_b = jnp.array(jnp.reshape(species == j, (N,)), dtype=R.dtype)
       mask = mask_a[:, jnp.newaxis] * mask_b[jnp.newaxis, :]
       if i == j:
         mask = mask * _diagonal_mask(mask)
       dU = mask * fn(dr, **s_kwargs)
       U = U + high_precision_sum(dU, axis=reduce_axis, keepdims=keepdims)
   return U / f32(2.0)
    def test_pair_no_species_scalar(self, spatial_dimension, dtype):
        square = lambda dr: dr**2
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        mapped_square = smap.pair(square, metric)
        metric = space.map_product(metric)

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            self.assertAllClose(
                mapped_square(R),
                np.array(0.5 * np.sum(square(metric(R, R))), dtype=dtype))