def test_wrap_network_with_self_interaction_layer_large_num_electrons( self): grids = jnp.linspace(-5, 5, 9, dtype=jnp.float32) density = 100. * utils.gaussian(grids=grids, center=1., sigma=1.).astype(jnp.float32) reshaped_density = density[jnp.newaxis, :, jnp.newaxis] inner_network_init_fn, inner_network_apply_fn = neural_xc.build_unet( num_filters_list=[2, 4], core_num_filters=4, activation='swish') init_fn, apply_fn = neural_xc.wrap_network_with_self_interaction_layer( network=(inner_network_init_fn, inner_network_apply_fn), grids=grids, interaction_fn=utils.exponential_coulomb) output_shape, init_params = init_fn(random.PRNGKey(0), input_shape=((-1, 9, 1))) self.assertEqual(output_shape, (-1, 9, 1)) self.assertEqual( apply_fn(init_params, reshaped_density).shape, (1, 9, 1)) np.testing.assert_allclose( apply_fn(init_params, reshaped_density), inner_network_apply_fn( # The initial parameters of the inner network. init_params[1][1], reshaped_density))
def test_global_functional_with_unet(self, activation): init_fn, xc_energy_density_fn = (neural_xc.global_functional( neural_xc.build_unet(num_filters_list=[4, 2], core_num_filters=4, activation=activation), grids=self.grids)) init_params = init_fn(rng=random.PRNGKey(0)) xc_energy_density = xc_energy_density_fn(self.density, init_params) self.assertEqual(xc_energy_density.shape, (17, ))
def test_build_unet(self): init_fn, apply_fn = neural_xc.build_unet(num_filters_list=[2, 4, 8], core_num_filters=16, activation='softplus') output_shape, init_params = init_fn(random.PRNGKey(0), input_shape=(-1, 9, 1)) self.assertEqual(output_shape, (-1, 9, 1)) output = apply_fn(init_params, jnp.array(np.random.randn(6, 9, 1))) self.assertEqual(output.shape, (6, 9, 1))
def test_global_functional_wrong_num_grids(self): with self.assertRaisesRegex( ValueError, 'The num_grids must be power of two plus one for global functional ' 'but got 6'): neural_xc.global_functional( neural_xc.build_unet(num_filters_list=[4, 2], core_num_filters=4, activation='softplus'), # grids with wrong num_grids. grids=jnp.linspace(-1, 1, 6))
def test_global_functional_wrong_num_spatial_shift(self): with self.assertRaisesRegex( ValueError, 'num_spatial_shift can not be less than 1 but got 0'): neural_xc.global_functional( neural_xc.build_unet(num_filters_list=[4, 2], core_num_filters=4, activation='swish'), grids=self.grids, # Wrong num_spatial_shift num_spatial_shift=0)
def test_global_functional_with_unet_wrong_output_shape(self): init_fn, xc_energy_density_fn = ( neural_xc.global_functional( stax.serial( neural_xc.build_unet(num_filters_list=[4, 2], core_num_filters=4, activation='softplus'), # Additional conv layer to make the output shape wrong. neural_xc.Conv1D(1, filter_shape=(3, ), strides=(2, ), padding='VALID')), grids=self.grids)) init_params = init_fn(rng=random.PRNGKey(0)) with self.assertRaisesRegex( ValueError, r'The output shape of the network ' r'should be \(-1, 17\) but got \(1, 8\)'): xc_energy_density_fn(self.density, init_params)
def test_wrap_network_with_self_interaction_layer_one_electron(self): grids = jnp.linspace(-5, 5, 9) density = utils.gaussian(grids=grids, center=1., sigma=1.) reshaped_density = density[jnp.newaxis, :, jnp.newaxis] init_fn, apply_fn = neural_xc.wrap_network_with_self_interaction_layer( network=neural_xc.build_unet(num_filters_list=[2, 4], core_num_filters=4, activation='swish'), grids=grids, interaction_fn=utils.exponential_coulomb) output_shape, init_params = init_fn(random.PRNGKey(0), input_shape=((-1, 9, 1))) self.assertEqual(output_shape, (-1, 9, 1)) np.testing.assert_allclose( apply_fn(init_params, reshaped_density), -0.5 * scf.get_hartree_potential( density=density, grids=grids, interaction_fn=utils.exponential_coulomb)[jnp.newaxis, :, jnp.newaxis])