コード例 #1
0
    def _prune_model(self, session):
        pruning_hparams = pruning.get_pruning_hparams().parse(
            self.pruning_spec)
        p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity)
        self.mask_update_op = p.conditional_mask_update_op()

        tf.global_variables_initializer().run()
        for _ in range(20):
            session.run(self.mask_update_op)
            session.run(self.increment_global_step)
コード例 #2
0
 def ApplyPruning(cls, pruning_hparams_dict, lstmobj, weight_name, wm_pc,  # pylint:disable=invalid-name
                  dtype, scope):
   if not cls._pruning_obj:
     cls.Setup(pruning_hparams_dict, global_step=py_utils.GetGlobalStep())
   compression_op_spec = pruning.get_pruning_hparams().override_from_dict(
       pruning_hparams_dict)
   return apply_customized_lstm_matrix_compression(cls._pruning_obj,
                                                   py_utils.WeightParams,
                                                   py_utils.WeightInit,
                                                   lstmobj, weight_name,
                                                   wm_pc.shape, dtype, scope,
                                                   compression_op_spec)
コード例 #3
0
 def _GetMaskUpdateOp(self):
   """Returns op to update masks and threshold variables for model pruning."""
   p = self.params
   tp = p.train
   mask_update_op = tf.no_op()
   if tp.pruning_hparams_dict:
     assert isinstance(tp.pruning_hparams_dict, dict)
     pruning_hparams = pruning.get_pruning_hparams().override_from_dict(
         tp.pruning_hparams_dict)
     pruning_obj = pruning.Pruning(
         pruning_hparams, global_step=self.global_step)
     pruning_obj.add_pruning_summaries()
     mask_update_op = pruning_obj.conditional_mask_update_op()
   return mask_update_op
コード例 #4
0
  def _blockMasking(self, hparams, weights, expected_mask):

    threshold = tf.Variable(0.0, name="threshold")
    sparsity = tf.Variable(0.5, name="sparsity")
    test_spec = ",".join(hparams)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

    # Set up pruning
    p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
    with self.cached_session():
      tf.global_variables_initializer().run()
      _, new_mask = p._maybe_update_block_mask(weights, threshold)
      # Check if the mask is the same size as the weights
      self.assertAllEqual(new_mask.get_shape(), weights.get_shape())
      mask_val = new_mask.eval()
      self.assertAllEqual(mask_val, expected_mask)
コード例 #5
0
    def testGroupSpecificBlockSparsity(self):
        param_list = [
            "begin_pruning_step=1",
            "pruning_frequency=1",
            "end_pruning_step=100",
            "target_sparsity=0.5",
            "group_sparsity_map=[group1:0.6,group2:0.75]",
            "group_block_dims_map=[group1:2x2,group2:2x4]",
            "threshold_decay=0.0",
            "group_pruning=True",
        ]
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        stacked_tensor_1 = pruning_utils.expand_tensor(
            tf.reshape(tf.linspace(1.0, 100.0, 100), [1, 100]), [2, 2])
        stacked_tensor_2 = pruning_utils.expand_tensor(
            tf.reshape(tf.linspace(1.0, 100.0, 100), [1, 100]), [2, 4])
        stacked_tensor_3 = pruning_utils.expand_tensor(
            tf.reshape(tf.linspace(1.0, 200.0, 100), [1, 100]), [2, 4])

        with tf.variable_scope("layer1"):
            w1 = tf.Variable(stacked_tensor_1, name="weights")
            _ = pruning.apply_mask_with_group(w1, group_name="group1")
        with tf.variable_scope("layer2"):
            w2 = tf.Variable(stacked_tensor_2, name="weights")
            _ = pruning.apply_mask_with_group(w2, group_name="group2")
        with tf.variable_scope("layer3"):
            w3 = tf.Variable(stacked_tensor_2, name="kernel")
            _ = pruning.apply_mask_with_group(w3, group_name="group2")
        with tf.variable_scope("layer4"):
            w4 = tf.Variable(stacked_tensor_3, name="kernel")
            _ = pruning.apply_mask_with_group(w4, group_name="group2")

        p = pruning.Pruning(pruning_hparams)
        mask_update_op = p.conditional_mask_update_op()
        increment_global_step = tf.assign_add(self.global_step, 1)

        with self.cached_session() as session:
            tf.global_variables_initializer().run()
            for _ in range(110):
                session.run(mask_update_op)
                session.run(increment_global_step)

            self.assertAllClose(session.run(pruning.get_weight_sparsity()),
                                [0.6, 0.9, 0.9, 0.45])
コード例 #6
0
    def Setup(cls, pruning_hparams_dict, global_step):  # pylint:disable=invalid-name
        """Set up the pruning op with pruning hyperparameters and global step.

    Args:
      pruning_hparams_dict: a dict containing pruning hyperparameters;
      global_step: global step in TensorFlow.
    """
        if cls._pruning_obj is not None:
            pass
        assert pruning_hparams_dict is not None
        assert isinstance(pruning_hparams_dict, dict)
        cls._pruning_hparams_dict = pruning_hparams_dict
        cls._global_step = global_step
        cls._pruning_hparams = pruning.get_pruning_hparams(
        ).override_from_dict(pruning_hparams_dict)
        cls._pruning_obj = pruning.Pruning(spec=cls._pruning_hparams,
                                           global_step=global_step)
コード例 #7
0
    def testFirstOrderGradientBlockMasking(self):
        param_list = [
            "prune_option=first_order_gradient",
            "gradient_decay_rate=0.5",
            "block_height=2",
            "block_width=2",
            "threshold_decay=0",
            "block_pooling_function=AVG",
        ]
        threshold = tf.Variable(0.0, name="threshold")
        sparsity = tf.Variable(0.5, name="sparsity")
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        weights_avg = tf.constant([[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2],
                                   [0.3, 0.3, 0.4, 0.4], [0.3, 0.3, 0.4, 0.4]])
        expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
                         [1., 1., 1., 1.], [1., 1., 1., 1.]]

        w = tf.Variable(weights_avg, name="weights")
        _ = pruning.apply_mask(w, prune_option="first_order_gradient")

        p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
        old_weight_update_op = p.old_weight_update_op()
        gradient_update_op = p.gradient_update_op()

        with self.cached_session() as session:
            tf.global_variables_initializer().run()
            session.run(gradient_update_op)
            session.run(old_weight_update_op)

            weights = pruning.get_weights()
            _ = pruning.get_old_weights()
            gradients = pruning.get_gradients()

            weight = weights[0]
            gradient = gradients[0]

            _, new_mask = p._maybe_update_block_mask(weight, threshold,
                                                     gradient)
            self.assertAllEqual(new_mask.get_shape(), weight.get_shape())
            mask_val = new_mask.eval()
            self.assertAllEqual(mask_val, expected_mask)
コード例 #8
0
    def _sparsity_m_by_n_masking(self, weight, block_size=4, sparsity=0.5):
        block_sparse_param = "block_width=" + str(block_size)
        param_list = [
            "target_sparsity=0.5",
            "intra_block_sparsity=True",
            block_sparse_param,
        ]
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
        sparsity = tf.Variable(sparsity, name="sparsity")

        p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
        mask_update_op = p.conditional_mask_update_op()

        with self.cached_session() as session:
            tf.global_variables_initializer().run()
            session.run(mask_update_op)
            _, new_mask = p._maybe_update_block_mask(weight, block_size)
            return new_mask
コード例 #9
0
  def setUp(self):
    super(PruningSpeechUtilsTest, self).setUp()
    # Add global step variable to the graph
    self.global_step = tf.train.get_or_create_global_step()
    # Add sparsity
    self.sparsity = tf.Variable(0.5, name="sparsity")
    # Parse hparams
    self.pruning_hparams = pruning.get_pruning_hparams().parse(
        self.TEST_HPARAMS)

    self.pruning_obj = pruning.Pruning(
        self.pruning_hparams, global_step=self.global_step)

    def MockWeightParamsFn(shape, init=None, dtype=None):
      if init is None:
        init = MockWeightInit.Constant(0.0)
      if dtype is None:
        dtype = tf.float32
      return {"dtype": dtype, "shape": shape, "init": init}

    self.mock_weight_params_fn = MockWeightParamsFn
    self.mock_lstmobj = MockLSTMCell()
    self.wm_pc = np.zeros((2, 2))
コード例 #10
0
  def testSecondOrderGradientCalculation(self):
    param_list = [
        "prune_option=second_order_gradient",
        "gradient_decay_rate=0.5",
    ]
    test_spec = ",".join(param_list)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
    tf.logging.info(pruning_hparams)

    w = tf.Variable(tf.linspace(1.0, 10.0, 10), name="weights")
    _ = pruning.apply_mask(w, prune_option="second_order_gradient")

    p = pruning.Pruning(pruning_hparams)
    old_weight_update_op = p.old_weight_update_op()
    old_old_weight_update_op = p.old_old_weight_update_op()
    gradient_update_op = p.gradient_update_op()

    with self.cached_session() as session:
      tf.global_variables_initializer().run()
      session.run(old_weight_update_op)
      session.run(old_old_weight_update_op)
      session.run(tf.assign(w, tf.math.scalar_mul(2.0, w)))
      session.run(gradient_update_op)

      old_weights = pruning.get_old_weights()
      old_old_weights = pruning.get_old_old_weights()
      gradients = pruning.get_gradients()

      old_weight = old_weights[0]
      old_old_weight = old_old_weights[0]
      gradient = gradients[0]
      self.assertAllEqual(
          gradient.eval(),
          tf.math.scalar_mul(0.5,
                             tf.nn.l2_normalize(tf.linspace(1.0, 10.0,
                                                            10))).eval())
      self.assertAllEqual(old_weight.eval(), old_old_weight.eval())
コード例 #11
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = contrib_framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Parse pruning hyperparameters
        pruning_hparams = pruning.get_pruning_hparams().parse(
            FLAGS.pruning_hparams)

        # Create a pruning object using the pruning hyperparameters
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

        # Use the pruning_obj to add ops to the training graph to update the masks
        # The conditional_mask_update_op will update the masks only when the
        # training step is in [begin_pruning_step, end_pruning_step] specified in
        # the pruning spec proto
        mask_update_op = pruning_obj.conditional_mask_update_op()

        # Use the pruning_obj to add summaries to the graph to track the sparsity
        # of each of the layers
        pruning_obj.add_pruning_summaries()

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1

            def before_run(self, run_context):
                self._step += 1
                self._start_time = time.time()
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                duration = time.time() - self._start_time
                loss_value = run_values.results
                if self._step % 10 == 0:
                    num_examples_per_step = 128
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str %
                          (datetime.datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
                # Update the masks
                mon_sess.run(mask_update_op)