コード例 #1
0
def behler_parrinello_neighbor_list(displacement,
                                    species=None,
                                    mlp_sizes=(30, 30),
                                    mlp_kwargs=None,
                                    sym_kwargs=None):
    if sym_kwargs is None:
        sym_kwargs = {}
    if mlp_kwargs is None:
        mlp_kwargs = {'activation': np.tanh}

    sym_fn = nn.behler_parrinello_symmetry_functions_neighbor_list(
        displacement, species, **sym_kwargs)

    @hk.without_apply_rng
    @hk.transform
    def model(R, neighbor, **kwargs):
        embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes + (1, ),
                                   activate_final=False,
                                   name='BPEncoder',
                                   **mlp_kwargs)
        embedding_fn = vmap(embedding_fn)
        sym = sym_fn(R, neighbor=neighbor, **kwargs)
        readout = embedding_fn(sym)
        return np.sum(readout)

    return model.init, model.apply
コード例 #2
0
ファイル: energy.py プロジェクト: scnlong/jax-md
def behler_parrinello_neighbor_list(displacement: DisplacementFn,
                                    box_size: float,
                                    species: Array=None,
                                    mlp_sizes: Tuple[int, ...]=(30, 30),
                                    mlp_kwargs: Dict[str, Any]=None,
                                    sym_kwargs: Dict[str, Any]=None,
                                    dr_threshold: float=0.5
                                    ) -> Tuple[NeighborFn,
                                               nn.InitFn,
                                               Callable[[PyTree,
                                                         Array,
                                                         NeighborList],
                                                        Array]]:
  if sym_kwargs is None:
    sym_kwargs = {}
  if mlp_kwargs is None:
    mlp_kwargs = {
        'activation': np.tanh
    }

  cutoff_distance = 8.0
  if 'cutoff_distance' in sym_kwargs:
    cutoff_distance = sym_kwargs['cutoff_distance']

  neighbor_fn = partition.neighbor_list(displacement,
                                        box_size,
                                        cutoff_distance,
                                        dr_threshold)

  sym_fn = nn.behler_parrinello_symmetry_functions_neighbor_list(displacement,
                                                                 species,
                                                                 **sym_kwargs)

  @hk.without_apply_rng
  @hk.transform
  def model(R, neighbor, **kwargs):
    embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes+(1,),
                               activate_final=False,
                               name='BPEncoder',
                               **mlp_kwargs)
    embedding_fn = vmap(embedding_fn)
    sym = sym_fn(R, neighbor, **kwargs)
    readout = embedding_fn(sym)
    return np.sum(readout)
  return neighbor_fn, model.init, model.apply
コード例 #3
0
ファイル: nn_test.py プロジェクト: zheshen/jax-md
 def test_behler_parrinello_symmetry_functions_neighbor_list(self,
                                                             N_types,
                                                             N_etas,
                                                             dtype):
   displacement, shift = space.free()
   neighbor_fn = partition.neighbor_list(displacement, 10.0, 8.0, 0.0)
   gr = nn.behler_parrinello_symmetry_functions_neighbor_list(
           displacement,np.array([1, 1, N_types]),
           radial_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype),
           angular_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype),
           lambdas=np.array([-1.0] * N_etas, dtype),
           zetas=np.array([1.0] * N_etas, dtype),
           cutoff_distance=8.0)
   R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
   nbrs = neighbor_fn(R)
   gr_out = gr(R, neighbor=nbrs)
   self.assertAllClose(gr_out.shape,
                       (3, N_etas *  (N_types + N_types * (N_types + 1) // 2)))
   self.assertAllClose(gr_out[2, 0], dtype(1.885791), rtol=1e-6, atol=1e-6)