def output_fn(array): output_array = utils.flip_and_average(locations=locations, grids=grids, array=array) return utils.flip_and_average(locations=locations, grids=grids, array=fn(output_array))
def test_flip_and_average_location_not_on_grids(self): with self.assertRaisesRegex( ValueError, r'Location 0\.25 is not on the grids'): utils.flip_and_average( # 0.25 is not on the grids. locations=jnp.array([0.0, 0.25]), grids=jnp.array([-0.1, 0.0, 0.1, 0.2, 0.3]), # Values of array do not matter in this test. array=jnp.array([0.1, 0.2, 0.6, 0.7, 0.2]))
def test_flip_and_average_the_back_of_array_center_not_on_grids(self): np.testing.assert_allclose( utils.flip_and_average( locations=jnp.array([0.4, 0.5]), grids=jnp.array([-0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]), array=jnp.array([0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8])), # The center is 0.45, which is the grid point between index 6 and 7. # The array on the grids [0.3, 0.4, 0.5, 0.6] are flipped: # [0.3, 0.5, 0.1, 0.8] # -> [0.8, 0.1, 0.5, 0.3] # The averaged array is # [0.55, 0.3, 0.3, 0.55] # Replace the corresponding range (slice(5, 9)) in the original array: # [0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8] # -> [0.1, 0.2, 0.6, 0.7, 0.2, 0.55, 0.3, 0.3, 0.55] [0.1, 0.2, 0.6, 0.7, 0.2, 0.55, 0.3, 0.3, 0.55])
def test_flip_and_average_the_front_of_array_center_on_grids(self): np.testing.assert_allclose( utils.flip_and_average( locations=jnp.array([-0.1, 0.3]), grids=jnp.array([-0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]), array=jnp.array([0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8])), # The center is 0.1, which is the grid point with index 3. # The array on the grids [-0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4] # are flipped: # [0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5] # -> [0.5, 0.3, 0.2, 0.7, 0.6, 0.2, 0.1] # The averaged array is # [0.3, 0.25, 0.4, 0.7, 0.4, 0.25, 0.3] # Replace the corresponding range (slice(0, 7)) in the original array: # [0.1, 0.2, 0.6, 0.7, 0.2, 0.3, 0.5, 0.1, 0.8] # -> [0.3, 0.25, 0.4, 0.7, 0.4, 0.25, 0.3, 0.1, 0.8] [0.3, 0.25, 0.4, 0.7, 0.4, 0.25, 0.3, 0.1, 0.8])
def kohn_sham_iteration( state, num_electrons, xc_energy_density_fn, interaction_fn, enforce_reflection_symmetry): """One iteration of Kohn-Sham calculation. Note xc_energy_density_fn must be wrapped by jax.tree_util.Partial so this function can take a callable. When the arguments of this callable changes, e.g. the parameters of the neural network, kohn_sham_iteration() will not be recompiled. Args: state: KohnShamState. num_electrons: Integer, the number of electrons in the system. The first num_electrons states are occupid. xc_energy_density_fn: function takes density (num_grids,) and returns the energy density (num_grids,). interaction_fn: function takes displacements and returns float numpy array with the same shape of displacements. enforce_reflection_symmetry: Boolean, whether to enforce reflection symmetry. If True, the system are symmetric respecting to the center. Returns: KohnShamState, the next state of Kohn-Sham iteration. """ if enforce_reflection_symmetry: xc_energy_density_fn = _flip_and_average_fn( xc_energy_density_fn, locations=state.locations, grids=state.grids) hartree_potential = get_hartree_potential( density=state.density, grids=state.grids, interaction_fn=interaction_fn) xc_potential = get_xc_potential( density=state.density, xc_energy_density_fn=xc_energy_density_fn, grids=state.grids) ks_potential = hartree_potential + xc_potential + state.external_potential xc_energy_density = xc_energy_density_fn(state.density) # Solve Kohn-Sham equation. density, total_eigen_energies, gap = solve_noninteracting_system( external_potential=ks_potential, num_electrons=num_electrons, grids=state.grids) total_energy = ( # kinetic energy = total_eigen_energies - external_potential_energy total_eigen_energies - get_external_potential_energy( external_potential=ks_potential, density=density, grids=state.grids) # Hartree energy + get_hartree_energy( density=density, grids=state.grids, interaction_fn=interaction_fn) # xc energy + get_xc_energy( density=density, xc_energy_density_fn=xc_energy_density_fn, grids=state.grids) # external energy + get_external_potential_energy( external_potential=state.external_potential, density=density, grids=state.grids) ) if enforce_reflection_symmetry: density = utils.flip_and_average( locations=state.locations, grids=state.grids, array=density) return state._replace( density=density, total_energy=total_energy, hartree_potential=hartree_potential, xc_potential=xc_potential, xc_energy_density=xc_energy_density, gap=gap)