from monai.networks.nets import ( se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154, ) from tests.utils import skip_if_quick TEST_CASE_1 = [senet154(3, 2, 2)] TEST_CASE_2 = [se_resnet50(3, 2, 2)] TEST_CASE_3 = [se_resnet101(3, 2, 2)] TEST_CASE_4 = [se_resnet152(3, 2, 2)] TEST_CASE_5 = [se_resnext50_32x4d(3, 2, 2)] TEST_CASE_6 = [se_resnext101_32x4d(3, 2, 2)] TEST_CASE_PRETRAINED = [se_resnet50(2, 3, 2, pretrained=True)] class TestSENET(unittest.TestCase): @parameterized.expand([ TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6 ]) def test_senet_shape(self, net): input_data = torch.randn(2, 2, 64, 64, 64) expected_shape = (2, 2) net.eval() with torch.no_grad(): result = net.forward(input_data)
se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154, ) device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [senet154(3, 2, 2).to(device)] TEST_CASE_2 = [se_resnet50(3, 2, 2).to(device)] TEST_CASE_3 = [se_resnet101(3, 2, 2).to(device)] TEST_CASE_4 = [se_resnet152(3, 2, 2).to(device)] TEST_CASE_5 = [se_resnext50_32x4d(3, 2, 2).to(device)] TEST_CASE_6 = [se_resnext101_32x4d(3, 2, 2).to(device)] TEST_CASE_PRETRAINED = [se_resnet50(2, 3, 2, pretrained=True).to(device)] class TestSENET(unittest.TestCase): @parameterized.expand([ TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6 ]) def test_senet_shape(self, net): input_data = torch.randn(2, 2, 64, 64, 64).to(device) expected_shape = (2, 2) net.eval() with torch.no_grad(): result = net.forward(input_data)
se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154, ) input_param = {"spatial_dims": 3, "in_channels": 2, "num_classes": 10} TEST_CASE_1 = [senet154(**input_param)] TEST_CASE_2 = [se_resnet50(**input_param)] TEST_CASE_3 = [se_resnet101(**input_param)] TEST_CASE_4 = [se_resnet152(**input_param)] TEST_CASE_5 = [se_resnext50_32x4d(**input_param)] TEST_CASE_6 = [se_resnext101_32x4d(**input_param)] class TestSENET(unittest.TestCase): @parameterized.expand([ TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6 ]) def test_senet154_shape(self, net): input_data = torch.randn(5, 2, 64, 64, 64) expected_shape = (5, 10) net.eval() with torch.no_grad(): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape)