def energy(R, **kwargs): dr = metric(R, R, **kwargs) total_charge = smap._high_precision_sum(charge_fn(dr), axis=1) embedding_energy = embedding_fn(total_charge) pairwise_energy = smap._high_precision_sum(smap._diagonal_mask( pairwise_fn(dr)), axis=1) / f32(2.0) return smap._high_precision_sum( embedding_energy + pairwise_energy, axis=axis)
def return_radial(atom_type): """Returns the radial symmetry functions for neighbor type atom_type.""" # import pdb ; pdb.set_trace() R_neigh = R[neighbor.idx] species_neigh = species[neighbor.idx] mask = np.logical_and(neighbor.idx < R.shape[0], species_neigh == atom_type) dr = _metric(R, R_neigh) radial = vmap(radial_fn, (0, None))(etas, dr) return smap._high_precision_sum(radial * mask[np.newaxis, :, :], axis=2).T