예제 #1
0
 def test_solve_noninteracting_system(
     self, num_electrons, expected_total_eigen_energies, expected_gap):
   # Quantum harmonic oscillator.
   grids = jnp.linspace(-10, 10, 1001)
   density, total_eigen_energies, gap = scf.solve_noninteracting_system(
       external_potential=0.5 * grids ** 2,
       num_electrons=num_electrons,
       grids=grids)
   self.assertTupleEqual(density.shape, (1001,))
   self.assertAlmostEqual(
       float(total_eigen_energies), expected_total_eigen_energies, places=7)
   self.assertAlmostEqual(float(gap), expected_gap, places=7)
예제 #2
0
def _kohn_sham_iteration(density, external_potential, grids, num_electrons,
                         xc_energy_density_fn, interaction_fn,
                         enforce_reflection_symmetry):
    """One iteration of Kohn-Sham calculation."""
    # NOTE(leeley): Since num_electrons in KohnShamState need to specify as
    # static argument in jit function, this function can not directly take
    # KohnShamState as input arguments. The related attributes in KohnShamState
    # are used as input arguments for this helper function.
    if enforce_reflection_symmetry:
        xc_energy_density_fn = _flip_and_average_on_center_fn(
            xc_energy_density_fn)

    hartree_potential = scf.get_hartree_potential(
        density=density, grids=grids, interaction_fn=interaction_fn)
    xc_potential = scf.get_xc_potential(
        density=density,
        xc_energy_density_fn=xc_energy_density_fn,
        grids=grids)
    ks_potential = hartree_potential + xc_potential + external_potential
    xc_energy_density = xc_energy_density_fn(density)

    # Solve Kohn-Sham equation.
    density, total_eigen_energies, gap = scf.solve_noninteracting_system(
        external_potential=ks_potential,
        num_electrons=num_electrons,
        grids=grids)

    total_energy = (
        # kinetic energy = total_eigen_energies - external_potential_energy
        total_eigen_energies - scf.get_external_potential_energy(
            external_potential=ks_potential, density=density, grids=grids)
        # Hartree energy
        + scf.get_hartree_energy(
            density=density, grids=grids, interaction_fn=interaction_fn)
        # xc energy
        + scf.get_xc_energy(density=density,
                            xc_energy_density_fn=xc_energy_density_fn,
                            grids=grids)
        # external energy
        + scf.get_external_potential_energy(
            external_potential=external_potential,
            density=density,
            grids=grids))

    if enforce_reflection_symmetry:
        density = _flip_and_average_on_center(density)

    return (density, total_energy, hartree_potential, xc_potential,
            xc_energy_density, gap)
예제 #3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    logging.info('JAX devices: %s', jax.devices())

    grids = np.arange(-256, 257) * 0.08
    external_potential = utils.get_atomic_chain_potential(
        grids=grids,
        locations=np.array([-0.8, 0.8]),
        nuclear_charges=np.array([1., 1.]),
        interaction_fn=utils.exponential_coulomb)

    density, total_eigen_energies, _ = scf.solve_noninteracting_system(
        external_potential, num_electrons=FLAGS.num_electrons, grids=grids)
    logging.info('density: %s', density)
    logging.info('total energy: %f', total_eigen_energies)