Esempio n. 1
0
def behler_parrinello(
    displacement: DisplacementFn,
    species: Array = None,
    mlp_sizes: Tuple[int, ...] = (30, 30),
    mlp_kwargs: Dict[str, Any] = None,
    sym_kwargs: Dict[str, Any] = None,
    per_particle: bool = False
) -> Tuple[nn.InitFn, Callable[[PyTree, Array], Array]]:
    if sym_kwargs is None:
        sym_kwargs = {}
    if mlp_kwargs is None:
        mlp_kwargs = {'activation': np.tanh}

    sym_fn = nn.behler_parrinello_symmetry_functions(displacement, species,
                                                     **sym_kwargs)

    @hk.without_apply_rng
    @hk.transform
    def model(R, **kwargs):
        embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes + (1, ),
                                   activate_final=False,
                                   name='BPEncoder',
                                   **mlp_kwargs)
        embedding_fn = vmap(embedding_fn)
        sym = sym_fn(R, **kwargs)
        readout = embedding_fn(sym)
        if per_particle:
            return readout
        return np.sum(readout)

    return model.init, model.apply
Esempio n. 2
0
def behler_parrinello(displacement,
                      species=None,
                      mlp_sizes=(30, 30),
                      mlp_kwargs=None,
                      sym_kwargs=None,
                      per_particle=False):
    if sym_kwargs is None:
        sym_kwargs = {}
    if mlp_kwargs is None:
        mlp_kwargs = {'activation': np.tanh}

    sym_fn = nn.behler_parrinello_symmetry_functions(displacement, species,
                                                     **sym_kwargs)

    @hk.without_apply_rng
    @hk.transform
    def model(R, **kwargs):
        embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes + (1, ),
                                   activate_final=False,
                                   name='BPEncoder',
                                   **mlp_kwargs)
        embedding_fn = vmap(embedding_fn)
        sym = sym_fn(R, **kwargs)
        readout = embedding_fn(sym)
        if per_particle:
            return readout
        return np.sum(readout)

    return model.init, model.apply
Esempio n. 3
0
 def test_behler_parrinello_symmetry_functions(self, N_types, N_etas, dtype):
   displacement, shift = space.free()
   gr = nn.behler_parrinello_symmetry_functions(
           displacement,np.array([1, 1, N_types]),
           radial_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype),
           angular_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype),
           lambdas=np.array([-1.0] * N_etas, dtype),
           zetas=np.array([1.0] * N_etas, dtype),
           cutoff_distance=8.0)
   R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
   gr_out = gr(R)
   self.assertAllClose(gr_out.shape, (3, N_etas *  (N_types + N_types * (N_types + 1) // 2)))
   self.assertAllClose(gr_out[2, 0], dtype(1.885791), rtol=1e-6, atol=1e-6)