def loss(flatten_params, target_energy): state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten( spec, flatten_params)), interaction_fn=utils.exponential_coulomb) final_state = scf.get_final_state(state) return (final_state.total_energy - target_energy)**2
def test_kohn_sham_convergence(self, density_mse_converge_tolerance, expected_converged): state = scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, # Use 3d LDA exchange functional and zero correlation functional. xc_energy_density_fn=tree_util.Partial( lambda density: -0.73855 * density**(1 / 3)), interaction_fn=utils.exponential_coulomb, density_mse_converge_tolerance=density_mse_converge_tolerance) np.testing.assert_allclose(state.converged, expected_converged)
def loss(flatten_params, target_density): state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten( spec, flatten_params)), interaction_fn=utils.exponential_coulomb, density_mse_converge_tolerance=-1) final_state = scf.get_final_state(state) return jnp.sum(jnp.abs(final_state.density - target_density)) * utils.get_dx(self.grids)
def test_kohn_sham(self, interaction_fn): state = scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, # Use 3d LDA exchange functional and zero correlation functional. xc_energy_density_fn=tree_util.Partial( lambda density: -0.73855 * density**(1 / 3)), interaction_fn=interaction_fn) for single_state in scf.state_iterator(state): self._test_state( single_state, self._create_testing_external_potential(interaction_fn))
def test_kohn_sham_neural_xc(self, interaction_fn): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=params_init), interaction_fn=interaction_fn) for single_state in scf.state_iterator(state): self._test_state( single_state, self._create_testing_external_potential(interaction_fn))