def test_kohn_sham_neural_xc_density_mse_converge_tolerance( self, density_mse_converge_tolerance, expected_converged): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) states = jit_scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn, params=params_init), interaction_fn=utils.exponential_coulomb, initial_density=self.num_electrons * utils.gaussian(grids=self.grids, center=0., sigma=0.5), density_mse_converge_tolerance=density_mse_converge_tolerance) np.testing.assert_array_equal(states.converged, expected_converged) for single_state in scf.state_iterator(states): self._test_state( single_state, self._create_testing_external_potential( utils.exponential_coulomb))
def test_kohn_sham(self, interaction_fn): state = scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, # Use 3d LDA exchange functional and zero correlation functional. xc_energy_density_fn=tree_util.Partial( lambda density: -0.73855 * density ** (1 / 3)), interaction_fn=interaction_fn) for single_state in scf.state_iterator(state): self._test_state( single_state, self._create_testing_external_potential(interaction_fn))
def test_kohn_sham_neural_xc(self, interaction_fn): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=params_init), interaction_fn=interaction_fn) for single_state in scf.state_iterator(state): self._test_state( single_state, self._create_testing_external_potential(interaction_fn))