Esempio n. 1
0
  def test_save_and_load_state(self):
    # Make up a random KohnShamState.
    state = scf.KohnShamState(
        density=np.random.random((5, 100)),
        total_energy=np.random.random(5,),
        locations=np.random.random((5, 2)),
        nuclear_charges=np.random.random((5, 2)),
        external_potential=np.random.random((5, 100)),
        grids=np.random.random((5, 100)),
        num_electrons=np.random.randint(10, size=5),
        hartree_potential=np.random.random((5, 100)))
    save_dir = os.path.join(self.test_dir, 'test_state')

    scf.save_state(save_dir, state)
    loaded_state = scf.load_state(save_dir)

    # Check fields.
    self.assertEqual(state._fields, loaded_state._fields)
    # Check values.
    for field in state._fields:
      value = getattr(state, field)
      if value is None:
        self.assertIsNone(getattr(loaded_state, field))
      else:
        np.testing.assert_allclose(value, getattr(loaded_state, field))
Esempio n. 2
0
 def setUp(self):
   super().setUp()
   self.states = scf.KohnShamState(
       density=np.random.random((5, 100)),
       total_energy=np.random.random(5,),
       locations=np.random.random((5, 2)),
       nuclear_charges=np.random.random((5, 2)),
       external_potential=np.random.random((5, 100)),
       grids=np.random.random((5, 100)),
       num_electrons=np.random.randint(10, size=5))
Esempio n. 3
0
  def get_molecules(self, selected_distance_x100=None):
    """Selects molecules from list of integers."""
    mask = self.get_mask(selected_distance_x100)
    num_samples = np.sum(mask)

    return scf.KohnShamState(
        density=self.densities[mask],
        total_energy=self.total_energies[mask],
        locations=self.locations[mask],
        nuclear_charges=self.nuclear_charges[mask],
        external_potential=self.external_potentials[mask],
        grids=np.tile(
            np.expand_dims(self.grids, axis=0), reps=(num_samples, 1)),
        num_electrons=np.repeat(self.num_electrons, repeats=num_samples),
        converged=np.repeat(True, repeats=num_samples),
        )
Esempio n. 4
0
 def _create_testing_initial_state(self, interaction_fn):
     locations = jnp.array([-0.5, 0.5])
     nuclear_charges = jnp.array([1, 1])
     return scf.KohnShamState(
         density=self.num_electrons *
         utils.gaussian(grids=self.grids, center=0., sigma=1.),
         # Set initial energy as inf, the actual value is not used in Kohn-Sham
         # calculation.
         total_energy=jnp.inf,
         locations=locations,
         nuclear_charges=nuclear_charges,
         external_potential=utils.get_atomic_chain_potential(
             grids=self.grids,
             locations=locations,
             nuclear_charges=nuclear_charges,
             interaction_fn=interaction_fn),
         grids=self.grids,
         num_electrons=self.num_electrons)
Esempio n. 5
0
def _kohn_sham(locations, nuclear_charges, num_electrons, num_iterations,
               grids, xc_energy_density_fn, interaction_fn, initial_density,
               alpha, alpha_decay, enforce_reflection_symmetry,
               num_mixing_iterations, density_mse_converge_tolerance,
               stop_gradient_step):
    """Jit-able Kohn Sham calculation."""
    num_grids = grids.shape[0]
    weights = _connection_weights(num_iterations, num_mixing_iterations)

    def _converged_kohn_sham_iteration(old_state_differences):
        old_state, differences = old_state_differences
        return old_state._replace(converged=True), differences

    def _uncoveraged_kohn_sham_iteration(idx_old_state_alpha_differences):
        idx, old_state, alpha, differences = idx_old_state_alpha_differences
        state = kohn_sham_iteration(
            state=old_state,
            num_electrons=num_electrons,
            xc_energy_density_fn=xc_energy_density_fn,
            interaction_fn=interaction_fn,
            enforce_reflection_symmetry=enforce_reflection_symmetry)
        differences = jax.ops.index_update(differences, idx,
                                           state.density - old_state.density)
        # Density mixing.
        state = state._replace(density=old_state.density +
                               alpha * jnp.dot(weights[idx], differences))
        return state, differences

    def _single_kohn_sham_iteration(carry, inputs):
        del inputs
        idx, old_state, alpha, converged, differences = carry
        state, differences = jax.lax.cond(
            converged,
            true_operand=(old_state, differences),
            true_fun=_converged_kohn_sham_iteration,
            false_operand=(idx, old_state, alpha, differences),
            false_fun=_uncoveraged_kohn_sham_iteration)
        converged = jnp.mean(
            jnp.square(state.density -
                       old_state.density)) < density_mse_converge_tolerance
        state = jax.lax.cond(idx <= stop_gradient_step,
                             true_fun=jax.lax.stop_gradient,
                             false_fun=lambda x: x,
                             operand=state)
        return (idx + 1, state, alpha * alpha_decay, converged,
                differences), state

    # Create initial state.
    state = scf.KohnShamState(
        density=initial_density,
        total_energy=jnp.inf,
        locations=locations,
        nuclear_charges=nuclear_charges,
        external_potential=utils.get_atomic_chain_potential(
            grids=grids,
            locations=locations,
            nuclear_charges=nuclear_charges,
            interaction_fn=interaction_fn),
        grids=grids,
        num_electrons=num_electrons,
        # Add dummy fields so the input and output of lax.scan have the same type
        # structure.
        hartree_potential=jnp.zeros_like(grids),
        xc_potential=jnp.zeros_like(grids),
        xc_energy_density=jnp.zeros_like(grids),
        gap=0.,
        converged=False)
    # Initialize the density differences with all zeros since the carry in
    # lax.scan must keep the same shape.
    differences = jnp.zeros((num_iterations, num_grids))

    _, states = jax.lax.scan(_single_kohn_sham_iteration,
                             init=(0, state, alpha, state.converged,
                                   differences),
                             xs=jnp.arange(num_iterations))
    return states