def test_solve_noninteracting_system( self, num_electrons, expected_total_eigen_energies, expected_gap): # Quantum harmonic oscillator. grids = jnp.linspace(-10, 10, 1001) density, total_eigen_energies, gap = scf.solve_noninteracting_system( external_potential=0.5 * grids ** 2, num_electrons=num_electrons, grids=grids) self.assertTupleEqual(density.shape, (1001,)) self.assertAlmostEqual( float(total_eigen_energies), expected_total_eigen_energies, places=7) self.assertAlmostEqual(float(gap), expected_gap, places=7)
def _kohn_sham_iteration(density, external_potential, grids, num_electrons, xc_energy_density_fn, interaction_fn, enforce_reflection_symmetry): """One iteration of Kohn-Sham calculation.""" # NOTE(leeley): Since num_electrons in KohnShamState need to specify as # static argument in jit function, this function can not directly take # KohnShamState as input arguments. The related attributes in KohnShamState # are used as input arguments for this helper function. if enforce_reflection_symmetry: xc_energy_density_fn = _flip_and_average_on_center_fn( xc_energy_density_fn) hartree_potential = scf.get_hartree_potential( density=density, grids=grids, interaction_fn=interaction_fn) xc_potential = scf.get_xc_potential( density=density, xc_energy_density_fn=xc_energy_density_fn, grids=grids) ks_potential = hartree_potential + xc_potential + external_potential xc_energy_density = xc_energy_density_fn(density) # Solve Kohn-Sham equation. density, total_eigen_energies, gap = scf.solve_noninteracting_system( external_potential=ks_potential, num_electrons=num_electrons, grids=grids) total_energy = ( # kinetic energy = total_eigen_energies - external_potential_energy total_eigen_energies - scf.get_external_potential_energy( external_potential=ks_potential, density=density, grids=grids) # Hartree energy + scf.get_hartree_energy( density=density, grids=grids, interaction_fn=interaction_fn) # xc energy + scf.get_xc_energy(density=density, xc_energy_density_fn=xc_energy_density_fn, grids=grids) # external energy + scf.get_external_potential_energy( external_potential=external_potential, density=density, grids=grids)) if enforce_reflection_symmetry: density = _flip_and_average_on_center(density) return (density, total_energy, hartree_potential, xc_potential, xc_energy_density, gap)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.info('JAX devices: %s', jax.devices()) grids = np.arange(-256, 257) * 0.08 external_potential = utils.get_atomic_chain_potential( grids=grids, locations=np.array([-0.8, 0.8]), nuclear_charges=np.array([1., 1.]), interaction_fn=utils.exponential_coulomb) density, total_eigen_energies, _ = scf.solve_noninteracting_system( external_potential, num_electrons=FLAGS.num_electrons, grids=grids) logging.info('density: %s', density) logging.info('total energy: %f', total_eigen_energies)