コード例 #1
0
 def loss(flatten_params, initial_state, target_energy):
     state = scf.kohn_sham_iteration(
         state=initial_state,
         num_electrons=self.num_electrons,
         xc_energy_density_fn=tree_util.Partial(
             xc_energy_density_fn,
             params=np_utils.unflatten(spec, flatten_params)),
         interaction_fn=utils.exponential_coulomb,
         enforce_reflection_symmetry=True)
     return (state.total_energy - target_energy)**2
コード例 #2
0
 def loss(flatten_params, initial_state, target_density):
     state = scf.kohn_sham_iteration(
         state=initial_state,
         num_electrons=self.num_electrons,
         xc_energy_density_fn=tree_util.Partial(
             xc_energy_density_fn,
             params=np_utils.unflatten(spec, flatten_params)),
         interaction_fn=utils.exponential_coulomb,
         enforce_reflection_symmetry=False)
     return jnp.sum(jnp.abs(state.density -
                            target_density)) * utils.get_dx(self.grids)
コード例 #3
0
 def loss(flatten_params, target_energy):
     state = scf.kohn_sham(locations=self.locations,
                           nuclear_charges=self.nuclear_charges,
                           num_electrons=self.num_electrons,
                           num_iterations=3,
                           grids=self.grids,
                           xc_energy_density_fn=tree_util.Partial(
                               xc_energy_density_fn,
                               params=np_utils.unflatten(
                                   spec, flatten_params)),
                           interaction_fn=utils.exponential_coulomb)
     final_state = scf.get_final_state(state)
     return (final_state.total_energy - target_energy)**2
コード例 #4
0
 def loss(flatten_params, target_density):
     state = scf.kohn_sham(locations=self.locations,
                           nuclear_charges=self.nuclear_charges,
                           num_electrons=self.num_electrons,
                           num_iterations=3,
                           grids=self.grids,
                           xc_energy_density_fn=tree_util.Partial(
                               xc_energy_density_fn,
                               params=np_utils.unflatten(
                                   spec, flatten_params)),
                           interaction_fn=utils.exponential_coulomb,
                           density_mse_converge_tolerance=-1)
     final_state = scf.get_final_state(state)
     return jnp.sum(jnp.abs(final_state.density -
                            target_density)) * utils.get_dx(self.grids)
コード例 #5
0
  def test_flatten(self):
    (tree, shapes), vec = np_utils.flatten(
        [(jnp.array([1, 2, 3]), (jnp.array([4, 5]))), jnp.array([99])])
    self.assertIsInstance(vec, np.ndarray)
    np.testing.assert_allclose(vec, [1., 2., 3., 4., 5., 99.])
    self.assertEqual(shapes, [(3,), (2,), (1,)])

    # unflatten should convert 1d array back to pytree.
    params = np_utils.unflatten((tree, shapes), vec)
    self.assertIsInstance(params[0][0], np.ndarray)
    np.testing.assert_allclose(params[0][0], [1., 2., 3.])
    self.assertIsInstance(params[0][1], np.ndarray)
    np.testing.assert_allclose(params[0][1], [4., 5.])
    self.assertIsInstance(params[1], np.ndarray)
    np.testing.assert_allclose(params[1], [99.])