예제 #1
0
 def testStringFirstConvDense(self):
     model = self._get_model_with_masked_layers(['conv_1'])
     layers2prune_processed = pruner.process_layers2prune(
         'firstconv', model)
     self.assertAllEqual(['conv_1'], layers2prune_processed)
     model = self._get_model_with_masked_layers(['dense_1'])
     layers2prune_processed = pruner.process_layers2prune(
         'firstdense', model)
     self.assertAllEqual(['dense_1'], layers2prune_processed)
예제 #2
0
 def testStringCustom(self):
     combinations = [['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'],
                     ['conv_1', 'conv_2', 'conv_3', 'conv_4'],
                     ['conv_1', 'conv_2', 'conv_3'], ['conv_1']]
     expected_results = [['conv_3', 'conv_5'], ['conv_2', 'conv_4'],
                         ['conv_2', 'conv_3'], ['conv_1', 'conv_1']]
     for res, layers2prune in zip(expected_results, combinations):
         model = self._get_model_with_masked_layers(layers2prune)
         model.layer_name_counter = {'C': (len(layers2prune) + 1)}
         layers2prune_processed = pruner.process_layers2prune(
             'midconv', model)
         self.assertAllEqual([res[0]], layers2prune_processed)
         layers2prune_processed = pruner.process_layers2prune(
             'lastconv', model)
         self.assertAllEqual([res[1]], layers2prune_processed)
예제 #3
0
 def testAssertion(self):
     all_layers = [
         'conv_1', 'conv_2', 'conv_3', 'dense_1', 'dense_2', 'dense_3'
     ]
     test_args = [['conv_0', 'dense_1'], ['output_1', 'dense_3'],
                  ['conv_4'], ['conv_0'], 'firstConv', 'LastConv']
     model = self._get_model_with_masked_layers(all_layers)
     for test_arg in test_args:
         with self.assertRaises(AssertionError):
             _ = pruner.process_layers2prune(test_arg, model)
예제 #4
0
 def testList(self):
     all_layers = [
         'conv_1', 'conv_2', 'conv_3', 'dense_1', 'dense_2', 'dense_3'
     ]
     test_lists = [['conv_1', 'dense_1'], ['conv_2', 'dense_3'],
                   ['conv_3', 'dense_3'], ['conv_2'], ['conv_2', 'conv_2']]
     model = self._get_model_with_masked_layers(all_layers)
     for test_list in test_lists:
         layers2prune_processed = pruner.process_layers2prune(
             test_list, model)
         self.assertAllEqual(test_list, layers2prune_processed)
예제 #5
0
 def testStringAll(self):
     combinations = [['conv_1', 'conv_2', 'dense_1'], ['dense_1'],
                     ['conv_1']]
     for layers2prune in combinations:
         model = self._get_model_with_masked_layers(layers2prune)
         mock_get_layer_keys = mock.Mock(return_value=layers2prune)
         model.get_layer_keys = mock_get_layer_keys
         layers2prune_processed = pruner.process_layers2prune('all', model)
         self.assertAllEqual(layers2prune, layers2prune_processed)
         mock_get_layer_keys.assert_called_once()
         mock_get_layer_keys.reset_mock()
예제 #6
0
 def testStringMidLast(self):
     all_layers = ['conv_1', 'conv_2', 'conv_3']
     model = self._get_model_with_masked_layers(all_layers)
     for l_name in all_layers:
         layers2prune_processed = pruner.process_layers2prune(l_name, model)
         self.assertAllEqual([l_name], layers2prune_processed)