def test_graph_network_neighbor_list_moving(self, spatial_dimension, dtype, format): if format is partition.OrderedSparse: self.skipTest('OrderedSparse format incompatible with GNN ' 'force field.') key = random.PRNGKey(0) R = random.uniform(key, (32, spatial_dimension), dtype=dtype) d, _ = space.free() cutoff = 0.3 dr_threshold = 0.1 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(key, R) neighbor_fn, _, nl_energy_fn = \ energy.graph_network_neighbor_list(d, 1.0, cutoff, dr_threshold, format=format) nbrs = neighbor_fn.allocate(R) key = random.fold_in(key, 1) R = R + random.uniform(key, (32, spatial_dimension), minval=-0.05, maxval=0.05, dtype=dtype) if format is partition.Dense: self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs)) else: self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs), rtol=2e-4, atol=2e-4)
def test_graph_network_shape_dtype(self, spatial_dimension, dtype): key = random.PRNGKey(0) R = random.uniform(key, (32, spatial_dimension), dtype=dtype) d, _ = space.free() cutoff = 0.2 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(key, R) E_out = energy_fn(params, R) assert E_out.shape == () assert E_out.dtype == dtype
def test_graph_network_neighbor_list(self, spatial_dimension, dtype): key = random.PRNGKey(0) R = random.uniform(key, (32, spatial_dimension), dtype=dtype) d, _ = space.free() cutoff = 0.2 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(key, R) neighbor_fn, _, nl_energy_fn = \ energy.graph_network_neighbor_list(d, 1.0, cutoff, 0.0) nbrs = neighbor_fn(R) self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs))
def test_graph_network_learning(self, spatial_dimension, dtype): key = random.PRNGKey(0) R_key, dr0_key, params_key = random.split(key, 3) d, _ = space.free() R = random.uniform(R_key, (6, 3, spatial_dimension), dtype=dtype) dr0 = random.uniform(dr0_key, (6, 3, 3), dtype=dtype) E_gt = vmap( lambda R, dr0: \ np.sum((space.distance(space.map_product(d)(R, R)) - dr0) ** 2)) cutoff = 0.2 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(params_key, R[0]) @jit def loss(params, R): return np.mean((vmap(energy_fn, (None, 0))(params, R) - E_gt(R, dr0)) ** 2) opt = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-4)) @jit def update(params, opt_state, R): updates, opt_state = opt.update(grad(loss)(params, R), opt_state) return optax.apply_updates(params, updates), opt_state opt_state = opt.init(params) l0 = loss(params, R) for i in range(4): params, opt_state = update(params, opt_state, R) assert loss(params, R) < l0 * 0.95