Пример #1
0
    def test_wrap_network_with_self_interaction_layer_large_num_electrons(
            self):
        grids = jnp.linspace(-5, 5, 9, dtype=jnp.float32)
        density = 100. * utils.gaussian(grids=grids, center=1.,
                                        sigma=1.).astype(jnp.float32)
        reshaped_density = density[jnp.newaxis, :, jnp.newaxis]
        inner_network_init_fn, inner_network_apply_fn = neural_xc.build_unet(
            num_filters_list=[2, 4], core_num_filters=4, activation='swish')

        init_fn, apply_fn = neural_xc.wrap_network_with_self_interaction_layer(
            network=(inner_network_init_fn, inner_network_apply_fn),
            grids=grids,
            interaction_fn=utils.exponential_coulomb)
        output_shape, init_params = init_fn(random.PRNGKey(0),
                                            input_shape=((-1, 9, 1)))

        self.assertEqual(output_shape, (-1, 9, 1))
        self.assertEqual(
            apply_fn(init_params, reshaped_density).shape, (1, 9, 1))
        np.testing.assert_allclose(
            apply_fn(init_params, reshaped_density),
            inner_network_apply_fn(
                # The initial parameters of the inner network.
                init_params[1][1],
                reshaped_density))
    def test_wrap_network_with_self_interaction_layer_one_electron(self):
        grids = jnp.linspace(-5, 5, 9)
        density = utils.gaussian(grids=grids, center=1., sigma=1.)
        reshaped_density = density[jnp.newaxis, :, jnp.newaxis]

        init_fn, apply_fn = neural_xc.wrap_network_with_self_interaction_layer(
            network=neural_xc.build_unet(num_filters_list=[2, 4],
                                         core_num_filters=4,
                                         activation='swish'),
            grids=grids,
            interaction_fn=utils.exponential_coulomb)
        output_shape, init_params = init_fn(random.PRNGKey(0),
                                            input_shape=((-1, 9, 1)))

        self.assertEqual(output_shape, (-1, 9, 1))
        np.testing.assert_allclose(
            apply_fn(init_params, reshaped_density),
            -0.5 * scf.get_hartree_potential(
                density=density,
                grids=grids,
                interaction_fn=utils.exponential_coulomb)[jnp.newaxis, :,
                                                          jnp.newaxis])