def test_get_atomic_chain_potential_incorrect_ndim( self, grids, locations, nuclear_charges, expected_message): with self.assertRaisesRegex(ValueError, expected_message): utils.get_atomic_chain_potential( grids=jnp.array(grids), locations=jnp.array(locations), nuclear_charges=jnp.array(nuclear_charges), interaction_fn=utils.exponential_coulomb)
def test_get_atomic_chain_potential_soft_coulomb(self): potential = utils.get_atomic_chain_potential( grids=jnp.linspace(-10, 10, 201), locations=jnp.array([0., 1.]), nuclear_charges=jnp.array([2, 1]), interaction_fn=utils.soft_coulomb) # -2 / jnp.sqrt(10 ** 2 + 1) - 1 / jnp.sqrt(11 ** 2 + 1) = -0.28954318 self.assertAlmostEqual(float(potential[0]), -0.28954318) # -2 / jnp.sqrt(0 ** 2 + 1) - 1 / jnp.sqrt(1 ** 2 + 1) = -2.70710678 self.assertAlmostEqual(float(potential[100]), -2.70710678) # -2 / jnp.sqrt(10 ** 2 + 1) - 1 / jnp.sqrt(9 ** 2 + 1) = -0.30943896 self.assertAlmostEqual(float(potential[200]), -0.30943896)
def test_get_atomic_chain_potential_exponential_coulomb(self): potential = utils.get_atomic_chain_potential( grids=jnp.linspace(-10, 10, 201), locations=jnp.array([0., 1.]), nuclear_charges=jnp.array([2, 1]), interaction_fn=utils.exponential_coulomb) # -2 * 1.071295 * jnp.exp(-np.abs(10) / 2.385345) - 1.071295 * jnp.exp( # -np.abs(11) / 2.385345) = -0.04302427 self.assertAlmostEqual(float(potential[0]), -0.04302427) # -2 * 1.071295 * jnp.exp(-np.abs(0) / 2.385345) - 1.071295 * jnp.exp( # -np.abs(1) / 2.385345) = -2.84702559 self.assertAlmostEqual(float(potential[100]), -2.84702559) # -2 * 1.071295 * jnp.exp(-np.abs(10) / 2.385345) - 1.071295 * jnp.exp( # -np.abs(9) / 2.385345) = -0.05699946 self.assertAlmostEqual(float(potential[200]), -0.05699946)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.info('JAX devices: %s', jax.devices()) grids = np.arange(-256, 257) * 0.08 external_potential = utils.get_atomic_chain_potential( grids=grids, locations=np.array([-0.8, 0.8]), nuclear_charges=np.array([1., 1.]), interaction_fn=utils.exponential_coulomb) density, total_eigen_energies, _ = scf.solve_noninteracting_system( external_potential, num_electrons=FLAGS.num_electrons, grids=grids) logging.info('density: %s', density) logging.info('total energy: %f', total_eigen_energies)
def _create_testing_initial_state(self, interaction_fn): locations = jnp.array([-0.5, 0.5]) nuclear_charges = jnp.array([1, 1]) return scf.KohnShamState( density=self.num_electrons * utils.gaussian(grids=self.grids, center=0., sigma=1.), # Set initial energy as inf, the actual value is not used in Kohn-Sham # calculation. total_energy=jnp.inf, locations=locations, nuclear_charges=nuclear_charges, external_potential=utils.get_atomic_chain_potential( grids=self.grids, locations=locations, nuclear_charges=nuclear_charges, interaction_fn=interaction_fn), grids=self.grids, num_electrons=self.num_electrons)
def _create_testing_external_potential(self, interaction_fn): return utils.get_atomic_chain_potential( grids=self.grids, locations=self.locations, nuclear_charges=self.nuclear_charges, interaction_fn=interaction_fn)
def _kohn_sham(locations, nuclear_charges, num_electrons, num_iterations, grids, xc_energy_density_fn, interaction_fn, initial_density, alpha, alpha_decay, enforce_reflection_symmetry, num_mixing_iterations, density_mse_converge_tolerance, stop_gradient_step): """Jit-able Kohn Sham calculation.""" num_grids = grids.shape[0] weights = _connection_weights(num_iterations, num_mixing_iterations) def _converged_kohn_sham_iteration(old_state_differences): old_state, differences = old_state_differences return old_state._replace(converged=True), differences def _uncoveraged_kohn_sham_iteration(idx_old_state_alpha_differences): idx, old_state, alpha, differences = idx_old_state_alpha_differences state = kohn_sham_iteration( state=old_state, num_electrons=num_electrons, xc_energy_density_fn=xc_energy_density_fn, interaction_fn=interaction_fn, enforce_reflection_symmetry=enforce_reflection_symmetry) differences = jax.ops.index_update(differences, idx, state.density - old_state.density) # Density mixing. state = state._replace(density=old_state.density + alpha * jnp.dot(weights[idx], differences)) return state, differences def _single_kohn_sham_iteration(carry, inputs): del inputs idx, old_state, alpha, converged, differences = carry state, differences = jax.lax.cond( converged, true_operand=(old_state, differences), true_fun=_converged_kohn_sham_iteration, false_operand=(idx, old_state, alpha, differences), false_fun=_uncoveraged_kohn_sham_iteration) converged = jnp.mean( jnp.square(state.density - old_state.density)) < density_mse_converge_tolerance state = jax.lax.cond(idx <= stop_gradient_step, true_fun=jax.lax.stop_gradient, false_fun=lambda x: x, operand=state) return (idx + 1, state, alpha * alpha_decay, converged, differences), state # Create initial state. state = scf.KohnShamState( density=initial_density, total_energy=jnp.inf, locations=locations, nuclear_charges=nuclear_charges, external_potential=utils.get_atomic_chain_potential( grids=grids, locations=locations, nuclear_charges=nuclear_charges, interaction_fn=interaction_fn), grids=grids, num_electrons=num_electrons, # Add dummy fields so the input and output of lax.scan have the same type # structure. hartree_potential=jnp.zeros_like(grids), xc_potential=jnp.zeros_like(grids), xc_energy_density=jnp.zeros_like(grids), gap=0., converged=False) # Initialize the density differences with all zeros since the carry in # lax.scan must keep the same shape. differences = jnp.zeros((num_iterations, num_grids)) _, states = jax.lax.scan(_single_kohn_sham_iteration, init=(0, state, alpha, state.converged, differences), xs=jnp.arange(num_iterations)) return states
def kohn_sham(locations, nuclear_charges, num_electrons, num_iterations, grids, xc_energy_density_fn, interaction_fn, initial_density=None, alpha=0.5, alpha_decay=0.9, enforce_reflection_symmetry=False, num_mixing_iterations=2, density_mse_converge_tolerance=-1.): """Runs Kohn-Sham to solve ground state of external potential. Args: locations: Float numpy array with shape (num_nuclei,), the locations of atoms. nuclear_charges: Float numpy array with shape (num_nuclei,), the nuclear charges. num_electrons: Integer, the number of electrons in the system. The first num_electrons states are occupid. num_iterations: Integer, the number of Kohn-Sham iterations. grids: Float numpy array with shape (num_grids,). xc_energy_density_fn: function takes density (num_grids,) and returns the energy density (num_grids,). interaction_fn: function takes displacements and returns float numpy array with the same shape of displacements. initial_density: Float numpy array with shape (num_grids,), initial guess of the density for Kohn-Sham calculation. Default None, the initial density is non-interacting solution from the external_potential. alpha: Float between 0 and 1, density linear mixing factor, the fraction of the output of the k-th Kohn-Sham iteration. If 0, the input density to the k-th Kohn-Sham iteration is fed into the (k+1)-th iteration. The output of the k-th Kohn-Sham iteration is completely ignored. If 1, the output density from the k-th Kohn-Sham iteration is fed into the (k+1)-th iteration, equivalent to no density mixing. alpha_decay: Float between 0 and 1, the decay factor of alpha. The mixing factor after k-th iteration is alpha * alpha_decay ** k. enforce_reflection_symmetry: Boolean, whether to enforce reflection symmetry. If True, the density are symmetric respecting to the center. num_mixing_iterations: Integer, the number of density differences in the previous iterations to mix the density. density_mse_converge_tolerance: Float, the stopping criteria. When the MSE density difference between two iterations is smaller than this value, the Kohn Sham iterations finish. The outputs of the rest of the steps are padded by the output of the converged step. Set this value to negative to disable early stopping. Returns: KohnShamState, the states of all the Kohn-Sham iteration steps. """ external_potential = utils.get_atomic_chain_potential( grids=grids, locations=locations, nuclear_charges=nuclear_charges, interaction_fn=interaction_fn) if initial_density is None: # Use the non-interacting solution from the external_potential as initial # guess. initial_density, _, _ = solve_noninteracting_system( external_potential=external_potential, num_electrons=num_electrons, grids=grids) # Create initial state. state = KohnShamState(density=initial_density, total_energy=jnp.inf, locations=locations, nuclear_charges=nuclear_charges, external_potential=external_potential, grids=grids, num_electrons=num_electrons) states = [] differences = None converged = False for _ in range(num_iterations): if converged: states.append(state) continue old_state = state state = kohn_sham_iteration( state=old_state, num_electrons=num_electrons, xc_energy_density_fn=xc_energy_density_fn, interaction_fn=interaction_fn, enforce_reflection_symmetry=enforce_reflection_symmetry) density_difference = state.density - old_state.density if differences is None: differences = jnp.array([density_difference]) else: differences = jnp.vstack([differences, density_difference]) if jnp.mean(jnp.square( density_difference)) < density_mse_converge_tolerance: converged = True state = state._replace(converged=converged) # Density mixing. state = state._replace( density=old_state.density + alpha * jnp.mean(differences[-num_mixing_iterations:], axis=0)) states.append(state) alpha *= alpha_decay return tree_util.tree_multimap(lambda *x: jnp.stack(x), *states)