def test_propagate_masks_ablated_neurons_mixed_conv_dense_layers(self): """Tests mask propagation on a two-layer convolutional/dense model.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros( self._masked_mixed_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.ones( self._masked_mixed_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape), 'bias': None, }, } with self.assertRaisesRegex( ValueError, 'propagate_masks requires knowledge of the spatial ' 'dimensions of the previous layer. Use a functionally equivalent ' 'conv. layer in place of a dense layer in a model with a mixed ' 'conv/dense setting.'): masked.propagate_masks(mask)
def test_propagate_masks_ablated_neurons_two_conv_layers(self): """Tests mask propagation on a two-layer convolutional model.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros( self._masked_conv_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.ones( self._masked_conv_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape), 'bias': None, }, } refined_mask = masked.propagate_masks(mask) with self.subTest(name='layer_1'): self.assertTrue( (refined_mask['MaskedModule_0']['kernel'] == 0).all()) # Since layer 1 is all zero, layer 2 is also effectively zero. with self.subTest(name='layer_2'): self.assertTrue( (refined_mask['MaskedModule_1']['kernel'] == 0).all())
def test_propagate_masks_ablated_neurons_two_layers_nonmasked(self): """Tests mask propagation where previous layer is not masked.""" mask = { 'Dense_0': { 'kernel': None, 'bias': None, }, 'MaskedModule_1': { 'kernel': jax.random.normal( self._rng, self._masked_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape, dtype=jnp.float32), 'bias': None, }, } refined_mask = masked.propagate_masks(mask) with self.subTest(name='layer_1'): self.assertIsNone(refined_mask['Dense_0']['kernel']) # Since layer 1 is all zero, layer 2 is also effectively zero. with self.subTest(name='layer_2'): # Since this is a single masked layer, should not affect mask at all. self.assertTrue((mask['MaskedModule_1']['kernel'] == refined_mask['MaskedModule_1']['kernel']).all())
def test_propagate_masks_ablated_neurons_one_conv_layer(self): """Tests mask propagation on a single layer model.""" mask = { 'MaskedModule_0': { 'kernel': jax.random.normal( self._rng, self._masked_conv_model.params['MaskedModule_0'] ['unmasked']['kernel'].shape, dtype=jnp.float32), 'bias': None, }, } refined_mask = masked.propagate_masks(mask) # Since this is a single layer, should not affect mask at all. self.assertTrue((mask['MaskedModule_0']['kernel'] == refined_mask['MaskedModule_0']['kernel']).all())
for label, value in mask_stats.items(): try: summary_writer.scalar(f'mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32' ', writing as string instead.', label) if FLAGS.dump_json: mask_stats['permutations'] = str(mask_stats['permutations']) utils.dump_dict_json( mask_stats, path.join(experiment_dir, 'mask_stats.json')) mask = masked.propagate_masks(mask) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Propagated mask stats: %s', str(mask_stats)) for label, value in mask_stats.items(): try: summary_writer.scalar(f'propagated_mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'propagated_mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32' ', writing as string instead.', label)