def test_kohn_sham_neural_xc_density_mse_converge_tolerance( self, density_mse_converge_tolerance, expected_converged): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) states = jit_scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn, params=params_init), interaction_fn=utils.exponential_coulomb, initial_density=self.num_electrons * utils.gaussian(grids=self.grids, center=0., sigma=0.5), density_mse_converge_tolerance=density_mse_converge_tolerance) np.testing.assert_array_equal(states.converged, expected_converged) for single_state in scf.state_iterator(states): self._test_state( single_state, self._create_testing_external_potential( utils.exponential_coulomb))
def test_local_density_approximation_wrong_output_shape(self): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(16), stax.Elu, stax.Dense(3))) init_params = init_fn(rng=random.PRNGKey(0)) with self.assertRaisesRegex( ValueError, r'The output shape of the network ' r'should be \(-1, 1\) but got \(11, 3\)'): xc_energy_density_fn(self.density, init_params)
def test_local_density_approximation(self): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(16), stax.Elu, stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) xc_energy_density = xc_energy_density_fn(self.density, init_params) # The first layer of the network takes 1 feature (density). self.assertEqual(init_params[0][0].shape, (1, 16)) self.assertEqual(xc_energy_density.shape, (11, ))
def test_kohn_sham_iteration_neural_xc(self, interaction_fn, enforce_reflection_symmetry): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state(interaction_fn) next_state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn, params=params_init), interaction_fn=interaction_fn, enforce_reflection_symmetry=enforce_reflection_symmetry) self._test_state(next_state, initial_state)
def test_kohn_sham_neural_xc(self, interaction_fn): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=params_init), interaction_fn=interaction_fn) for single_state in scf.state_iterator(state): self._test_state( single_state, self._create_testing_external_potential(interaction_fn))
def test_kohn_sham_iteration_neural_xc_density_loss_gradient_symmetry( self): # The network only has one layer. # The initial params contains weights with shape (1, 1) and bias (1,). init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) target_density = ( utils.gaussian(grids=self.grids, center=-0.5, sigma=1.) + utils.gaussian(grids=self.grids, center=0.5, sigma=1.)) spec, flatten_init_params = np_utils.flatten(init_params) def loss(flatten_params, initial_state, target_density): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(self.grids) grad_fn = jax.grad(loss) params_grad = grad_fn(flatten_init_params, initial_state=initial_state, target_density=target_density) # Check gradient values. np.testing.assert_allclose(params_grad, [0.2013181, 0.], atol=1e-7) # Check whether the gradient values match the numerical gradient. np.testing.assert_allclose(optimize.approx_fprime( xk=flatten_init_params, f=functools.partial(loss, initial_state=initial_state, target_density=target_density), epsilon=1e-9), params_grad, atol=1e-4)
def test_kohn_sham_iteration_neural_xc_energy_loss_gradient(self): # The network only has one layer. # The initial params contains weights with shape (1, 1) and bias (1,). init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) target_energy = 2. spec, flatten_init_params = np_utils.flatten(init_params) def loss(flatten_params, initial_state, target_energy): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return (state.total_energy - target_energy)**2 grad_fn = jax.grad(loss) params_grad = grad_fn(flatten_init_params, initial_state=initial_state, target_energy=target_energy) # Check gradient values. np.testing.assert_allclose(params_grad, [-1.40994668, -2.58881225]) # Check whether the gradient values match the numerical gradient. np.testing.assert_allclose(optimize.approx_fprime( xk=flatten_init_params, f=functools.partial(loss, initial_state=initial_state, target_energy=target_energy), epsilon=1e-9), params_grad, atol=3e-4)
def test_kohn_sham_neural_xc_energy_loss_gradient(self): # The network only has one layer. # The initial params contains weights with shape (1, 1) and bias (1,). init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) target_energy = 2. spec, flatten_init_params = np_utils.flatten(init_params) def loss(flatten_params, target_energy): state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten( spec, flatten_params)), interaction_fn=utils.exponential_coulomb) final_state = scf.get_final_state(state) return (final_state.total_energy - target_energy)**2 grad_fn = jax.grad(loss) params_grad = grad_fn(flatten_init_params, target_energy=target_energy) # Check gradient values. np.testing.assert_allclose(params_grad, [-3.908153, -5.448675], atol=1e-6) # Check whether the gradient values match the numerical gradient. np.testing.assert_allclose(optimize.approx_fprime( xk=flatten_init_params, f=functools.partial(loss, target_energy=target_energy), epsilon=1e-8), params_grad, atol=5e-3)