def loss(flatten_params, initial_state, target_energy): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return (state.total_energy - target_energy) ** 2
def loss(flatten_params, initial_state, target_density): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=False) return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx( self.grids)
def test_kohn_sham_iteration( self, interaction_fn, enforce_reflection_symmetry): initial_state = self._create_testing_initial_state(interaction_fn) next_state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, # Use 3d LDA exchange functional and zero correlation functional. xc_energy_density_fn=tree_util.Partial( lambda density: -0.73855 * density ** (1 / 3)), interaction_fn=interaction_fn, enforce_reflection_symmetry=enforce_reflection_symmetry) self._test_state(next_state, initial_state)
def test_kohn_sham_iteration_neural_xc( self, interaction_fn, enforce_reflection_symmetry): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state(interaction_fn) next_state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=params_init), interaction_fn=interaction_fn, enforce_reflection_symmetry=enforce_reflection_symmetry) self._test_state(next_state, initial_state)