def test_prepare_for_inference_cpu(self):
        import torch.nn as nn

        class Foo(nn.Module):
            def __init__(self):
                super().__init__()
                layers = []
                layers2 = []
                for _ in range(10):
                    layers.append(nn.Conv2d(3, 3, 1))
                    layers.append(nn.BatchNorm2d(3))
                    layers.append(nn.ReLU())

                    layers2.append(nn.Conv2d(3, 3, 1))
                    layers2.append(nn.BatchNorm2d(3))
                    layers2.append(nn.ReLU())
                self.model = nn.Sequential(*layers)
                self.model2 = nn.Sequential(*layers2)

            def forward(self, x):
                return self.model(x) + self.model2(x)

        N, C, H, W, = 1, 3, 224, 224
        inp = torch.randn(N, C, H, W)
        with torch.no_grad():
            model = Foo().eval()
            optimized_model = optimization.prepare_for_inference(model)
            torch.testing.assert_allclose(model(inp), optimized_model(inp))
    def test_prepare_for_inference_cpu_torchvision(self):
        models = [
            torchvision.models.resnet18,
            torchvision.models.resnet50,
            torchvision.models.densenet121,
            torchvision.models.shufflenet_v2_x1_0,
            torchvision.models.vgg16,
            torchvision.models.mobilenet_v2,
            torchvision.models.mnasnet1_0,
            torchvision.models.resnext50_32x4d
        ]
        with torch.no_grad():
            for model_type in models:
                model = model_type()
                C, H, W, = 3, 224, 224
                inp = torch.randn(3, C, H, W)
                model(inp)
                model.eval()
                inp = torch.randn(1, C, H, W)
                heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0)
                optimized_model = optimization.prepare_for_inference(model)

                orig_out = model(inp)
                new_out = optimized_model(inp)
                torch.testing.assert_allclose(orig_out, new_out)