Beispiel #1
0
def _get_hartree_energy(density, grids, interaction_fn):
    """Gets the Hartree energy."""
    n1 = jnp.expand_dims(density, axis=0)
    n2 = jnp.expand_dims(density, axis=1)
    r1 = jnp.expand_dims(grids, axis=0)
    r2 = jnp.expand_dims(grids, axis=1)
    return 0.5 * jnp.sum(
        n1 * n2 * interaction_fn(r1 - r2)) * utils.get_dx(grids)**2
 def test_self_interaction_weight(self, density_integral, expected_weight):
     grids = jnp.linspace(-5, 5, 11)
     self.assertAlmostEqual(
         neural_xc.self_interaction_weight(
             reshaped_density=density_integral *
             utils.gaussian(grids=grids, center=1.,
                            sigma=1.)[jnp.newaxis, :, jnp.newaxis],
             dx=utils.get_dx(grids),
             width=1.), expected_weight)
Beispiel #3
0
def _wavefunctions_to_density(num_electrons, wavefunctions, grids):
    """Converts wavefunctions to density."""
    # Reduce the amount of computation by removing most of the unoccupid states.
    wavefunctions = wavefunctions[:num_electrons]
    # Normalize the wavefunctions.
    wavefunctions = wavefunctions / jnp.sqrt(
        jnp.sum(wavefunctions**2, axis=1, keepdims=True) * utils.get_dx(grids))
    # Each eigenstate has spin up and spin down.
    intensities = jnp.repeat(wavefunctions**2, repeats=2, axis=0)
    return jnp.sum(intensities[:num_electrons], axis=0)
Beispiel #4
0
 def loss(flatten_params, initial_state, target_density):
   state = scf.kohn_sham_iteration(
       state=initial_state,
       num_electrons=self.num_electrons,
       xc_energy_density_fn=tree_util.Partial(
           xc_energy_density_fn,
           params=np_utils.unflatten(spec, flatten_params)),
       interaction_fn=utils.exponential_coulomb,
       enforce_reflection_symmetry=True)
   return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(
       self.grids)
Beispiel #5
0
def get_kinetic_matrix(grids):
    """Gets kinetic matrix.

  Args:
    grids: Float numpy array with shape (num_grids,).

  Returns:
    Float numpy array with shape (num_grids, num_grids).
  """
    dx = utils.get_dx(grids)
    return -0.5 * discrete_laplacian(grids.size) / (dx * dx)
Beispiel #6
0
def get_external_potential_energy(external_potential, density, grids):
    """Gets external potential energy.

  Args:
    external_potential: Float numpy array with shape (num_grids,).
    density: Float numpy array with shape (num_grids,).
    grids: Float numpy array with shape (num_grids,).

  Returns:
    Float.
  """
    return jnp.dot(density, external_potential) * utils.get_dx(grids)
 def loss(flatten_params, target_density):
     state = scf.kohn_sham(locations=self.locations,
                           nuclear_charges=self.nuclear_charges,
                           num_electrons=self.num_electrons,
                           num_iterations=3,
                           grids=self.grids,
                           xc_energy_density_fn=tree_util.Partial(
                               xc_energy_density_fn,
                               params=np_utils.unflatten(
                                   spec, flatten_params)),
                           interaction_fn=utils.exponential_coulomb,
                           density_mse_converge_tolerance=-1)
     final_state = scf.get_final_state(state)
     return jnp.sum(jnp.abs(final_state.density -
                            target_density)) * utils.get_dx(self.grids)
Beispiel #8
0
def get_xc_energy(density, xc_energy_density_fn, grids):
  r"""Gets xc energy.

  E_xc = \int density * xc_energy_density_fn(density) dx.

  Args:
    density: Float numpy array with shape (num_grids,).
    xc_energy_density_fn: function takes density and returns float numpy array
        with shape (num_grids,).
    grids: Float numpy array with shape (num_grids,).

  Returns:
    Float.
  """
  return jnp.dot(xc_energy_density_fn(density), density) * utils.get_dx(grids)
Beispiel #9
0
  def test_get_hartree_energy(self, interaction_fn):
    grids = jnp.linspace(-5, 5, 11)
    dx = utils.get_dx(grids)
    density = utils.gaussian(grids=grids, center=1., sigma=1.)

    # Compute the expected Hartree energy by nested for loops.
    expected_hartree_energy = 0.
    for x_0, n_0 in zip(grids, density):
      for x_1, n_1 in zip(grids, density):
        expected_hartree_energy += 0.5 * n_0 * n_1 * interaction_fn(
            x_0 - x_1) * dx ** 2

    self.assertAlmostEqual(
        float(scf.get_hartree_energy(
            density=density, grids=grids, interaction_fn=interaction_fn)),
        float(expected_hartree_energy))
Beispiel #10
0
  def test_get_hartree_potential(self, interaction_fn):
    grids = jnp.linspace(-5, 5, 11)
    dx = utils.get_dx(grids)
    density = utils.gaussian(grids=grids, center=1., sigma=1.)

    # Compute the expected Hartree energy by nested for loops.
    expected_hartree_potential = np.zeros_like(grids)
    for i, x_0 in enumerate(grids):
      for x_1, n_1 in zip(grids, density):
        expected_hartree_potential[i] += np.sum(
            n_1 * interaction_fn(x_0 - x_1)) * dx

    np.testing.assert_allclose(
        scf.get_hartree_potential(
            density=density, grids=grids, interaction_fn=interaction_fn),
        expected_hartree_potential)
Beispiel #11
0
def spherical_superposition_density(grids, locations, nuclear_charges):
  """Builds initial guess of density by superposition of spherical densities.

  Args:
    grids: Float numpy array with shape (num_grids,).
    locations: Float numpy array with shape (num_nuclei,).
    nuclear_charges: Float numpy array with shape (num_nuclei,).

  Returns:
    Float numpy array with shape (num_grids,).
  """
  # (num_nuclei, num_grids)
  displacements = np.expand_dims(
      np.array(grids), axis=0) - np.expand_dims(np.array(locations), axis=1)
  density = _get_exact_h_atom_density(displacements, float(utils.get_dx(grids)))
  return np.dot(nuclear_charges, density)
def self_interaction_layer(grids, interaction_fn):
    """Layer construction function for self-interaction.

  The first input is density and the second input is the feature to mix.

  When the density integral is one, this layer outputs -0.5 * Hartree potential
  in the same shape of the density which will cancel the Hartree term.
  When the density integral is not one, the output is a linear combination of
  two inputs to this layer. The weight are determined by
  self_interaction_weight().

  Args:
    grids: Float numpy array with shape (num_grids,).
    interaction_fn: function takes displacements and returns
        float numpy array with the same shape of displacements.

  Returns:
    (init_fn, apply_fn) pair.
  """
    grids = grids.astype(jnp.float32)
    dx = utils.get_dx(grids)

    def init_fn(rng, input_shape):
        del rng
        if len(input_shape) != 2:
            raise ValueError(f'self_interaction_layer must have two inputs, '
                             f'but got {len(input_shape)}')
        if input_shape[0] != input_shape[1]:
            raise ValueError(
                f'The input shape to self_interaction_layer must be equal, '
                f'but got {input_shape[0]} and {input_shape[1]}')
        return input_shape[0], (jnp.array(1.), )

    def apply_fn(params, inputs, **kwargs):  # pylint: disable=missing-docstring
        del kwargs
        width, = params
        reshaped_density, features = inputs
        beta = self_interaction_weight(reshaped_density=reshaped_density,
                                       dx=dx,
                                       width=width)
        hartree = -0.5 * scf.get_hartree_potential(
            density=reshaped_density.reshape(-1),
            grids=grids,
            interaction_fn=interaction_fn).reshape(reshaped_density.shape)
        return hartree * beta + features * (1 - beta)

    return init_fn, apply_fn
Beispiel #13
0
def get_xc_potential(density, xc_energy_density_fn, grids):
    """Gets xc potential.

  The xc potential is derived from xc_energy_density through automatic
  differentiation.

  Args:
    density: Float numpy array with shape (num_grids,).
    xc_energy_density_fn: function takes density and returns float numpy array
        with shape (num_grids,).
    grids: Float numpy array with shape (num_grids,).

  Returns:
    Float numpy array with shape (num_grids,).
  """
    return jax.grad(get_xc_energy)(density, xc_energy_density_fn,
                                   grids) / utils.get_dx(grids)
Beispiel #14
0
 def _test_state(self, state, external_potential):
     # The density in the final state should normalize to number of electrons.
     self.assertAlmostEqual(
         float(jnp.sum(state.density) * utils.get_dx(self.grids)),
         self.num_electrons)
     # The total energy should be finite after iterations.
     self.assertTrue(jnp.isfinite(state.total_energy))
     self.assertLen(state.hartree_potential, len(state.grids))
     self.assertLen(state.xc_potential, len(state.grids))
     # locations, nuclear_charges, external_potential, grids and num_electrons
     # remain unchanged.
     np.testing.assert_allclose(state.locations, self.locations)
     np.testing.assert_allclose(state.nuclear_charges, self.nuclear_charges)
     np.testing.assert_allclose(external_potential,
                                state.external_potential)
     np.testing.assert_allclose(state.grids, self.grids)
     self.assertEqual(state.num_electrons, self.num_electrons)
     self.assertGreater(state.gap, 0)
Beispiel #15
0
 def test_get_dx_incorrect_ndim(self):
     with self.assertRaisesRegex(
             ValueError, 'grids.ndim is expected to be 1 but got 2'):
         utils.get_dx(jnp.array([[-0.1], [0.], [0.1]]))
Beispiel #16
0
def _get_hartree_potential(density, grids, interaction_fn):
    """Gets the Hartree potential."""
    n1 = jnp.expand_dims(density, axis=0)
    r1 = jnp.expand_dims(grids, axis=0)
    r2 = jnp.expand_dims(grids, axis=1)
    return jnp.sum(n1 * interaction_fn(r1 - r2), axis=1) * utils.get_dx(grids)
Beispiel #17
0
def exponential_global_convolution(
    num_channels,
    grids,
    minval,
    maxval,
    downsample_factor=0,
    eta_init=nn.initializers.normal()):
  """Layer construction function for exponential global convolution.

  Args:
    num_channels: Integer, the number of channels.
    grids: Float numpy array with shape (num_grids,).
    minval: Float, the min value in the uniform sampling for exponential width.
    maxval: Float, the max value in the uniform sampling for exponential width.
    downsample_factor: Integer, the factor of downsampling. The grids are
        downsampled with step size 2 ** downsample_factor.
    eta_init: Initializer function in nn.initializers.

  Returns:
    (init_fn, apply_fn) pair.
  """
  grids = grids[::2 ** downsample_factor]
  displacements = jnp.expand_dims(
      grids, axis=0) - jnp.expand_dims(grids, axis=1)
  dx = utils.get_dx(grids)

  def init_fn(rng, input_shape):
    if num_channels <= 0:
      raise ValueError(f'num_channels must be positive but got {num_channels}')
    if len(input_shape) != 3:
      raise ValueError(
          f'The ndim of input should be 3, but got {len(input_shape)}')
    if input_shape[1] != len(grids):
      raise ValueError(
          f'input_shape[1] should be len(grids), but got {input_shape[1]}')
    if input_shape[2] != 1:
      raise ValueError(
          f'input_shape[2] should be 1, but got {input_shape[2]}')
    output_shape = input_shape[:-1] + (num_channels,)
    eta = eta_init(rng, shape=(num_channels,))
    return output_shape, (eta,)

  def apply_fn(params, inputs, **kwargs):
    """Applies layer.

    Args:
      params: Layer parameters, (eta,).
      inputs: Float numpy array with shape
          (batch_size, num_grids, num_in_channels).
      **kwargs: Other key word arguments. Unused.

    Returns:
      Float numpy array with shape (batch_size, num_grids, num_channels).
    """
    del kwargs
    eta, = params
    # shape (num_grids, num_grids, num_channels)
    kernels = _exponential_function_channels(
        displacements, widths=minval + (maxval - minval) * nn.sigmoid(eta))
    # shape (batch_size, num_grids, num_channels)
    return jnp.squeeze(
        # shape (batch_size, 1, num_grids, num_channels)
        jnp.tensordot(inputs, kernels, axes=(1, 0)) * dx,
        axis=1)

  return init_fn, apply_fn
Beispiel #18
0
 def test_get_dx(self):
     self.assertAlmostEqual(utils.get_dx(jnp.linspace(0, 1, 11)), 0.1)