def testPerLayerBlockSparsity(self): param_list = [ "block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]", "block_pooling_function=AVG", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with tf.variable_scope("layer1"): w1 = tf.Variable([[-0.1, 0.1], [-0.2, 0.2]], name="weights") pruning.apply_mask(w1) with tf.variable_scope("layer2"): w2 = tf.Variable([[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]], name="weights") pruning.apply_mask(w2) sparsity = tf.Variable(0.5, name="sparsity") p = pruning.Pruning(pruning_hparams, sparsity=sparsity) mask_update_op = p.mask_update_op() with self.cached_session() as session: tf.global_variables_initializer().run() session.run(mask_update_op) mask1_eval = session.run(pruning.get_masks()[0]) mask2_eval = session.run(pruning.get_masks()[1]) self.assertAllEqual( session.run(pruning.get_weight_sparsity()), [0.5, 0.5]) self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]]) self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]])
def testMaskedLSTMCell(self): expected_num_masks = 1 expected_num_rows = 2 * self.dim expected_num_cols = 4 * self.dim with self.cached_session(): inputs = tf.Variable(tf.random_normal([self.batch_size, self.dim])) c = tf.Variable(tf.random_normal([self.batch_size, self.dim])) h = tf.Variable(tf.random_normal([self.batch_size, self.dim])) state = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(c, h) lstm_cell = rnn_cells.MaskedLSTMCell(self.dim) lstm_cell(inputs, state) self.assertEqual(len(pruning.get_masks()), expected_num_masks) self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) self.assertEqual(len(pruning.get_weights()), expected_num_masks) self.assertEqual(len(pruning.get_old_weights()), expected_num_masks) self.assertEqual(len(pruning.get_old_old_weights()), expected_num_masks) self.assertEqual(len(pruning.get_gradients()), expected_num_masks) for mask in pruning.get_masks(): self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) for weight in pruning.get_weights(): self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols)) for old_weight in pruning.get_old_weights(): self.assertEqual(old_weight.shape, (expected_num_rows, expected_num_cols)) for old_old_weight in pruning.get_old_old_weights(): self.assertEqual(old_old_weight.shape, (expected_num_rows, expected_num_cols)) for gradient in pruning.get_gradients(): self.assertEqual(gradient.shape, (expected_num_rows, expected_num_cols))