Esempio n. 1
0
  def testMaskedLSTMCell(self):
    expected_num_masks = 1
    expected_num_rows = 2 * self.dim
    expected_num_cols = 4 * self.dim
    with self.cached_session():
      inputs = tf.Variable(tf.random_normal([self.batch_size, self.dim]))
      c = tf.Variable(tf.random_normal([self.batch_size, self.dim]))
      h = tf.Variable(tf.random_normal([self.batch_size, self.dim]))
      state = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(c, h)
      lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
      lstm_cell(inputs, state)
      self.assertEqual(len(pruning.get_masks()), expected_num_masks)
      self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
      self.assertEqual(len(pruning.get_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_old_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_old_old_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_gradients()), expected_num_masks)

      for mask in pruning.get_masks():
        self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
      for weight in pruning.get_weights():
        self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
      for old_weight in pruning.get_old_weights():
        self.assertEqual(old_weight.shape,
                         (expected_num_rows, expected_num_cols))
      for old_old_weight in pruning.get_old_old_weights():
        self.assertEqual(old_old_weight.shape,
                         (expected_num_rows, expected_num_cols))
      for gradient in pruning.get_gradients():
        self.assertEqual(gradient.shape, (expected_num_rows, expected_num_cols))
  def testFirstOrderGradientCalculation(self):
    param_list = [
        "prune_option=first_order_gradient",
        "gradient_decay_rate=0.5",
    ]
    test_spec = ",".join(param_list)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
    tf.logging.info(pruning_hparams)

    w = tf.Variable(tf.linspace(1.0, 10.0, 10), name="weights")
    _ = pruning.apply_mask(w, prune_option="first_order_gradient")

    p = pruning.Pruning(pruning_hparams)
    old_weight_update_op = p.old_weight_update_op()
    gradient_update_op = p.gradient_update_op()

    with self.cached_session() as session:
      tf.global_variables_initializer().run()
      session.run(gradient_update_op)
      session.run(old_weight_update_op)

      weights = pruning.get_weights()
      old_weights = pruning.get_old_weights()
      gradients = pruning.get_gradients()

      weight = weights[0]
      old_weight = old_weights[0]
      gradient = gradients[0]
      self.assertAllEqual(
          gradient.eval(),
          tf.math.scalar_mul(0.5,
                             tf.nn.l2_normalize(tf.linspace(1.0, 10.0,
                                                            10))).eval())
      self.assertAllEqual(weight.eval(), old_weight.eval())
Esempio n. 3
0
    def testFirstOrderGradientBlockMasking(self):
        param_list = [
            "prune_option=first_order_gradient",
            "gradient_decay_rate=0.5",
            "block_height=2",
            "block_width=2",
            "threshold_decay=0",
            "block_pooling_function=AVG",
        ]
        threshold = tf.Variable(0.0, name="threshold")
        sparsity = tf.Variable(0.5, name="sparsity")
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        weights_avg = tf.constant([[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2],
                                   [0.3, 0.3, 0.4, 0.4], [0.3, 0.3, 0.4, 0.4]])
        expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
                         [1., 1., 1., 1.], [1., 1., 1., 1.]]

        w = tf.Variable(weights_avg, name="weights")
        _ = pruning.apply_mask(w, prune_option="first_order_gradient")

        p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
        old_weight_update_op = p.old_weight_update_op()
        gradient_update_op = p.gradient_update_op()

        with self.cached_session() as session:
            tf.global_variables_initializer().run()
            session.run(gradient_update_op)
            session.run(old_weight_update_op)

            weights = pruning.get_weights()
            _ = pruning.get_old_weights()
            gradients = pruning.get_gradients()

            weight = weights[0]
            gradient = gradients[0]

            _, new_mask = p._maybe_update_block_mask(weight, threshold,
                                                     gradient)
            self.assertAllEqual(new_mask.get_shape(), weight.get_shape())
            mask_val = new_mask.eval()
            self.assertAllEqual(mask_val, expected_mask)