def test_graph_network_neighbor_list_moving(self, spatial_dimension, dtype, format): if format is partition.OrderedSparse: self.skipTest('OrderedSparse format incompatible with GNN ' 'force field.') key = random.PRNGKey(0) R = random.uniform(key, (32, spatial_dimension), dtype=dtype) d, _ = space.free() cutoff = 0.3 dr_threshold = 0.1 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(key, R) neighbor_fn, _, nl_energy_fn = \ energy.graph_network_neighbor_list(d, 1.0, cutoff, dr_threshold, format=format) nbrs = neighbor_fn.allocate(R) key = random.fold_in(key, 1) R = R + random.uniform(key, (32, spatial_dimension), minval=-0.05, maxval=0.05, dtype=dtype) if format is partition.Dense: self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs)) else: self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs), rtol=2e-4, atol=2e-4)
def test_graph_network_neighbor_list(self, spatial_dimension, dtype): key = random.PRNGKey(0) R = random.uniform(key, (32, spatial_dimension), dtype=dtype) d, _ = space.free() cutoff = 0.2 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(key, R) neighbor_fn, _, nl_energy_fn = \ energy.graph_network_neighbor_list(d, 1.0, cutoff, 0.0) nbrs = neighbor_fn(R) self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs))