예제 #1
0
 def test_shuffled_neuron_no_input_ablation_mask_invalid_model(self):
     """Tests shuffled mask with model containing no masked layers."""
     with self.assertRaisesRegex(
             ValueError,
             'Model does not support masking, i.e. no layers are '
             'wrapped by a MaskedModule.'):
         masked.shuffled_neuron_no_input_ablation_mask(
             self._unmasked_model, self._rng, 0.5)
예제 #2
0
    def test_shuffled_neuron_no_input_ablation_mask_invalid_sparsity(self):
        """Tests shuffled mask with invalid sparsity."""

        with self.subTest(name='sparsity_too_small'):
            with self.assertRaisesRegex(
                    ValueError,
                    r'Given sparsity, -0.5, is not in range \[0, 1\]'):
                masked.shuffled_neuron_no_input_ablation_mask(
                    self._masked_model, self._rng, -0.5)

        with self.subTest(name='sparsity_too_large'):
            with self.assertRaisesRegex(
                    ValueError,
                    r'Given sparsity, 1.5, is not in range \[0, 1\]'):
                masked.shuffled_neuron_no_input_ablation_mask(
                    self._masked_model, self._rng, 1.5)
예제 #3
0
    def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty_twolayer(
            self):
        """Tests shuffled mask generation for two layers, for 0% sparsity."""
        mask = masked.shuffled_neuron_no_input_ablation_mask(
            self._masked_model_twolayer, self._rng, 0.0)

        with self.subTest(name='shuffled_empty_mask_layer1'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_empty_mask_values_layer1'):
            self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())

        with self.subTest(name='shuffled_empty_mask_layer2'):
            self.assertIn('MaskedModule_1', mask)

        with self.subTest(name='shuffled_empty_mask_values_layer2'):
            self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all())

        masked_output = self._masked_model_twolayer(self._input, mask=mask)

        with self.subTest(name='shuffled_empty_dense_values'):
            self.assertTrue(
                jnp.isclose(masked_output,
                            self._unmasked_output_twolayer).all())

        with self.subTest(name='shuffled_empty_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output_twolayer.shape)
예제 #4
0
    def test_shuffled_neuron_no_input_ablation_mask_sparsity_full(self):
        """Tests shuffled mask generation, for 100% sparsity."""
        mask = masked.shuffled_neuron_no_input_ablation_mask(
            self._masked_model, self._rng, 1.0)

        with self.subTest(name='shuffled_full_mask'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_full_mask_values'):
            self.assertEqual(
                jnp.count_nonzero(mask['MaskedModule_0']['kernel']),
                jnp.prod(jnp.array(self._input_dimensions)))

        with self.subTest(name='shuffled_full_no_input_ablation'):
            # Check no row (neurons are columns) is completely ablated.
            self.assertTrue((jnp.count_nonzero(
                mask['MaskedModule_0']['kernel'], axis=0) != 0).all())

        with self.subTest(name='shuffled_full_mask_not_masked_values'):
            self.assertIsNone(mask['MaskedModule_0']['bias'])

        masked_output = self._masked_model(self._input, mask=mask)

        with self.subTest(name='shuffled_full_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output.shape)
예제 #5
0
    def test_shuffled_neuron_no_input_ablation_mask_sparsity_half_full(self):
        """Tests shuffled mask generation, for a half-full mask."""
        mask = masked.shuffled_neuron_no_input_ablation_mask(
            self._masked_model, self._rng, 0.5)
        param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][
            'kernel'].shape

        with self.subTest(name='shuffled_mask_values'):
            self.assertEqual(jnp.sum(mask['MaskedModule_0']['kernel']),
                             param_shape[0] // 2 * param_shape[1])

        with self.subTest(name='shuffled_half_no_input_ablation'):
            # Check no row (neurons are columns) is completely ablated.
            self.assertTrue((jnp.count_nonzero(
                mask['MaskedModule_0']['kernel'], axis=0) != 0).all())
예제 #6
0
    def test_shuffled_neuron_no_input_ablation_mask_sparsity_full_twolayer(
            self):
        """Tests shuffled mask generation for two layers, and 100% sparsity."""
        mask = masked.shuffled_neuron_no_input_ablation_mask(
            self._masked_model_twolayer, self._rng, 1.0)

        with self.subTest(name='shuffled_full_mask_layer1'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_full_mask_values_layer1'):
            self.assertEqual(
                jnp.count_nonzero(mask['MaskedModule_0']['kernel']),
                jnp.prod(jnp.array(self._input_dimensions)))

        with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'):
            self.assertIsNone(mask['MaskedModule_0']['bias'])

        with self.subTest(name='shuffled_full_no_input_ablation_layer1'):
            # Check no row (neurons are columns) is completely ablated.
            self.assertTrue((jnp.count_nonzero(
                mask['MaskedModule_0']['kernel'], axis=0) != 0).all())

        with self.subTest(name='shuffled_full_mask_layer2'):
            self.assertIn('MaskedModule_1', mask)

        with self.subTest(name='shuffled_full_mask_values_layer2'):
            self.assertEqual(
                jnp.count_nonzero(mask['MaskedModule_1']['kernel']),
                jnp.prod(MaskedTwoLayerDense.NUM_FEATURES[0]))

        with self.subTest(name='shuffled_full_mask_not_masked_values_layer2'):
            self.assertIsNone(mask['MaskedModule_1']['bias'])

        with self.subTest(name='shuffled_full_no_input_ablation_layer2'):
            # Note: check no *inputs* are ablated, and inputs < num_neurons.
            self.assertEqual(
                jnp.sum(
                    jnp.count_nonzero(mask['MaskedModule_1']['kernel'],
                                      axis=0)),
                MaskedTwoLayerDense.NUM_FEATURES[0])

        masked_output = self._masked_model_twolayer(self._input, mask=mask)

        with self.subTest(name='shuffled_full_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output_twolayer.shape)
예제 #7
0
    def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty(self):
        """Tests shuffled mask generation, for 0% sparsity."""
        mask = masked.shuffled_neuron_no_input_ablation_mask(
            self._masked_model, self._rng, 0.0)

        with self.subTest(name='shuffled_empty_mask'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_empty_mask_values'):
            self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())

        with self.subTest(name='shuffled_empty_mask_not_masked_values'):
            self.assertIsNone(mask['MaskedModule_0']['bias'])

        masked_output = self._masked_model(self._input, mask=mask)

        with self.subTest(name='shuffled_empty_dense_values'):
            self.assertTrue(
                jnp.isclose(masked_output, self._unmasked_output).all())

        with self.subTest(name='shuffled_empty_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output.shape)