from jax import random
import jax.numpy as np

from jax.api import grad

from jax import test_util as jtu
from jax import jit, vmap

from jax_md import smap, space, energy, quantity, partition
from jax_md.util import *
from jax_md import test_util

jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS

test_util.update_test_tolerance(f32_tolerance=5e-6, f64_tolerance=1e-14)

PARTICLE_COUNT = 1000
STOCHASTIC_SAMPLES = 3
SPATIAL_DIMENSION = [2, 3]

NEIGHBOR_LIST_PARTICLE_COUNT = 100

if FLAGS.jax_enable_x64:
    POSITION_DTYPE = [f32, f64]
else:
    POSITION_DTYPE = [f32]


class SMapTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
Exemple #2
0
from jax.config import config as jax_config
from jax import random
import jax.numpy as np

from jax.api import jit, grad, vmap
from jax_md import space, quantity, test_util
from jax_md.util import *

from jax import test_util as jtu

jax_config.parse_flags_with_absl()
jax_config.enable_omnistaging()
FLAGS = jax_config.FLAGS

test_util.update_test_tolerance(1e-5, 2e-7)

PARTICLE_COUNT = 10
STOCHASTIC_SAMPLES = 10
SPATIAL_DIMENSION = [2, 3]
DTYPES = [f32, f64] if FLAGS.jax_enable_x64 else [f32]


class QuantityTest(jtu.JaxTestCase):
    def test_canonicalize_mass(self):
        assert quantity.canonicalize_mass(3.0) == 3.0
        assert quantity.canonicalize_mass(f32(3.0)) == f32(3.0)
        assert quantity.canonicalize_mass(f64(3.0)) == f64(3.0)

    @parameterized.named_parameters(
        jtu.cases_from_list({
Exemple #3
0
FLAGS = jax_config.FLAGS

PARTICLE_COUNT = 100
STOCHASTIC_SAMPLES = 10
SPATIAL_DIMENSION = [2, 3]
UNIT_CELL_SIZE = [7, 8]

SOFT_SPHERE_ALPHA = [2.0, 3.0]
N_TYPES_TO_TEST = [1, 2]

if FLAGS.jax_enable_x64:
    POSITION_DTYPE = [f32, f64]
else:
    POSITION_DTYPE = [f32]

update_test_tolerance(2e-5, 1e-6)


def lattice_repeater(small_cell_pos, latvec, no_rep):
    dtype = small_cell_pos.dtype
    pos = onp.copy(small_cell_pos).tolist()
    for atom in small_cell_pos:
        for i in range(no_rep):
            for j in range(no_rep):
                for k in range(no_rep):
                    if not i == j == k == 0:
                        repeated_atom = atom + latvec[0] * i + latvec[
                            1] * j + latvec[2] * k
                        pos.append(onp.array(repeated_atom).tolist())
    return np.array(pos, dtype), f32(latvec * no_rep)
Exemple #4
0
SPATIAL_DIMENSION = [2, 3]
UNIT_CELL_SIZE = [7, 8]

SOFT_SPHERE_ALPHA = [2.0, 2.5, 3.0]
N_TYPES_TO_TEST = [1, 2]

if FLAGS.jax_enable_x64:
  POSITION_DTYPE = [f32, f64]
else:
  POSITION_DTYPE = [f32]

NEIGHBOR_LIST_FORMAT = [partition.Dense,
                        partition.Sparse,
                        partition.OrderedSparse]

test_util.update_test_tolerance(2e-5, 1e-10)


def lattice_repeater(small_cell_pos, latvec, no_rep):
  dtype = small_cell_pos.dtype
  pos = onp.copy(small_cell_pos).tolist()
  for atom in small_cell_pos:
    for i in range(no_rep):
      for j in range(no_rep):
        for k in range(no_rep):
          if not i == j == k == 0:
            repeated_atom = atom + latvec[0] * i + latvec[1] * j + latvec[2] * k
            pos.append(onp.array(repeated_atom).tolist())
  return np.array(pos, dtype), f32(latvec*no_rep)

Exemple #5
0
jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS

PARTICLE_COUNT = 100
STOCHASTIC_SAMPLES = 10
SPATIAL_DIMENSION = [2, 3]
UNIT_CELL_SIZE = [7, 8]

SOFT_SPHERE_ALPHA = [2.0, 3.0]

if FLAGS.jax_enable_x64:
    POSITION_DTYPE = [f32, f64]
else:
    POSITION_DTYPE = [f32]

update_test_tolerance(1e-5, 1e-7)


def lattice_repeater(small_cell_pos, latvec, no_rep):
    dtype = small_cell_pos.dtype
    pos = onp.copy(small_cell_pos).tolist()
    for atom in small_cell_pos:
        for i in range(no_rep):
            for j in range(no_rep):
                for k in range(no_rep):
                    if not i == j == k == 0:
                        repeated_atom = atom + latvec[0] * i + latvec[
                            1] * j + latvec[2] * k
                        pos.append(onp.array(repeated_atom).tolist())
    return np.array(pos, dtype), f32(latvec * no_rep)
Exemple #6
0
from absl.testing import parameterized

from jax.config import config as jax_config
from jax import random
import jax.numpy as np

from jax.api import grad

from jax import test_util as jtu

from jax_md import space, test_util
from jax_md.util import *

from functools import partial

test_util.update_test_tolerance(5e-5, 5e-13)

jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS

PARTICLE_COUNT = 10
STOCHASTIC_SAMPLES = 10
SHIFT_STEPS = 10
SPATIAL_DIMENSION = [2, 3]

if FLAGS.jax_enable_x64:
    POSITION_DTYPE = [f32, f64]
else:
    POSITION_DTYPE = [f32]