Ejemplo n.º 1
0
    def test_wrap_network_with_self_interaction_layer_large_num_electrons(
            self):
        grids = jnp.linspace(-5, 5, 9, dtype=jnp.float32)
        density = 100. * utils.gaussian(grids=grids, center=1.,
                                        sigma=1.).astype(jnp.float32)
        reshaped_density = density[jnp.newaxis, :, jnp.newaxis]
        inner_network_init_fn, inner_network_apply_fn = neural_xc.build_unet(
            num_filters_list=[2, 4], core_num_filters=4, activation='swish')

        init_fn, apply_fn = neural_xc.wrap_network_with_self_interaction_layer(
            network=(inner_network_init_fn, inner_network_apply_fn),
            grids=grids,
            interaction_fn=utils.exponential_coulomb)
        output_shape, init_params = init_fn(random.PRNGKey(0),
                                            input_shape=((-1, 9, 1)))

        self.assertEqual(output_shape, (-1, 9, 1))
        self.assertEqual(
            apply_fn(init_params, reshaped_density).shape, (1, 9, 1))
        np.testing.assert_allclose(
            apply_fn(init_params, reshaped_density),
            inner_network_apply_fn(
                # The initial parameters of the inner network.
                init_params[1][1],
                reshaped_density))
 def test_global_functional_with_unet(self, activation):
     init_fn, xc_energy_density_fn = (neural_xc.global_functional(
         neural_xc.build_unet(num_filters_list=[4, 2],
                              core_num_filters=4,
                              activation=activation),
         grids=self.grids))
     init_params = init_fn(rng=random.PRNGKey(0))
     xc_energy_density = xc_energy_density_fn(self.density, init_params)
     self.assertEqual(xc_energy_density.shape, (17, ))
    def test_build_unet(self):
        init_fn, apply_fn = neural_xc.build_unet(num_filters_list=[2, 4, 8],
                                                 core_num_filters=16,
                                                 activation='softplus')
        output_shape, init_params = init_fn(random.PRNGKey(0),
                                            input_shape=(-1, 9, 1))
        self.assertEqual(output_shape, (-1, 9, 1))

        output = apply_fn(init_params, jnp.array(np.random.randn(6, 9, 1)))
        self.assertEqual(output.shape, (6, 9, 1))
 def test_global_functional_wrong_num_grids(self):
     with self.assertRaisesRegex(
             ValueError,
             'The num_grids must be power of two plus one for global functional '
             'but got 6'):
         neural_xc.global_functional(
             neural_xc.build_unet(num_filters_list=[4, 2],
                                  core_num_filters=4,
                                  activation='softplus'),
             # grids with wrong num_grids.
             grids=jnp.linspace(-1, 1, 6))
 def test_global_functional_wrong_num_spatial_shift(self):
     with self.assertRaisesRegex(
             ValueError,
             'num_spatial_shift can not be less than 1 but got 0'):
         neural_xc.global_functional(
             neural_xc.build_unet(num_filters_list=[4, 2],
                                  core_num_filters=4,
                                  activation='swish'),
             grids=self.grids,
             # Wrong num_spatial_shift
             num_spatial_shift=0)
    def test_global_functional_with_unet_wrong_output_shape(self):
        init_fn, xc_energy_density_fn = (
            neural_xc.global_functional(
                stax.serial(
                    neural_xc.build_unet(num_filters_list=[4, 2],
                                         core_num_filters=4,
                                         activation='softplus'),
                    # Additional conv layer to make the output shape wrong.
                    neural_xc.Conv1D(1,
                                     filter_shape=(3, ),
                                     strides=(2, ),
                                     padding='VALID')),
                grids=self.grids))
        init_params = init_fn(rng=random.PRNGKey(0))

        with self.assertRaisesRegex(
                ValueError, r'The output shape of the network '
                r'should be \(-1, 17\) but got \(1, 8\)'):
            xc_energy_density_fn(self.density, init_params)
    def test_wrap_network_with_self_interaction_layer_one_electron(self):
        grids = jnp.linspace(-5, 5, 9)
        density = utils.gaussian(grids=grids, center=1., sigma=1.)
        reshaped_density = density[jnp.newaxis, :, jnp.newaxis]

        init_fn, apply_fn = neural_xc.wrap_network_with_self_interaction_layer(
            network=neural_xc.build_unet(num_filters_list=[2, 4],
                                         core_num_filters=4,
                                         activation='swish'),
            grids=grids,
            interaction_fn=utils.exponential_coulomb)
        output_shape, init_params = init_fn(random.PRNGKey(0),
                                            input_shape=((-1, 9, 1)))

        self.assertEqual(output_shape, (-1, 9, 1))
        np.testing.assert_allclose(
            apply_fn(init_params, reshaped_density),
            -0.5 * scf.get_hartree_potential(
                density=density,
                grids=grids,
                interaction_fn=utils.exponential_coulomb)[jnp.newaxis, :,
                                                          jnp.newaxis])