def testStringFirstConvDense(self): model = self._get_model_with_masked_layers(['conv_1']) layers2prune_processed = pruner.process_layers2prune( 'firstconv', model) self.assertAllEqual(['conv_1'], layers2prune_processed) model = self._get_model_with_masked_layers(['dense_1']) layers2prune_processed = pruner.process_layers2prune( 'firstdense', model) self.assertAllEqual(['dense_1'], layers2prune_processed)
def testStringCustom(self): combinations = [['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'], ['conv_1', 'conv_2', 'conv_3', 'conv_4'], ['conv_1', 'conv_2', 'conv_3'], ['conv_1']] expected_results = [['conv_3', 'conv_5'], ['conv_2', 'conv_4'], ['conv_2', 'conv_3'], ['conv_1', 'conv_1']] for res, layers2prune in zip(expected_results, combinations): model = self._get_model_with_masked_layers(layers2prune) model.layer_name_counter = {'C': (len(layers2prune) + 1)} layers2prune_processed = pruner.process_layers2prune( 'midconv', model) self.assertAllEqual([res[0]], layers2prune_processed) layers2prune_processed = pruner.process_layers2prune( 'lastconv', model) self.assertAllEqual([res[1]], layers2prune_processed)
def testAssertion(self): all_layers = [ 'conv_1', 'conv_2', 'conv_3', 'dense_1', 'dense_2', 'dense_3' ] test_args = [['conv_0', 'dense_1'], ['output_1', 'dense_3'], ['conv_4'], ['conv_0'], 'firstConv', 'LastConv'] model = self._get_model_with_masked_layers(all_layers) for test_arg in test_args: with self.assertRaises(AssertionError): _ = pruner.process_layers2prune(test_arg, model)
def testList(self): all_layers = [ 'conv_1', 'conv_2', 'conv_3', 'dense_1', 'dense_2', 'dense_3' ] test_lists = [['conv_1', 'dense_1'], ['conv_2', 'dense_3'], ['conv_3', 'dense_3'], ['conv_2'], ['conv_2', 'conv_2']] model = self._get_model_with_masked_layers(all_layers) for test_list in test_lists: layers2prune_processed = pruner.process_layers2prune( test_list, model) self.assertAllEqual(test_list, layers2prune_processed)
def testStringAll(self): combinations = [['conv_1', 'conv_2', 'dense_1'], ['dense_1'], ['conv_1']] for layers2prune in combinations: model = self._get_model_with_masked_layers(layers2prune) mock_get_layer_keys = mock.Mock(return_value=layers2prune) model.get_layer_keys = mock_get_layer_keys layers2prune_processed = pruner.process_layers2prune('all', model) self.assertAllEqual(layers2prune, layers2prune_processed) mock_get_layer_keys.assert_called_once() mock_get_layer_keys.reset_mock()
def testStringMidLast(self): all_layers = ['conv_1', 'conv_2', 'conv_3'] model = self._get_model_with_masked_layers(all_layers) for l_name in all_layers: layers2prune_processed = pruner.process_layers2prune(l_name, model) self.assertAllEqual([l_name], layers2prune_processed)