Ejemplo n.º 1
0
  def testMaskedLSTMCell(self):
    expected_num_masks = 1
    expected_num_rows = 2 * self.dim
    expected_num_cols = 4 * self.dim
    with self.cached_session():
      inputs = tf.Variable(tf.random_normal([self.batch_size, self.dim]))
      c = tf.Variable(tf.random_normal([self.batch_size, self.dim]))
      h = tf.Variable(tf.random_normal([self.batch_size, self.dim]))
      state = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(c, h)
      lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
      lstm_cell(inputs, state)
      self.assertEqual(len(pruning.get_masks()), expected_num_masks)
      self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
      self.assertEqual(len(pruning.get_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_old_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_old_old_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_gradients()), expected_num_masks)

      for mask in pruning.get_masks():
        self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
      for weight in pruning.get_weights():
        self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
      for old_weight in pruning.get_old_weights():
        self.assertEqual(old_weight.shape,
                         (expected_num_rows, expected_num_cols))
      for old_old_weight in pruning.get_old_old_weights():
        self.assertEqual(old_old_weight.shape,
                         (expected_num_rows, expected_num_cols))
      for gradient in pruning.get_gradients():
        self.assertEqual(gradient.shape, (expected_num_rows, expected_num_cols))
Ejemplo n.º 2
0
  def testSecondOrderGradientCalculation(self):
    param_list = [
        "prune_option=second_order_gradient",
        "gradient_decay_rate=0.5",
    ]
    test_spec = ",".join(param_list)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
    tf.logging.info(pruning_hparams)

    w = tf.Variable(tf.linspace(1.0, 10.0, 10), name="weights")
    _ = pruning.apply_mask(w, prune_option="second_order_gradient")

    p = pruning.Pruning(pruning_hparams)
    old_weight_update_op = p.old_weight_update_op()
    old_old_weight_update_op = p.old_old_weight_update_op()
    gradient_update_op = p.gradient_update_op()

    with self.cached_session() as session:
      tf.global_variables_initializer().run()
      session.run(old_weight_update_op)
      session.run(old_old_weight_update_op)
      session.run(tf.assign(w, tf.math.scalar_mul(2.0, w)))
      session.run(gradient_update_op)

      old_weights = pruning.get_old_weights()
      old_old_weights = pruning.get_old_old_weights()
      gradients = pruning.get_gradients()

      old_weight = old_weights[0]
      old_old_weight = old_old_weights[0]
      gradient = gradients[0]
      self.assertAllEqual(
          gradient.eval(),
          tf.math.scalar_mul(0.5,
                             tf.nn.l2_normalize(tf.linspace(1.0, 10.0,
                                                            10))).eval())
      self.assertAllEqual(old_weight.eval(), old_old_weight.eval())