def test_context_stack(mocker): """Check that the context stack is pushed and popped correctly.""" # Default is cpu assert_that(midnite.get_device()).is_equal_to(torch.device("cpu")) # Patch out torch mock_device = mocker.Mock(spec=torch.device) mocker.patch("torch.device", return_value=mock_device) # Context with midnite.device("cuda:0"): torch.device.assert_called_once_with("cuda:0") torch.device.reset_mock() assert_that(midnite.get_device()).is_equal_to(mock_device) with midnite.device("cuda:1"): torch.device.assert_called_once_with("cuda:1") assert_that(midnite.get_device()).is_equal_to(mock_device) assert_that(midnite.get_device()).is_not_equal_to(mock_device)
# # The input image is split spatially (`SpatialSplit`), i.e. we _stride_ through the the spatial positions, doing a measurement at every step. The _chunk size_ controls which parts of the image are occluded. # # ## Paramters: Chunk Size and Stride # To be generic for any split, _chunk size_ and _stride_ are three-dimensional tuples for (depth, height, width). In our example, the depth dimension is irrelevant, since our split is spatial. # # A small _chunk size_ creates a nosiy, fine-grained heatmap, as only small parts of the image features are occluded. On the other hand, a larger chunk size rather shows which _areas_ are important. # In[3]: import midnite from midnite.visualization.base import * # Use 'cpu' if you have no GPU available with midnite.device('cuda:0'): show_heatmap(Occlusion( alexnet, SplitSelector(NeuronSplit(), [283]), SpatialSplit(), chunk_size=(1, 3, 3), stride=(1, 10, 10), ).visualize(example_img)) show_heatmap(Occlusion( alexnet, SplitSelector(NeuronSplit(), [283]), SpatialSplit(), chunk_size=(1, 15, 15), stride=(1, 10, 10), ).visualize(example_img))
# ### GradCAM # Shows which spatial locations had influence on the most important channels for classification. # In[6]: gc_heatmap = gradcam(features, classifier, example_img) show_heatmap(gc_heatmap, scale=1.2) # ### Occlusion # Shows which parts of the image are crucial for correct classification. Requires much more computation time since it's not gradient-based, but should be more robust to noise than other methods. Read more about it in the [Occlusion notebook](https://luminovo.gitlab.io/public/midnite/latest/notebooks/details_occlusion.html). # In[7]: # does not run in short time without GPU with midnite.device("cuda:0"): oc_heatmap = occlusion(net, example_img) show_heatmap(oc_heatmap, scale=1.2) # ### Class Visualization # Generates an image that maximally excites the network for a given class. This is an optimization of the network's gradient in image space and thus requires some computation power. # # However, since there is no defined "goal" for visual clarity, it is usually necessary to manually tune the optimization parameters. This can be done by using the [Max Mean Activation](https://luminovo.gitlab.io/public/midnite/latest/notebooks/details_max_mean_activation.html) building block. # In[8]: # does not run in short time without GPU with midnite.device("cuda:0"): cv_img = class_visualization(net, 283)
def test_context_fail_early(): """Checks that creating a context with invalid device immediately fails""" with pytest.raises(RuntimeError): with midnite.device("does not exist"): pass
# ### Step 4: Calculate Uncertainties # # Correct label for in-distribution image: 283 # In[5]: import torch import midnite import tabulate # Run without gradients, on cpu (use cuda:0 instead if you have a gpu available) with torch.no_grad(): with midnite.device("cpu"): id_pr = alexnet_ensemble(id_example) ood_pr = alexnet_ensemble(ood_example) rand_pr = alexnet_ensemble(random_example) # Print pretty table table = [ ["max prediction", id_pr[0].argmax(), ood_pr[0].argmax(), rand_pr[0].argmax()], ["max probability", f"{id_pr[0].max():.3f}", f"{ood_pr[0].max():.3f}", f"{rand_pr[0].max():.3f}"], ["pred. entropy (~total uncert.)", f"{id_pr[1].sum():.3f}", f"{ood_pr[1].sum():.3f}", f"{rand_pr[1].sum():.3f}"], ["mutual info. (~model uncert.)", f"{id_pr[2].sum():.3f}", f"{ood_pr[2].sum():.3f}", f"{rand_pr[2].sum():.3f}"] ]