def testInitWithVariableReuse(self): with self.cached_session(): p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) p_copy = pruning.Pruning( spec=self.pruning_hparams, sparsity=self.sparsity) tf.global_variables_initializer().run() sparsity = p._sparsity.eval() self.assertAlmostEqual(sparsity, 0.5) self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval())
def testConditionalMaskUpdate(self): param_list = [ "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6", "nbins=100" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) weights = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) sparsity = tf.Variable(0.00, name="sparsity") # Set up pruning p = pruning.Pruning(pruning_hparams, sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.conditional_mask_update_op() sparsity_val = tf.linspace(0.0, 0.9, 10) increment_global_step = tf.assign_add(self.global_step, 1) non_zero_count = [] with self.cached_session() as session: tf.global_variables_initializer().run() for i in range(10): session.run(tf.assign(sparsity, sparsity_val[i])) session.run(mask_update_op) session.run(increment_global_step) non_zero_count.append(np.count_nonzero(masked_weights.eval())) # Weights pruned at steps 0,2,4,and,6 expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] self.assertAllEqual(expected_non_zero_count, non_zero_count)
def __init__(self, scope='default_scope', spec=None, global_step=None): super(KMeansPruningCompressionOp, self).__init__(scope, spec, global_step) pruning_spec = copy.deepcopy(self._spec) pruning_spec.prune_option = 'weight' self.pruning_obj = pruning.Pruning( pruning_spec, global_step=self._global_step)
def get_matrix_compression_object( hparams, # pylint:disable=invalid-name global_step=None, sparsity=None): """Returns a pruning/compression object. Args: hparams: Pruning spec as defined in pruing.py; global_step: A tensorflow variable that is used for scheduling pruning/compression; sparsity: A tensorflow scalar variable storing the sparsity. Returns: A Pruning or compression_lib.compression_op.ApplyCompression object. """ if global_step is None: train_global_step = tf.train.get_global_step() if train_global_step is None: global_step = 0 else: global_step = tf.cast(train_global_step, tf.int32) if hparams.prune_option in [ 'weight', 'first_order_gradient', 'second_order_gradient' ]: return pruning.Pruning(hparams, global_step, sparsity) else: return compression_wrapper.get_apply_compression( hparams, global_step=global_step)
def testWeightSparsityTiebreaker(self): param_list = [ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", "target_sparsity=0.5", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with tf.variable_scope("layer1"): w1 = tf.Variable(np.ones([100], dtype=np.float32), name="weights") _ = pruning.apply_mask(w1) p = pruning.Pruning(pruning_hparams) mask_update_op = p.conditional_mask_update_op() increment_global_step = tf.assign_add(self.global_step, 1) with self.cached_session() as session: tf.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) session.run(increment_global_step) self.assertAllClose( session.run(pruning.get_weight_sparsity()), [0.5])
def testInit(self): p = pruning.Pruning(self.pruning_hparams) self.assertEqual(p._spec.name, "test") self.assertAlmostEqual(p._spec.threshold_decay, 0.9) self.assertEqual(p._spec.pruning_frequency, 10) self.assertEqual(p._spec.sparsity_function_end_step, 100) self.assertAlmostEqual(p._spec.target_sparsity, 0.9)
def testInitWithExternalSparsity(self): with self.cached_session(): p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) tf.global_variables_initializer().run() sparsity = p._sparsity.eval() self.assertAlmostEqual(sparsity, 0.5)
def testFirstOrderGradientCalculation(self): param_list = [ "prune_option=first_order_gradient", "gradient_decay_rate=0.5", ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) tf.logging.info(pruning_hparams) w = tf.Variable(tf.linspace(1.0, 10.0, 10), name="weights") _ = pruning.apply_mask(w, prune_option="first_order_gradient") p = pruning.Pruning(pruning_hparams) old_weight_update_op = p.old_weight_update_op() gradient_update_op = p.gradient_update_op() with self.cached_session() as session: tf.global_variables_initializer().run() session.run(gradient_update_op) session.run(old_weight_update_op) weights = pruning.get_weights() old_weights = pruning.get_old_weights() gradients = pruning.get_gradients() weight = weights[0] old_weight = old_weights[0] gradient = gradients[0] self.assertAllEqual( gradient.eval(), tf.math.scalar_mul(0.5, tf.nn.l2_normalize(tf.linspace(1.0, 10.0, 10))).eval()) self.assertAllEqual(weight.eval(), old_weight.eval())
def setUp(self): super(PruningSpeechUtilsTest, self).setUp() # Add global step variable to the graph self.global_step = tf.train.get_or_create_global_step() # Add sparsity self.sparsity = tf.Variable(0.5, name="sparsity") # Parse hparams self.pruning_hparams = pruning.get_pruning_hparams().parse( self.TEST_HPARAMS) self.pruning_obj = pruning.Pruning(self.pruning_hparams, global_step=self.global_step) self.compression_obj = pruning_interface.get_matrix_compression_object( self.pruning_hparams, global_step=self.global_step) def MockWeightParamsFn(shape, init=None, dtype=None): if init is None: init = MockWeightInit.Constant(0.0) if dtype is None: dtype = tf.float32 return {"dtype": dtype, "shape": shape, "init": init} self.mock_weight_params_fn = MockWeightParamsFn self.mock_lstmobj = MockLSTMCell() self.wm_pc = np.zeros((2, 2))
def testWeightSpecificSparsity(self): param_list = [ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", "target_sparsity=0.5", "weight_sparsity_map=[layer1:0.6,layer2/weights:0.75,.*kernel:0.6]", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with tf.variable_scope("layer1"): w1 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w1) with tf.variable_scope("layer2"): w2 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w2) with tf.variable_scope("layer3"): w3 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="kernel") _ = pruning.apply_mask(w3) p = pruning.Pruning(pruning_hparams) mask_update_op = p.conditional_mask_update_op() increment_global_step = tf.assign_add(self.global_step, 1) with self.cached_session() as session: tf.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) session.run(increment_global_step) self.assertAllClose( session.run(pruning.get_weight_sparsity()), [0.6, 0.75, 0.6])
def testPerLayerBlockSparsity(self): param_list = [ "block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]", "block_pooling_function=AVG", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with tf.variable_scope("layer1"): w1 = tf.Variable([[-0.1, 0.1], [-0.2, 0.2]], name="weights") pruning.apply_mask(w1) with tf.variable_scope("layer2"): w2 = tf.Variable([[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]], name="weights") pruning.apply_mask(w2) sparsity = tf.Variable(0.5, name="sparsity") p = pruning.Pruning(pruning_hparams, sparsity=sparsity) mask_update_op = p.mask_update_op() with self.cached_session() as session: tf.global_variables_initializer().run() session.run(mask_update_op) mask1_eval = session.run(pruning.get_masks()[0]) mask2_eval = session.run(pruning.get_masks()[1]) self.assertAllEqual( session.run(pruning.get_weight_sparsity()), [0.5, 0.5]) self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]]) self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]])
def _prune_model(self, session): pruning_hparams = pruning.get_pruning_hparams().parse(self.pruning_spec) p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity) self.mask_update_op = p.conditional_mask_update_op() tf.global_variables_initializer().run() for _ in range(20): session.run(self.mask_update_op) session.run(self.increment_global_step)
def testUpdateSingleMask(self): with self.cached_session() as session: weights = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) sparsity = tf.Variable(0.95, name="sparsity") p = pruning.Pruning(sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.mask_update_op() tf.global_variables_initializer().run() masked_weights_val = masked_weights.eval() self.assertAllEqual(np.count_nonzero(masked_weights_val), 100) session.run(mask_update_op) masked_weights_val = masked_weights.eval() self.assertAllEqual(np.count_nonzero(masked_weights_val), 5)
def _GetMaskUpdateOp(self): """Returns op to update masks and threshold variables for model pruning.""" p = self.params tp = p.train mask_update_op = tf.no_op() if tp.pruning_hparams_dict: assert isinstance(tp.pruning_hparams_dict, dict) pruning_hparams = pruning.get_pruning_hparams().override_from_dict( tp.pruning_hparams_dict) pruning_obj = pruning.Pruning( pruning_hparams, global_step=self.global_step) pruning_obj.add_pruning_summaries() mask_update_op = pruning_obj.conditional_mask_update_op() return mask_update_op
def _blockMasking(self, hparams, weights, expected_mask): threshold = tf.Variable(0.0, name="threshold") sparsity = tf.Variable(0.5, name="sparsity") test_spec = ",".join(hparams) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) # Set up pruning p = pruning.Pruning(pruning_hparams, sparsity=sparsity) with self.cached_session(): tf.global_variables_initializer().run() _, new_mask = p._maybe_update_block_mask(weights, threshold) # Check if the mask is the same size as the weights self.assertAllEqual(new_mask.get_shape(), weights.get_shape()) mask_val = new_mask.eval() self.assertAllEqual(mask_val, expected_mask)
def testGroupSpecificBlockSparsity(self): param_list = [ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", "target_sparsity=0.5", "group_sparsity_map=[group1:0.6,group2:0.75]", "group_block_dims_map=[group1:2x2,group2:2x4]", "threshold_decay=0.0", "group_pruning=True", ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) stacked_tensor_1 = pruning_utils.expand_tensor( tf.reshape(tf.linspace(1.0, 100.0, 100), [1, 100]), [2, 2]) stacked_tensor_2 = pruning_utils.expand_tensor( tf.reshape(tf.linspace(1.0, 100.0, 100), [1, 100]), [2, 4]) stacked_tensor_3 = pruning_utils.expand_tensor( tf.reshape(tf.linspace(1.0, 200.0, 100), [1, 100]), [2, 4]) with tf.variable_scope("layer1"): w1 = tf.Variable(stacked_tensor_1, name="weights") _ = pruning.apply_mask_with_group(w1, group_name="group1") with tf.variable_scope("layer2"): w2 = tf.Variable(stacked_tensor_2, name="weights") _ = pruning.apply_mask_with_group(w2, group_name="group2") with tf.variable_scope("layer3"): w3 = tf.Variable(stacked_tensor_2, name="kernel") _ = pruning.apply_mask_with_group(w3, group_name="group2") with tf.variable_scope("layer4"): w4 = tf.Variable(stacked_tensor_3, name="kernel") _ = pruning.apply_mask_with_group(w4, group_name="group2") p = pruning.Pruning(pruning_hparams) mask_update_op = p.conditional_mask_update_op() increment_global_step = tf.assign_add(self.global_step, 1) with self.cached_session() as session: tf.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) session.run(increment_global_step) self.assertAllClose(session.run(pruning.get_weight_sparsity()), [0.6, 0.9, 0.9, 0.45])
def testPartitionedVariableMasking(self): partitioner = tf.variable_axis_size_partitioner(40) with self.cached_session() as session: with tf.variable_scope("", partitioner=partitioner): sparsity = tf.Variable(0.5, name="Sparsity") weights = tf.get_variable( "weights", initializer=tf.linspace(1.0, 100.0, 100)) masked_weights = pruning.apply_mask( weights, scope=tf.get_variable_scope()) p = pruning.Pruning(sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.mask_update_op() tf.global_variables_initializer().run() masked_weights_val = masked_weights.eval() session.run(mask_update_op) masked_weights_val = masked_weights.eval() self.assertAllEqual(np.count_nonzero(masked_weights_val), 50)
def Setup(cls, pruning_hparams_dict, global_step): # pylint:disable=invalid-name """Set up the pruning op with pruning hyperparameters and global step. Args: pruning_hparams_dict: a dict containing pruning hyperparameters; global_step: global step in TensorFlow. """ if cls._pruning_obj is not None: pass assert pruning_hparams_dict is not None assert isinstance(pruning_hparams_dict, dict) cls._pruning_hparams_dict = pruning_hparams_dict cls._global_step = global_step cls._pruning_hparams = pruning.get_pruning_hparams( ).override_from_dict(pruning_hparams_dict) cls._pruning_obj = pruning.Pruning(spec=cls._pruning_hparams, global_step=global_step)
def testFirstOrderGradientBlockMasking(self): param_list = [ "prune_option=first_order_gradient", "gradient_decay_rate=0.5", "block_height=2", "block_width=2", "threshold_decay=0", "block_pooling_function=AVG", ] threshold = tf.Variable(0.0, name="threshold") sparsity = tf.Variable(0.5, name="sparsity") test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) weights_avg = tf.constant([[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], [0.3, 0.3, 0.4, 0.4]]) expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [1., 1., 1., 1.], [1., 1., 1., 1.]] w = tf.Variable(weights_avg, name="weights") _ = pruning.apply_mask(w, prune_option="first_order_gradient") p = pruning.Pruning(pruning_hparams, sparsity=sparsity) old_weight_update_op = p.old_weight_update_op() gradient_update_op = p.gradient_update_op() with self.cached_session() as session: tf.global_variables_initializer().run() session.run(gradient_update_op) session.run(old_weight_update_op) weights = pruning.get_weights() _ = pruning.get_old_weights() gradients = pruning.get_gradients() weight = weights[0] gradient = gradients[0] _, new_mask = p._maybe_update_block_mask(weight, threshold, gradient) self.assertAllEqual(new_mask.get_shape(), weight.get_shape()) mask_val = new_mask.eval() self.assertAllEqual(mask_val, expected_mask)
def _sparsity_m_by_n_masking(self, weight, block_size=4, sparsity=0.5): block_sparse_param = "block_width=" + str(block_size) param_list = [ "target_sparsity=0.5", "intra_block_sparsity=True", block_sparse_param, ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) sparsity = tf.Variable(sparsity, name="sparsity") p = pruning.Pruning(pruning_hparams, sparsity=sparsity) mask_update_op = p.conditional_mask_update_op() with self.cached_session() as session: tf.global_variables_initializer().run() session.run(mask_update_op) _, new_mask = p._maybe_update_block_mask(weight, block_size) return new_mask
def train(): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = contrib_framework.get_or_create_global_step() # Get images and labels for CIFAR-10. images, labels = cifar10.distorted_inputs() # Build a Graph that computes the logits predictions from the # inference model. logits = cifar10.inference(images) # Calculate loss. loss = cifar10.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse( FLAGS.pruning_hparams) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) # Use the pruning_obj to add ops to the training graph to update the masks # The conditional_mask_update_op will update the masks only when the # training step is in [begin_pruning_step, end_pruning_step] specified in # the pruning spec proto mask_update_op = pruning_obj.conditional_mask_update_op() # Use the pruning_obj to add summaries to the graph to track the sparsity # of each of the layers pruning_obj.add_pruning_summaries() class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 def before_run(self, run_context): self._step += 1 self._start_time = time.time() return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): duration = time.time() - self._start_time loss_value = run_values.results if self._step % 10 == 0: num_examples_per_step = 128 examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[ tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook() ], config=tf.ConfigProto(log_device_placement=FLAGS. log_device_placement)) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op) # Update the masks mon_sess.run(mask_update_op)