示例#1
0
  def test_graph_network_neighbor_list_moving(self,
                                              spatial_dimension,
                                              dtype,
                                              format):
    if format is partition.OrderedSparse:
      self.skipTest('OrderedSparse format incompatible with GNN '
                    'force field.')

    key = random.PRNGKey(0)

    R = random.uniform(key, (32, spatial_dimension), dtype=dtype)

    d, _ = space.free()

    cutoff = 0.3
    dr_threshold = 0.1

    init_fn, energy_fn = energy.graph_network(d, cutoff)
    params = init_fn(key, R)

    neighbor_fn, _, nl_energy_fn = \
      energy.graph_network_neighbor_list(d, 1.0, cutoff,
                                         dr_threshold, format=format)

    nbrs = neighbor_fn.allocate(R)
    key = random.fold_in(key, 1)
    R = R + random.uniform(key, (32, spatial_dimension),
                           minval=-0.05, maxval=0.05, dtype=dtype)
    if format is partition.Dense:
      self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs))
    else:
      self.assertAllClose(energy_fn(params, R), nl_energy_fn(params, R, nbrs),
                          rtol=2e-4, atol=2e-4)
示例#2
0
    def test_graph_network_shape_dtype(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        R = random.uniform(key, (32, spatial_dimension), dtype=dtype)

        d, _ = space.free()

        cutoff = 0.2

        init_fn, energy_fn = energy.graph_network(d, cutoff)
        params = init_fn(key, R)

        E_out = energy_fn(params, R)

        assert E_out.shape == ()
        assert E_out.dtype == dtype
示例#3
0
    def test_graph_network_neighbor_list(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        R = random.uniform(key, (32, spatial_dimension), dtype=dtype)

        d, _ = space.free()

        cutoff = 0.2

        init_fn, energy_fn = energy.graph_network(d, cutoff)
        params = init_fn(key, R)

        neighbor_fn, _, nl_energy_fn = \
          energy.graph_network_neighbor_list(d, 1.0, cutoff, 0.0)

        nbrs = neighbor_fn(R)
        self.assertAllClose(energy_fn(params, R),
                            nl_energy_fn(params, R, nbrs))
示例#4
0
  def test_graph_network_learning(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    R_key, dr0_key, params_key = random.split(key, 3)

    d, _ = space.free()

    R = random.uniform(R_key, (6, 3, spatial_dimension), dtype=dtype)
    dr0 = random.uniform(dr0_key, (6, 3, 3), dtype=dtype)
    E_gt = vmap(
      lambda R, dr0: \
      np.sum((space.distance(space.map_product(d)(R, R)) - dr0) ** 2))

    cutoff = 0.2

    init_fn, energy_fn = energy.graph_network(d, cutoff)
    params = init_fn(params_key, R[0])

    @jit
    def loss(params, R):
      return np.mean((vmap(energy_fn, (None, 0))(params, R) - E_gt(R, dr0)) ** 2)

    opt = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-4))


    @jit
    def update(params, opt_state, R):
      updates, opt_state = opt.update(grad(loss)(params, R),
                                      opt_state)
      return optax.apply_updates(params, updates), opt_state

    opt_state = opt.init(params)

    l0 = loss(params, R)
    for i in range(4):
      params, opt_state = update(params, opt_state, R)

    assert loss(params, R) < l0 * 0.95