Пример #1
0
 def output_fn(array):
     output_array = utils.flip_and_average(locations=locations,
                                           grids=grids,
                                           array=array)
     return utils.flip_and_average(locations=locations,
                                   grids=grids,
                                   array=fn(output_array))
Пример #2
0
 def test_flip_and_average_location_not_on_grids(self):
     with self.assertRaisesRegex(ValueError,
                                 r'Location 0\.25 is not on the grids'):
         utils.flip_and_average(
             # 0.25 is not on the grids.
             locations=jnp.array([0.0, 0.25]),
             grids=jnp.array([-0.1, 0.0, 0.1, 0.2, 0.3]),
             # Values of array do not matter in this test.
             array=jnp.array([0.1, 0.2, 0.6, 0.7, 0.2]))
Пример #3
0
 def test_flip_and_average_the_back_of_array_center_not_on_grids(self):
   np.testing.assert_allclose(
       utils.flip_and_average(
           locations=jnp.array([0.4, 0.5]),
           grids=jnp.array([-0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
           array=jnp.array([0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8])),
       # The center is 0.45, which is the grid point between index 6 and 7.
       # The array on the grids [0.3, 0.4, 0.5, 0.6] are flipped:
       # [0.3, 0.5, 0.1, 0.8]
       # -> [0.8, 0.1, 0.5, 0.3]
       # The averaged array is
       # [0.55, 0.3, 0.3, 0.55]
       # Replace the corresponding range (slice(5, 9)) in the original array:
       # [0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8]
       # -> [0.1, 0.2, 0.6, 0.7, 0.2, 0.55, 0.3, 0.3, 0.55]
       [0.1, 0.2, 0.6, 0.7, 0.2, 0.55, 0.3, 0.3, 0.55])
Пример #4
0
 def test_flip_and_average_the_front_of_array_center_on_grids(self):
   np.testing.assert_allclose(
       utils.flip_and_average(
           locations=jnp.array([-0.1, 0.3]),
           grids=jnp.array([-0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
           array=jnp.array([0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8])),
       # The center is 0.1, which is the grid point with index 3.
       # The array on the grids [-0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4]
       # are flipped:
       # [0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5]
       # -> [0.5, 0.3, 0.2, 0.7, 0.6, 0.2, 0.1]
       # The averaged array is
       # [0.3, 0.25, 0.4, 0.7, 0.4, 0.25, 0.3]
       # Replace the corresponding range (slice(0, 7)) in the original array:
       # [0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8]
       # -> [0.3, 0.25, 0.4, 0.7, 0.4, 0.25, 0.3, 0.1, 0.8]
       [0.3, 0.25, 0.4, 0.7, 0.4, 0.25, 0.3, 0.1, 0.8])
Пример #5
0
def kohn_sham_iteration(state, num_electrons, xc_energy_density_fn,
                        interaction_fn, enforce_reflection_symmetry):
    """One iteration of Kohn-Sham calculation.

  Note xc_energy_density_fn must be wrapped by jax.tree_util.Partial so this
  function can take a callable. When the arguments of this callable changes,
  e.g. the parameters of the neural network, kohn_sham_iteration() will not be
  recompiled.

  Args:
    state: KohnShamState.
    num_electrons: Integer, the number of electrons in the system. The first
        num_electrons states are occupid.
    xc_energy_density_fn: function takes density (num_grids,) and returns
        the energy density (num_grids,).
    interaction_fn: function takes displacements and returns
        float numpy array with the same shape of displacements.
    enforce_reflection_symmetry: Boolean, whether to enforce reflection
        symmetry. If True, the system are symmetric respecting to the center.

  Returns:
    KohnShamState, the next state of Kohn-Sham iteration.
  """
    if enforce_reflection_symmetry:
        xc_energy_density_fn = _flip_and_average_fn(xc_energy_density_fn,
                                                    locations=state.locations,
                                                    grids=state.grids)

    hartree_potential = get_hartree_potential(density=state.density,
                                              grids=state.grids,
                                              interaction_fn=interaction_fn)
    xc_potential = get_xc_potential(density=state.density,
                                    xc_energy_density_fn=xc_energy_density_fn,
                                    grids=state.grids)
    ks_potential = hartree_potential + xc_potential + state.external_potential
    xc_energy_density = xc_energy_density_fn(state.density)

    # Solve Kohn-Sham equation.
    density, total_eigen_energies, gap = solve_noninteracting_system(
        external_potential=ks_potential,
        num_electrons=num_electrons,
        grids=state.grids)

    total_energy = (
        # kinetic energy = total_eigen_energies - external_potential_energy
        total_eigen_energies -
        get_external_potential_energy(external_potential=ks_potential,
                                      density=density,
                                      grids=state.grids)
        # Hartree energy
        + get_hartree_energy(
            density=density, grids=state.grids, interaction_fn=interaction_fn)
        # xc energy
        + get_xc_energy(density=density,
                        xc_energy_density_fn=xc_energy_density_fn,
                        grids=state.grids)
        # external energy
        + get_external_potential_energy(
            external_potential=state.external_potential,
            density=density,
            grids=state.grids))

    if enforce_reflection_symmetry:
        density = utils.flip_and_average(locations=state.locations,
                                         grids=state.grids,
                                         array=density)

    return state._replace(density=density,
                          total_energy=total_energy,
                          hartree_potential=hartree_potential,
                          xc_potential=xc_potential,
                          xc_energy_density=xc_energy_density,
                          gap=gap)