Example #1
0
    def test_kohn_sham_neural_xc_density_mse_converge_tolerance(
            self, density_mse_converge_tolerance, expected_converged):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
        params_init = init_fn(rng=random.PRNGKey(0))

        states = jit_scf.kohn_sham(
            locations=self.locations,
            nuclear_charges=self.nuclear_charges,
            num_electrons=self.num_electrons,
            num_iterations=3,
            grids=self.grids,
            xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                   params=params_init),
            interaction_fn=utils.exponential_coulomb,
            initial_density=self.num_electrons *
            utils.gaussian(grids=self.grids, center=0., sigma=0.5),
            density_mse_converge_tolerance=density_mse_converge_tolerance)

        np.testing.assert_array_equal(states.converged, expected_converged)

        for single_state in scf.state_iterator(states):
            self._test_state(
                single_state,
                self._create_testing_external_potential(
                    utils.exponential_coulomb))
    def test_local_density_approximation_wrong_output_shape(self):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(16), stax.Elu, stax.Dense(3)))
        init_params = init_fn(rng=random.PRNGKey(0))

        with self.assertRaisesRegex(
                ValueError, r'The output shape of the network '
                r'should be \(-1, 1\) but got \(11, 3\)'):
            xc_energy_density_fn(self.density, init_params)
    def test_local_density_approximation(self):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(16), stax.Elu, stax.Dense(1)))
        init_params = init_fn(rng=random.PRNGKey(0))
        xc_energy_density = xc_energy_density_fn(self.density, init_params)

        # The first layer of the network takes 1 feature (density).
        self.assertEqual(init_params[0][0].shape, (1, 16))
        self.assertEqual(xc_energy_density.shape, (11, ))
Example #4
0
 def test_kohn_sham_iteration_neural_xc(self, enforce_reflection_symmetry):
     init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
         stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
     params_init = init_fn(rng=random.PRNGKey(0))
     initial_state = self._create_testing_initial_state(
         utils.exponential_coulomb)
     next_state = jit_scf.kohn_sham_iteration(
         state=initial_state,
         num_electrons=self.num_electrons,
         xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                params=params_init),
         interaction_fn=utils.exponential_coulomb,
         enforce_reflection_symmetry=enforce_reflection_symmetry)
     self._test_state(next_state, initial_state)
 def test_kohn_sham_neural_xc(self, interaction_fn):
     init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
         stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
     params_init = init_fn(rng=random.PRNGKey(0))
     state = scf.kohn_sham(locations=self.locations,
                           nuclear_charges=self.nuclear_charges,
                           num_electrons=self.num_electrons,
                           num_iterations=3,
                           grids=self.grids,
                           xc_energy_density_fn=tree_util.Partial(
                               xc_energy_density_fn, params=params_init),
                           interaction_fn=interaction_fn)
     for single_state in scf.state_iterator(state):
         self._test_state(
             single_state,
             self._create_testing_external_potential(interaction_fn))
Example #6
0
  def test_kohn_sham_iteration_neural_xc_density_loss_gradient_symmetry(self):
    # The network only has one layer.
    # The initial params contains weights with shape (1, 1) and bias (1,).
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
        stax.serial(stax.Dense(1)))
    init_params = init_fn(rng=random.PRNGKey(0))
    initial_state = self._create_testing_initial_state(
        utils.exponential_coulomb)
    target_density = (
        utils.gaussian(grids=self.grids, center=-0.5, sigma=1.)
        + utils.gaussian(grids=self.grids, center=0.5, sigma=1.))
    spec, flatten_init_params = np_utils.flatten(init_params)

    def loss(flatten_params, initial_state, target_density):
      state = scf.kohn_sham_iteration(
          state=initial_state,
          num_electrons=self.num_electrons,
          xc_energy_density_fn=tree_util.Partial(
              xc_energy_density_fn,
              params=np_utils.unflatten(spec, flatten_params)),
          interaction_fn=utils.exponential_coulomb,
          enforce_reflection_symmetry=True)
      return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(
          self.grids)

    grad_fn = jax.grad(loss)

    params_grad = grad_fn(
        flatten_init_params,
        initial_state=initial_state,
        target_density=target_density)

    # Check gradient values.
    np.testing.assert_allclose(params_grad, [-1.34137017, 0.], atol=5e-7)

    # Check whether the gradient values match the numerical gradient.
    np.testing.assert_allclose(
        optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(
                loss,
                initial_state=initial_state,
                target_density=target_density),
            epsilon=1e-9),
        params_grad, atol=2e-4)
Example #7
0
  def test_kohn_sham_iteration_neural_xc_energy_loss_gradient(self):
    # The network only has one layer.
    # The initial params contains weights with shape (1, 1) and bias (1,).
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
        stax.serial(stax.Dense(1)))
    init_params = init_fn(rng=random.PRNGKey(0))
    initial_state = self._create_testing_initial_state(
        utils.exponential_coulomb)
    target_energy = 2.
    spec, flatten_init_params = np_utils.flatten(init_params)

    def loss(flatten_params, initial_state, target_energy):
      state = scf.kohn_sham_iteration(
          state=initial_state,
          num_electrons=self.num_electrons,
          xc_energy_density_fn=tree_util.Partial(
              xc_energy_density_fn,
              params=np_utils.unflatten(spec, flatten_params)),
          interaction_fn=utils.exponential_coulomb,
          enforce_reflection_symmetry=True)
      return (state.total_energy - target_energy) ** 2

    grad_fn = jax.grad(loss)

    params_grad = grad_fn(
        flatten_init_params,
        initial_state=initial_state,
        target_energy=target_energy)

    # Check gradient values.
    np.testing.assert_allclose(params_grad, [-8.549952, -14.754195])

    # Check whether the gradient values match the numerical gradient.
    np.testing.assert_allclose(
        optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(
                loss, initial_state=initial_state, target_energy=target_energy),
            epsilon=1e-9),
        params_grad, atol=2e-3)
Example #8
0
  def test_kohn_sham_neural_xc_energy_loss_gradient(self):
    # The network only has one layer.
    # The initial params contains weights with shape (1, 1) and bias (1,).
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
        stax.serial(stax.Dense(1)))
    init_params = init_fn(rng=random.PRNGKey(0))
    target_energy = 2.
    spec, flatten_init_params = np_utils.flatten(init_params)

    def loss(flatten_params, target_energy):
      state = scf.kohn_sham(
          locations=self.locations,
          nuclear_charges=self.nuclear_charges,
          num_electrons=self.num_electrons,
          num_iterations=3,
          grids=self.grids,
          xc_energy_density_fn=tree_util.Partial(
              xc_energy_density_fn,
              params=np_utils.unflatten(spec, flatten_params)),
          interaction_fn=utils.exponential_coulomb)
      final_state = scf.get_final_state(state)
      return (final_state.total_energy - target_energy) ** 2

    grad_fn = jax.grad(loss)

    params_grad = grad_fn(flatten_init_params, target_energy=target_energy)

    # Check gradient values.
    np.testing.assert_allclose(params_grad, [-3.908153, -5.448675], atol=1e-6)

    # Check whether the gradient values match the numerical gradient.
    np.testing.assert_allclose(
        optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(loss, target_energy=target_energy),
            epsilon=1e-8),
        params_grad, atol=5e-3)