def test_kohn_sham_neural_xc_density_mse_converge_tolerance(
            self, density_mse_converge_tolerance, expected_converged):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
        params_init = init_fn(rng=random.PRNGKey(0))

        states = jit_scf.kohn_sham(
            locations=self.locations,
            nuclear_charges=self.nuclear_charges,
            num_electrons=self.num_electrons,
            num_iterations=3,
            grids=self.grids,
            xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                   params=params_init),
            interaction_fn=utils.exponential_coulomb,
            initial_density=self.num_electrons *
            utils.gaussian(grids=self.grids, center=0., sigma=0.5),
            density_mse_converge_tolerance=density_mse_converge_tolerance)

        np.testing.assert_array_equal(states.converged, expected_converged)

        for single_state in scf.state_iterator(states):
            self._test_state(
                single_state,
                self._create_testing_external_potential(
                    utils.exponential_coulomb))
    def test_local_density_approximation_wrong_output_shape(self):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(16), stax.Elu, stax.Dense(3)))
        init_params = init_fn(rng=random.PRNGKey(0))

        with self.assertRaisesRegex(
                ValueError, r'The output shape of the network '
                r'should be \(-1, 1\) but got \(11, 3\)'):
            xc_energy_density_fn(self.density, init_params)
    def test_local_density_approximation(self):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(16), stax.Elu, stax.Dense(1)))
        init_params = init_fn(rng=random.PRNGKey(0))
        xc_energy_density = xc_energy_density_fn(self.density, init_params)

        # The first layer of the network takes 1 feature (density).
        self.assertEqual(init_params[0][0].shape, (1, 16))
        self.assertEqual(xc_energy_density.shape, (11, ))
Exemple #4
0
 def test_kohn_sham_iteration_neural_xc(self, interaction_fn,
                                        enforce_reflection_symmetry):
     init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
         stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
     params_init = init_fn(rng=random.PRNGKey(0))
     initial_state = self._create_testing_initial_state(interaction_fn)
     next_state = scf.kohn_sham_iteration(
         state=initial_state,
         num_electrons=self.num_electrons,
         xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                params=params_init),
         interaction_fn=interaction_fn,
         enforce_reflection_symmetry=enforce_reflection_symmetry)
     self._test_state(next_state, initial_state)
Exemple #5
0
 def test_kohn_sham_neural_xc(self, interaction_fn):
     init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
         stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
     params_init = init_fn(rng=random.PRNGKey(0))
     state = scf.kohn_sham(locations=self.locations,
                           nuclear_charges=self.nuclear_charges,
                           num_electrons=self.num_electrons,
                           num_iterations=3,
                           grids=self.grids,
                           xc_energy_density_fn=tree_util.Partial(
                               xc_energy_density_fn, params=params_init),
                           interaction_fn=interaction_fn)
     for single_state in scf.state_iterator(state):
         self._test_state(
             single_state,
             self._create_testing_external_potential(interaction_fn))
Exemple #6
0
    def test_kohn_sham_iteration_neural_xc_density_loss_gradient_symmetry(
            self):
        # The network only has one layer.
        # The initial params contains weights with shape (1, 1) and bias (1,).
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(1)))
        init_params = init_fn(rng=random.PRNGKey(0))
        initial_state = self._create_testing_initial_state(
            utils.exponential_coulomb)
        target_density = (
            utils.gaussian(grids=self.grids, center=-0.5, sigma=1.) +
            utils.gaussian(grids=self.grids, center=0.5, sigma=1.))
        spec, flatten_init_params = np_utils.flatten(init_params)

        def loss(flatten_params, initial_state, target_density):
            state = scf.kohn_sham_iteration(
                state=initial_state,
                num_electrons=self.num_electrons,
                xc_energy_density_fn=tree_util.Partial(
                    xc_energy_density_fn,
                    params=np_utils.unflatten(spec, flatten_params)),
                interaction_fn=utils.exponential_coulomb,
                enforce_reflection_symmetry=True)
            return jnp.sum(jnp.abs(state.density -
                                   target_density)) * utils.get_dx(self.grids)

        grad_fn = jax.grad(loss)

        params_grad = grad_fn(flatten_init_params,
                              initial_state=initial_state,
                              target_density=target_density)

        # Check gradient values.
        np.testing.assert_allclose(params_grad, [0.2013181, 0.], atol=1e-7)

        # Check whether the gradient values match the numerical gradient.
        np.testing.assert_allclose(optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(loss,
                                initial_state=initial_state,
                                target_density=target_density),
            epsilon=1e-9),
                                   params_grad,
                                   atol=1e-4)
Exemple #7
0
    def test_kohn_sham_iteration_neural_xc_energy_loss_gradient(self):
        # The network only has one layer.
        # The initial params contains weights with shape (1, 1) and bias (1,).
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(1)))
        init_params = init_fn(rng=random.PRNGKey(0))
        initial_state = self._create_testing_initial_state(
            utils.exponential_coulomb)
        target_energy = 2.
        spec, flatten_init_params = np_utils.flatten(init_params)

        def loss(flatten_params, initial_state, target_energy):
            state = scf.kohn_sham_iteration(
                state=initial_state,
                num_electrons=self.num_electrons,
                xc_energy_density_fn=tree_util.Partial(
                    xc_energy_density_fn,
                    params=np_utils.unflatten(spec, flatten_params)),
                interaction_fn=utils.exponential_coulomb,
                enforce_reflection_symmetry=True)
            return (state.total_energy - target_energy)**2

        grad_fn = jax.grad(loss)

        params_grad = grad_fn(flatten_init_params,
                              initial_state=initial_state,
                              target_energy=target_energy)

        # Check gradient values.
        np.testing.assert_allclose(params_grad, [-1.40994668, -2.58881225])

        # Check whether the gradient values match the numerical gradient.
        np.testing.assert_allclose(optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(loss,
                                initial_state=initial_state,
                                target_energy=target_energy),
            epsilon=1e-9),
                                   params_grad,
                                   atol=3e-4)
Exemple #8
0
    def test_kohn_sham_neural_xc_energy_loss_gradient(self):
        # The network only has one layer.
        # The initial params contains weights with shape (1, 1) and bias (1,).
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(1)))
        init_params = init_fn(rng=random.PRNGKey(0))
        target_energy = 2.
        spec, flatten_init_params = np_utils.flatten(init_params)

        def loss(flatten_params, target_energy):
            state = scf.kohn_sham(locations=self.locations,
                                  nuclear_charges=self.nuclear_charges,
                                  num_electrons=self.num_electrons,
                                  num_iterations=3,
                                  grids=self.grids,
                                  xc_energy_density_fn=tree_util.Partial(
                                      xc_energy_density_fn,
                                      params=np_utils.unflatten(
                                          spec, flatten_params)),
                                  interaction_fn=utils.exponential_coulomb)
            final_state = scf.get_final_state(state)
            return (final_state.total_energy - target_energy)**2

        grad_fn = jax.grad(loss)

        params_grad = grad_fn(flatten_init_params, target_energy=target_energy)

        # Check gradient values.
        np.testing.assert_allclose(params_grad, [-3.908153, -5.448675],
                                   atol=1e-6)

        # Check whether the gradient values match the numerical gradient.
        np.testing.assert_allclose(optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(loss, target_energy=target_energy),
            epsilon=1e-8),
                                   params_grad,
                                   atol=5e-3)