Ejemplo n.º 1
0
  def testPerLayerBlockSparsity(self):
    param_list = [
        "block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]",
        "block_pooling_function=AVG", "threshold_decay=0.0"
    ]

    test_spec = ",".join(param_list)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

    with tf.variable_scope("layer1"):
      w1 = tf.Variable([[-0.1, 0.1], [-0.2, 0.2]], name="weights")
      pruning.apply_mask(w1)

    with tf.variable_scope("layer2"):
      w2 = tf.Variable([[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]],
                       name="weights")
      pruning.apply_mask(w2)

    sparsity = tf.Variable(0.5, name="sparsity")

    p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
    mask_update_op = p.mask_update_op()
    with self.cached_session() as session:
      tf.global_variables_initializer().run()
      session.run(mask_update_op)
      mask1_eval = session.run(pruning.get_masks()[0])
      mask2_eval = session.run(pruning.get_masks()[1])

      self.assertAllEqual(
          session.run(pruning.get_weight_sparsity()), [0.5, 0.5])

      self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]])
      self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]])
Ejemplo n.º 2
0
  def testMaskedLSTMCell(self):
    expected_num_masks = 1
    expected_num_rows = 2 * self.dim
    expected_num_cols = 4 * self.dim
    with self.cached_session():
      inputs = tf.Variable(tf.random_normal([self.batch_size, self.dim]))
      c = tf.Variable(tf.random_normal([self.batch_size, self.dim]))
      h = tf.Variable(tf.random_normal([self.batch_size, self.dim]))
      state = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(c, h)
      lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
      lstm_cell(inputs, state)
      self.assertEqual(len(pruning.get_masks()), expected_num_masks)
      self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
      self.assertEqual(len(pruning.get_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_old_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_old_old_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_gradients()), expected_num_masks)

      for mask in pruning.get_masks():
        self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
      for weight in pruning.get_weights():
        self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
      for old_weight in pruning.get_old_weights():
        self.assertEqual(old_weight.shape,
                         (expected_num_rows, expected_num_cols))
      for old_old_weight in pruning.get_old_old_weights():
        self.assertEqual(old_old_weight.shape,
                         (expected_num_rows, expected_num_cols))
      for gradient in pruning.get_gradients():
        self.assertEqual(gradient.shape, (expected_num_rows, expected_num_cols))