def test_shape(self, input_data, expected_shape): if input_data["model"] == "densenet2d": model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) if input_data["model"] == "densenet3d": model = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6, )) if input_data["model"] == "senet2d": model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) if input_data["model"] == "senet3d": model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() cam = CAM(nn_module=model, target_layers=input_data["target_layers"], fc_layers=input_data["fc_layers"]) image = torch.rand(input_data["shape"], device=device) result = cam(x=image, layer_idx=-1) fea_shape = cam.feature_map_size(input_data["shape"], device=device) self.assertTupleEqual(fea_shape, input_data["feature_shape"]) self.assertTupleEqual(result.shape, expected_shape)
def test_shape(self, input_data, expected_shape): if input_data["model"] == "densenet2d": model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) if input_data["model"] == "densenet3d": model = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6, )) if input_data["model"] == "senet2d": model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) if input_data["model"] == "senet3d": model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() cam = GradCAM(nn_module=model, target_layers=input_data["target_layers"]) image = torch.rand(input_data["shape"], device=device) result = cam(x=image, layer_idx=-1) np.testing.assert_array_equal(cam.nn_module.class_idx.cpu(), model(image).max(1)[-1].cpu()) fea_shape = cam.feature_map_size(input_data["shape"], device=device) self.assertTupleEqual(fea_shape, input_data["feature_shape"]) self.assertTupleEqual(result.shape, expected_shape) # check result is same whether class_idx=None is used or not result2 = cam(x=image, layer_idx=-1, class_idx=model(image).max(1)[-1].cpu()) np.testing.assert_array_almost_equal(result, result2)
def setUp(self): self.original_urls = se_mod.SE_NET_MODELS.copy() replace_url = test_is_quick() if not replace_url: try: SEResNet50(pretrained=True, spatial_dims=2, in_channels=3, num_classes=2) except OSError as rt_e: print(rt_e) if "certificate" in str(rt_e): # [SSL: CERTIFICATE_VERIFY_FAILED] replace_url = True if replace_url: testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") testing_data_urls = { "senet154": { "url": testing_data_config("models", "senet154-c7b49a05", "url"), "filename": "senet154-c7b49a05.pth", }, "se_resnet50": { "url": testing_data_config("models", "se_resnet50-ce0d4300", "url"), "filename": "se_resnet50-ce0d4300.pth", }, "se_resnet101": { "url": testing_data_config("models", "se_resnet101-7e38fcc6", "url"), "filename": "se_resnet101-7e38fcc6.pth", }, "se_resnet152": { "url": testing_data_config("models", "se_resnet152-d17c99b7", "url"), "filename": "se_resnet152-d17c99b7.pth", }, "se_resnext50_32x4d": { "url": testing_data_config("models", "se_resnext50_32x4d-a260b3a4", "url"), "filename": "se_resnext50_32x4d-a260b3a4.pth", }, "se_resnext101_32x4d": { "url": testing_data_config("models", "se_resnext101_32x4d-3b2fe3d8", "url"), "filename": "se_resnext101_32x4d-3b2fe3d8.pth", }, } for item in testing_data_urls: testing_data_urls[item]["filename"] = os.path.join(testing_dir, testing_data_urls[item]["filename"]) se_mod.SE_NET_MODELS = testing_data_urls
import unittest import torch from parameterized import parameterized from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 from monai.visualize import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad DENSENET2D = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) DENSENET3D = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6, )) SENET2D = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) SENET3D = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) TESTS = [] for type in (VanillaGrad, SmoothGrad, GuidedBackpropGrad, GuidedBackpropSmoothGrad): # 2D densenet TESTS.append([type, DENSENET2D, (1, 1, 48, 64), (1, 1, 48, 64)]) # 3D densenet TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6), (1, 1, 6, 6, 6)]) # 2D senet TESTS.append([type, SENET2D, (1, 3, 64, 64), (1, 1, 64, 64)]) # 3D senet TESTS.append([type, SENET3D, (1, 3, 8, 8, 48), (1, 1, 8, 8, 48)])