Пример #1
0
 def test_build_sliding_net_invalid_window_size(self):
   with self.assertRaisesRegex(
       ValueError, 'window_size cannot be less than 1, but got 0'):
     neural_xc.build_sliding_net(
         window_size=0,
         num_filters_list=[2, 4, 8],
         activation='softplus')
 def test_global_functional_with_sliding_net(self, activation):
     init_fn, xc_energy_density_fn = (neural_xc.global_functional(
         neural_xc.build_sliding_net(window_size=3,
                                     num_filters_list=[4, 2, 2],
                                     activation=activation),
         grids=self.grids))
     init_params = init_fn(rng=random.PRNGKey(0))
     xc_energy_density = xc_energy_density_fn(self.density, init_params)
     self.assertEqual(xc_energy_density.shape, (17, ))
    def test_build_sliding_net(self):
        init_fn, apply_fn = neural_xc.build_sliding_net(
            window_size=3, num_filters_list=[2, 4, 8], activation='softplus')
        output_shape, init_params = init_fn(random.PRNGKey(0),
                                            input_shape=(-1, 9, 1))
        self.assertEqual(output_shape, (-1, 9, 1))

        output = apply_fn(init_params, jnp.array(np.random.randn(6, 9, 1)))
        self.assertEqual(output.shape, (6, 9, 1))
    def test_global_functional_with_sliding_net_wrong_output_shape(self):
        init_fn, xc_energy_density_fn = (
            neural_xc.global_functional(
                stax.serial(
                    neural_xc.build_sliding_net(window_size=3,
                                                num_filters_list=[4, 2, 2],
                                                activation='softplus'),
                    # Additional conv layer to make the output shape wrong.
                    neural_xc.Conv1D(1,
                                     filter_shape=(1, ),
                                     strides=(2, ),
                                     padding='VALID')),
                grids=self.grids))
        init_params = init_fn(rng=random.PRNGKey(0))

        with self.assertRaisesRegex(
                ValueError, r'The output shape of the network '
                r'should be \(-1, 17\) but got \(1, 9\)'):
            xc_energy_density_fn(self.density, init_params)