def _get_hartree_energy(density, grids, interaction_fn): """Gets the Hartree energy.""" n1 = jnp.expand_dims(density, axis=0) n2 = jnp.expand_dims(density, axis=1) r1 = jnp.expand_dims(grids, axis=0) r2 = jnp.expand_dims(grids, axis=1) return 0.5 * jnp.sum( n1 * n2 * interaction_fn(r1 - r2)) * utils.get_dx(grids) ** 2
def test_self_interaction_weight(self, density_integral, expected_weight): grids = jnp.linspace(-5, 5, 11) self.assertAlmostEqual( neural_xc.self_interaction_weight( reshaped_density=density_integral * utils.gaussian(grids=grids, center=1., sigma=1.)[jnp.newaxis, :, jnp.newaxis], dx=utils.get_dx(grids), width=1.), expected_weight)
def _wavefunctions_to_density(num_electrons, wavefunctions, grids): """Converts wavefunctions to density.""" # Reduce the amount of computation by removing most of the unoccupid states. wavefunctions = wavefunctions[:num_electrons] # Normalize the wavefunctions. wavefunctions = wavefunctions / jnp.sqrt(jnp.sum( wavefunctions ** 2, axis=1, keepdims=True) * utils.get_dx(grids)) # Each eigenstate has spin up and spin down. intensities = jnp.repeat(wavefunctions ** 2, repeats=2, axis=0) return jnp.sum(intensities[:num_electrons], axis=0)
def loss(flatten_params, initial_state, target_density): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(self.grids)
def get_kinetic_matrix(grids): """Gets kinetic matrix. Args: grids: Float numpy array with shape (num_grids,). Returns: Float numpy array with shape (num_grids, num_grids). """ dx = utils.get_dx(grids) return -0.5 * discrete_laplacian(grids.size) / (dx * dx)
def get_external_potential_energy(external_potential, density, grids): """Gets external potential energy. Args: external_potential: Float numpy array with shape (num_grids,). density: Float numpy array with shape (num_grids,). grids: Float numpy array with shape (num_grids,). Returns: Float. """ return jnp.dot(density, external_potential) * utils.get_dx(grids)
def get_xc_energy(density, xc_energy_density_fn, grids): r"""Gets xc energy. E_xc = \int density * xc_energy_density_fn(density) dx. Args: density: Float numpy array with shape (num_grids,). xc_energy_density_fn: function takes density and returns float numpy array with shape (num_grids,). grids: Float numpy array with shape (num_grids,). Returns: Float. """ return jnp.dot(xc_energy_density_fn(density), density) * utils.get_dx(grids)
def loss(flatten_params, target_density): state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten( spec, flatten_params)), interaction_fn=utils.exponential_coulomb, density_mse_converge_tolerance=-1) final_state = scf.get_final_state(state) return jnp.sum(jnp.abs(final_state.density - target_density)) * utils.get_dx(self.grids)
def self_interaction_layer(grids, interaction_fn): """Layer construction function for self-interaction. The first input is density and the second input is the feature to mix. When the density integral is one, this layer outputs -0.5 * Hartree potential in the same shape of the density which will cancel the Hartree term. When the density integral is not one, the output is a linear combination of two inputs to this layer. The weight are determined by self_interaction_weight(). Args: grids: Float numpy array with shape (num_grids,). interaction_fn: function takes displacements and returns float numpy array with the same shape of displacements. Returns: (init_fn, apply_fn) pair. """ grids = grids.astype(jnp.float32) dx = utils.get_dx(grids) def init_fn(rng, input_shape): del rng if len(input_shape) != 2: raise ValueError(f'self_interaction_layer must have two inputs, ' f'but got {len(input_shape)}') if input_shape[0] != input_shape[1]: raise ValueError( f'The input shape to self_interaction_layer must be equal, ' f'but got {input_shape[0]} and {input_shape[1]}') return input_shape[0], (jnp.array(1.), ) def apply_fn(params, inputs, **kwargs): # pylint: disable=missing-docstring del kwargs width, = params reshaped_density, features = inputs beta = self_interaction_weight(reshaped_density=reshaped_density, dx=dx, width=width) hartree = -0.5 * scf.get_hartree_potential( density=reshaped_density.reshape(-1), grids=grids, interaction_fn=interaction_fn).reshape(reshaped_density.shape) return hartree * beta + features * (1 - beta) return init_fn, apply_fn
def get_xc_potential(density, xc_energy_density_fn, grids): """Gets xc potential. The xc potential is derived from xc_energy_density through automatic differentiation. Args: density: Float numpy array with shape (num_grids,). xc_energy_density_fn: function takes density and returns float numpy array with shape (num_grids,). grids: Float numpy array with shape (num_grids,). Returns: Float numpy array with shape (num_grids,). """ return jax.grad(get_xc_energy)( density, xc_energy_density_fn, grids) / utils.get_dx(grids)
def test_get_hartree_potential(self, interaction_fn): grids = jnp.linspace(-5, 5, 11) dx = utils.get_dx(grids) density = utils.gaussian(grids=grids, center=1., sigma=1.) # Compute the expected Hartree energy by nested for loops. expected_hartree_potential = np.zeros_like(grids) for i, x_0 in enumerate(grids): for x_1, n_1 in zip(grids, density): expected_hartree_potential[i] += np.sum( n_1 * interaction_fn(x_0 - x_1)) * dx np.testing.assert_allclose( scf.get_hartree_potential(density=density, grids=grids, interaction_fn=interaction_fn), expected_hartree_potential)
def spherical_superposition_density(grids, locations, nuclear_charges): """Builds initial guess of density by superposition of spherical densities. Args: grids: Float numpy array with shape (num_grids,). locations: Float numpy array with shape (num_nuclei,). nuclear_charges: Float numpy array with shape (num_nuclei,). Returns: Float numpy array with shape (num_grids,). """ # (num_nuclei, num_grids) displacements = np.expand_dims(np.array(grids), axis=0) - np.expand_dims( np.array(locations), axis=1) density = _get_exact_h_atom_density(displacements, float(utils.get_dx(grids))) return np.dot(nuclear_charges, density)
def test_get_hartree_energy(self, interaction_fn): grids = jnp.linspace(-5, 5, 11) dx = utils.get_dx(grids) density = utils.gaussian(grids=grids, center=1., sigma=1.) # Compute the expected Hartree energy by nested for loops. expected_hartree_energy = 0. for x_0, n_0 in zip(grids, density): for x_1, n_1 in zip(grids, density): expected_hartree_energy += 0.5 * n_0 * n_1 * interaction_fn( x_0 - x_1) * dx**2 self.assertAlmostEqual( float( scf.get_hartree_energy(density=density, grids=grids, interaction_fn=interaction_fn)), float(expected_hartree_energy))
def _test_state(self, state, external_potential): # The density in the final state should normalize to number of electrons. self.assertAlmostEqual( float(jnp.sum(state.density) * utils.get_dx(self.grids)), self.num_electrons) # The total energy should be finite after iterations. self.assertTrue(jnp.isfinite(state.total_energy)) self.assertLen(state.hartree_potential, len(state.grids)) self.assertLen(state.xc_potential, len(state.grids)) # locations, nuclear_charges, external_potential, grids and num_electrons # remain unchanged. np.testing.assert_allclose(state.locations, self.locations) np.testing.assert_allclose(state.nuclear_charges, self.nuclear_charges) np.testing.assert_allclose(external_potential, state.external_potential) np.testing.assert_allclose(state.grids, self.grids) self.assertEqual(state.num_electrons, self.num_electrons) self.assertGreater(state.gap, 0)
def test_get_dx(self): self.assertAlmostEqual(utils.get_dx(jnp.linspace(0, 1, 11)), 0.1)
def test_get_dx_incorrect_ndim(self): with self.assertRaisesRegex( ValueError, 'grids.ndim is expected to be 1 but got 2'): utils.get_dx(jnp.array([[-0.1], [0.], [0.1]]))
def exponential_global_convolution(num_channels, grids, minval, maxval, downsample_factor=0, eta_init=nn.initializers.normal()): """Layer construction function for exponential global convolution. Args: num_channels: Integer, the number of channels. grids: Float numpy array with shape (num_grids,). minval: Float, the min value in the uniform sampling for exponential width. maxval: Float, the max value in the uniform sampling for exponential width. downsample_factor: Integer, the factor of downsampling. The grids are downsampled with step size 2 ** downsample_factor. eta_init: Initializer function in nn.initializers. Returns: (init_fn, apply_fn) pair. """ grids = grids.astype(jnp.float32)[::2**downsample_factor] displacements = jnp.expand_dims(grids, axis=0) - jnp.expand_dims(grids, axis=1) dx = utils.get_dx(grids) def init_fn(rng, input_shape): if num_channels <= 0: raise ValueError( f'num_channels must be positive but got {num_channels}') if len(input_shape) != 3: raise ValueError( f'The ndim of input should be 3, but got {len(input_shape)}') if input_shape[1] != len(grids): raise ValueError( f'input_shape[1] should be len(grids), but got {input_shape[1]}' ) if input_shape[2] != 1: raise ValueError( f'input_shape[2] should be 1, but got {input_shape[2]}') output_shape = input_shape[:-1] + (num_channels, ) eta = eta_init(rng, shape=(num_channels, )) return output_shape, (eta, ) def apply_fn(params, inputs, **kwargs): """Applies layer. Args: params: Layer parameters, (eta,). inputs: Float numpy array with shape (batch_size, num_grids, num_in_channels). **kwargs: Other key word arguments. Unused. Returns: Float numpy array with shape (batch_size, num_grids, num_channels). """ del kwargs eta, = params # shape (num_grids, num_grids, num_channels) kernels = _exponential_function_channels( displacements, widths=minval + (maxval - minval) * nn.sigmoid(eta)) # shape (batch_size, num_grids, num_channels) return jnp.squeeze( # shape (batch_size, 1, num_grids, num_channels) jnp.tensordot(inputs, kernels, axes=(1, 0)) * dx, axis=1) return init_fn, apply_fn
def _get_hartree_potential(density, grids, interaction_fn): """Gets the Hartree potential.""" n1 = jnp.expand_dims(density, axis=0) r1 = jnp.expand_dims(grids, axis=0) r2 = jnp.expand_dims(grids, axis=1) return jnp.sum(n1 * interaction_fn(r1 - r2), axis=1) * utils.get_dx(grids)