Example #1
0
 def apply_fn(params, inputs, **kwargs):  # pylint: disable=missing-docstring
   del kwargs
   width, = params
   reshaped_density, features = inputs
   beta = self_interaction_weight(
       reshaped_density=reshaped_density, dx=dx, width=width)
   hartree = -0.5 * scf.get_hartree_potential(
       density=reshaped_density.reshape(-1),
       grids=grids,
       interaction_fn=interaction_fn).reshape(reshaped_density.shape)
   return hartree * beta + features * (1 - beta)
Example #2
0
def _kohn_sham_iteration(density, external_potential, grids, num_electrons,
                         xc_energy_density_fn, interaction_fn,
                         enforce_reflection_symmetry):
    """One iteration of Kohn-Sham calculation."""
    # NOTE(leeley): Since num_electrons in KohnShamState need to specify as
    # static argument in jit function, this function can not directly take
    # KohnShamState as input arguments. The related attributes in KohnShamState
    # are used as input arguments for this helper function.
    if enforce_reflection_symmetry:
        xc_energy_density_fn = _flip_and_average_on_center_fn(
            xc_energy_density_fn)

    hartree_potential = scf.get_hartree_potential(
        density=density, grids=grids, interaction_fn=interaction_fn)
    xc_potential = scf.get_xc_potential(
        density=density,
        xc_energy_density_fn=xc_energy_density_fn,
        grids=grids)
    ks_potential = hartree_potential + xc_potential + external_potential
    xc_energy_density = xc_energy_density_fn(density)

    # Solve Kohn-Sham equation.
    density, total_eigen_energies, gap = scf.solve_noninteracting_system(
        external_potential=ks_potential,
        num_electrons=num_electrons,
        grids=grids)

    total_energy = (
        # kinetic energy = total_eigen_energies - external_potential_energy
        total_eigen_energies - scf.get_external_potential_energy(
            external_potential=ks_potential, density=density, grids=grids)
        # Hartree energy
        + scf.get_hartree_energy(
            density=density, grids=grids, interaction_fn=interaction_fn)
        # xc energy
        + scf.get_xc_energy(density=density,
                            xc_energy_density_fn=xc_energy_density_fn,
                            grids=grids)
        # external energy
        + scf.get_external_potential_energy(
            external_potential=external_potential,
            density=density,
            grids=grids))

    if enforce_reflection_symmetry:
        density = _flip_and_average_on_center(density)

    return (density, total_energy, hartree_potential, xc_potential,
            xc_energy_density, gap)
Example #3
0
  def test_get_xc_potential_hartree(self):
    grids = jnp.linspace(-5, 5, 10001)
    density = utils.gaussian(grids=grids, center=1., sigma=1.)
    def half_hartree_potential(density):
      return 0.5 * scf.get_hartree_potential(
          density=density,
          grids=grids,
          interaction_fn=utils.exponential_coulomb)

    np.testing.assert_allclose(
        scf.get_xc_potential(
            density=density,
            xc_energy_density_fn=half_hartree_potential,
            grids=grids),
        scf.get_hartree_potential(
            density, grids=grids, interaction_fn=utils.exponential_coulomb))
Example #4
0
  def test_get_hartree_potential(self, interaction_fn):
    grids = jnp.linspace(-5, 5, 11)
    dx = utils.get_dx(grids)
    density = utils.gaussian(grids=grids, center=1., sigma=1.)

    # Compute the expected Hartree energy by nested for loops.
    expected_hartree_potential = np.zeros_like(grids)
    for i, x_0 in enumerate(grids):
      for x_1, n_1 in zip(grids, density):
        expected_hartree_potential[i] += np.sum(
            n_1 * interaction_fn(x_0 - x_1)) * dx

    np.testing.assert_allclose(
        scf.get_hartree_potential(
            density=density, grids=grids, interaction_fn=interaction_fn),
        expected_hartree_potential)
Example #5
0
  def test_self_interaction_layer_one_electron(self):
    grids = jnp.linspace(-5, 5, 11)
    density = utils.gaussian(grids=grids, center=1., sigma=1.)
    reshaped_density = density[jnp.newaxis, :, jnp.newaxis]

    init_fn, apply_fn = neural_xc.self_interaction_layer(
        grids=grids, interaction_fn=utils.exponential_coulomb)
    output_shape, init_params = init_fn(
        random.PRNGKey(0), input_shape=((-1, 11, 1), (-1, 11, 1)))

    self.assertEqual(output_shape, (-1, 11, 1))
    self.assertAlmostEqual(init_params, (1.,))
    np.testing.assert_allclose(
        # The features (second input) is not used for one electron.
        apply_fn(
            init_params, (reshaped_density, jnp.ones_like(reshaped_density))),
        -0.5 * scf.get_hartree_potential(
            density=density,
            grids=grids,
            interaction_fn=utils.exponential_coulomb)[
                jnp.newaxis, :, jnp.newaxis])
    def test_wrap_network_with_self_interaction_layer_one_electron(self):
        grids = jnp.linspace(-5, 5, 9)
        density = utils.gaussian(grids=grids, center=1., sigma=1.)
        reshaped_density = density[jnp.newaxis, :, jnp.newaxis]

        init_fn, apply_fn = neural_xc.wrap_network_with_self_interaction_layer(
            network=neural_xc.build_unet(num_filters_list=[2, 4],
                                         core_num_filters=4,
                                         activation='swish'),
            grids=grids,
            interaction_fn=utils.exponential_coulomb)
        output_shape, init_params = init_fn(random.PRNGKey(0),
                                            input_shape=((-1, 9, 1)))

        self.assertEqual(output_shape, (-1, 9, 1))
        np.testing.assert_allclose(
            apply_fn(init_params, reshaped_density),
            -0.5 * scf.get_hartree_potential(
                density=density,
                grids=grids,
                interaction_fn=utils.exponential_coulomb)[jnp.newaxis, :,
                                                          jnp.newaxis])
Example #7
0
 def half_hartree_potential(density):
   return 0.5 * scf.get_hartree_potential(
       density=density,
       grids=grids,
       interaction_fn=utils.exponential_coulomb)