def testMaskedLSTMCell(self): expected_num_masks = 1 expected_num_rows = 2 * self.dim expected_num_cols = 4 * self.dim with self.cached_session(): inputs = tf.Variable(tf.random_normal([self.batch_size, self.dim])) c = tf.Variable(tf.random_normal([self.batch_size, self.dim])) h = tf.Variable(tf.random_normal([self.batch_size, self.dim])) state = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(c, h) lstm_cell = rnn_cells.MaskedLSTMCell(self.dim) lstm_cell(inputs, state) self.assertEqual(len(pruning.get_masks()), expected_num_masks) self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) self.assertEqual(len(pruning.get_weights()), expected_num_masks) self.assertEqual(len(pruning.get_old_weights()), expected_num_masks) self.assertEqual(len(pruning.get_old_old_weights()), expected_num_masks) self.assertEqual(len(pruning.get_gradients()), expected_num_masks) for mask in pruning.get_masks(): self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) for weight in pruning.get_weights(): self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols)) for old_weight in pruning.get_old_weights(): self.assertEqual(old_weight.shape, (expected_num_rows, expected_num_cols)) for old_old_weight in pruning.get_old_old_weights(): self.assertEqual(old_old_weight.shape, (expected_num_rows, expected_num_cols)) for gradient in pruning.get_gradients(): self.assertEqual(gradient.shape, (expected_num_rows, expected_num_cols))
def testSecondOrderGradientCalculation(self): param_list = [ "prune_option=second_order_gradient", "gradient_decay_rate=0.5", ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) tf.logging.info(pruning_hparams) w = tf.Variable(tf.linspace(1.0, 10.0, 10), name="weights") _ = pruning.apply_mask(w, prune_option="second_order_gradient") p = pruning.Pruning(pruning_hparams) old_weight_update_op = p.old_weight_update_op() old_old_weight_update_op = p.old_old_weight_update_op() gradient_update_op = p.gradient_update_op() with self.cached_session() as session: tf.global_variables_initializer().run() session.run(old_weight_update_op) session.run(old_old_weight_update_op) session.run(tf.assign(w, tf.math.scalar_mul(2.0, w))) session.run(gradient_update_op) old_weights = pruning.get_old_weights() old_old_weights = pruning.get_old_old_weights() gradients = pruning.get_gradients() old_weight = old_weights[0] old_old_weight = old_old_weights[0] gradient = gradients[0] self.assertAllEqual( gradient.eval(), tf.math.scalar_mul(0.5, tf.nn.l2_normalize(tf.linspace(1.0, 10.0, 10))).eval()) self.assertAllEqual(old_weight.eval(), old_old_weight.eval())