def test_mobilenet_v2_from_tf2(self):
     model = keras.applications.MobileNetV2(input_shape=(224, 224, 3),
                                            weights='imagenet',
                                            include_top=True,
                                            pooling='max')
     pruner = lottery_ticket_pruner.LotteryTicketPruner(model)
     self.assertEqual(53, len(pruner.prune_masks_map))
Exemple #2
0
 def test_inception_v3(self):
     model = keras.applications.InceptionV3(input_shape=(299, 299, 3),
                                            weights='imagenet',
                                            include_top=True,
                                            pooling='max')
     pruner = lottery_ticket_pruner.LotteryTicketPruner(model)
     self.assertEqual(95, len(pruner.prune_masks_map))
Exemple #3
0
    def test_LotteryTicketPruner_use_case_1(self):
        model = self._create_test_model()
        starting_weights = model.get_weights()
        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)

        # First layer is the input layer; ignore it
        # Second layer is Dense layer with 2 weights. First is fully connected weights. Second is output weights.
        interesting_key = tuple([1, tuple([0])])
        num_unmasked = np.sum(pruner.prune_masks_map[interesting_key][0])
        self.assertEqual(num_unmasked, TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES)

        # No pruning percent specified so no weight should change
        initial_model_weights_sum = self._summed_model_weights(model)
        pruner.apply_pruning(model)
        new_model_weights_sum = self._summed_model_weights(model)
        self.assertEqual(initial_model_weights_sum, new_model_weights_sum)

        pruner.calc_prune_mask(model, 0.5, 'random')
        num_masked = np.sum(pruner.prune_masks_map[interesting_key][0] == 0)
        self.assertEqual(num_masked, int(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * 0.5))

        pruner.calc_prune_mask(model, 0.2, 'random')
        num_masked = np.sum(pruner.prune_masks_map[interesting_key][0] == 0)
        self.assertEqual(num_masked, int(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * 0.6))

        model.set_weights(starting_weights)
        new_model_weights_sum = self._summed_model_weights(model)
        self.assertEqual(initial_model_weights_sum, new_model_weights_sum)

        pruner.apply_pruning(model)
        num_masked = np.sum(pruner.prune_masks_map[interesting_key][0] == 0)
        self.assertEqual(num_masked, int(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * 0.6))
Exemple #4
0
    def test_prune_func_large_final(self):
        """ Tests case where many or all weights are same value. Hence we might be tempted to mask on all of the
        smallest weights rather than honoring only up to the prune rate
        """
        model = self._create_test_dnn_model()
        interesting_layer = model.layers[1]
        interesting_weights_index = 0

        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)

        # Assign the weights values between 0..N, with 1/2 the weights being negative
        weights = interesting_layer.get_weights()
        interesting_weights = weights[interesting_weights_index]
        num_interesting_layer_weights = np.prod(interesting_weights.shape)
        new_weights = np.array(np.random.choice(range(num_interesting_layer_weights),
                                                size=num_interesting_layer_weights, replace=False))
        rand_multiplier = np.random.choice([1, -1], size=num_interesting_layer_weights, replace=True)
        new_weights *= rand_multiplier
        new_weights = new_weights.reshape(interesting_weights.shape)
        weights[interesting_weights_index] = new_weights
        interesting_layer.set_weights(weights)

        pruner.set_pretrained_weights(model)

        # Now verify that the absolute value of all unpruned weights are as large or larger than the smallest expected
        # non-zero weight
        prune_rate = 0.2
        pruner.calc_prune_mask(model, prune_rate, 'large_final')
        pruner.apply_pruning(model)
        weights = interesting_layer.get_weights()
        pruned_weights = weights[interesting_weights_index]
        pruned_weights = np.abs(pruned_weights)
        num_zero = np.sum(pruned_weights == 0.0)
        self.assertEqual(int(num_interesting_layer_weights * prune_rate), num_zero)
        expected_non_zero_min = int(np.prod(pruned_weights.shape) * prune_rate)
        num_in_expected_range = np.sum(pruned_weights >= expected_non_zero_min)
        self.assertEqual(num_interesting_layer_weights - num_zero, num_in_expected_range)

        # Now do another round of pruning
        prune_rate = 0.5
        new_overall_prune_rate = 0.6    # (1.0 - 0.2) * 0.5
        pruner.calc_prune_mask(model, prune_rate, 'large_final')
        pruner.apply_pruning(model)
        weights = interesting_layer.get_weights()
        pruned_weights = weights[interesting_weights_index]
        pruned_weights = np.abs(pruned_weights)
        num_zero = np.sum(pruned_weights == 0.0)
        self.assertEqual(int(num_interesting_layer_weights * new_overall_prune_rate), num_zero)
        expected_non_zero_min = int(np.prod(pruned_weights.shape) * new_overall_prune_rate)
        num_in_expected_range = np.sum(pruned_weights >= expected_non_zero_min)
        self.assertEqual(num_interesting_layer_weights - num_zero, num_in_expected_range)
Exemple #5
0
    def test_apply_dwr(self):
        model = self._create_test_model()
        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)
        interesting_layer_index = 1
        interesting_weights_index = 0
        tpl = (interesting_layer_index, (interesting_weights_index, ))
        interesting_layer = model.layers[interesting_layer_index]

        # Assign the weights values between 0..N, with 1/2 the weights being negative
        weights = interesting_layer.get_weights()
        interesting_weights = weights[interesting_weights_index]
        num_interesting_layer_weights = np.prod(interesting_weights.shape)
        test_weights = np.array(np.random.choice(range(num_interesting_layer_weights),
                                                 size=num_interesting_layer_weights, replace=False))
        test_weights = test_weights.reshape(interesting_weights.shape)
        weights[interesting_weights_index] = test_weights
        interesting_layer.set_weights(weights)

        prune_rate1 = 0.5
        pruner.calc_prune_mask(model, prune_rate1, 'smallest_weights')
        pruner.apply_pruning(model)
        pruner.apply_dwr(model)

        # Mask out any pruned weights
        pruned_weights = interesting_layer.get_weights()[interesting_weights_index]
        expected_test_weights = test_weights * pruner.prune_masks_map[tpl][interesting_weights_index]
        # We expect DWR to have increased the value of unmasked weight by a factor of 2.0 (1.0 / 0.5 = 2.0)
        expected_test_weights *= (1.0 / prune_rate1)
        np.testing.assert_array_equal(expected_test_weights, pruned_weights)

        # Prune again to make sure we accumulate the DWR multiplier as expected
        weights[interesting_weights_index] = test_weights
        interesting_layer.set_weights(weights)

        prune_rate2 = 0.2
        pruner.calc_prune_mask(model, prune_rate2, 'smallest_weights')
        pruner.apply_pruning(model)
        pruner.apply_dwr(model)

        # Mask out any pruned weights
        pruned_weights = interesting_layer.get_weights()[interesting_weights_index]
        expected_test_weights = test_weights * pruner.prune_masks_map[tpl][interesting_weights_index]
        # We expect DWR to have increased the value of unmasked weight by a factor of 2.5
        # (1.0 / ((1.0 - 0.5) * 0.2) = 2.5)
        # But since there is rounding due to counting the number of 1s in the prune mask (an int) the rescaling factor
        # is not quite exactly 2.5
        num_first_prune_ones = int(num_interesting_layer_weights * prune_rate1)
        denominator = (num_interesting_layer_weights - (num_first_prune_ones + int(num_first_prune_ones * prune_rate2)))
        rescale_factor = num_interesting_layer_weights / denominator
        expected_test_weights *= rescale_factor
        np.testing.assert_array_almost_equal(expected_test_weights, pruned_weights, decimal=3)
Exemple #6
0
    def test_calc_prune_mask_negative(self):
        model = self._create_test_model()
        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)
        with self.assertRaises(ValueError) as ex:
            pruner.calc_prune_mask(model, 0.3, 'unknown_strategy')
        self.assertIn('smallest_weights', str(ex.exception))
        self.assertIn('smallest_weights_global', str(ex.exception))

        with self.assertRaises(ValueError) as ex:
            pruner.calc_prune_mask(model, -0.25, 'smallest_weights_global')
        self.assertIn('exclusive', str(ex.exception))

        with self.assertRaises(ValueError) as ex:
            pruner.calc_prune_mask(model, 1.1, 'smallest_weights_global')
        self.assertIn('exclusive', str(ex.exception))
 def test_inception_v3(self):
     if hasattr(keras.applications, 'InceptionV3'):
         factory_func = keras.applications.InceptionV3
     elif hasattr(keras.applications.inception_v3, 'InceptionV3'):
         factory_func = keras.applications.inception_v3.InceptionV3
     else:
         raise Exception(
             'Cannot find InceptionV3 while using `from tensorflow.python import keras`'
         )
     model = factory_func(input_shape=(299, 299, 3),
                          weights='imagenet',
                          include_top=True,
                          pooling='max')
     pruner = lottery_ticket_pruner.LotteryTicketPruner(model)
     self.assertEqual(95, len(pruner.prune_masks_map))
Exemple #8
0
    def test_prune_large_final_negative(self):
        """ Negative tests for 'large_final' pruning strategy
        """
        model = self._create_test_dnn_model()
        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)

        # Don't call this since not calling this is the purpose of this test
        # pruner.set_pretrained_weights(model)

        # Now verify that the absolute value of all unpruned weights are as large or larger than the smallest expected
        # non-zero weight
        with self.assertRaises(ValueError) as ex:
            pruner.calc_prune_mask(model, 0.2, 'large_final')
        self.assertIn('large_final', str(ex.exception))
        self.assertIn('LotteryTicketPruner.pretrained_weights()', str(ex.exception))
Exemple #9
0
    def test_prune_func_large_final_same_weight_values(self):
        """ Tests case where many or all weights are same value. Hence we might be tempted to mask on all of the
        smallest weights rather than honoring only up to the prune rate
        """
        model = self._create_test_dnn_model()
        interesting_layer = model.layers[1]
        interesting_weights_index = 0

        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)

        # Assign the weights values between 0..N, with 1/2 the weights being negative
        test_weight_value = 1.23
        weights = interesting_layer.get_weights()
        interesting_weights = weights[interesting_weights_index]
        num_interesting_layer_weights = np.prod(interesting_weights.shape)
        new_weights = np.array(interesting_weights)
        new_weights.fill(test_weight_value)
        weights[interesting_weights_index] = new_weights
        interesting_layer.set_weights(weights)

        pruner.set_pretrained_weights(model)

        # Now verify that the absolute value of all unpruned weights are as large or larger than the smallest expected
        # non-zero weight
        prune_rate = 0.2
        pruner.calc_prune_mask(model, prune_rate, 'large_final')
        pruner.apply_pruning(model)
        weights = interesting_layer.get_weights()
        pruned_weights = weights[interesting_weights_index]
        num_zero = np.sum(pruned_weights == 0.0)
        self.assertEqual(int(num_interesting_layer_weights * prune_rate), num_zero)
        num_of_expected_value = np.sum(pruned_weights == test_weight_value)
        self.assertEqual(num_interesting_layer_weights - num_zero, num_of_expected_value)

        # Now do another round of pruning
        prune_rate = 0.5
        new_overall_prune_rate = 0.6    # (1.0 - 0.2) * 0.5
        pruner.calc_prune_mask(model, prune_rate, 'large_final')
        pruner.apply_pruning(model)
        weights = interesting_layer.get_weights()
        pruned_weights = weights[interesting_weights_index]
        num_zero = np.sum(pruned_weights == 0.0)
        self.assertEqual(int(num_interesting_layer_weights * new_overall_prune_rate), num_zero)
        num_of_expected_value = np.sum(pruned_weights == test_weight_value)
        self.assertEqual(num_interesting_layer_weights - num_zero, num_of_expected_value)
Exemple #10
0
    def test_smallest_weights_2(self):
        model = self._create_test_model()
        # First layer is the input layer; ignore it
        # Second layer is Dense layer with 2 weights. First is fully connected weights. Second is output weights.
        interesting_layer = model.layers[1]
        interesting_layer_shape = interesting_layer.weights[0].shape

        dl_test_weights = np.random.choice(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES,
                                           size=TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES, replace=False)
        # Make some weights negative
        dl_test_weights -= TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES // 2
        dl_test_weights = dl_test_weights.reshape(interesting_layer_shape)
        interesting_layer.set_weights([dl_test_weights, interesting_layer.get_weights()[1]])
        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)

        prune_rate = 0.5
        pruner.calc_prune_mask(model, prune_rate, 'smallest_weights')
        pruner.apply_pruning(model)
        actual_weights = interesting_layer.get_weights()
        min_expected_pos = TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * prune_rate // 2 - 1
        max_expected_neg = -TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * prune_rate // 2 + 1
        unpruned_pos = np.sum(actual_weights[0] >= min_expected_pos)
        unpruned_neg = np.sum(actual_weights[0] <= max_expected_neg)
        unpruned = unpruned_pos + unpruned_neg
        self.assertIn(unpruned, [int(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * prune_rate),
                                 int(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * prune_rate) - 1])
        expected_to_be_pruned = TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES - unpruned - 1
        self.assertLessEqual(abs(int(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * prune_rate) - expected_to_be_pruned),
                             1)

        # Prune again
        prune_rate2 = 0.1
        expected_to_be_pruned2 = int(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * prune_rate2 * (1.0 - prune_rate))
        pruner.calc_prune_mask(model, prune_rate2, 'smallest_weights')
        pruner.apply_pruning(model)
        actual_weights = interesting_layer.get_weights()
        min_expected_pos = expected_to_be_pruned2 // 2 - 1
        max_expected_neg = -expected_to_be_pruned2 // 2 + 1
        unpruned_pos = np.sum(actual_weights[0] >= min_expected_pos)
        unpruned_neg = np.sum(actual_weights[0] <= max_expected_neg)
        unpruned = unpruned_pos + unpruned_neg
        expected_unpruned = TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES - expected_to_be_pruned - expected_to_be_pruned2
        self.assertLessEqual(abs(expected_unpruned - unpruned), 1)
Exemple #11
0
    def test_reset_masks(self):
        model = self._create_test_model()
        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)
        interesting_layer_index = 1
        interesting_weights_index = 0
        tpl = tuple([interesting_layer_index, tuple([interesting_weights_index])])

        original_mask = np.array(pruner.prune_masks_map[tpl][interesting_weights_index])
        self.assertEqual(TEST_DENSE_WEIGHT_COUNT, np.sum(original_mask))

        # Prune and make sure prune mask has changed
        pruner.calc_prune_mask(model, 0.2, 'smallest_weights')
        pruned_mask = pruner.prune_masks_map[tpl][interesting_weights_index]
        num_pruned = np.sum(pruned_mask)
        self.assertLess(num_pruned, TEST_DENSE_WEIGHT_COUNT)

        # Now reset
        pruner.reset_masks()
        reset_mask = np.array(pruner.prune_masks_map[tpl][interesting_weights_index])
        self.assertEqual(TEST_DENSE_WEIGHT_COUNT, np.sum(reset_mask))
Exemple #12
0
    def test_prune_func_smallest_weights_global_negative(self):
        model = self._create_test_model()
        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)

        # Both percentage and count are unspecified
        with self.assertRaises(ValueError) as ex:
            _ = _prune_func_smallest_weights_global(None, None, prune_percentage=None, prune_count=None)
        self.assertIn('prune_percentage', str(ex.exception))
        self.assertIn('prune_count', str(ex.exception))

        # Prune percentage is zero
        with unittest.mock.patch('logging.Logger.warning') as warning:
            _ = _prune_func_smallest_weights_global(pruner.iterate_prunables(model), None, prune_percentage=0.0,
                                                    prune_count=None)
            self.assertEqual(1, warning.call_count)

        # Prune count is zero
        with unittest.mock.patch('logging.Logger.warning') as warning:
            _ = _prune_func_smallest_weights_global(pruner.iterate_prunables(model), None, prune_percentage=None,
                                                    prune_count=0)
            self.assertEqual(1, warning.call_count)
Exemple #13
0
    def test_constructor(self):
        model1 = self._create_test_model()
        pruner = lottery_ticket_pruner.LotteryTicketPruner(model1)

        # Disabled since there are legit cases where the two models may different. E.g when using transfer learning
        # one may choose to replace, say, a single head layer in the original model with 2 or more layers in the new
        # model.
        # # Different number of layers
        # model2 = self._create_test_mode_extra_layer()
        # with self.assertRaises(ValueError) as ex:
        #     pruner.calc_prune_mask(model2, 0.2, 'smallest_weights')
        # self.assertIn('must have the same number of layers', str(ex.exception))

        # Different shapes
        model2 = self._create_test_model_diff_shape(diff_input_shape=True)
        with self.assertRaises(ValueError) as ex:
            pruner.apply_pruning(model2)
        self.assertIn('must have the same input shape', str(ex.exception))

        model2 = self._create_test_model_diff_shape(diff_output_shape=True)
        with self.assertRaises(ValueError) as ex:
            pruner.calc_prune_mask(model2, 0.2, 'smallest_weights')
        self.assertIn('must have the same output shape', str(ex.exception))
Exemple #14
0
    def test_smallest_weights(self):
        model = self._create_test_model()
        # First layer is the input layer; ignore it
        # Second layer is Dense layer with 2 weights. First is fully connected weights. Second is output weights.
        interesting_layer_index = 1
        interesting_layer = model.layers[interesting_layer_index]
        interesting_layer_shape = interesting_layer.weights[0].shape
        interesting_layer_weight_count = int(np.prod(interesting_layer_shape))
        interesting_key = tuple([interesting_layer_index, tuple([0])])

        dl_test_weights = np.random.choice(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES,
                                           size=TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES, replace=False)
        # Get rid of zero weights since we count those below during verification
        dl_test_weights += 1
        dl_test_weights = dl_test_weights.reshape(interesting_layer_shape)
        interesting_layer.set_weights([dl_test_weights, interesting_layer.get_weights()[1]])
        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)

        pruner.calc_prune_mask(model, 0.5, 'smallest_weights')
        num_masked = np.sum(pruner.prune_masks_map[interesting_key][0] == 0)
        self.assertEqual(num_masked, int(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * 0.5))

        pruner.apply_pruning(model)
        actual_weights = interesting_layer.get_weights()
        actual_weights[0][actual_weights[0] == 0.0] = math.inf
        min_weight = np.min(actual_weights[0])
        self.assertGreaterEqual(min_weight, int(interesting_layer_weight_count * 0.5))

        pruner.calc_prune_mask(model, 0.2, 'smallest_weights')
        num_masked = np.sum(pruner.prune_masks_map[interesting_key][0] == 0)
        self.assertEqual(num_masked, int(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * 0.6))

        pruner.apply_pruning(model)
        actual_weights = interesting_layer.get_weights()
        actual_weights[0][actual_weights[0] == 0.0] = math.inf
        min_weight = np.min(actual_weights[0])
        self.assertGreaterEqual(min_weight, int(interesting_layer_weight_count * 0.6))
Exemple #15
0
    def test_smallest_weights_similar_weights(self):
        """ Tests case where many or all weights are same value. Hence we might be tempted to mask on all of the
        smallest weights rather than honoring only up to the prune rate
        """
        model = self._create_test_model()
        # First layer is the input layer; ignore it
        # Second layer is Dense layer with 2 weights. First is fully connected weights. Second is output weights.
        interesting_layer = model.layers[1]
        interesting_layer_shape = interesting_layer.weights[0].shape

        # Make all weights the same
        dl_test_weights = np.ones([TEST_DENSE_LAYER_INPUTS, TEST_NUM_CLASSES], dtype=int)
        # Make some weights negative
        dl_test_weights = dl_test_weights.reshape(interesting_layer_shape)
        interesting_layer.set_weights([dl_test_weights, interesting_layer.get_weights()[1]])
        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)

        prune_rate = 0.5
        pruner.calc_prune_mask(model, prune_rate, 'smallest_weights')
        pruner.apply_pruning(model)
        actual_weights = interesting_layer.get_weights()
        expected = int(TEST_DENSE_LAYER_INPUTS * TEST_NUM_CLASSES * prune_rate)
        actual = np.sum(actual_weights[0])
        self.assertEqual(expected, actual)
    def test_smallest_weights_global(self):
        """ Tests case where many or all weights are same value. Hence we might be tempted to mask on all of the
        smallest weights rather than honoring only up to the prune rate
        """
        model = self._create_test_dnn_model()
        interesting_layers = [
            model.layers[1], model.layers[4], model.layers[8]
        ]
        interesting_weights_index = 0

        # Make sure no weights are zero so our checks below for zeroes only existing in masked weights are reliable
        weight_counts = []
        for layer in interesting_layers:
            weights = layer.get_weights()
            weights[interesting_weights_index][
                weights[interesting_weights_index] == 0.0] = 0.1234
            layer.set_weights(weights)
            num_weights = np.prod(weights[interesting_weights_index].shape)
            weight_counts.append(num_weights)

        pruner = lottery_ticket_pruner.LotteryTicketPruner(model)

        num_pruned1 = 0
        for layer in interesting_layers:
            weights = layer.get_weights()
            num_pruned1 += np.sum(weights[interesting_weights_index] == 0.0)

        prune_rate = 0.5
        pruner.calc_prune_mask(model, prune_rate, 'smallest_weights_global')

        # calc_prune_mask() shouldn't do the actual pruning so verify that weights didn't change
        num_pruned2 = 0
        for layer in interesting_layers:
            weights = layer.get_weights()
            num_pruned2 += np.sum(weights[interesting_weights_index] == 0.0)
        self.assertEqual(num_pruned1, num_pruned2)

        pruner.apply_pruning(model)
        pruned_counts = []
        for layer in interesting_layers:
            weights = layer.get_weights()
            pruned_counts.append(
                np.sum(weights[interesting_weights_index] == 0.0))

        total_weights = np.sum(weight_counts)
        num_pruned = np.sum(pruned_counts)
        self.assertAlmostEqual(prune_rate,
                               num_pruned / total_weights,
                               places=1)
        # Given the seeding we did at the beginning of this test these results should be reproducible. They were
        # obtained by manual inspection.
        # Ranges are used here since TF 1.x on python 3.6, 3.7 gives slightly different results from TF 2.x on
        # python 3.8. These assertions accomodate both.
        self.assertTrue(62 <= pruned_counts[0] <= 67,
                        msg=f'pruned_counts={pruned_counts}')
        self.assertTrue(2 <= pruned_counts[1] <= 5,
                        msg=f'pruned_counts={pruned_counts}')
        self.assertTrue(5 <= pruned_counts[2] <= 9,
                        msg=f'pruned_counts={pruned_counts}')
        self.assertEqual(75, sum(pruned_counts))

        # Now prune once more to make sure cumulative pruning works as expected
        total_prune_rate = prune_rate
        prune_rate = 0.2
        total_prune_rate = total_prune_rate + (1.0 -
                                               total_prune_rate) * prune_rate
        pruner.calc_prune_mask(model, prune_rate, 'smallest_weights_global')
        pruner.apply_pruning(model)

        pruned_counts = []
        for layer in interesting_layers:
            weights = layer.get_weights()
            pruned_counts.append(
                np.sum(weights[interesting_weights_index] == 0.0))

        total_weights = np.sum(weight_counts)
        num_pruned = np.sum(pruned_counts)
        self.assertEqual(num_pruned / total_weights, total_prune_rate)
        # Given the seeding we did at the beginning of this test these results should be reproducible. They were
        # obtained by manual inspection.
        # Ranges are used here since TF 1.x on python 3.6, 3.7 gives slightly different results from TF 2.x on
        # python 3.8. These assertions accomodate both.
        self.assertTrue(74 <= pruned_counts[0] <= 78,
                        msg=f'pruned_counts={pruned_counts}')
        self.assertTrue(2 <= pruned_counts[1] <= 5,
                        msg=f'pruned_counts={pruned_counts}')
        self.assertTrue(9 <= pruned_counts[2] <= 12,
                        msg=f'pruned_counts={pruned_counts}')
        self.assertEqual(90, sum(pruned_counts))
def evaluate(which_set, prune_strategy, use_dwr, epochs, output_dir):
    """ Evaluates multiple training approaches:
            A model with randomly initialized weights evaluated with no training having been done
            A model trained from randomly initialized weights
            A model with randomly initialized weights evaluated with no training having been done *but* lottery ticket
                pruning has been done prior to evaluation.
            Several models trained from randomly initialized weights *but* with lottery ticket pruning applied at the
                end of every epoch.
        :param which_set: One of 'mnist', 'cifar10', 'cifar10_reduced_10x'.
            'mnist' is the standard MNIST data set (70k total images of digits 0-9).
            'cifar10' is the standard CIFAR10 data set (60k total images in 10 classes, 6k images/class)
            'cifar10_reduced_10x' is just like 'cifar10' but with the total training, test sets reduced by 10x.
                (6k total images in 10 classes, 600 images/class).
                This is useful for seeing the effects of lottery ticket pruning on a smaller dataset.
        :param prune_strategy: One of the strategies supported by `LotteryTicketPruner.calc_prune_mask()`
            A string indicating how the pruning should be done.
                'random': Pruning is done randomly across all prunable layers.
                'smallest_weights': The smallest weights at each prunable layer are pruned. Each prunable layer has the
                    specified percentage pruned from the layer's weights.
                'smallest_weights_global': The smallest weights across all prunable layers are pruned. Some layers may
                    have substantially more or less weights than `prune_percentage` pruned from them. But overall,
                    across all prunable layers, `prune_percentage` weights will be pruned.
                'large_final': Keeps the weights that have the largest magnitude from the previously trained model.
                    This is 'large_final' as defined in https://arxiv.org/pdf/1905.01067.pdf
                'large_final': Keeps the weights that have the largest magnitude from the previously trained model.
                    This is 'large_final' as defined in https://arxiv.org/pdf/1905.01067.pdf
        :param boolean use_dwr: Whether or not to apply Dynamic Weight Rescaling (DWR) to the unpruned weights in the
            model.
            See section 5.2, "Dynamic Weight Rescaling" of https://arxiv.org/pdf/1905.01067.pdf.
            A quote from that paper describes it best:
                "For each training iteration and for each layer, we multiply the underlying weights by the ratio of the
                total number of weights in the layer over the number of ones in the corresponding mask."
        :param epochs: The number of epochs to train the models for.
        :param output_dir: The directory to put output files.
        :returns losses and accuracies for the evaluations. Each are a dict of keyed by experiment name and whose value
            is the loss/accuracy.
    """
    losses = {}
    accuracies = {}

    experiment = 'xfer_learn'
    mnist = MNIST(experiment, which_set=which_set)
    # Split the dataset into two, the dataset that we'll use to classically train a model, and the dataset we'll
    # use to apply train a new model using transfer learning and lottery ticket pruning.
    tl_dataset = mnist.dataset.split_dataset()
    model = mnist.create_model()

    experiment = 'xfer_learn_no_training'
    losses[experiment], accuracies[experiment] = mnist.evaluate(model)

    # Classically train a model on data from half of the classes
    experiment = 'xfer_learn_train_1st_half_data'
    mnist.fit(model, epochs)
    losses[experiment], accuracies[experiment] = mnist.evaluate(model)
    # For this experiment we consider the starting weights to be the initial weights of the trained model on the
    # first N/2 class' samples. This is the source model that we will use to do transfer learning to train a model on
    # the remaining data that has "new" class labels.
    starting_weights = model.get_weights()

    pruner = lottery_ticket_pruner.LotteryTicketPruner(model)

    # Now we classically train a model on the other half of the data from the previously unknown classes
    experiment = 'xfer_learn_train_2nd_half_data'
    mnist.fit(model, epochs, dataset=tl_dataset)
    trained_weights = model.get_weights()
    losses[experiment], accuracies[experiment] = mnist.evaluate(
        model, dataset=tl_dataset)
    epoch_logs = mnist.get_epoch_logs()
    pruner.set_pretrained_weights(model)

    # Evaluate performance of model with original weights and pruning applied
    num_prune_rounds = 4
    prune_rate = 0.2
    overall_prune_rate = 0.0
    for i in range(num_prune_rounds):
        prune_rate = pow(prune_rate, 1.0 / (i + 1))
        overall_prune_rate = overall_prune_rate + prune_rate * (
            1.0 - overall_prune_rate)

        # Make sure each iteration of pruning uses that same trained weights to determine pruning mask
        model.set_weights(trained_weights)
        pruner.calc_prune_mask(model, prune_rate, prune_strategy)
        # Now revert model to original random starting weights and apply pruning
        model.set_weights(starting_weights)
        pruner.apply_pruning(model)

        experiment = 'xfer_learn_no_training_pruned@{:.4f}'.format(
            overall_prune_rate)
        losses[experiment], accuracies[experiment] = mnist.evaluate(model)

    pruner.reset_masks()

    # Calculate pruning mask below using trained weights
    model.set_weights(trained_weights)

    # Now train from original weights and prune during training
    prune_rate = 0.2
    overall_prune_rate = 0.0
    for i in range(num_prune_rounds):
        prune_rate = pow(prune_rate, 1.0 / (i + 1))
        overall_prune_rate = overall_prune_rate + prune_rate * (
            1.0 - overall_prune_rate)

        # Calculate the pruning mask using the trained model and it's final trained weights
        pruner.calc_prune_mask(model, prune_rate, prune_strategy)

        # Now create a new model that has the original random starting weights and train it
        experiment = 'xfer_learn_pruned@{:.4f}'.format(overall_prune_rate)
        mnist_pruned = MNISTPruned(experiment,
                                   pruner,
                                   use_dwr=use_dwr,
                                   which_set=which_set)
        # Need to split the dataset here so `mnist` and `mnist_pruned` models have same shape
        _ = mnist_pruned.dataset.split_dataset()
        prune_trained_model = mnist_pruned.create_model()
        prune_trained_model.set_weights(starting_weights)
        mnist_pruned.fit(prune_trained_model, epochs, dataset=tl_dataset)
        losses[experiment], accuracies[experiment] = mnist_pruned.evaluate(
            prune_trained_model, dataset=tl_dataset)

        epoch_logs = _merge_epoch_logs(epoch_logs,
                                       mnist_pruned.get_epoch_logs())
        _to_floats(epoch_logs)

        # Periodically save the results to allow inspection during these multiple lengthy iterations
        with open(os.path.join(output_dir, 'epoch_logs.json'), 'w') as f:
            json.dump(epoch_logs, f, indent=4)

    # Now save csv file so it's easier to compare loss, accuracy across the experiments
    headings = []
    for experiment in epoch_logs[0].keys():
        headings.extend([experiment, '', '', ''])
    sub_headings = ['train_loss', 'train_acc', 'val_loss', 'val_acc'] * len(
        epoch_logs[0])
    epoch_logs_df = pd.DataFrame([], columns=[headings, sub_headings])
    all_keys = set(epoch_logs[0].keys())
    for epoch, epoch_results in epoch_logs.items():
        row = []
        for experiment in all_keys:
            if experiment in epoch_results:
                exp_dict = epoch_logs[epoch][experiment]
                row.extend([
                    exp_dict['loss'], exp_dict['acc'], exp_dict['val_loss'],
                    exp_dict['val_acc']
                ])
            else:
                row.extend([math.nan, math.nan, math.nan, math.nan])

        epoch_logs_df.loc[epoch] = row
    epoch_logs_df.to_csv(os.path.join(output_dir, 'epoch_logs.csv'))

    return losses, accuracies
 def setUp(self):
     self.mnist_test = MNISTTest()
     self.model = self.mnist_test.create_model()
     self.pruner = lottery_ticket_pruner.LotteryTicketPruner(self.model)