Ejemplo n.º 1
0
 def loss(flatten_params, initial_state, target_energy):
   state = scf.kohn_sham_iteration(
       state=initial_state,
       num_electrons=self.num_electrons,
       xc_energy_density_fn=tree_util.Partial(
           xc_energy_density_fn,
           params=np_utils.unflatten(spec, flatten_params)),
       interaction_fn=utils.exponential_coulomb,
       enforce_reflection_symmetry=True)
   return (state.total_energy - target_energy) ** 2
Ejemplo n.º 2
0
 def loss(flatten_params, initial_state, target_density):
   state = scf.kohn_sham_iteration(
       state=initial_state,
       num_electrons=self.num_electrons,
       xc_energy_density_fn=tree_util.Partial(
           xc_energy_density_fn,
           params=np_utils.unflatten(spec, flatten_params)),
       interaction_fn=utils.exponential_coulomb,
       enforce_reflection_symmetry=False)
   return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(
       self.grids)
Ejemplo n.º 3
0
 def loss(flatten_params, target_energy):
   state = scf.kohn_sham(
       locations=self.locations,
       nuclear_charges=self.nuclear_charges,
       num_electrons=self.num_electrons,
       num_iterations=3,
       grids=self.grids,
       xc_energy_density_fn=tree_util.Partial(
           xc_energy_density_fn,
           params=np_utils.unflatten(spec, flatten_params)),
       interaction_fn=utils.exponential_coulomb)
   final_state = scf.get_final_state(state)
   return (final_state.total_energy - target_energy) ** 2
Ejemplo n.º 4
0
 def loss(flatten_params, target_density):
     state = scf.kohn_sham(locations=self.locations,
                           nuclear_charges=self.nuclear_charges,
                           num_electrons=self.num_electrons,
                           num_iterations=3,
                           grids=self.grids,
                           xc_energy_density_fn=tree_util.Partial(
                               xc_energy_density_fn,
                               params=np_utils.unflatten(
                                   spec, flatten_params)),
                           interaction_fn=utils.exponential_coulomb,
                           density_mse_converge_tolerance=-1)
     final_state = scf.get_final_state(state)
     return jnp.sum(jnp.abs(final_state.density -
                            target_density)) * utils.get_dx(self.grids)
Ejemplo n.º 5
0
    def test_flatten(self):
        (tree, shapes), vec = np_utils.flatten([(jnp.array([1, 2, 3]),
                                                 (jnp.array([4, 5]))),
                                                jnp.array([99])])
        self.assertIsInstance(vec, np.ndarray)
        np.testing.assert_allclose(vec, [1., 2., 3., 4., 5., 99.])
        self.assertEqual(shapes, [(3, ), (2, ), (1, )])

        # unflatten should convert 1d array back to pytree.
        params = np_utils.unflatten((tree, shapes), vec)
        self.assertIsInstance(params[0][0], np.ndarray)
        np.testing.assert_allclose(params[0][0], [1., 2., 3.])
        self.assertIsInstance(params[0][1], np.ndarray)
        np.testing.assert_allclose(params[0][1], [4., 5.])
        self.assertIsInstance(params[1], np.ndarray)
        np.testing.assert_allclose(params[1], [99.])