def test_save_and_load_state(self): # Make up a random KohnShamState. state = scf.KohnShamState( density=np.random.random((5, 100)), total_energy=np.random.random(5,), locations=np.random.random((5, 2)), nuclear_charges=np.random.random((5, 2)), external_potential=np.random.random((5, 100)), grids=np.random.random((5, 100)), num_electrons=np.random.randint(10, size=5), hartree_potential=np.random.random((5, 100))) save_dir = os.path.join(self.test_dir, 'test_state') scf.save_state(save_dir, state) loaded_state = scf.load_state(save_dir) # Check fields. self.assertEqual(state._fields, loaded_state._fields) # Check values. for field in state._fields: value = getattr(state, field) if value is None: self.assertIsNone(getattr(loaded_state, field)) else: np.testing.assert_allclose(value, getattr(loaded_state, field))
def setUp(self): super().setUp() self.states = scf.KohnShamState( density=np.random.random((5, 100)), total_energy=np.random.random(5,), locations=np.random.random((5, 2)), nuclear_charges=np.random.random((5, 2)), external_potential=np.random.random((5, 100)), grids=np.random.random((5, 100)), num_electrons=np.random.randint(10, size=5))
def get_molecules(self, selected_distance_x100=None): """Selects molecules from list of integers.""" mask = self.get_mask(selected_distance_x100) num_samples = np.sum(mask) return scf.KohnShamState( density=self.densities[mask], total_energy=self.total_energies[mask], locations=self.locations[mask], nuclear_charges=self.nuclear_charges[mask], external_potential=self.external_potentials[mask], grids=np.tile( np.expand_dims(self.grids, axis=0), reps=(num_samples, 1)), num_electrons=np.repeat(self.num_electrons, repeats=num_samples), converged=np.repeat(True, repeats=num_samples), )
def _create_testing_initial_state(self, interaction_fn): locations = jnp.array([-0.5, 0.5]) nuclear_charges = jnp.array([1, 1]) return scf.KohnShamState( density=self.num_electrons * utils.gaussian(grids=self.grids, center=0., sigma=1.), # Set initial energy as inf, the actual value is not used in Kohn-Sham # calculation. total_energy=jnp.inf, locations=locations, nuclear_charges=nuclear_charges, external_potential=utils.get_atomic_chain_potential( grids=self.grids, locations=locations, nuclear_charges=nuclear_charges, interaction_fn=interaction_fn), grids=self.grids, num_electrons=self.num_electrons)
def _kohn_sham(locations, nuclear_charges, num_electrons, num_iterations, grids, xc_energy_density_fn, interaction_fn, initial_density, alpha, alpha_decay, enforce_reflection_symmetry, num_mixing_iterations, density_mse_converge_tolerance, stop_gradient_step): """Jit-able Kohn Sham calculation.""" num_grids = grids.shape[0] weights = _connection_weights(num_iterations, num_mixing_iterations) def _converged_kohn_sham_iteration(old_state_differences): old_state, differences = old_state_differences return old_state._replace(converged=True), differences def _uncoveraged_kohn_sham_iteration(idx_old_state_alpha_differences): idx, old_state, alpha, differences = idx_old_state_alpha_differences state = kohn_sham_iteration( state=old_state, num_electrons=num_electrons, xc_energy_density_fn=xc_energy_density_fn, interaction_fn=interaction_fn, enforce_reflection_symmetry=enforce_reflection_symmetry) differences = jax.ops.index_update(differences, idx, state.density - old_state.density) # Density mixing. state = state._replace(density=old_state.density + alpha * jnp.dot(weights[idx], differences)) return state, differences def _single_kohn_sham_iteration(carry, inputs): del inputs idx, old_state, alpha, converged, differences = carry state, differences = jax.lax.cond( converged, true_operand=(old_state, differences), true_fun=_converged_kohn_sham_iteration, false_operand=(idx, old_state, alpha, differences), false_fun=_uncoveraged_kohn_sham_iteration) converged = jnp.mean( jnp.square(state.density - old_state.density)) < density_mse_converge_tolerance state = jax.lax.cond(idx <= stop_gradient_step, true_fun=jax.lax.stop_gradient, false_fun=lambda x: x, operand=state) return (idx + 1, state, alpha * alpha_decay, converged, differences), state # Create initial state. state = scf.KohnShamState( density=initial_density, total_energy=jnp.inf, locations=locations, nuclear_charges=nuclear_charges, external_potential=utils.get_atomic_chain_potential( grids=grids, locations=locations, nuclear_charges=nuclear_charges, interaction_fn=interaction_fn), grids=grids, num_electrons=num_electrons, # Add dummy fields so the input and output of lax.scan have the same type # structure. hartree_potential=jnp.zeros_like(grids), xc_potential=jnp.zeros_like(grids), xc_energy_density=jnp.zeros_like(grids), gap=0., converged=False) # Initialize the density differences with all zeros since the carry in # lax.scan must keep the same shape. differences = jnp.zeros((num_iterations, num_grids)) _, states = jax.lax.scan(_single_kohn_sham_iteration, init=(0, state, alpha, state.converged, differences), xs=jnp.arange(num_iterations)) return states