def test_simple_cnn(self):
        """Compare a network with itself after batchnorm is removed."""
        model = SimpleCNN()
        train_randomly(model)
        model2 = remove_batchnorm(model)

        self.assertLess(len(model2._modules.keys()),
                        len(model._modules.keys()))
        self.assertTrue(compare_models(model, model2, (1, 32, 32)))
    def test_gsc(self):
        """
        Compare the GSC network after batchnorm is removed.
        """
        model = gsc_sparse_cnn(pretrained=True)
        model2 = remove_batchnorm(model)

        self.assertLess(len(model2._modules.keys()),
                        len(model._modules.keys()))
        self.assertTrue(compare_models(model, model2, (1, 32, 32)))
示例#3
0
    def test_simple_cnn(self):
        """Compare a network with itself after batchnorm is removed."""
        model = create_simple_cnn()
        train_randomly(model)
        model2 = remove_batchnorm(model)

        expected_modules = set(name for name, m in model.named_children()
                               if not isinstance(m, BATCH_NORM_CLASSES))
        actual_modules = set(name for name, m in model2.named_children())

        self.assertEqual(actual_modules, expected_modules)
        self.assertTrue(compare_models(model, model2, (1, 32, 32)))
    def test_cnn_more_out_channels(self):
        """Compare another network with itself after batchnorm is removed."""
        model = SimpleCNN(
            cnn_out_channels=16,
            linear_units=20,
        )
        train_randomly(model)
        model2 = remove_batchnorm(model)

        self.assertLess(len(model2._modules.keys()),
                        len(model._modules.keys()))
        self.assertTrue(compare_models(model, model2, (1, 32, 32)))
示例#5
0
    def test_gsc(self):
        """
        Compare the GSC network after batchnorm is removed.
        """
        model = gsc_sparse_cnn(pretrained=True)
        model2 = remove_batchnorm(model)

        expected_modules = set(name for name, m in model.named_children()
                               if not isinstance(m, BATCH_NORM_CLASSES))
        actual_modules = set(name for name, m in model2.named_children())

        self.assertEqual(actual_modules, expected_modules)
        self.assertTrue(compare_models(model, model2, (1, 32, 32)))
    def test_cnn_sparse_weights(self):
        """
        Compare a network with 3 in_channels with itself after batchnorm is removed.
        """
        model = SimpleCNN(
            in_channels=3,
            cnn_out_channels=4,
            linear_units=5,
            sparse_weights=True,
        )
        train_randomly(model, in_channels=3)
        model.apply(rezero_weights)
        model2 = remove_batchnorm(model)

        self.assertLess(len(model2._modules.keys()),
                        len(model._modules.keys()))
        self.assertTrue(compare_models(model, model2, (3, 32, 32)))
示例#7
0
    def test_cnn_sparse_weights(self):
        """
        Compare a network with 3 in_channels with itself after batchnorm is removed.
        """
        model = create_simple_cnn(
            in_channels=3,
            cnn_out_channels=4,
            linear_units=5,
            sparse_weights=True,
        )
        train_randomly(model, in_channels=3)
        model2 = remove_batchnorm(model)

        expected_modules = set(name for name, m in model.named_children()
                               if not isinstance(m, BATCH_NORM_CLASSES))
        actual_modules = set(name for name, m in model2.named_children())

        self.assertEqual(actual_modules, expected_modules)
        self.assertTrue(compare_models(model, model2, (3, 32, 32)))