def _kohn_sham_iteration(density, external_potential, grids, num_electrons, xc_energy_density_fn, interaction_fn, enforce_reflection_symmetry): """One iteration of Kohn-Sham calculation.""" # NOTE(leeley): Since num_electrons in KohnShamState need to specify as # static argument in jit function, this function can not directly take # KohnShamState as input arguments. The related attributes in KohnShamState # are used as input arguments for this helper function. if enforce_reflection_symmetry: xc_energy_density_fn = _flip_and_average_on_center_fn( xc_energy_density_fn) hartree_potential = scf.get_hartree_potential( density=density, grids=grids, interaction_fn=interaction_fn) xc_potential = scf.get_xc_potential( density=density, xc_energy_density_fn=xc_energy_density_fn, grids=grids) ks_potential = hartree_potential + xc_potential + external_potential xc_energy_density = xc_energy_density_fn(density) # Solve Kohn-Sham equation. density, total_eigen_energies, gap = scf.solve_noninteracting_system( external_potential=ks_potential, num_electrons=num_electrons, grids=grids) total_energy = ( # kinetic energy = total_eigen_energies - external_potential_energy total_eigen_energies - scf.get_external_potential_energy( external_potential=ks_potential, density=density, grids=grids) # Hartree energy + scf.get_hartree_energy( density=density, grids=grids, interaction_fn=interaction_fn) # xc energy + scf.get_xc_energy(density=density, xc_energy_density_fn=xc_energy_density_fn, grids=grids) # external energy + scf.get_external_potential_energy( external_potential=external_potential, density=density, grids=grids)) if enforce_reflection_symmetry: density = _flip_and_average_on_center(density) return (density, total_energy, hartree_potential, xc_potential, xc_energy_density, gap)
def test_get_hartree_energy(self, interaction_fn): grids = jnp.linspace(-5, 5, 11) dx = utils.get_dx(grids) density = utils.gaussian(grids=grids, center=1., sigma=1.) # Compute the expected Hartree energy by nested for loops. expected_hartree_energy = 0. for x_0, n_0 in zip(grids, density): for x_1, n_1 in zip(grids, density): expected_hartree_energy += 0.5 * n_0 * n_1 * interaction_fn( x_0 - x_1) * dx ** 2 self.assertAlmostEqual( float(scf.get_hartree_energy( density=density, grids=grids, interaction_fn=interaction_fn)), float(expected_hartree_energy))