Example #1
0
def _get_hartree_energy(density, grids, interaction_fn):
  """Gets the Hartree energy."""
  n1 = jnp.expand_dims(density, axis=0)
  n2 = jnp.expand_dims(density, axis=1)
  r1 = jnp.expand_dims(grids, axis=0)
  r2 = jnp.expand_dims(grids, axis=1)
  return 0.5 * jnp.sum(
      n1 * n2 * interaction_fn(r1 - r2)) * utils.get_dx(grids) ** 2
 def test_self_interaction_weight(self, density_integral, expected_weight):
     grids = jnp.linspace(-5, 5, 11)
     self.assertAlmostEqual(
         neural_xc.self_interaction_weight(
             reshaped_density=density_integral *
             utils.gaussian(grids=grids, center=1.,
                            sigma=1.)[jnp.newaxis, :, jnp.newaxis],
             dx=utils.get_dx(grids),
             width=1.), expected_weight)
Example #3
0
def _wavefunctions_to_density(num_electrons, wavefunctions, grids):
  """Converts wavefunctions to density."""
  # Reduce the amount of computation by removing most of the unoccupid states.
  wavefunctions = wavefunctions[:num_electrons]
  # Normalize the wavefunctions.
  wavefunctions = wavefunctions / jnp.sqrt(jnp.sum(
      wavefunctions ** 2, axis=1, keepdims=True) * utils.get_dx(grids))
  # Each eigenstate has spin up and spin down.
  intensities = jnp.repeat(wavefunctions ** 2, repeats=2, axis=0)
  return jnp.sum(intensities[:num_electrons], axis=0)
Example #4
0
 def loss(flatten_params, initial_state, target_density):
     state = scf.kohn_sham_iteration(
         state=initial_state,
         num_electrons=self.num_electrons,
         xc_energy_density_fn=tree_util.Partial(
             xc_energy_density_fn,
             params=np_utils.unflatten(spec, flatten_params)),
         interaction_fn=utils.exponential_coulomb,
         enforce_reflection_symmetry=True)
     return jnp.sum(jnp.abs(state.density -
                            target_density)) * utils.get_dx(self.grids)
Example #5
0
def get_kinetic_matrix(grids):
  """Gets kinetic matrix.

  Args:
    grids: Float numpy array with shape (num_grids,).

  Returns:
    Float numpy array with shape (num_grids, num_grids).
  """
  dx = utils.get_dx(grids)
  return -0.5 * discrete_laplacian(grids.size) / (dx * dx)
Example #6
0
def get_external_potential_energy(external_potential, density, grids):
  """Gets external potential energy.

  Args:
    external_potential: Float numpy array with shape (num_grids,).
    density: Float numpy array with shape (num_grids,).
    grids: Float numpy array with shape (num_grids,).

  Returns:
    Float.
  """
  return jnp.dot(density, external_potential) * utils.get_dx(grids)
Example #7
0
def get_xc_energy(density, xc_energy_density_fn, grids):
  r"""Gets xc energy.

  E_xc = \int density * xc_energy_density_fn(density) dx.

  Args:
    density: Float numpy array with shape (num_grids,).
    xc_energy_density_fn: function takes density and returns float numpy array
        with shape (num_grids,).
    grids: Float numpy array with shape (num_grids,).

  Returns:
    Float.
  """
  return jnp.dot(xc_energy_density_fn(density), density) * utils.get_dx(grids)
Example #8
0
 def loss(flatten_params, target_density):
     state = scf.kohn_sham(locations=self.locations,
                           nuclear_charges=self.nuclear_charges,
                           num_electrons=self.num_electrons,
                           num_iterations=3,
                           grids=self.grids,
                           xc_energy_density_fn=tree_util.Partial(
                               xc_energy_density_fn,
                               params=np_utils.unflatten(
                                   spec, flatten_params)),
                           interaction_fn=utils.exponential_coulomb,
                           density_mse_converge_tolerance=-1)
     final_state = scf.get_final_state(state)
     return jnp.sum(jnp.abs(final_state.density -
                            target_density)) * utils.get_dx(self.grids)
Example #9
0
def self_interaction_layer(grids, interaction_fn):
    """Layer construction function for self-interaction.

  The first input is density and the second input is the feature to mix.

  When the density integral is one, this layer outputs -0.5 * Hartree potential
  in the same shape of the density which will cancel the Hartree term.
  When the density integral is not one, the output is a linear combination of
  two inputs to this layer. The weight are determined by
  self_interaction_weight().

  Args:
    grids: Float numpy array with shape (num_grids,).
    interaction_fn: function takes displacements and returns
        float numpy array with the same shape of displacements.

  Returns:
    (init_fn, apply_fn) pair.
  """
    grids = grids.astype(jnp.float32)
    dx = utils.get_dx(grids)

    def init_fn(rng, input_shape):
        del rng
        if len(input_shape) != 2:
            raise ValueError(f'self_interaction_layer must have two inputs, '
                             f'but got {len(input_shape)}')
        if input_shape[0] != input_shape[1]:
            raise ValueError(
                f'The input shape to self_interaction_layer must be equal, '
                f'but got {input_shape[0]} and {input_shape[1]}')
        return input_shape[0], (jnp.array(1.), )

    def apply_fn(params, inputs, **kwargs):  # pylint: disable=missing-docstring
        del kwargs
        width, = params
        reshaped_density, features = inputs
        beta = self_interaction_weight(reshaped_density=reshaped_density,
                                       dx=dx,
                                       width=width)
        hartree = -0.5 * scf.get_hartree_potential(
            density=reshaped_density.reshape(-1),
            grids=grids,
            interaction_fn=interaction_fn).reshape(reshaped_density.shape)
        return hartree * beta + features * (1 - beta)

    return init_fn, apply_fn
Example #10
0
def get_xc_potential(density, xc_energy_density_fn, grids):
  """Gets xc potential.

  The xc potential is derived from xc_energy_density through automatic
  differentiation.

  Args:
    density: Float numpy array with shape (num_grids,).
    xc_energy_density_fn: function takes density and returns float numpy array
        with shape (num_grids,).
    grids: Float numpy array with shape (num_grids,).

  Returns:
    Float numpy array with shape (num_grids,).
  """
  return jax.grad(get_xc_energy)(
      density, xc_energy_density_fn, grids) / utils.get_dx(grids)
Example #11
0
    def test_get_hartree_potential(self, interaction_fn):
        grids = jnp.linspace(-5, 5, 11)
        dx = utils.get_dx(grids)
        density = utils.gaussian(grids=grids, center=1., sigma=1.)

        # Compute the expected Hartree energy by nested for loops.
        expected_hartree_potential = np.zeros_like(grids)
        for i, x_0 in enumerate(grids):
            for x_1, n_1 in zip(grids, density):
                expected_hartree_potential[i] += np.sum(
                    n_1 * interaction_fn(x_0 - x_1)) * dx

        np.testing.assert_allclose(
            scf.get_hartree_potential(density=density,
                                      grids=grids,
                                      interaction_fn=interaction_fn),
            expected_hartree_potential)
Example #12
0
def spherical_superposition_density(grids, locations, nuclear_charges):
    """Builds initial guess of density by superposition of spherical densities.

  Args:
    grids: Float numpy array with shape (num_grids,).
    locations: Float numpy array with shape (num_nuclei,).
    nuclear_charges: Float numpy array with shape (num_nuclei,).

  Returns:
    Float numpy array with shape (num_grids,).
  """
    # (num_nuclei, num_grids)
    displacements = np.expand_dims(np.array(grids), axis=0) - np.expand_dims(
        np.array(locations), axis=1)
    density = _get_exact_h_atom_density(displacements,
                                        float(utils.get_dx(grids)))
    return np.dot(nuclear_charges, density)
Example #13
0
    def test_get_hartree_energy(self, interaction_fn):
        grids = jnp.linspace(-5, 5, 11)
        dx = utils.get_dx(grids)
        density = utils.gaussian(grids=grids, center=1., sigma=1.)

        # Compute the expected Hartree energy by nested for loops.
        expected_hartree_energy = 0.
        for x_0, n_0 in zip(grids, density):
            for x_1, n_1 in zip(grids, density):
                expected_hartree_energy += 0.5 * n_0 * n_1 * interaction_fn(
                    x_0 - x_1) * dx**2

        self.assertAlmostEqual(
            float(
                scf.get_hartree_energy(density=density,
                                       grids=grids,
                                       interaction_fn=interaction_fn)),
            float(expected_hartree_energy))
Example #14
0
 def _test_state(self, state, external_potential):
     # The density in the final state should normalize to number of electrons.
     self.assertAlmostEqual(
         float(jnp.sum(state.density) * utils.get_dx(self.grids)),
         self.num_electrons)
     # The total energy should be finite after iterations.
     self.assertTrue(jnp.isfinite(state.total_energy))
     self.assertLen(state.hartree_potential, len(state.grids))
     self.assertLen(state.xc_potential, len(state.grids))
     # locations, nuclear_charges, external_potential, grids and num_electrons
     # remain unchanged.
     np.testing.assert_allclose(state.locations, self.locations)
     np.testing.assert_allclose(state.nuclear_charges, self.nuclear_charges)
     np.testing.assert_allclose(external_potential,
                                state.external_potential)
     np.testing.assert_allclose(state.grids, self.grids)
     self.assertEqual(state.num_electrons, self.num_electrons)
     self.assertGreater(state.gap, 0)
Example #15
0
 def test_get_dx(self):
   self.assertAlmostEqual(utils.get_dx(jnp.linspace(0, 1, 11)), 0.1)
Example #16
0
 def test_get_dx_incorrect_ndim(self):
   with self.assertRaisesRegex(
       ValueError, 'grids.ndim is expected to be 1 but got 2'):
     utils.get_dx(jnp.array([[-0.1], [0.], [0.1]]))
Example #17
0
def exponential_global_convolution(num_channels,
                                   grids,
                                   minval,
                                   maxval,
                                   downsample_factor=0,
                                   eta_init=nn.initializers.normal()):
    """Layer construction function for exponential global convolution.

  Args:
    num_channels: Integer, the number of channels.
    grids: Float numpy array with shape (num_grids,).
    minval: Float, the min value in the uniform sampling for exponential width.
    maxval: Float, the max value in the uniform sampling for exponential width.
    downsample_factor: Integer, the factor of downsampling. The grids are
        downsampled with step size 2 ** downsample_factor.
    eta_init: Initializer function in nn.initializers.

  Returns:
    (init_fn, apply_fn) pair.
  """
    grids = grids.astype(jnp.float32)[::2**downsample_factor]
    displacements = jnp.expand_dims(grids, axis=0) - jnp.expand_dims(grids,
                                                                     axis=1)
    dx = utils.get_dx(grids)

    def init_fn(rng, input_shape):
        if num_channels <= 0:
            raise ValueError(
                f'num_channels must be positive but got {num_channels}')
        if len(input_shape) != 3:
            raise ValueError(
                f'The ndim of input should be 3, but got {len(input_shape)}')
        if input_shape[1] != len(grids):
            raise ValueError(
                f'input_shape[1] should be len(grids), but got {input_shape[1]}'
            )
        if input_shape[2] != 1:
            raise ValueError(
                f'input_shape[2] should be 1, but got {input_shape[2]}')
        output_shape = input_shape[:-1] + (num_channels, )
        eta = eta_init(rng, shape=(num_channels, ))
        return output_shape, (eta, )

    def apply_fn(params, inputs, **kwargs):
        """Applies layer.

    Args:
      params: Layer parameters, (eta,).
      inputs: Float numpy array with shape
          (batch_size, num_grids, num_in_channels).
      **kwargs: Other key word arguments. Unused.

    Returns:
      Float numpy array with shape (batch_size, num_grids, num_channels).
    """
        del kwargs
        eta, = params
        # shape (num_grids, num_grids, num_channels)
        kernels = _exponential_function_channels(
            displacements, widths=minval + (maxval - minval) * nn.sigmoid(eta))
        # shape (batch_size, num_grids, num_channels)
        return jnp.squeeze(
            # shape (batch_size, 1, num_grids, num_channels)
            jnp.tensordot(inputs, kernels, axes=(1, 0)) * dx,
            axis=1)

    return init_fn, apply_fn
Example #18
0
def _get_hartree_potential(density, grids, interaction_fn):
  """Gets the Hartree potential."""
  n1 = jnp.expand_dims(density, axis=0)
  r1 = jnp.expand_dims(grids, axis=0)
  r2 = jnp.expand_dims(grids, axis=1)
  return jnp.sum(n1 * interaction_fn(r1 - r2), axis=1) * utils.get_dx(grids)