def test_get_nuclear_interaction_energy_incorrect_ndim( self, locations, nuclear_charges, expected_message): with self.assertRaisesRegex(ValueError, expected_message): utils.get_nuclear_interaction_energy( locations=jnp.array(locations), nuclear_charges=jnp.array(nuclear_charges), interaction_fn=utils.exponential_coulomb)
def test_get_nuclear_interaction_energy(self, locations, nuclear_charges, interaction_fn, ecpected_energy): self.assertAlmostEqual( float( utils.get_nuclear_interaction_energy( locations=jnp.array(locations), nuclear_charges=jnp.array(nuclear_charges), interaction_fn=interaction_fn)), ecpected_energy)