コード例 #1
0
 def test_global_functional_wrong_num_spatial_shift(self):
     with self.assertRaisesRegex(
             ValueError,
             'num_spatial_shift can not be less than 1 but got 0'):
         neural_xc.global_functional(
             neural_xc.build_unet(num_filters_list=[4, 2],
                                  core_num_filters=4,
                                  activation='swish'),
             grids=self.grids,
             # Wrong num_spatial_shift
             num_spatial_shift=0)
コード例 #2
0
 def test_global_functional_wrong_num_grids(self):
     with self.assertRaisesRegex(
             ValueError,
             'The num_grids must be power of two plus one for global functional '
             'but got 6'):
         neural_xc.global_functional(
             neural_xc.build_unet(num_filters_list=[4, 2],
                                  core_num_filters=4,
                                  activation='softplus'),
             # grids with wrong num_grids.
             grids=jnp.linspace(-1, 1, 6))
コード例 #3
0
 def test_global_functional_with_sliding_net(self, activation):
     init_fn, xc_energy_density_fn = (neural_xc.global_functional(
         neural_xc.build_sliding_net(window_size=3,
                                     num_filters_list=[4, 2, 2],
                                     activation=activation),
         grids=self.grids))
     init_params = init_fn(rng=random.PRNGKey(0))
     xc_energy_density = xc_energy_density_fn(self.density, init_params)
     self.assertEqual(xc_energy_density.shape, (17, ))
コード例 #4
0
    def test_global_functional_with_sliding_net_wrong_output_shape(self):
        init_fn, xc_energy_density_fn = (
            neural_xc.global_functional(
                stax.serial(
                    neural_xc.build_sliding_net(window_size=3,
                                                num_filters_list=[4, 2, 2],
                                                activation='softplus'),
                    # Additional conv layer to make the output shape wrong.
                    neural_xc.Conv1D(1,
                                     filter_shape=(1, ),
                                     strides=(2, ),
                                     padding='VALID')),
                grids=self.grids))
        init_params = init_fn(rng=random.PRNGKey(0))

        with self.assertRaisesRegex(
                ValueError, r'The output shape of the network '
                r'should be \(-1, 17\) but got \(1, 9\)'):
            xc_energy_density_fn(self.density, init_params)