def testConditionalMaskUpdate(self): weight = K.variable(np.linspace(1.0, 100.0, 100), name="weights") mask = K.ones(weight.get_shape()) threshold = K.zeros([]) def linear_sparsity(step): sparsity_val = ops.convert_to_tensor( [0.0, 0.1, 0.1, 0.3, 0.3, 0.5, 0.5, 0.5, 0.5, 0.5]) return ops.convert_to_tensor(True), sparsity_val[step] # Set up pruning p = pruning_impl.Pruning(pruning_vars=[(weight, mask, threshold)], training_step_fn=self.training_step_fn, pruning_schedule=linear_sparsity, block_size=self.block_size, block_pooling_type=self.block_pooling_type) non_zero_count = [] for _ in range(10): if context.executing_eagerly(): p.conditional_mask_update() p.weight_mask_op() state_ops.assign_add(self.global_step, 1) else: K.get_session().run(p.conditional_mask_update()) K.get_session().run(p.weight_mask_op()) K.get_session().run(state_ops.assign_add(self.global_step, 1)) non_zero_count.append(np.count_nonzero(K.get_value(weight))) # Weights pruned at steps 1,3,5 expected_non_zero_count = [100, 90, 90, 70, 70, 50, 50, 50, 50, 50] self.assertAllEqual(expected_non_zero_count, non_zero_count)
def testExtremelySparseMask(self): weight = tf.Variable(np.linspace(1.0, 100.0, 100), name="weights") weight_dtype = weight.dtype.base_dtype mask = tf.Variable(tf.ones(weight.get_shape(), dtype=weight_dtype), name="mask", dtype=weight_dtype) threshold = tf.Variable(tf.zeros([], dtype=weight_dtype), name="threshold", dtype=weight_dtype) self.initialize() extreme_sparsity = pruning_schedule.ConstantSparsity(0.9999, 0, 100, 1) p = pruning_impl.Pruning(pruning_vars=[(weight, mask, threshold)], training_step_fn=self.training_step_fn, pruning_schedule=extreme_sparsity, block_size=self.block_size, block_pooling_type=self.block_pooling_type) mask_before_pruning = K.get_value(mask) self.assertAllEqual(np.count_nonzero(mask_before_pruning), 100) if tf.executing_eagerly(): p.conditional_mask_update() else: K.get_session().run(p.conditional_mask_update()) # We should always have a single connection remaining. mask_after_pruning = K.get_value(mask) self.assertAllEqual(np.count_nonzero(mask_after_pruning), 1)
def build(self, input_shape): super(PruneLowMagnitude, self).build(input_shape) weight_vars, mask_vars, threshold_vars = [], [], [] self.prunable_weights = self.layer.get_prunable_weights() # For each of the prunable weights, add mask and threshold variables for weight in self.prunable_weights: mask = self.add_variable( 'mask', shape=weight.shape, initializer=tf.keras.initializers.get('ones'), dtype=weight.dtype, trainable=False, aggregation=tf.VariableAggregation.MEAN) threshold = self.add_variable( 'threshold', shape=[], initializer=tf.keras.initializers.get('zeros'), dtype=weight.dtype, trainable=False, aggregation=tf.VariableAggregation.MEAN) weight_vars.append(weight) mask_vars.append(mask) threshold_vars.append(threshold) self.pruning_vars = list(zip(weight_vars, mask_vars, threshold_vars)) # Add a scalar tracking the number of updates to the wrapped layer. self.pruning_step = self.add_variable( 'pruning_step', shape=[], initializer=tf.keras.initializers.Constant(-1), dtype=tf.int64, trainable=False) def training_step_fn(): return self.pruning_step # Create a pruning object self.pruning_obj = pruning_impl.Pruning( training_step_fn=training_step_fn, pruning_vars=self.pruning_vars, pruning_schedule=self.pruning_schedule, block_size=self.block_size, sparsity_m_by_n=self.sparsity_m_by_n, block_pooling_type=self.block_pooling_type)
def _blockMasking(self, block_size, block_pooling_type, weight, expected_mask): mask = K.ones(weight.get_shape()) threshold = K.zeros([]) # Set up pruning p = pruning_impl.Pruning(pruning_vars=[(weight, mask, threshold)], training_step_fn=self.training_step_fn, pruning_schedule=self.constant_sparsity, block_size=block_size, block_pooling_type=block_pooling_type) _, new_mask = p._maybe_update_block_mask(weight) # Check if the mask is the same size as the weights self.assertAllEqual(new_mask.get_shape(), weight.get_shape()) mask_after_pruning = K.get_value(new_mask) self.assertAllEqual(mask_after_pruning, expected_mask)
def testConstructsMaskAndThresholdCorrectly(self): p = pruning_impl.Pruning( lambda: 0, None, # Sparsity math often returns values with small tolerances. lambda x: (True, 0.200000018), (1, 1), None) # input matrix is [ 1.0, 2.0, ..., 8.0, 9.0, 10.0 ] threshold, mask = p._update_mask(np.arange(1, 11)) self.assertEqual(3, K.get_value(threshold)) self.assertAllEqual( # expected matrix is [ 0.0, 0.0, 1.0, 1.0 ... 1.0 ] np.concatenate((np.zeros(2), np.ones(8))), K.get_value(mask))
def _sparsity_m_by_n_masking(self, weight, m_by_n=(2, 4)): mask = tf.Variable(tf.ones(weight.get_shape()), name="mask") threshold = tf.Variable(1, name="threshold") self.initialize() # Set up pruning p = pruning_impl.Pruning( pruning_vars=[(weight, mask, threshold)], training_step_fn=self.training_step_fn, pruning_schedule=self.constant_sparsity, block_size=(1, 1), block_pooling_type="AVG", sparsity_m_by_n=m_by_n, ) _, new_mask = p._maybe_update_block_mask(weight) return new_mask
def testUpdateSingleMask(self): weight = K.variable(np.linspace(1.0, 100.0, 100), name="weights") mask = K.ones(weight.get_shape()) threshold = K.zeros([]) p = pruning_impl.Pruning(pruning_vars=[(weight, mask, threshold)], training_step_fn=self.training_step_fn, pruning_schedule=self.constant_sparsity, block_size=self.block_size, block_pooling_type=self.block_pooling_type) mask_before_pruning = K.get_value(mask) self.assertAllEqual(np.count_nonzero(mask_before_pruning), 100) if context.executing_eagerly(): p.conditional_mask_update() else: K.get_session().run(p.conditional_mask_update()) mask_after_pruning = K.get_value(mask) self.assertAllEqual(np.count_nonzero(mask_after_pruning), 50)