def test_flatten(self):
        (tree, shapes), vec = np_utils.flatten([(jnp.array([1, 2, 3]),
                                                 (jnp.array([4, 5]))),
                                                jnp.array([99])])
        self.assertIsInstance(vec, np.ndarray)
        np.testing.assert_allclose(vec, [1., 2., 3., 4., 5., 99.])
        self.assertEqual(shapes, [(3, ), (2, ), (1, )])

        # unflatten should convert 1d array back to pytree.
        params = np_utils.unflatten((tree, shapes), vec)
        self.assertIsInstance(params[0][0], np.ndarray)
        np.testing.assert_allclose(params[0][0], [1., 2., 3.])
        self.assertIsInstance(params[0][1], np.ndarray)
        np.testing.assert_allclose(params[0][1], [4., 5.])
        self.assertIsInstance(params[1], np.ndarray)
        np.testing.assert_allclose(params[1], [99.])
Beispiel #2
0
  def test_kohn_sham_iteration_neural_xc_density_loss_gradient_symmetry(self):
    # The network only has one layer.
    # The initial params contains weights with shape (1, 1) and bias (1,).
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
        stax.serial(stax.Dense(1)))
    init_params = init_fn(rng=random.PRNGKey(0))
    initial_state = self._create_testing_initial_state(
        utils.exponential_coulomb)
    target_density = (
        utils.gaussian(grids=self.grids, center=-0.5, sigma=1.)
        + utils.gaussian(grids=self.grids, center=0.5, sigma=1.))
    spec, flatten_init_params = np_utils.flatten(init_params)

    def loss(flatten_params, initial_state, target_density):
      state = scf.kohn_sham_iteration(
          state=initial_state,
          num_electrons=self.num_electrons,
          xc_energy_density_fn=tree_util.Partial(
              xc_energy_density_fn,
              params=np_utils.unflatten(spec, flatten_params)),
          interaction_fn=utils.exponential_coulomb,
          enforce_reflection_symmetry=True)
      return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(
          self.grids)

    grad_fn = jax.grad(loss)

    params_grad = grad_fn(
        flatten_init_params,
        initial_state=initial_state,
        target_density=target_density)

    # Check gradient values.
    np.testing.assert_allclose(params_grad, [-1.34137017, 0.], atol=5e-7)

    # Check whether the gradient values match the numerical gradient.
    np.testing.assert_allclose(
        optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(
                loss,
                initial_state=initial_state,
                target_density=target_density),
            epsilon=1e-9),
        params_grad, atol=2e-4)
Beispiel #3
0
  def test_kohn_sham_iteration_neural_xc_energy_loss_gradient(self):
    # The network only has one layer.
    # The initial params contains weights with shape (1, 1) and bias (1,).
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
        stax.serial(stax.Dense(1)))
    init_params = init_fn(rng=random.PRNGKey(0))
    initial_state = self._create_testing_initial_state(
        utils.exponential_coulomb)
    target_energy = 2.
    spec, flatten_init_params = np_utils.flatten(init_params)

    def loss(flatten_params, initial_state, target_energy):
      state = scf.kohn_sham_iteration(
          state=initial_state,
          num_electrons=self.num_electrons,
          xc_energy_density_fn=tree_util.Partial(
              xc_energy_density_fn,
              params=np_utils.unflatten(spec, flatten_params)),
          interaction_fn=utils.exponential_coulomb,
          enforce_reflection_symmetry=True)
      return (state.total_energy - target_energy) ** 2

    grad_fn = jax.grad(loss)

    params_grad = grad_fn(
        flatten_init_params,
        initial_state=initial_state,
        target_energy=target_energy)

    # Check gradient values.
    np.testing.assert_allclose(params_grad, [-8.549952, -14.754195])

    # Check whether the gradient values match the numerical gradient.
    np.testing.assert_allclose(
        optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(
                loss, initial_state=initial_state, target_energy=target_energy),
            epsilon=1e-9),
        params_grad, atol=2e-3)
Beispiel #4
0
  def test_kohn_sham_neural_xc_energy_loss_gradient(self):
    # The network only has one layer.
    # The initial params contains weights with shape (1, 1) and bias (1,).
    init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
        stax.serial(stax.Dense(1)))
    init_params = init_fn(rng=random.PRNGKey(0))
    target_energy = 2.
    spec, flatten_init_params = np_utils.flatten(init_params)

    def loss(flatten_params, target_energy):
      state = scf.kohn_sham(
          locations=self.locations,
          nuclear_charges=self.nuclear_charges,
          num_electrons=self.num_electrons,
          num_iterations=3,
          grids=self.grids,
          xc_energy_density_fn=tree_util.Partial(
              xc_energy_density_fn,
              params=np_utils.unflatten(spec, flatten_params)),
          interaction_fn=utils.exponential_coulomb)
      final_state = scf.get_final_state(state)
      return (final_state.total_energy - target_energy) ** 2

    grad_fn = jax.grad(loss)

    params_grad = grad_fn(flatten_init_params, target_energy=target_energy)

    # Check gradient values.
    np.testing.assert_allclose(params_grad, [-3.908153, -5.448675], atol=1e-6)

    # Check whether the gradient values match the numerical gradient.
    np.testing.assert_allclose(
        optimize.approx_fprime(
            xk=flatten_init_params,
            f=functools.partial(loss, target_energy=target_energy),
            epsilon=1e-8),
        params_grad, atol=5e-3)