def test_global_functional_wrong_num_spatial_shift(self): with self.assertRaisesRegex( ValueError, 'num_spatial_shift can not be less than 1 but got 0'): neural_xc.global_functional( neural_xc.build_unet(num_filters_list=[4, 2], core_num_filters=4, activation='swish'), grids=self.grids, # Wrong num_spatial_shift num_spatial_shift=0)
def test_global_functional_wrong_num_grids(self): with self.assertRaisesRegex( ValueError, 'The num_grids must be power of two plus one for global functional ' 'but got 6'): neural_xc.global_functional( neural_xc.build_unet(num_filters_list=[4, 2], core_num_filters=4, activation='softplus'), # grids with wrong num_grids. grids=jnp.linspace(-1, 1, 6))
def test_global_functional_with_sliding_net(self, activation): init_fn, xc_energy_density_fn = (neural_xc.global_functional( neural_xc.build_sliding_net(window_size=3, num_filters_list=[4, 2, 2], activation=activation), grids=self.grids)) init_params = init_fn(rng=random.PRNGKey(0)) xc_energy_density = xc_energy_density_fn(self.density, init_params) self.assertEqual(xc_energy_density.shape, (17, ))
def test_global_functional_with_sliding_net_wrong_output_shape(self): init_fn, xc_energy_density_fn = ( neural_xc.global_functional( stax.serial( neural_xc.build_sliding_net(window_size=3, num_filters_list=[4, 2, 2], activation='softplus'), # Additional conv layer to make the output shape wrong. neural_xc.Conv1D(1, filter_shape=(1, ), strides=(2, ), padding='VALID')), grids=self.grids)) init_params = init_fn(rng=random.PRNGKey(0)) with self.assertRaisesRegex( ValueError, r'The output shape of the network ' r'should be \(-1, 17\) but got \(1, 9\)'): xc_energy_density_fn(self.density, init_params)