Ejemplo n.º 1
0
 def output_fn(array):
     output_array = utils.flip_and_average(locations=locations,
                                           grids=grids,
                                           array=array)
     return utils.flip_and_average(locations=locations,
                                   grids=grids,
                                   array=fn(output_array))
Ejemplo n.º 2
0
 def test_flip_and_average_location_not_on_grids(self):
   with self.assertRaisesRegex(
       ValueError, r'Location 0\.25 is not on the grids'):
     utils.flip_and_average(
         # 0.25 is not on the grids.
         locations=jnp.array([0.0, 0.25]),
         grids=jnp.array([-0.1, 0.0, 0.1, 0.2, 0.3]),
         # Values of array do not matter in this test.
         array=jnp.array([0.1, 0.2, 0.6, 0.7, 0.2]))
Ejemplo n.º 3
0
 def test_flip_and_average_the_back_of_array_center_not_on_grids(self):
   np.testing.assert_allclose(
       utils.flip_and_average(
           locations=jnp.array([0.4, 0.5]),
           grids=jnp.array([-0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
           array=jnp.array([0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8])),
       # The center is 0.45, which is the grid point between index 6 and 7.
       # The array on the grids [0.3, 0.4, 0.5, 0.6] are flipped:
       # [0.3, 0.5, 0.1, 0.8]
       # -> [0.8, 0.1, 0.5, 0.3]
       # The averaged array is
       # [0.55, 0.3, 0.3, 0.55]
       # Replace the corresponding range (slice(5, 9)) in the original array:
       # [0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8]
       # -> [0.1, 0.2, 0.6, 0.7, 0.2, 0.55, 0.3, 0.3, 0.55]
       [0.1, 0.2, 0.6, 0.7, 0.2, 0.55, 0.3, 0.3, 0.55])
Ejemplo n.º 4
0
 def test_flip_and_average_the_front_of_array_center_on_grids(self):
   np.testing.assert_allclose(
       utils.flip_and_average(
           locations=jnp.array([-0.1, 0.3]),
           grids=jnp.array([-0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
           array=jnp.array([0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8])),
       # The center is 0.1, which is the grid point with index 3.
       # The array on the grids [-0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4]
       # are flipped:
       # [0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5]
       # -> [0.5, 0.3, 0.2, 0.7, 0.6, 0.2, 0.1]
       # The averaged array is
       # [0.3, 0.25, 0.4, 0.7, 0.4, 0.25, 0.3]
       # Replace the corresponding range (slice(0, 7)) in the original array:
       # [0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8]
       # -> [0.3, 0.25, 0.4, 0.7, 0.4, 0.25, 0.3, 0.1, 0.8]
       [0.3, 0.25, 0.4, 0.7, 0.4, 0.25, 0.3, 0.1, 0.8])
Ejemplo n.º 5
0
def kohn_sham_iteration(
    state,
    num_electrons,
    xc_energy_density_fn,
    interaction_fn,
    enforce_reflection_symmetry):
  """One iteration of Kohn-Sham calculation.

  Note xc_energy_density_fn must be wrapped by jax.tree_util.Partial so this
  function can take a callable. When the arguments of this callable changes,
  e.g. the parameters of the neural network, kohn_sham_iteration() will not be
  recompiled.

  Args:
    state: KohnShamState.
    num_electrons: Integer, the number of electrons in the system. The first
        num_electrons states are occupid.
    xc_energy_density_fn: function takes density (num_grids,) and returns
        the energy density (num_grids,).
    interaction_fn: function takes displacements and returns
        float numpy array with the same shape of displacements.
    enforce_reflection_symmetry: Boolean, whether to enforce reflection
        symmetry. If True, the system are symmetric respecting to the center.

  Returns:
    KohnShamState, the next state of Kohn-Sham iteration.
  """
  if enforce_reflection_symmetry:
    xc_energy_density_fn = _flip_and_average_fn(
        xc_energy_density_fn, locations=state.locations, grids=state.grids)

  hartree_potential = get_hartree_potential(
      density=state.density,
      grids=state.grids,
      interaction_fn=interaction_fn)
  xc_potential = get_xc_potential(
      density=state.density,
      xc_energy_density_fn=xc_energy_density_fn,
      grids=state.grids)
  ks_potential = hartree_potential + xc_potential + state.external_potential
  xc_energy_density = xc_energy_density_fn(state.density)

  # Solve Kohn-Sham equation.
  density, total_eigen_energies, gap = solve_noninteracting_system(
      external_potential=ks_potential,
      num_electrons=num_electrons,
      grids=state.grids)

  total_energy = (
      # kinetic energy = total_eigen_energies - external_potential_energy
      total_eigen_energies
      - get_external_potential_energy(
          external_potential=ks_potential,
          density=density,
          grids=state.grids)
      # Hartree energy
      + get_hartree_energy(
          density=density,
          grids=state.grids,
          interaction_fn=interaction_fn)
      # xc energy
      + get_xc_energy(
          density=density,
          xc_energy_density_fn=xc_energy_density_fn,
          grids=state.grids)
      # external energy
      + get_external_potential_energy(
          external_potential=state.external_potential,
          density=density,
          grids=state.grids)
      )

  if enforce_reflection_symmetry:
    density = utils.flip_and_average(
        locations=state.locations, grids=state.grids, array=density)

  return state._replace(
      density=density,
      total_energy=total_energy,
      hartree_potential=hartree_potential,
      xc_potential=xc_potential,
      xc_energy_density=xc_energy_density,
      gap=gap)