Beispiel #1
0
    def __call__(  # type: ignore
        self,
        inputs: torch.Tensor,
        network: nn.Module,
        *args: Any,
        **kwargs: Any,
    ):
        """Unified callable function API of Inferers.

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``
            args: other optional args to be passed to the `__call__` of cam.
            kwargs: other optional keyword args to be passed to `__call__` of cam.

        """
        cam: Union[CAM, GradCAM, GradCAMpp]
        if self.cam_name == "cam":
            cam = CAM(network, self.target_layers, *self.args, **self.kwargs)
        elif self.cam_name == "gradcam":
            cam = GradCAM(network, self.target_layers, *self.args,
                          **self.kwargs)
        else:
            cam = GradCAMpp(network, self.target_layers, *self.args,
                            **self.kwargs)

        return cam(inputs, self.class_idx, *args, **kwargs)
Beispiel #2
0
 def test_shape(self, input_data, expected_shape):
     if input_data["model"] == "densenet2d":
         model = densenet121(spatial_dims=2, in_channels=1, out_channels=3)
     if input_data["model"] == "densenet3d":
         model = DenseNet(spatial_dims=3,
                          in_channels=1,
                          out_channels=3,
                          init_features=2,
                          growth_rate=2,
                          block_config=(6, ))
     if input_data["model"] == "senet2d":
         model = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4)
     if input_data["model"] == "senet3d":
         model = se_resnet50(spatial_dims=3, in_channels=3, num_classes=4)
     device = "cuda:0" if torch.cuda.is_available() else "cpu"
     model.to(device)
     model.eval()
     cam = GradCAMpp(nn_module=model,
                     target_layers=input_data["target_layers"])
     image = torch.rand(input_data["shape"], device=device)
     result = cam(x=image, layer_idx=-1)
     fea_shape = cam.feature_map_size(input_data["shape"], device=device)
     self.assertTupleEqual(fea_shape, input_data["feature_shape"])
     self.assertTupleEqual(result.shape, expected_shape)