def testGroupSpecificBlockSparsity(self): param_list = [ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", "target_sparsity=0.5", "group_sparsity_map=[group1:0.6,group2:0.75]", "group_block_dims_map=[group1:2x2,group2:2x4]", "threshold_decay=0.0", "group_pruning=True", ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) stacked_tensor_1 = pruning_utils.expand_tensor( tf.reshape(tf.linspace(1.0, 100.0, 100), [1, 100]), [2, 2]) stacked_tensor_2 = pruning_utils.expand_tensor( tf.reshape(tf.linspace(1.0, 100.0, 100), [1, 100]), [2, 4]) stacked_tensor_3 = pruning_utils.expand_tensor( tf.reshape(tf.linspace(1.0, 200.0, 100), [1, 100]), [2, 4]) with tf.variable_scope("layer1"): w1 = tf.Variable(stacked_tensor_1, name="weights") _ = pruning.apply_mask_with_group(w1, group_name="group1") with tf.variable_scope("layer2"): w2 = tf.Variable(stacked_tensor_2, name="weights") _ = pruning.apply_mask_with_group(w2, group_name="group2") with tf.variable_scope("layer3"): w3 = tf.Variable(stacked_tensor_2, name="kernel") _ = pruning.apply_mask_with_group(w3, group_name="group2") with tf.variable_scope("layer4"): w4 = tf.Variable(stacked_tensor_3, name="kernel") _ = pruning.apply_mask_with_group(w4, group_name="group2") p = pruning.Pruning(pruning_hparams) mask_update_op = p.conditional_mask_update_op() increment_global_step = tf.assign_add(self.global_step, 1) with self.cached_session() as session: tf.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) session.run(increment_global_step) self.assertAllClose(session.run(pruning.get_weight_sparsity()), [0.6, 0.9, 0.9, 0.45])
def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim): with self.cached_session() as session: tf.global_variables_initializer().run() expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim) kronecker_product = pruning_utils.kronecker_product( tensor, tf.ones(block_dim)) expanded_tensor_val, kronecker_product_val = session.run( [expanded_tensor, kronecker_product]) self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
def _maybe_update_block_mask(self, weights, threshold, gradients=None): """Performs block-granular masking of the weights. Block pruning occurs only if the block_height or block_width is > 1 and if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise pruning occurs. Args: weights: The weight tensor that needs to be masked. threshold: The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold gradients: The gradient tensor that used for salience calculation. Returns: new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step new_mask: A numpy array of the same size and shape as weights containing 0 or 1 to indicate which of the values in weights falls below the threshold Raises: ValueError: if block pooling function is not AVG or MAX """ block_dims = self._get_block_dims(weights.op.name) squeezed_weights = tf.squeeze(weights) if squeezed_weights.get_shape().ndims != 2 or block_dims == [1, 1]: return self._update_mask(weights, threshold, gradients) if (self._spec.prune_option in ('first_order_gradient', 'second_order_gradient') and gradients is None): raise ValueError( 'Gradient based pruning implementation for block sparsity is not supported.' ) for i in range(2): if block_dims[i] == -1: block_dims[i] = squeezed_weights.get_shape()[i] if self._block_pooling_function not in ['AVG', 'MAX']: raise ValueError( 'Unknown pooling function for block sparsity: %s' % self._block_pooling_function) with tf.name_scope(weights.op.name + '_pruning_ops'): abs_weights = tf.abs(squeezed_weights) if gradients is not None: abs_gradients = tf.abs(tf.squeeze(gradients)) pool_window = block_dims pool_fn = pruning_utils.factorized_pool squeeze_axis = None if not self._spec.use_tpu: pool_fn = tf.nn.pool abs_weights = tf.reshape(abs_weights, [ 1, abs_weights.get_shape()[0], abs_weights.get_shape()[1], 1 ]) if gradients is not None: # Reshape gradients to be a rank 4 tensor of shape [1, .., .., 1]. abs_gradients = tf.reshape(abs_gradients, [ 1, gradients.get_shape()[0], gradients.get_shape()[1], 1 ]) squeeze_axis = [0, 3] pooled_weights = pool_fn(abs_weights, window_shape=pool_window, pooling_type=self._block_pooling_function, strides=pool_window, padding='SAME', name=weights.op.name + '_pooled') if gradients is not None: pooled_gradients = pool_fn( abs_gradients, window_shape=pool_window, pooling_type=self._block_pooling_function, strides=pool_window, padding='SAME', name=gradients.op.name + '_pooled') else: pooled_gradients = None if pooled_weights.get_shape().ndims != 2: pooled_weights = tf.squeeze(pooled_weights, axis=squeeze_axis) if gradients is not None and pooled_gradients.get_shape( ).ndims != 2: pooled_gradients = tf.squeeze(pooled_gradients, axis=squeeze_axis) smoothed_threshold, new_mask = self._update_mask( pooled_weights, threshold, pooled_gradients) updated_mask = pruning_utils.expand_tensor(new_mask, block_dims) sliced_mask = tf.slice(updated_mask, [0, 0], [ squeezed_weights.get_shape()[0], squeezed_weights.get_shape()[1] ]) return smoothed_threshold, tf.reshape(sliced_mask, tf.shape(weights))