def test_simple_cnn(self): """Compare a network with itself after batchnorm is removed.""" model = SimpleCNN() train_randomly(model) model2 = remove_batchnorm(model) self.assertLess(len(model2._modules.keys()), len(model._modules.keys())) self.assertTrue(compare_models(model, model2, (1, 32, 32)))
def test_gsc(self): """ Compare the GSC network after batchnorm is removed. """ model = gsc_sparse_cnn(pretrained=True) model2 = remove_batchnorm(model) self.assertLess(len(model2._modules.keys()), len(model._modules.keys())) self.assertTrue(compare_models(model, model2, (1, 32, 32)))
def test_simple_cnn(self): """Compare a network with itself after batchnorm is removed.""" model = create_simple_cnn() train_randomly(model) model2 = remove_batchnorm(model) expected_modules = set(name for name, m in model.named_children() if not isinstance(m, BATCH_NORM_CLASSES)) actual_modules = set(name for name, m in model2.named_children()) self.assertEqual(actual_modules, expected_modules) self.assertTrue(compare_models(model, model2, (1, 32, 32)))
def test_cnn_more_out_channels(self): """Compare another network with itself after batchnorm is removed.""" model = SimpleCNN( cnn_out_channels=16, linear_units=20, ) train_randomly(model) model2 = remove_batchnorm(model) self.assertLess(len(model2._modules.keys()), len(model._modules.keys())) self.assertTrue(compare_models(model, model2, (1, 32, 32)))
def test_gsc(self): """ Compare the GSC network after batchnorm is removed. """ model = gsc_sparse_cnn(pretrained=True) model2 = remove_batchnorm(model) expected_modules = set(name for name, m in model.named_children() if not isinstance(m, BATCH_NORM_CLASSES)) actual_modules = set(name for name, m in model2.named_children()) self.assertEqual(actual_modules, expected_modules) self.assertTrue(compare_models(model, model2, (1, 32, 32)))
def test_cnn_sparse_weights(self): """ Compare a network with 3 in_channels with itself after batchnorm is removed. """ model = SimpleCNN( in_channels=3, cnn_out_channels=4, linear_units=5, sparse_weights=True, ) train_randomly(model, in_channels=3) model.apply(rezero_weights) model2 = remove_batchnorm(model) self.assertLess(len(model2._modules.keys()), len(model._modules.keys())) self.assertTrue(compare_models(model, model2, (3, 32, 32)))
def test_cnn_sparse_weights(self): """ Compare a network with 3 in_channels with itself after batchnorm is removed. """ model = create_simple_cnn( in_channels=3, cnn_out_channels=4, linear_units=5, sparse_weights=True, ) train_randomly(model, in_channels=3) model2 = remove_batchnorm(model) expected_modules = set(name for name, m in model.named_children() if not isinstance(m, BATCH_NORM_CLASSES)) actual_modules = set(name for name, m in model2.named_children()) self.assertEqual(actual_modules, expected_modules) self.assertTrue(compare_models(model, model2, (3, 32, 32)))