def testMaskedLSTMCell(self): expected_num_masks = 1 expected_num_rows = 2 * self.dim expected_num_cols = 4 * self.dim with self.cached_session(): inputs = tf.Variable(tf.random_normal([self.batch_size, self.dim])) c = tf.Variable(tf.random_normal([self.batch_size, self.dim])) h = tf.Variable(tf.random_normal([self.batch_size, self.dim])) state = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(c, h) lstm_cell = rnn_cells.MaskedLSTMCell(self.dim) lstm_cell(inputs, state) self.assertEqual(len(pruning.get_masks()), expected_num_masks) self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) self.assertEqual(len(pruning.get_weights()), expected_num_masks) self.assertEqual(len(pruning.get_old_weights()), expected_num_masks) self.assertEqual(len(pruning.get_old_old_weights()), expected_num_masks) self.assertEqual(len(pruning.get_gradients()), expected_num_masks) for mask in pruning.get_masks(): self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) for weight in pruning.get_weights(): self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols)) for old_weight in pruning.get_old_weights(): self.assertEqual(old_weight.shape, (expected_num_rows, expected_num_cols)) for old_old_weight in pruning.get_old_old_weights(): self.assertEqual(old_old_weight.shape, (expected_num_rows, expected_num_cols)) for gradient in pruning.get_gradients(): self.assertEqual(gradient.shape, (expected_num_rows, expected_num_cols))
def testFirstOrderGradientCalculation(self): param_list = [ "prune_option=first_order_gradient", "gradient_decay_rate=0.5", ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) tf.logging.info(pruning_hparams) w = tf.Variable(tf.linspace(1.0, 10.0, 10), name="weights") _ = pruning.apply_mask(w, prune_option="first_order_gradient") p = pruning.Pruning(pruning_hparams) old_weight_update_op = p.old_weight_update_op() gradient_update_op = p.gradient_update_op() with self.cached_session() as session: tf.global_variables_initializer().run() session.run(gradient_update_op) session.run(old_weight_update_op) weights = pruning.get_weights() old_weights = pruning.get_old_weights() gradients = pruning.get_gradients() weight = weights[0] old_weight = old_weights[0] gradient = gradients[0] self.assertAllEqual( gradient.eval(), tf.math.scalar_mul(0.5, tf.nn.l2_normalize(tf.linspace(1.0, 10.0, 10))).eval()) self.assertAllEqual(weight.eval(), old_weight.eval())
def testFirstOrderGradientBlockMasking(self): param_list = [ "prune_option=first_order_gradient", "gradient_decay_rate=0.5", "block_height=2", "block_width=2", "threshold_decay=0", "block_pooling_function=AVG", ] threshold = tf.Variable(0.0, name="threshold") sparsity = tf.Variable(0.5, name="sparsity") test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) weights_avg = tf.constant([[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], [0.3, 0.3, 0.4, 0.4]]) expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [1., 1., 1., 1.], [1., 1., 1., 1.]] w = tf.Variable(weights_avg, name="weights") _ = pruning.apply_mask(w, prune_option="first_order_gradient") p = pruning.Pruning(pruning_hparams, sparsity=sparsity) old_weight_update_op = p.old_weight_update_op() gradient_update_op = p.gradient_update_op() with self.cached_session() as session: tf.global_variables_initializer().run() session.run(gradient_update_op) session.run(old_weight_update_op) weights = pruning.get_weights() _ = pruning.get_old_weights() gradients = pruning.get_gradients() weight = weights[0] gradient = gradients[0] _, new_mask = p._maybe_update_block_mask(weight, threshold, gradient) self.assertAllEqual(new_mask.get_shape(), weight.get_shape()) mask_val = new_mask.eval() self.assertAllEqual(mask_val, expected_mask)