def test_flatten(self): (tree, shapes), vec = np_utils.flatten([(jnp.array([1, 2, 3]), (jnp.array([4, 5]))), jnp.array([99])]) self.assertIsInstance(vec, np.ndarray) np.testing.assert_allclose(vec, [1., 2., 3., 4., 5., 99.]) self.assertEqual(shapes, [(3, ), (2, ), (1, )]) # unflatten should convert 1d array back to pytree. params = np_utils.unflatten((tree, shapes), vec) self.assertIsInstance(params[0][0], np.ndarray) np.testing.assert_allclose(params[0][0], [1., 2., 3.]) self.assertIsInstance(params[0][1], np.ndarray) np.testing.assert_allclose(params[0][1], [4., 5.]) self.assertIsInstance(params[1], np.ndarray) np.testing.assert_allclose(params[1], [99.])
def test_kohn_sham_iteration_neural_xc_density_loss_gradient_symmetry(self): # The network only has one layer. # The initial params contains weights with shape (1, 1) and bias (1,). init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) target_density = ( utils.gaussian(grids=self.grids, center=-0.5, sigma=1.) + utils.gaussian(grids=self.grids, center=0.5, sigma=1.)) spec, flatten_init_params = np_utils.flatten(init_params) def loss(flatten_params, initial_state, target_density): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx( self.grids) grad_fn = jax.grad(loss) params_grad = grad_fn( flatten_init_params, initial_state=initial_state, target_density=target_density) # Check gradient values. np.testing.assert_allclose(params_grad, [-1.34137017, 0.], atol=5e-7) # Check whether the gradient values match the numerical gradient. np.testing.assert_allclose( optimize.approx_fprime( xk=flatten_init_params, f=functools.partial( loss, initial_state=initial_state, target_density=target_density), epsilon=1e-9), params_grad, atol=2e-4)
def test_kohn_sham_iteration_neural_xc_energy_loss_gradient(self): # The network only has one layer. # The initial params contains weights with shape (1, 1) and bias (1,). init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) target_energy = 2. spec, flatten_init_params = np_utils.flatten(init_params) def loss(flatten_params, initial_state, target_energy): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return (state.total_energy - target_energy) ** 2 grad_fn = jax.grad(loss) params_grad = grad_fn( flatten_init_params, initial_state=initial_state, target_energy=target_energy) # Check gradient values. np.testing.assert_allclose(params_grad, [-8.549952, -14.754195]) # Check whether the gradient values match the numerical gradient. np.testing.assert_allclose( optimize.approx_fprime( xk=flatten_init_params, f=functools.partial( loss, initial_state=initial_state, target_energy=target_energy), epsilon=1e-9), params_grad, atol=2e-3)
def test_kohn_sham_neural_xc_energy_loss_gradient(self): # The network only has one layer. # The initial params contains weights with shape (1, 1) and bias (1,). init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(1))) init_params = init_fn(rng=random.PRNGKey(0)) target_energy = 2. spec, flatten_init_params = np_utils.flatten(init_params) def loss(flatten_params, target_energy): state = scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb) final_state = scf.get_final_state(state) return (final_state.total_energy - target_energy) ** 2 grad_fn = jax.grad(loss) params_grad = grad_fn(flatten_init_params, target_energy=target_energy) # Check gradient values. np.testing.assert_allclose(params_grad, [-3.908153, -5.448675], atol=1e-6) # Check whether the gradient values match the numerical gradient. np.testing.assert_allclose( optimize.approx_fprime( xk=flatten_init_params, f=functools.partial(loss, target_energy=target_energy), epsilon=1e-8), params_grad, atol=5e-3)