Exemplo n.º 1
0
 def test_get_atomic_chain_potential_incorrect_ndim(
     self, grids, locations, nuclear_charges, expected_message):
   with self.assertRaisesRegex(ValueError, expected_message):
     utils.get_atomic_chain_potential(
         grids=jnp.array(grids),
         locations=jnp.array(locations),
         nuclear_charges=jnp.array(nuclear_charges),
         interaction_fn=utils.exponential_coulomb)
Exemplo n.º 2
0
 def test_get_atomic_chain_potential_soft_coulomb(self):
   potential = utils.get_atomic_chain_potential(
       grids=jnp.linspace(-10, 10, 201),
       locations=jnp.array([0., 1.]),
       nuclear_charges=jnp.array([2, 1]),
       interaction_fn=utils.soft_coulomb)
   # -2 / jnp.sqrt(10 ** 2 + 1) - 1 / jnp.sqrt(11 ** 2 + 1) = -0.28954318
   self.assertAlmostEqual(float(potential[0]), -0.28954318)
   # -2 / jnp.sqrt(0 ** 2 + 1) - 1 / jnp.sqrt(1 ** 2 + 1) = -2.70710678
   self.assertAlmostEqual(float(potential[100]), -2.70710678)
   # -2 / jnp.sqrt(10 ** 2 + 1) - 1 / jnp.sqrt(9 ** 2 + 1) = -0.30943896
   self.assertAlmostEqual(float(potential[200]), -0.30943896)
Exemplo n.º 3
0
 def test_get_atomic_chain_potential_exponential_coulomb(self):
   potential = utils.get_atomic_chain_potential(
       grids=jnp.linspace(-10, 10, 201),
       locations=jnp.array([0., 1.]),
       nuclear_charges=jnp.array([2, 1]),
       interaction_fn=utils.exponential_coulomb)
   # -2 * 1.071295 * jnp.exp(-np.abs(10) / 2.385345) - 1.071295 * jnp.exp(
   #     -np.abs(11) / 2.385345) = -0.04302427
   self.assertAlmostEqual(float(potential[0]), -0.04302427)
   # -2 * 1.071295 * jnp.exp(-np.abs(0) / 2.385345) - 1.071295 * jnp.exp(
   #     -np.abs(1) / 2.385345) = -2.84702559
   self.assertAlmostEqual(float(potential[100]), -2.84702559)
   # -2 * 1.071295 * jnp.exp(-np.abs(10) / 2.385345) - 1.071295 * jnp.exp(
   #     -np.abs(9) / 2.385345) = -0.05699946
   self.assertAlmostEqual(float(potential[200]), -0.05699946)
Exemplo n.º 4
0
 def _create_testing_initial_state(self, interaction_fn):
     locations = jnp.array([-0.5, 0.5])
     nuclear_charges = jnp.array([1, 1])
     return scf.KohnShamState(
         density=self.num_electrons *
         utils.gaussian(grids=self.grids, center=0., sigma=1.),
         # Set initial energy as inf, the actual value is not used in Kohn-Sham
         # calculation.
         total_energy=jnp.inf,
         locations=locations,
         nuclear_charges=nuclear_charges,
         external_potential=utils.get_atomic_chain_potential(
             grids=self.grids,
             locations=locations,
             nuclear_charges=nuclear_charges,
             interaction_fn=interaction_fn),
         grids=self.grids,
         num_electrons=self.num_electrons)
Exemplo n.º 5
0
def _kohn_sham(locations, nuclear_charges, num_electrons, num_iterations,
               grids, xc_energy_density_fn, interaction_fn, initial_density,
               alpha, alpha_decay, enforce_reflection_symmetry,
               num_mixing_iterations, density_mse_converge_tolerance,
               stop_gradient_step):
    """Jit-able Kohn Sham calculation."""
    num_grids = grids.shape[0]
    weights = _connection_weights(num_iterations, num_mixing_iterations)

    def _converged_kohn_sham_iteration(old_state_differences):
        old_state, differences = old_state_differences
        return old_state._replace(converged=True), differences

    def _uncoveraged_kohn_sham_iteration(idx_old_state_alpha_differences):
        idx, old_state, alpha, differences = idx_old_state_alpha_differences
        state = kohn_sham_iteration(
            state=old_state,
            num_electrons=num_electrons,
            xc_energy_density_fn=xc_energy_density_fn,
            interaction_fn=interaction_fn,
            enforce_reflection_symmetry=enforce_reflection_symmetry)
        differences = jax.ops.index_update(differences, idx,
                                           state.density - old_state.density)
        # Density mixing.
        state = state._replace(density=old_state.density +
                               alpha * jnp.dot(weights[idx], differences))
        return state, differences

    def _single_kohn_sham_iteration(carry, inputs):
        del inputs
        idx, old_state, alpha, converged, differences = carry
        state, differences = jax.lax.cond(
            converged,
            true_operand=(old_state, differences),
            true_fun=_converged_kohn_sham_iteration,
            false_operand=(idx, old_state, alpha, differences),
            false_fun=_uncoveraged_kohn_sham_iteration)
        converged = jnp.mean(
            jnp.square(state.density -
                       old_state.density)) < density_mse_converge_tolerance
        state = jax.lax.cond(idx <= stop_gradient_step,
                             true_fun=jax.lax.stop_gradient,
                             false_fun=lambda x: x,
                             operand=state)
        return (idx + 1, state, alpha * alpha_decay, converged,
                differences), state

    # Create initial state.
    state = scf.KohnShamState(
        density=initial_density,
        total_energy=jnp.inf,
        locations=locations,
        nuclear_charges=nuclear_charges,
        external_potential=utils.get_atomic_chain_potential(
            grids=grids,
            locations=locations,
            nuclear_charges=nuclear_charges,
            interaction_fn=interaction_fn),
        grids=grids,
        num_electrons=num_electrons,
        # Add dummy fields so the input and output of lax.scan have the same type
        # structure.
        hartree_potential=jnp.zeros_like(grids),
        xc_potential=jnp.zeros_like(grids),
        xc_energy_density=jnp.zeros_like(grids),
        gap=0.,
        converged=False)
    # Initialize the density differences with all zeros since the carry in
    # lax.scan must keep the same shape.
    differences = jnp.zeros((num_iterations, num_grids))

    _, states = jax.lax.scan(_single_kohn_sham_iteration,
                             init=(0, state, alpha, state.converged,
                                   differences),
                             xs=jnp.arange(num_iterations))
    return states
Exemplo n.º 6
0
def kohn_sham(
    locations,
    nuclear_charges,
    num_electrons,
    num_iterations,
    grids,
    xc_energy_density_fn,
    interaction_fn,
    initial_density=None,
    alpha=0.5,
    alpha_decay=0.9,
    enforce_reflection_symmetry=False,
    num_mixing_iterations=2,
    density_mse_converge_tolerance=-1.):
  """Runs Kohn-Sham to solve ground state of external potential.

  Args:
    locations: Float numpy array with shape (num_nuclei,), the locations of
        atoms.
    nuclear_charges: Float numpy array with shape (num_nuclei,), the nuclear
        charges.
    num_electrons: Integer, the number of electrons in the system. The first
        num_electrons states are occupid.
    num_iterations: Integer, the number of Kohn-Sham iterations.
    grids: Float numpy array with shape (num_grids,).
    xc_energy_density_fn: function takes density (num_grids,) and returns
        the energy density (num_grids,).
    interaction_fn: function takes displacements and returns
        float numpy array with the same shape of displacements.
    initial_density: Float numpy array with shape (num_grids,), initial guess
        of the density for Kohn-Sham calculation. Default None, the initial
        density is non-interacting solution from the external_potential.
    alpha: Float between 0 and 1, density linear mixing factor, the fraction
        of the output of the k-th Kohn-Sham iteration.
        If 0, the input density to the k-th Kohn-Sham iteration is fed into
        the (k+1)-th iteration. The output of the k-th Kohn-Sham iteration is
        completely ignored.
        If 1, the output density from the k-th Kohn-Sham iteration is fed into
        the (k+1)-th iteration, equivalent to no density mixing.
    alpha_decay: Float between 0 and 1, the decay factor of alpha. The mixing
        factor after k-th iteration is alpha * alpha_decay ** k.
    enforce_reflection_symmetry: Boolean, whether to enforce reflection
        symmetry. If True, the density are symmetric respecting to the center.
    num_mixing_iterations: Integer, the number of density differences in the
        previous iterations to mix the density.
    density_mse_converge_tolerance: Float, the stopping criteria. When the MSE
        density difference between two iterations is smaller than this value,
        the Kohn Sham iterations finish. The outputs of the rest of the steps
        are padded by the output of the converged step. Set this value to
        negative to disable early stopping.

  Returns:
    KohnShamState, the states of all the Kohn-Sham iteration steps.
  """
  external_potential = utils.get_atomic_chain_potential(
      grids=grids,
      locations=locations,
      nuclear_charges=nuclear_charges,
      interaction_fn=interaction_fn)
  if initial_density is None:
    # Use the non-interacting solution from the external_potential as initial
    # guess.
    initial_density, _, _ = solve_noninteracting_system(
        external_potential=external_potential,
        num_electrons=num_electrons,
        grids=grids)
  # Create initial state.
  state = KohnShamState(
      density=initial_density,
      total_energy=jnp.inf,
      locations=locations,
      nuclear_charges=nuclear_charges,
      external_potential=external_potential,
      grids=grids,
      num_electrons=num_electrons)
  states = []
  differences = None
  converged = False
  for _ in range(num_iterations):
    if converged:
      states.append(state)
      continue

    old_state = state
    state = kohn_sham_iteration(
        state=old_state,
        num_electrons=num_electrons,
        xc_energy_density_fn=xc_energy_density_fn,
        interaction_fn=interaction_fn,
        enforce_reflection_symmetry=enforce_reflection_symmetry)
    density_difference = state.density - old_state.density
    if differences is None:
      differences = jnp.array([density_difference])
    else:
      differences = jnp.vstack([differences, density_difference])
    if jnp.mean(
        jnp.square(density_difference)) < density_mse_converge_tolerance:
      converged = True
    state = state._replace(converged=converged)
    # Density mixing.
    state = state._replace(
        density=old_state.density
        + alpha * jnp.mean(differences[-num_mixing_iterations:], axis=0))
    states.append(state)
    alpha *= alpha_decay

  return tree_util.tree_multimap(lambda *x: jnp.stack(x), *states)
Exemplo n.º 7
0
 def _create_testing_external_potential(self, interaction_fn):
     return utils.get_atomic_chain_potential(
         grids=self.grids,
         locations=self.locations,
         nuclear_charges=self.nuclear_charges,
         interaction_fn=interaction_fn)