def behler_parrinello( displacement: DisplacementFn, species: Array = None, mlp_sizes: Tuple[int, ...] = (30, 30), mlp_kwargs: Dict[str, Any] = None, sym_kwargs: Dict[str, Any] = None, per_particle: bool = False ) -> Tuple[nn.InitFn, Callable[[PyTree, Array], Array]]: if sym_kwargs is None: sym_kwargs = {} if mlp_kwargs is None: mlp_kwargs = {'activation': np.tanh} sym_fn = nn.behler_parrinello_symmetry_functions(displacement, species, **sym_kwargs) @hk.without_apply_rng @hk.transform def model(R, **kwargs): embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes + (1, ), activate_final=False, name='BPEncoder', **mlp_kwargs) embedding_fn = vmap(embedding_fn) sym = sym_fn(R, **kwargs) readout = embedding_fn(sym) if per_particle: return readout return np.sum(readout) return model.init, model.apply
def behler_parrinello(displacement, species=None, mlp_sizes=(30, 30), mlp_kwargs=None, sym_kwargs=None, per_particle=False): if sym_kwargs is None: sym_kwargs = {} if mlp_kwargs is None: mlp_kwargs = {'activation': np.tanh} sym_fn = nn.behler_parrinello_symmetry_functions(displacement, species, **sym_kwargs) @hk.without_apply_rng @hk.transform def model(R, **kwargs): embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes + (1, ), activate_final=False, name='BPEncoder', **mlp_kwargs) embedding_fn = vmap(embedding_fn) sym = sym_fn(R, **kwargs) readout = embedding_fn(sym) if per_particle: return readout return np.sum(readout) return model.init, model.apply
def test_behler_parrinello_symmetry_functions(self, N_types, N_etas, dtype): displacement, shift = space.free() gr = nn.behler_parrinello_symmetry_functions( displacement,np.array([1, 1, N_types]), radial_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), angular_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), lambdas=np.array([-1.0] * N_etas, dtype), zetas=np.array([1.0] * N_etas, dtype), cutoff_distance=8.0) R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype) gr_out = gr(R) self.assertAllClose(gr_out.shape, (3, N_etas * (N_types + N_types * (N_types + 1) // 2))) self.assertAllClose(gr_out[2, 0], dtype(1.885791), rtol=1e-6, atol=1e-6)