def test_shuffled_neuron_no_input_ablation_mask_invalid_model(self): """Tests shuffled mask with model containing no masked layers.""" with self.assertRaisesRegex( ValueError, 'Model does not support masking, i.e. no layers are ' 'wrapped by a MaskedModule.'): masked.shuffled_neuron_no_input_ablation_mask( self._unmasked_model, self._rng, 0.5)
def test_shuffled_neuron_no_input_ablation_mask_invalid_sparsity(self): """Tests shuffled mask with invalid sparsity.""" with self.subTest(name='sparsity_too_small'): with self.assertRaisesRegex( ValueError, r'Given sparsity, -0.5, is not in range \[0, 1\]'): masked.shuffled_neuron_no_input_ablation_mask( self._masked_model, self._rng, -0.5) with self.subTest(name='sparsity_too_large'): with self.assertRaisesRegex( ValueError, r'Given sparsity, 1.5, is not in range \[0, 1\]'): masked.shuffled_neuron_no_input_ablation_mask( self._masked_model, self._rng, 1.5)
def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty_twolayer( self): """Tests shuffled mask generation for two layers, for 0% sparsity.""" mask = masked.shuffled_neuron_no_input_ablation_mask( self._masked_model_twolayer, self._rng, 0.0) with self.subTest(name='shuffled_empty_mask_layer1'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_empty_mask_values_layer1'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='shuffled_empty_mask_layer2'): self.assertIn('MaskedModule_1', mask) with self.subTest(name='shuffled_empty_mask_values_layer2'): self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all()) masked_output = self._masked_model_twolayer(self._input, mask=mask) with self.subTest(name='shuffled_empty_dense_values'): self.assertTrue( jnp.isclose(masked_output, self._unmasked_output_twolayer).all()) with self.subTest(name='shuffled_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output_twolayer.shape)
def test_shuffled_neuron_no_input_ablation_mask_sparsity_full(self): """Tests shuffled mask generation, for 100% sparsity.""" mask = masked.shuffled_neuron_no_input_ablation_mask( self._masked_model, self._rng, 1.0) with self.subTest(name='shuffled_full_mask'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_full_mask_values'): self.assertEqual( jnp.count_nonzero(mask['MaskedModule_0']['kernel']), jnp.prod(jnp.array(self._input_dimensions))) with self.subTest(name='shuffled_full_no_input_ablation'): # Check no row (neurons are columns) is completely ablated. self.assertTrue((jnp.count_nonzero( mask['MaskedModule_0']['kernel'], axis=0) != 0).all()) with self.subTest(name='shuffled_full_mask_not_masked_values'): self.assertIsNone(mask['MaskedModule_0']['bias']) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='shuffled_full_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_shuffled_neuron_no_input_ablation_mask_sparsity_half_full(self): """Tests shuffled mask generation, for a half-full mask.""" mask = masked.shuffled_neuron_no_input_ablation_mask( self._masked_model, self._rng, 0.5) param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][ 'kernel'].shape with self.subTest(name='shuffled_mask_values'): self.assertEqual(jnp.sum(mask['MaskedModule_0']['kernel']), param_shape[0] // 2 * param_shape[1]) with self.subTest(name='shuffled_half_no_input_ablation'): # Check no row (neurons are columns) is completely ablated. self.assertTrue((jnp.count_nonzero( mask['MaskedModule_0']['kernel'], axis=0) != 0).all())
def test_shuffled_neuron_no_input_ablation_mask_sparsity_full_twolayer( self): """Tests shuffled mask generation for two layers, and 100% sparsity.""" mask = masked.shuffled_neuron_no_input_ablation_mask( self._masked_model_twolayer, self._rng, 1.0) with self.subTest(name='shuffled_full_mask_layer1'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_full_mask_values_layer1'): self.assertEqual( jnp.count_nonzero(mask['MaskedModule_0']['kernel']), jnp.prod(jnp.array(self._input_dimensions))) with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'): self.assertIsNone(mask['MaskedModule_0']['bias']) with self.subTest(name='shuffled_full_no_input_ablation_layer1'): # Check no row (neurons are columns) is completely ablated. self.assertTrue((jnp.count_nonzero( mask['MaskedModule_0']['kernel'], axis=0) != 0).all()) with self.subTest(name='shuffled_full_mask_layer2'): self.assertIn('MaskedModule_1', mask) with self.subTest(name='shuffled_full_mask_values_layer2'): self.assertEqual( jnp.count_nonzero(mask['MaskedModule_1']['kernel']), jnp.prod(MaskedTwoLayerDense.NUM_FEATURES[0])) with self.subTest(name='shuffled_full_mask_not_masked_values_layer2'): self.assertIsNone(mask['MaskedModule_1']['bias']) with self.subTest(name='shuffled_full_no_input_ablation_layer2'): # Note: check no *inputs* are ablated, and inputs < num_neurons. self.assertEqual( jnp.sum( jnp.count_nonzero(mask['MaskedModule_1']['kernel'], axis=0)), MaskedTwoLayerDense.NUM_FEATURES[0]) masked_output = self._masked_model_twolayer(self._input, mask=mask) with self.subTest(name='shuffled_full_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output_twolayer.shape)
def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty(self): """Tests shuffled mask generation, for 0% sparsity.""" mask = masked.shuffled_neuron_no_input_ablation_mask( self._masked_model, self._rng, 0.0) with self.subTest(name='shuffled_empty_mask'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_empty_mask_values'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='shuffled_empty_mask_not_masked_values'): self.assertIsNone(mask['MaskedModule_0']['bias']) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='shuffled_empty_dense_values'): self.assertTrue( jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='shuffled_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)