コード例 #1
0
  def test_graph_network_neighbor_list_moving(self,
                                              spatial_dimension,
                                              dtype,
                                              format):
    if format is partition.OrderedSparse:
      self.skipTest('OrderedSparse format incompatible with GNN '
                    'force field.')

    key = random.PRNGKey(0)

    R = random.uniform(key, (32, spatial_dimension), dtype=dtype)

    d, _ = space.free()

    cutoff = 0.3
    dr_threshold = 0.1

    init_fn, energy_fn = energy.graph_network(d, cutoff)
    params = init_fn(key, R)

    neighbor_fn, _, nl_energy_fn = \
      energy.graph_network_neighbor_list(d, 1.0, cutoff,
                                         dr_threshold, format=format)

    nbrs = neighbor_fn.allocate(R)
    key = random.fold_in(key, 1)
    R = R + random.uniform(key, (32, spatial_dimension),
                           minval=-0.05, maxval=0.05, dtype=dtype)
    if format is partition.Dense:
      self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs))
    else:
      self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs),
                          rtol=2e-4, atol=2e-4)
コード例 #2
0
ファイル: energy_test.py プロジェクト: kessel/jax-md
    def test_graph_network_neighbor_list(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        R = random.uniform(key, (32, spatial_dimension), dtype=dtype)

        d, _ = space.free()

        cutoff = 0.2

        init_fn, energy_fn = energy.graph_network(d, cutoff)
        params = init_fn(key, R)

        neighbor_fn, _, nl_energy_fn = \
          energy.graph_network_neighbor_list(d, 1.0, cutoff, 0.0)

        nbrs = neighbor_fn(R)
        self.assertAllClose(energy_fn(params, R),
                            nl_energy_fn(params, R, nbrs))