def loss(flatten_params, initial_state, target_energy): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return (state.total_energy - target_energy)**2
def loss(flatten_params, initial_state, target_density): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=False) return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(self.grids)
def loss(flatten_params, target_energy): state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten( spec, flatten_params)), interaction_fn=utils.exponential_coulomb) final_state = scf.get_final_state(state) return (final_state.total_energy - target_energy)**2
def loss(flatten_params, target_density): state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten( spec, flatten_params)), interaction_fn=utils.exponential_coulomb, density_mse_converge_tolerance=-1) final_state = scf.get_final_state(state) return jnp.sum(jnp.abs(final_state.density - target_density)) * utils.get_dx(self.grids)
def test_flatten(self): (tree, shapes), vec = np_utils.flatten( [(jnp.array([1, 2, 3]), (jnp.array([4, 5]))), jnp.array([99])]) self.assertIsInstance(vec, np.ndarray) np.testing.assert_allclose(vec, [1., 2., 3., 4., 5., 99.]) self.assertEqual(shapes, [(3,), (2,), (1,)]) # unflatten should convert 1d array back to pytree. params = np_utils.unflatten((tree, shapes), vec) self.assertIsInstance(params[0][0], np.ndarray) np.testing.assert_allclose(params[0][0], [1., 2., 3.]) self.assertIsInstance(params[0][1], np.ndarray) np.testing.assert_allclose(params[0][1], [4., 5.]) self.assertIsInstance(params[1], np.ndarray) np.testing.assert_allclose(params[1], [99.])