def testInitWithVariableReuse(self):
   with self.cached_session():
     p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
     p_copy = pruning.Pruning(
         spec=self.pruning_hparams, sparsity=self.sparsity)
     tf.global_variables_initializer().run()
     sparsity = p._sparsity.eval()
     self.assertAlmostEqual(sparsity, 0.5)
     self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval())
 def testConditionalMaskUpdate(self):
   param_list = [
       "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6",
       "nbins=100"
   ]
   test_spec = ",".join(param_list)
   pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
   weights = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights")
   masked_weights = pruning.apply_mask(weights)
   sparsity = tf.Variable(0.00, name="sparsity")
   # Set up pruning
   p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
   p._spec.threshold_decay = 0.0
   mask_update_op = p.conditional_mask_update_op()
   sparsity_val = tf.linspace(0.0, 0.9, 10)
   increment_global_step = tf.assign_add(self.global_step, 1)
   non_zero_count = []
   with self.cached_session() as session:
     tf.global_variables_initializer().run()
     for i in range(10):
       session.run(tf.assign(sparsity, sparsity_val[i]))
       session.run(mask_update_op)
       session.run(increment_global_step)
       non_zero_count.append(np.count_nonzero(masked_weights.eval()))
   # Weights pruned at steps 0,2,4,and,6
   expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
   self.assertAllEqual(expected_non_zero_count, non_zero_count)
  def __init__(self, scope='default_scope', spec=None, global_step=None):
    super(KMeansPruningCompressionOp, self).__init__(scope, spec, global_step)

    pruning_spec = copy.deepcopy(self._spec)
    pruning_spec.prune_option = 'weight'
    self.pruning_obj = pruning.Pruning(
        pruning_spec, global_step=self._global_step)
def get_matrix_compression_object(
        hparams,  # pylint:disable=invalid-name
        global_step=None,
        sparsity=None):
    """Returns a pruning/compression object.

  Args:
    hparams: Pruning spec as defined in pruing.py;
    global_step: A tensorflow variable that is used for scheduling
    pruning/compression;
    sparsity: A tensorflow scalar variable storing the sparsity.

  Returns:
    A Pruning or compression_lib.compression_op.ApplyCompression object.
  """
    if global_step is None:
        train_global_step = tf.train.get_global_step()
        if train_global_step is None:
            global_step = 0
        else:
            global_step = tf.cast(train_global_step, tf.int32)
    if hparams.prune_option in [
            'weight', 'first_order_gradient', 'second_order_gradient'
    ]:
        return pruning.Pruning(hparams, global_step, sparsity)
    else:
        return compression_wrapper.get_apply_compression(
            hparams, global_step=global_step)
  def testWeightSparsityTiebreaker(self):
    param_list = [
        "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100",
        "target_sparsity=0.5",
        "threshold_decay=0.0"
    ]
    test_spec = ",".join(param_list)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

    with tf.variable_scope("layer1"):
      w1 = tf.Variable(np.ones([100], dtype=np.float32),
                       name="weights")
      _ = pruning.apply_mask(w1)

    p = pruning.Pruning(pruning_hparams)
    mask_update_op = p.conditional_mask_update_op()
    increment_global_step = tf.assign_add(self.global_step, 1)

    with self.cached_session() as session:
      tf.global_variables_initializer().run()
      for _ in range(110):
        session.run(mask_update_op)
        session.run(increment_global_step)

      self.assertAllClose(
          session.run(pruning.get_weight_sparsity()), [0.5])
 def testInit(self):
   p = pruning.Pruning(self.pruning_hparams)
   self.assertEqual(p._spec.name, "test")
   self.assertAlmostEqual(p._spec.threshold_decay, 0.9)
   self.assertEqual(p._spec.pruning_frequency, 10)
   self.assertEqual(p._spec.sparsity_function_end_step, 100)
   self.assertAlmostEqual(p._spec.target_sparsity, 0.9)
 def testInitWithExternalSparsity(self):
     with self.cached_session():
         p = pruning.Pruning(spec=self.pruning_hparams,
                             sparsity=self.sparsity)
         tf.global_variables_initializer().run()
         sparsity = p._sparsity.eval()
         self.assertAlmostEqual(sparsity, 0.5)
  def testFirstOrderGradientCalculation(self):
    param_list = [
        "prune_option=first_order_gradient",
        "gradient_decay_rate=0.5",
    ]
    test_spec = ",".join(param_list)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
    tf.logging.info(pruning_hparams)

    w = tf.Variable(tf.linspace(1.0, 10.0, 10), name="weights")
    _ = pruning.apply_mask(w, prune_option="first_order_gradient")

    p = pruning.Pruning(pruning_hparams)
    old_weight_update_op = p.old_weight_update_op()
    gradient_update_op = p.gradient_update_op()

    with self.cached_session() as session:
      tf.global_variables_initializer().run()
      session.run(gradient_update_op)
      session.run(old_weight_update_op)

      weights = pruning.get_weights()
      old_weights = pruning.get_old_weights()
      gradients = pruning.get_gradients()

      weight = weights[0]
      old_weight = old_weights[0]
      gradient = gradients[0]
      self.assertAllEqual(
          gradient.eval(),
          tf.math.scalar_mul(0.5,
                             tf.nn.l2_normalize(tf.linspace(1.0, 10.0,
                                                            10))).eval())
      self.assertAllEqual(weight.eval(), old_weight.eval())
示例#9
0
    def setUp(self):
        super(PruningSpeechUtilsTest, self).setUp()
        # Add global step variable to the graph
        self.global_step = tf.train.get_or_create_global_step()
        # Add sparsity
        self.sparsity = tf.Variable(0.5, name="sparsity")
        # Parse hparams
        self.pruning_hparams = pruning.get_pruning_hparams().parse(
            self.TEST_HPARAMS)

        self.pruning_obj = pruning.Pruning(self.pruning_hparams,
                                           global_step=self.global_step)

        self.compression_obj = pruning_interface.get_matrix_compression_object(
            self.pruning_hparams, global_step=self.global_step)

        def MockWeightParamsFn(shape, init=None, dtype=None):
            if init is None:
                init = MockWeightInit.Constant(0.0)
            if dtype is None:
                dtype = tf.float32
            return {"dtype": dtype, "shape": shape, "init": init}

        self.mock_weight_params_fn = MockWeightParamsFn
        self.mock_lstmobj = MockLSTMCell()
        self.wm_pc = np.zeros((2, 2))
示例#10
0
  def testWeightSpecificSparsity(self):
    param_list = [
        "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100",
        "target_sparsity=0.5",
        "weight_sparsity_map=[layer1:0.6,layer2/weights:0.75,.*kernel:0.6]",
        "threshold_decay=0.0"
    ]
    test_spec = ",".join(param_list)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

    with tf.variable_scope("layer1"):
      w1 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights")
      _ = pruning.apply_mask(w1)
    with tf.variable_scope("layer2"):
      w2 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights")
      _ = pruning.apply_mask(w2)
    with tf.variable_scope("layer3"):
      w3 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="kernel")
      _ = pruning.apply_mask(w3)

    p = pruning.Pruning(pruning_hparams)
    mask_update_op = p.conditional_mask_update_op()
    increment_global_step = tf.assign_add(self.global_step, 1)

    with self.cached_session() as session:
      tf.global_variables_initializer().run()
      for _ in range(110):
        session.run(mask_update_op)
        session.run(increment_global_step)

      self.assertAllClose(
          session.run(pruning.get_weight_sparsity()), [0.6, 0.75, 0.6])
示例#11
0
  def testPerLayerBlockSparsity(self):
    param_list = [
        "block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]",
        "block_pooling_function=AVG", "threshold_decay=0.0"
    ]

    test_spec = ",".join(param_list)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

    with tf.variable_scope("layer1"):
      w1 = tf.Variable([[-0.1, 0.1], [-0.2, 0.2]], name="weights")
      pruning.apply_mask(w1)

    with tf.variable_scope("layer2"):
      w2 = tf.Variable([[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]],
                       name="weights")
      pruning.apply_mask(w2)

    sparsity = tf.Variable(0.5, name="sparsity")

    p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
    mask_update_op = p.mask_update_op()
    with self.cached_session() as session:
      tf.global_variables_initializer().run()
      session.run(mask_update_op)
      mask1_eval = session.run(pruning.get_masks()[0])
      mask2_eval = session.run(pruning.get_masks()[1])

      self.assertAllEqual(
          session.run(pruning.get_weight_sparsity()), [0.5, 0.5])

      self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]])
      self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]])
示例#12
0
  def _prune_model(self, session):
    pruning_hparams = pruning.get_pruning_hparams().parse(self.pruning_spec)
    p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity)
    self.mask_update_op = p.conditional_mask_update_op()

    tf.global_variables_initializer().run()
    for _ in range(20):
      session.run(self.mask_update_op)
      session.run(self.increment_global_step)
示例#13
0
 def testUpdateSingleMask(self):
   with self.cached_session() as session:
     weights = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights")
     masked_weights = pruning.apply_mask(weights)
     sparsity = tf.Variable(0.95, name="sparsity")
     p = pruning.Pruning(sparsity=sparsity)
     p._spec.threshold_decay = 0.0
     mask_update_op = p.mask_update_op()
     tf.global_variables_initializer().run()
     masked_weights_val = masked_weights.eval()
     self.assertAllEqual(np.count_nonzero(masked_weights_val), 100)
     session.run(mask_update_op)
     masked_weights_val = masked_weights.eval()
     self.assertAllEqual(np.count_nonzero(masked_weights_val), 5)
示例#14
0
 def _GetMaskUpdateOp(self):
   """Returns op to update masks and threshold variables for model pruning."""
   p = self.params
   tp = p.train
   mask_update_op = tf.no_op()
   if tp.pruning_hparams_dict:
     assert isinstance(tp.pruning_hparams_dict, dict)
     pruning_hparams = pruning.get_pruning_hparams().override_from_dict(
         tp.pruning_hparams_dict)
     pruning_obj = pruning.Pruning(
         pruning_hparams, global_step=self.global_step)
     pruning_obj.add_pruning_summaries()
     mask_update_op = pruning_obj.conditional_mask_update_op()
   return mask_update_op
示例#15
0
  def _blockMasking(self, hparams, weights, expected_mask):

    threshold = tf.Variable(0.0, name="threshold")
    sparsity = tf.Variable(0.5, name="sparsity")
    test_spec = ",".join(hparams)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

    # Set up pruning
    p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
    with self.cached_session():
      tf.global_variables_initializer().run()
      _, new_mask = p._maybe_update_block_mask(weights, threshold)
      # Check if the mask is the same size as the weights
      self.assertAllEqual(new_mask.get_shape(), weights.get_shape())
      mask_val = new_mask.eval()
      self.assertAllEqual(mask_val, expected_mask)
示例#16
0
    def testGroupSpecificBlockSparsity(self):
        param_list = [
            "begin_pruning_step=1",
            "pruning_frequency=1",
            "end_pruning_step=100",
            "target_sparsity=0.5",
            "group_sparsity_map=[group1:0.6,group2:0.75]",
            "group_block_dims_map=[group1:2x2,group2:2x4]",
            "threshold_decay=0.0",
            "group_pruning=True",
        ]
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        stacked_tensor_1 = pruning_utils.expand_tensor(
            tf.reshape(tf.linspace(1.0, 100.0, 100), [1, 100]), [2, 2])
        stacked_tensor_2 = pruning_utils.expand_tensor(
            tf.reshape(tf.linspace(1.0, 100.0, 100), [1, 100]), [2, 4])
        stacked_tensor_3 = pruning_utils.expand_tensor(
            tf.reshape(tf.linspace(1.0, 200.0, 100), [1, 100]), [2, 4])

        with tf.variable_scope("layer1"):
            w1 = tf.Variable(stacked_tensor_1, name="weights")
            _ = pruning.apply_mask_with_group(w1, group_name="group1")
        with tf.variable_scope("layer2"):
            w2 = tf.Variable(stacked_tensor_2, name="weights")
            _ = pruning.apply_mask_with_group(w2, group_name="group2")
        with tf.variable_scope("layer3"):
            w3 = tf.Variable(stacked_tensor_2, name="kernel")
            _ = pruning.apply_mask_with_group(w3, group_name="group2")
        with tf.variable_scope("layer4"):
            w4 = tf.Variable(stacked_tensor_3, name="kernel")
            _ = pruning.apply_mask_with_group(w4, group_name="group2")

        p = pruning.Pruning(pruning_hparams)
        mask_update_op = p.conditional_mask_update_op()
        increment_global_step = tf.assign_add(self.global_step, 1)

        with self.cached_session() as session:
            tf.global_variables_initializer().run()
            for _ in range(110):
                session.run(mask_update_op)
                session.run(increment_global_step)

            self.assertAllClose(session.run(pruning.get_weight_sparsity()),
                                [0.6, 0.9, 0.9, 0.45])
示例#17
0
 def testPartitionedVariableMasking(self):
   partitioner = tf.variable_axis_size_partitioner(40)
   with self.cached_session() as session:
     with tf.variable_scope("", partitioner=partitioner):
       sparsity = tf.Variable(0.5, name="Sparsity")
       weights = tf.get_variable(
           "weights", initializer=tf.linspace(1.0, 100.0, 100))
       masked_weights = pruning.apply_mask(
           weights, scope=tf.get_variable_scope())
     p = pruning.Pruning(sparsity=sparsity)
     p._spec.threshold_decay = 0.0
     mask_update_op = p.mask_update_op()
     tf.global_variables_initializer().run()
     masked_weights_val = masked_weights.eval()
     session.run(mask_update_op)
     masked_weights_val = masked_weights.eval()
     self.assertAllEqual(np.count_nonzero(masked_weights_val), 50)
示例#18
0
    def Setup(cls, pruning_hparams_dict, global_step):  # pylint:disable=invalid-name
        """Set up the pruning op with pruning hyperparameters and global step.

    Args:
      pruning_hparams_dict: a dict containing pruning hyperparameters;
      global_step: global step in TensorFlow.
    """
        if cls._pruning_obj is not None:
            pass
        assert pruning_hparams_dict is not None
        assert isinstance(pruning_hparams_dict, dict)
        cls._pruning_hparams_dict = pruning_hparams_dict
        cls._global_step = global_step
        cls._pruning_hparams = pruning.get_pruning_hparams(
        ).override_from_dict(pruning_hparams_dict)
        cls._pruning_obj = pruning.Pruning(spec=cls._pruning_hparams,
                                           global_step=global_step)
示例#19
0
    def testFirstOrderGradientBlockMasking(self):
        param_list = [
            "prune_option=first_order_gradient",
            "gradient_decay_rate=0.5",
            "block_height=2",
            "block_width=2",
            "threshold_decay=0",
            "block_pooling_function=AVG",
        ]
        threshold = tf.Variable(0.0, name="threshold")
        sparsity = tf.Variable(0.5, name="sparsity")
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        weights_avg = tf.constant([[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2],
                                   [0.3, 0.3, 0.4, 0.4], [0.3, 0.3, 0.4, 0.4]])
        expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
                         [1., 1., 1., 1.], [1., 1., 1., 1.]]

        w = tf.Variable(weights_avg, name="weights")
        _ = pruning.apply_mask(w, prune_option="first_order_gradient")

        p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
        old_weight_update_op = p.old_weight_update_op()
        gradient_update_op = p.gradient_update_op()

        with self.cached_session() as session:
            tf.global_variables_initializer().run()
            session.run(gradient_update_op)
            session.run(old_weight_update_op)

            weights = pruning.get_weights()
            _ = pruning.get_old_weights()
            gradients = pruning.get_gradients()

            weight = weights[0]
            gradient = gradients[0]

            _, new_mask = p._maybe_update_block_mask(weight, threshold,
                                                     gradient)
            self.assertAllEqual(new_mask.get_shape(), weight.get_shape())
            mask_val = new_mask.eval()
            self.assertAllEqual(mask_val, expected_mask)
示例#20
0
    def _sparsity_m_by_n_masking(self, weight, block_size=4, sparsity=0.5):
        block_sparse_param = "block_width=" + str(block_size)
        param_list = [
            "target_sparsity=0.5",
            "intra_block_sparsity=True",
            block_sparse_param,
        ]
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
        sparsity = tf.Variable(sparsity, name="sparsity")

        p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
        mask_update_op = p.conditional_mask_update_op()

        with self.cached_session() as session:
            tf.global_variables_initializer().run()
            session.run(mask_update_op)
            _, new_mask = p._maybe_update_block_mask(weight, block_size)
            return new_mask
示例#21
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = contrib_framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Parse pruning hyperparameters
        pruning_hparams = pruning.get_pruning_hparams().parse(
            FLAGS.pruning_hparams)

        # Create a pruning object using the pruning hyperparameters
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

        # Use the pruning_obj to add ops to the training graph to update the masks
        # The conditional_mask_update_op will update the masks only when the
        # training step is in [begin_pruning_step, end_pruning_step] specified in
        # the pruning spec proto
        mask_update_op = pruning_obj.conditional_mask_update_op()

        # Use the pruning_obj to add summaries to the graph to track the sparsity
        # of each of the layers
        pruning_obj.add_pruning_summaries()

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1

            def before_run(self, run_context):
                self._step += 1
                self._start_time = time.time()
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                duration = time.time() - self._start_time
                loss_value = run_values.results
                if self._step % 10 == 0:
                    num_examples_per_step = 128
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str %
                          (datetime.datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
                # Update the masks
                mon_sess.run(mask_update_op)