def loss(flatten_params, target_energy): state = scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb) final_state = scf.get_final_state(state) return (final_state.total_energy - target_energy) ** 2
def loss(flatten_params, target_density): state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten( spec, flatten_params)), interaction_fn=utils.exponential_coulomb, density_mse_converge_tolerance=-1) final_state = scf.get_final_state(state) return jnp.sum(jnp.abs(final_state.density - target_density)) * utils.get_dx(self.grids)