示例#1
0
    def test_pair_neighbor_list_scalar_params_matrix(self, spatial_dimension,
                                                     dtype):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d))
        mapped_square = jit(smap.pair(truncated_square, d))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5)
            sigma = 0.5 * (sigma + sigma.T)
            neighbor_fn = partition.neighbor_list(disp, box_size,
                                                  np.max(sigma), 0.)
            nbrs = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))
示例#2
0
    def test_pair_dynamic_species_vector(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * np.sum(dr**2, axis=2)
        params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        disp, _ = space.free()

        mapped_square = smap.pair(square,
                                  disp,
                                  species=quantity.Dynamic,
                                  param=params)

        disp = vmap(vmap(disp, (0, None), 0), (None, 0), 0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(square(disp(R_1, R_2), param))
            self.assertAllClose(mapped_square(R, species, 2),
                                np.array(total, dtype=dtype), True)
示例#3
0
    def test_pair_neighbor_list_force_scalar_diverging_potential(
            self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        def potential(dr, sigma):
            return np.where(dr < sigma, dr**-6, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 4. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(
            quantity.force(smap.pair_neighbor_list(potential, d)))
        mapped_square = jit(quantity.force(smap.pair(potential, d)))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=4.5)
            neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0)
            nbrs = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))
示例#4
0
    def test_pair_neighbor_list_scalar_params_species(self, spatial_dimension,
                                                      dtype):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        tol = 2e-6 if dtype == np.float32 else None

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)
        species = np.zeros((N, ), np.int32)
        species = np.where(np.arange(N) > N / 3, 1, species)
        species = np.where(np.arange(N) > 2 * N / 3, 2, species)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square, d, species=species))
        mapped_square = jit(smap.pair(truncated_square, d, species=species))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5)
            sigma = 0.5 * (sigma + sigma.T)
            neighbor_fn = jit(
                partition.neighbor_list(disp, box_size, np.max(sigma), R))
            idx = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, idx, sigma=sigma), True,
                                tol, tol)
示例#5
0
    def test_pair_neighbor_list_scalar(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        tol = 2e-10 if dtype == np.float32 else None

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 4. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d))
        mapped_square = jit(smap.pair(truncated_square, d))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=2.5)
            neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma,
                                                      R))
            idx = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, idx, sigma=sigma), True,
                                tol, tol)
示例#6
0
    def test_pair_neighbor_list_scalar_nonadditive(self, spatial_dimension,
                                                   dtype, format):
        key = random.PRNGKey(0)

        def truncated_square(dR, sigma):
            dr = space.distance(dR)
            return np.where(dr < sigma, dr**2, f32(0.))

        N = PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square,
                                    disp,
                                    sigma=lambda x, y: x * y))
        mapped_square = jit(smap.pair(truncated_square, disp, sigma=1.0))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (N, ), minval=0.5, maxval=1.5)
            sigma_pair = sigma[:, None] * sigma[None, :]
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  np.max(sigma)**2,
                                                  0.,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma_pair),
                                neighbor_square(R, nbrs, sigma=sigma))
示例#7
0
    def test_pair_dynamic_species_scalar(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * dr**2
        params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        mapped_square = smap.pair(square,
                                  metric,
                                  species=quantity.Dynamic,
                                  param=params)

        metric = space.map_product(metric)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(
                        square(metric(R_1, R_2), param))
            self.assertAllClose(mapped_square(R, species, 2),
                                np.array(total, dtype=dtype), True)
示例#8
0
    def test_pair_neighbor_list_vector(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        def truncated_square(dR, sigma):
            dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1, ))
            return np.where(dr < sigma, dR**2, f32(0.))

        tol = 5e-6 if dtype == np.float32 else 1e-14

        N = PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square, disp, reduce_axis=(1, )))
        mapped_square = jit(
            smap.pair(truncated_square, disp, reduce_axis=(1, )))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=1.5)
            neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma,
                                                      R))
            idx = neighbor_fn(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, idx, sigma=sigma), True,
                                tol, tol)
示例#9
0
    def test_pair_scalar_dummy_arg(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=f32(1.0), **unused_kwargs: param * dr**2

        key, split = random.split(key)
        R = random.normal(key, (PARTICLE_COUNT, spatial_dimension),
                          dtype=dtype)
        displacement, shift = space.free()

        mapped = smap.pair(square, space.metric(displacement))

        mapped(R, t=f32(0))
示例#10
0
def soft_sphere_pair(
    displacement_or_metric, species=None, sigma=1.0, epsilon=1.0, alpha=2.0): 
  """Convenience wrapper to compute soft sphere energy over a system."""
  sigma = np.array(sigma, dtype=f32)
  epsilon = np.array(epsilon, dtype=f32)
  alpha = np.array(alpha, dtype=f32)
  return smap.pair(
      soft_sphere,
      space.canonicalize_displacement_or_metric(displacement_or_metric),
      species=species,
      sigma=sigma,
      epsilon=epsilon,
      alpha=alpha)
示例#11
0
def morse_pair(
    displacement_or_metric,
    species=None, sigma=1.0, epsilon=5.0, alpha=5.0, r_onset=2.0, r_cutoff=2.5):
  """Convenience wrapper to compute Morse energy over a system."""
  sigma = np.array(sigma, dtype=f32)
  epsilon = np.array(epsilon, dtype=f32)
  alpha = np.array(alpha, dtype=f32)
  return smap.pair(
    multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff),
    space.canonicalize_displacement_or_metric(displacement_or_metric),
    species=species,
    sigma=sigma,
    epsilon=epsilon,
    alpha=alpha)
示例#12
0
def lennard_jones_pair(
    displacement_or_metric,
    species=None, sigma=1.0, epsilon=1.0, r_onset=2.0, r_cutoff=2.5):
  """Convenience wrapper to compute Lennard-Jones energy over a system."""
  sigma = np.array(sigma, dtype=f32)
  epsilon = np.array(epsilon, dtype=f32)
  r_onset = f32(r_onset * np.max(sigma))
  r_cutoff = f32(r_cutoff * np.max(sigma))
  return smap.pair(
    multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff),
    space.canonicalize_displacement_or_metric(displacement_or_metric),
    species=species,
    sigma=sigma,
    epsilon=epsilon)
示例#13
0
  def test_pair_no_species_vector(self, spatial_dimension, dtype):
    square = lambda dr: np.sum(dr ** 2, axis=2)
    disp, _ = space.free()

    mapped_square = smap.pair(square, disp)

    disp = space.map_product(disp)
    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      mapped_ref = np.array(0.5 * np.sum(square(disp(R, R))), dtype=dtype)
      self.assertAllClose(mapped_square(R), mapped_ref)
示例#14
0
def soft_sphere_pair(displacement_or_metric: DisplacementOrMetricFn,
                     species: Array = None,
                     sigma: Array = 1.0,
                     epsilon: Array = 1.0,
                     alpha: Array = 2.0,
                     per_particle: bool = False):
    """Convenience wrapper to compute soft sphere energy over a system."""
    sigma = np.array(sigma, dtype=f32)
    epsilon = np.array(epsilon, dtype=f32)
    alpha = np.array(alpha, dtype=f32)
    return smap.pair(
        soft_sphere,
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        sigma=sigma,
        epsilon=epsilon,
        alpha=alpha,
        reduce_axis=(1, ) if per_particle else None)
示例#15
0
    def test_pair_no_species_scalar(self, spatial_dimension, dtype):
        square = lambda dr: dr**2
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        mapped_square = smap.pair(square, metric)
        metric = space.map_product(metric)

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            self.assertAllClose(
                mapped_square(R),
                np.array(0.5 * np.sum(square(metric(R, R))), dtype=dtype))
示例#16
0
def bks_pair(displacement_or_metric, species, Q_sq, exp_coeff, exp_decay,
             attractive_coeff, repulsive_coeff, coulomb_alpha, cutoff):
    Q_sq = np.array(Q_sq, f32)
    exp_coeff = np.array(exp_coeff, f32)
    exp_decay = np.array(exp_decay, f32)
    attractive_coeff = np.array(attractive_coeff, f32)
    repulsive_coeff = np.array(repulsive_coeff, f32)

    return smap.pair(bks,
                     displacement_or_metric,
                     species=species,
                     Q_sq=Q_sq,
                     exp_coeff=exp_coeff,
                     exp_decay=exp_decay,
                     attractive_coeff=attractive_coeff,
                     repulsive_coeff=repulsive_coeff,
                     coulomb_alpha=coulomb_alpha,
                     cutoff=cutoff)
示例#17
0
    def test_stress_lammps_periodic_general(self, dim, dtype):
        key = random.PRNGKey(0)
        N = 64

        (box, R, V), (E, C) = test_util.load_lammps_stress_data(dtype)

        displacement_fn, _ = space.periodic_general(box)
        energy_fn = smap.pair(
            lambda dr, **kwargs: jnp.where(dr < f32(2.5),
                                           energy.lennard_jones(dr), f32(0.0)),
            space.canonicalize_displacement_or_metric(displacement_fn))

        ad_stress = quantity.stress(energy_fn, R, box, velocity=V)

        tol = 5e-5

        self.assertAllClose(energy_fn(R) / len(R), E, atol=tol, rtol=tol)
        self.assertAllClose(C, ad_stress, atol=tol, rtol=tol)
示例#18
0
  def test_pair_no_species_vector_nonadditive(self, spatial_dimension, dtype):
    square = lambda dr, params: params * np.sum(dr ** 2, axis=2)
    disp, _ = space.free()

    mapped_square = smap.pair(square, disp, params=lambda x, y: x * y)

    disp = space.map_product(disp)
    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, R_key, params_key = random.split(key, 3)
      R = random.uniform(
        R_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      params = random.uniform(
        params_key, (PARTICLE_COUNT,), dtype=dtype, minval=0.1, maxval=1.5)
      pp_params = params[None, :] * params[:, None]
      mapped_ref = np.array(0.5 * np.sum(square(disp(R, R), pp_params)),
                            dtype=dtype)
      self.assertAllClose(mapped_square(R, params=params), mapped_ref)
示例#19
0
def lennard_jones_pair(displacement_or_metric: DisplacementOrMetricFn,
                       species: Array = None,
                       sigma: Array = 1.0,
                       epsilon: Array = 1.0,
                       r_onset: Array = 2.0,
                       r_cutoff: Array = 2.5,
                       per_particle: bool = False) -> Callable[[Array], Array]:
    """Convenience wrapper to compute Lennard-Jones energy over a system."""
    sigma = np.array(sigma, dtype=f32)
    epsilon = np.array(epsilon, dtype=f32)
    r_onset = r_onset * np.max(sigma)
    r_cutoff = r_cutoff * np.max(sigma)
    return smap.pair(
        multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff),
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        sigma=sigma,
        epsilon=epsilon,
        reduce_axis=(1, ) if per_particle else None)
示例#20
0
    def test_pair_grid_energy(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f16(9.0)
        cell_size = f16(2.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        energy_fn = smap.pair(energy.soft_sphere,
                              metric,
                              quantity.Dynamic,
                              reduce_axis=(1, ),
                              keepdims=True)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        grid_energy_fn = smap.grid(energy_fn, box_size, cell_size, R)
        species = np.zeros((PARTICLE_COUNT, ), dtype=np.int64)
        self.assertAllClose(np.array(energy_fn(R, species, 1), dtype=dtype),
                            grid_energy_fn(R), True)
示例#21
0
def soft_sphere_pair(displacement_or_metric: DisplacementOrMetricFn,
                     species: Array = None,
                     sigma: Array = 1.0,
                     epsilon: Array = 1.0,
                     alpha: Array = 2.0,
                     per_particle: bool = False):
    """Convenience wrapper to compute soft sphere energy over a system."""
    sigma = maybe_downcast(sigma)
    epsilon = maybe_downcast(epsilon)
    alpha = maybe_downcast(alpha)
    return smap.pair(
        soft_sphere,
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        ignore_unused_parameters=True,
        species=species,
        sigma=sigma,
        epsilon=epsilon,
        alpha=alpha,
        reduce_axis=(1, ) if per_particle else None)
示例#22
0
  def test_pair_no_species_scalar_dynamic(self, spatial_dimension, dtype):
    square = lambda dr, epsilon: epsilon * dr ** 2
    displacement, _ = space.free()
    metric = lambda Ra, Rb, **kwargs: \
        np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

    mapped_square = smap.pair(square, metric, epsilon=1.0)
    metric = space.map_product(metric)

    key = random.PRNGKey(0)
    for _ in range(STOCHASTIC_SAMPLES):
      key, split1, split2 = random.split(key, 3)
      R = random.uniform(
        split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      epsilon = random.uniform(split2, (PARTICLE_COUNT,), dtype=dtype)
      mat_epsilon = 0.5 * (epsilon[:, np.newaxis] + epsilon[np.newaxis, :])
      self.assertAllClose(
        mapped_square(R, epsilon=epsilon),
        np.array(0.5 * np.sum(
          square(metric(R, R), mat_epsilon)), dtype=dtype))
示例#23
0
def morse_pair(displacement_or_metric: DisplacementOrMetricFn,
               species: Array = None,
               sigma: Array = 1.0,
               epsilon: Array = 5.0,
               alpha: Array = 5.0,
               r_onset: float = 2.0,
               r_cutoff: float = 2.5,
               per_particle: bool = False) -> Callable[[Array], Array]:
    """Convenience wrapper to compute Morse energy over a system."""
    sigma = np.array(sigma, dtype=f32)
    epsilon = np.array(epsilon, dtype=f32)
    alpha = np.array(alpha, dtype=f32)
    return smap.pair(
        multiplicative_isotropic_cutoff(morse, r_onset, r_cutoff),
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        sigma=sigma,
        epsilon=epsilon,
        alpha=alpha,
        reduce_axis=(1, ) if per_particle else None)
示例#24
0
def bks_pair(displacement_or_metric: DisplacementOrMetricFn, species: Array,
             Q_sq: Array, exp_coeff: Array, exp_decay: Array,
             attractive_coeff: Array, repulsive_coeff: Array,
             coulomb_alpha: Array, cutoff: float) -> Callable[[Array], Array]:
    """Convenience wrapper to compute BKS energy over a system."""
    Q_sq = np.array(Q_sq, f32)
    exp_coeff = np.array(exp_coeff, f32)
    exp_decay = np.array(exp_decay, f32)
    attractive_coeff = np.array(attractive_coeff, f32)
    repulsive_coeff = np.array(repulsive_coeff, f32)

    return smap.pair(bks,
                     displacement_or_metric,
                     species=species,
                     Q_sq=Q_sq,
                     exp_coeff=exp_coeff,
                     exp_decay=exp_decay,
                     attractive_coeff=attractive_coeff,
                     repulsive_coeff=repulsive_coeff,
                     coulomb_alpha=coulomb_alpha,
                     cutoff=cutoff)
示例#25
0
def bks_pair(displacement_or_metric: DisplacementOrMetricFn, species: Array,
             Q_sq: Array, exp_coeff: Array, exp_decay: Array,
             attractive_coeff: Array, repulsive_coeff: Array,
             coulomb_alpha: Array, cutoff: float) -> Callable[[Array], Array]:
    """Convenience wrapper to compute BKS energy over a system."""
    Q_sq = maybe_downcast(Q_sq)
    exp_coeff = maybe_downcast(exp_coeff)
    exp_decay = maybe_downcast(exp_decay)
    attractive_coeff = maybe_downcast(attractive_coeff)
    repulsive_coeff = maybe_downcast(repulsive_coeff)

    return smap.pair(bks,
                     displacement_or_metric,
                     species=species,
                     ignore_unused_parameters=True,
                     Q_sq=Q_sq,
                     exp_coeff=exp_coeff,
                     exp_decay=exp_decay,
                     attractive_coeff=attractive_coeff,
                     repulsive_coeff=repulsive_coeff,
                     coulomb_alpha=coulomb_alpha,
                     cutoff=cutoff)
示例#26
0
def harmonic_morse_pair(displacement_or_metric,
                        species=None,
                        D0=5.0,
                        alpha=10.0,
                        r0=1.0,
                        k=50.0):
    """The harmonic morse function over all pairs of particles in a system."""

    # Initialize various parameters of the harmonic morse function
    D0 = jnp.array(D0, dtype=jnp.float32)
    alpha = jnp.array(alpha, dtype=jnp.float32)
    r0 = jnp.array(r0, dtype=jnp.float32)
    k = jnp.array(k, dtype=jnp.float32)

    # Pass the harmonic morse function defined above along with its parameters and a
    # displacement/metric function.
    return smap.pair(
        harmonic_morse,
        space.canonicalize_displacement_or_metric(displacement_or_metric),
        species=species,
        D0=D0,
        alpha=alpha,
        r0=r0,
        k=k)
示例#27
0
def bks_pair(displacement_or_metric: DisplacementOrMetricFn,
             species: Array,
             Q_sq: Array,
             exp_coeff: Array,
             exp_decay: Array,
             attractive_coeff: Array,
             repulsive_coeff: Array,
             coulomb_alpha: Array,
             cutoff: float) -> Callable[[Array], Array]:
  Q_sq = np.array(Q_sq, f32)
  exp_coeff = np.array(exp_coeff, f32)
  exp_decay = np.array(exp_decay, f32)
  attractive_coeff = np.array(attractive_coeff, f32)
  repulsive_coeff = np.array(repulsive_coeff, f32)

  return smap.pair(bks, displacement_or_metric, 
                   species=species, 
                   Q_sq=Q_sq, 
                   exp_coeff=exp_coeff, 
                   exp_decay=exp_decay, 
                   attractive_coeff=attractive_coeff, 
                   repulsive_coeff=repulsive_coeff,
                   coulomb_alpha=coulomb_alpha,
                   cutoff=cutoff)
示例#28
0
    def test_pair_neighbor_list_vector(self, spatial_dimension, dtype, format):
        if format is partition.OrderedSparse:
            self.skipTest('Vector valued pair_neighbor_list not supported.')
        key = random.PRNGKey(0)

        def truncated_square(dR, sigma):
            dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1, ))
            return np.where(dr < sigma, dR**2, f32(0.))

        N = PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square,
                                    disp,
                                    sigma=1.0,
                                    reduce_axis=(1, )))
        mapped_square = jit(
            smap.pair(truncated_square, disp, sigma=1.0, reduce_axis=(1, )))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=1.5)
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  sigma,
                                                  0.,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))