Пример #1
0
    def test_propagate_masks_ablated_neurons_mixed_conv_dense_layers(self):
        """Tests mask propagation on a two-layer convolutional/dense model."""
        mask = {
            'MaskedModule_0': {
                'kernel':
                jnp.zeros(
                    self._masked_mixed_model_twolayer.params['MaskedModule_0']
                    ['unmasked']['kernel'].shape),
                'bias':
                None,
            },
            'MaskedModule_1': {
                'kernel':
                jnp.ones(
                    self._masked_mixed_model_twolayer.params['MaskedModule_1']
                    ['unmasked']['kernel'].shape),
                'bias':
                None,
            },
        }

        with self.assertRaisesRegex(
                ValueError,
                'propagate_masks requires knowledge of the spatial '
                'dimensions of the previous layer. Use a functionally equivalent '
                'conv. layer in place of a dense layer in a model with a mixed '
                'conv/dense setting.'):
            masked.propagate_masks(mask)
Пример #2
0
    def test_propagate_masks_ablated_neurons_two_conv_layers(self):
        """Tests mask propagation on a two-layer convolutional model."""
        mask = {
            'MaskedModule_0': {
                'kernel':
                jnp.zeros(
                    self._masked_conv_model_twolayer.params['MaskedModule_0']
                    ['unmasked']['kernel'].shape),
                'bias':
                None,
            },
            'MaskedModule_1': {
                'kernel':
                jnp.ones(
                    self._masked_conv_model_twolayer.params['MaskedModule_1']
                    ['unmasked']['kernel'].shape),
                'bias':
                None,
            },
        }

        refined_mask = masked.propagate_masks(mask)

        with self.subTest(name='layer_1'):
            self.assertTrue(
                (refined_mask['MaskedModule_0']['kernel'] == 0).all())

        # Since layer 1 is all zero, layer 2 is also effectively zero.
        with self.subTest(name='layer_2'):
            self.assertTrue(
                (refined_mask['MaskedModule_1']['kernel'] == 0).all())
Пример #3
0
    def test_propagate_masks_ablated_neurons_two_layers_nonmasked(self):
        """Tests mask propagation where previous layer is not masked."""
        mask = {
            'Dense_0': {
                'kernel': None,
                'bias': None,
            },
            'MaskedModule_1': {
                'kernel':
                jax.random.normal(
                    self._rng,
                    self._masked_model_twolayer.params['MaskedModule_1']
                    ['unmasked']['kernel'].shape,
                    dtype=jnp.float32),
                'bias':
                None,
            },
        }

        refined_mask = masked.propagate_masks(mask)

        with self.subTest(name='layer_1'):
            self.assertIsNone(refined_mask['Dense_0']['kernel'])

        # Since layer 1 is all zero, layer 2 is also effectively zero.
        with self.subTest(name='layer_2'):
            # Since this is a single masked layer, should not affect mask at all.
            self.assertTrue((mask['MaskedModule_1']['kernel'] ==
                             refined_mask['MaskedModule_1']['kernel']).all())
Пример #4
0
    def test_propagate_masks_ablated_neurons_one_conv_layer(self):
        """Tests mask propagation on a single layer model."""
        mask = {
            'MaskedModule_0': {
                'kernel':
                jax.random.normal(
                    self._rng,
                    self._masked_conv_model.params['MaskedModule_0']
                    ['unmasked']['kernel'].shape,
                    dtype=jnp.float32),
                'bias':
                None,
            },
        }

        refined_mask = masked.propagate_masks(mask)

        # Since this is a single layer, should not affect mask at all.
        self.assertTrue((mask['MaskedModule_0']['kernel'] ==
                         refined_mask['MaskedModule_0']['kernel']).all())
Пример #5
0
    for label, value in mask_stats.items():
      try:
        summary_writer.scalar(f'mask/{label}', value, 0)
      # This is needed because permutations (long int) can't be cast to float32.
      except (OverflowError, ValueError):
        summary_writer.text(f'mask/{label}', str(value), 0)
        logging.error('Could not write mask/%s to tensorflow summary as float32'
                      ', writing as string instead.', label)

    if FLAGS.dump_json:
      mask_stats['permutations'] = str(mask_stats['permutations'])
      utils.dump_dict_json(
          mask_stats, path.join(experiment_dir, 'mask_stats.json'))

  mask = masked.propagate_masks(mask)

  if jax.host_id() == 0:
    mask_stats = symmetry.get_mask_stats(mask)
    logging.info('Propagated mask stats: %s', str(mask_stats))


    for label, value in mask_stats.items():
      try:
        summary_writer.scalar(f'propagated_mask/{label}', value, 0)
      # This is needed because permutations (long int) can't be cast to float32.
      except (OverflowError, ValueError):
        summary_writer.text(f'propagated_mask/{label}', str(value), 0)
        logging.error('Could not write mask/%s to tensorflow summary as float32'
                      ', writing as string instead.', label)