def behler_parrinello_neighbor_list(displacement, species=None, mlp_sizes=(30, 30), mlp_kwargs=None, sym_kwargs=None): if sym_kwargs is None: sym_kwargs = {} if mlp_kwargs is None: mlp_kwargs = {'activation': np.tanh} sym_fn = nn.behler_parrinello_symmetry_functions_neighbor_list( displacement, species, **sym_kwargs) @hk.without_apply_rng @hk.transform def model(R, neighbor, **kwargs): embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes + (1, ), activate_final=False, name='BPEncoder', **mlp_kwargs) embedding_fn = vmap(embedding_fn) sym = sym_fn(R, neighbor=neighbor, **kwargs) readout = embedding_fn(sym) return np.sum(readout) return model.init, model.apply
def behler_parrinello_neighbor_list(displacement: DisplacementFn, box_size: float, species: Array=None, mlp_sizes: Tuple[int, ...]=(30, 30), mlp_kwargs: Dict[str, Any]=None, sym_kwargs: Dict[str, Any]=None, dr_threshold: float=0.5 ) -> Tuple[NeighborFn, nn.InitFn, Callable[[PyTree, Array, NeighborList], Array]]: if sym_kwargs is None: sym_kwargs = {} if mlp_kwargs is None: mlp_kwargs = { 'activation': np.tanh } cutoff_distance = 8.0 if 'cutoff_distance' in sym_kwargs: cutoff_distance = sym_kwargs['cutoff_distance'] neighbor_fn = partition.neighbor_list(displacement, box_size, cutoff_distance, dr_threshold) sym_fn = nn.behler_parrinello_symmetry_functions_neighbor_list(displacement, species, **sym_kwargs) @hk.without_apply_rng @hk.transform def model(R, neighbor, **kwargs): embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes+(1,), activate_final=False, name='BPEncoder', **mlp_kwargs) embedding_fn = vmap(embedding_fn) sym = sym_fn(R, neighbor, **kwargs) readout = embedding_fn(sym) return np.sum(readout) return neighbor_fn, model.init, model.apply
def test_behler_parrinello_symmetry_functions_neighbor_list(self, N_types, N_etas, dtype): displacement, shift = space.free() neighbor_fn = partition.neighbor_list(displacement, 10.0, 8.0, 0.0) gr = nn.behler_parrinello_symmetry_functions_neighbor_list( displacement,np.array([1, 1, N_types]), radial_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), angular_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype), lambdas=np.array([-1.0] * N_etas, dtype), zetas=np.array([1.0] * N_etas, dtype), cutoff_distance=8.0) R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype) nbrs = neighbor_fn(R) gr_out = gr(R, neighbor=nbrs) self.assertAllClose(gr_out.shape, (3, N_etas * (N_types + N_types * (N_types + 1) // 2))) self.assertAllClose(gr_out[2, 0], dtype(1.885791), rtol=1e-6, atol=1e-6)