Пример #1
0
def test_register_hooks_to_relu_layers(mocker, available_models):
    print()
    for name, model_module in available_models:
        print(f'Testing model: {name}', end='\r')
        stdout.write('\x1b[2K')

        model = model_module()

        relu_layers = find_relu_layers(model, nn.ReLU)

        for layer in relu_layers:
            mocker.spy(layer, 'register_forward_hook')
            mocker.spy(layer, 'register_backward_hook')

        backprop = Backprop(model)

        target_class = 5
        input_ = torch.zeros([1, 3, 224, 224])

        if 'inception' in name:
            input_ = torch.zeros([1, 3, 299, 299])

        make_mock_output(mocker, model, target_class)

        backprop.calculate_gradients(input_, target_class, guided=True)

        for layer in relu_layers:

            layer.register_forward_hook.assert_called_once()
            layer.register_backward_hook.assert_called_once()
Пример #2
0
def test_checks_input_size_for_inception_model(mocker):
    with pytest.raises(ValueError) as error:
        model = models.inception_v3()
        backprop = Backprop(model)

        target_class = 5
        input_ = torch.zeros([1, 3, 224, 224])

        backprop.calculate_gradients(input_, target_class)

    assert 'Image must be 299x299 for Inception models.' in str(error.value)
Пример #3
0
def test_handle_binary_classifier(mocker, model):
    backprop = Backprop(model)

    target_class = 0
    input_ = torch.zeros([1, 3, 224, 224])

    mock_output = torch.tensor([0.8])
    mock_output.requires_grad = True
    mocker.patch.object(model, 'forward', return_value=mock_output)

    backprop.calculate_gradients(input_, target_class)
Пример #4
0
def test_warn_when_prediction_is_wrong(mocker, model):
    backprop = Backprop(model)

    top_class = torch.tensor(1)
    target_class = 5

    input_ = torch.zeros([1, 3, 224, 224])

    make_mock_output(mocker, model, top_class)

    with pytest.warns(UserWarning):
        backprop.calculate_gradients(input_, target_class)
Пример #5
0
def test_zero_out_gradients(mocker, model):
    backprop = Backprop(model)
    mocker.spy(model, 'zero_grad')

    target_class = 5
    input_ = torch.zeros([1, 3, 224, 224])

    make_mock_output(mocker, model, target_class)

    backprop.calculate_gradients(input_, target_class)

    model.zero_grad.assert_called_once()
Пример #6
0
def test_calc_gradients_of_top_class_if_target_not_provided(mocker, model):
    backprop = Backprop(model)

    top_class = 5
    input_ = torch.zeros([1, 3, 224, 224])

    target = make_expected_gradient_target(top_class)

    mock_output = make_mock_output(mocker, model, top_class)

    backprop.calculate_gradients(input_)

    args, kwargs = mock_output.backward.call_args

    assert torch.all(kwargs['gradient'].eq(target))
Пример #7
0
def test_handle_greyscale_input(mocker, model_grayscale):
    backprop = Backprop(model_grayscale)

    input_ = torch.zeros([1, 1, 224, 224], requires_grad=True)

    gradients = backprop.calculate_gradients(input_)

    assert gradients.shape == (1, 224, 224)
Пример #8
0
def test_calc_gradients_of_top_class_if_prediction_is_wrong(mocker, model):
    backprop = Backprop(model)

    top_class = torch.tensor(5)
    target_class = 7
    input_ = torch.zeros([1, 3, 224, 224])

    target = make_expected_gradient_target(top_class)

    mock_output = make_mock_output(mocker, model, top_class)

    with pytest.warns(UserWarning):
        backprop.calculate_gradients(input_, target_class)

    args, kwargs = mock_output.backward.call_args

    assert torch.all(kwargs['gradient'].eq(target))
Пример #9
0
def test_return_max_across_color_channels_if_specified(mocker, model):
    backprop = Backprop(model)

    target_class = 5
    input_ = torch.zeros([1, 3, 224, 224])

    make_mock_output(mocker, model, target_class)

    gradients = backprop.calculate_gradients(input_,
                                             target_class,
                                             take_max=True)

    assert gradients.shape == (1, 224, 224)
Пример #10
0
def saliency_map(model, img_p):
    """
    Return saliency map over the image : shape (255, 255)
    """

    # Load and preprocess image.
    X = load_and_preprocess_img(img_p)  # (3, 255, 255 torch tensor)
    X.requires_grad_()  # This is critical to actually get gradients.
    """
    # Predict y.
    ypred = 4 # (for now: set ypred manually)

    # Can I get any gradients at all.
    with torch.set_grad_enabled(True):
        out = model(X)
        out = out.reshape((5,))
        out.backward(torch.FloatTensor([1., 1., 1., 1., 1.]))        
        saliency = X.grad
        print(saliency)
    """

    # Predict grade and get gradient. Use flashtorch library.
    # See https://mc.ai/feature-visualisation-in-pytorch%E2%80%8A-%E2%80%8Asaliency-maps/
    with torch.set_grad_enabled(True):
        backprop = Backprop(model)  # flashtorch.saliency Backprop object.
        gradients = backprop.calculate_gradients(X,
                                                 take_max=True,
                                                 guided=False)  # (1, 255, 255)

    # Cast image, saliency maps to numpy arrays.
    X = X.detach()  # must 'detach' from gradients before slicing.
    img_np = X.numpy()[0]  # (3, 255, 255)
    img_np = img_np.swapaxes(0, 1)  # (255, 3, 255)
    img_np = img_np.swapaxes(1, 2)  # (255, 255, 3)
    saliency_map_np = gradients.numpy()[0]  # (255, 255)
    print(max(np.max(saliency_map_np, axis=0)))
    print(saliency_map_np)
    print(img_np.shape)
    print(saliency_map_np.shape)

    # Smooth heatmap.
    saliency_map_np = gaussian_filter(saliency_map_np, sigma=10)

    # Plot image and overlay saliency map.
    heatmap = sns.heatmap(saliency_map_np, alpha=0.5)
    heatmap.imshow(img_np, cmap="YlGnBu")
    plt.show()

    return saliency_map_np
Пример #11
0
def test_register_hooks_to_relu_layers(mocker, name, model_module):
    model = model_module()
    relu_layers = find_relu_layers(model, nn.ReLU)

    for layer in relu_layers:
        mocker.spy(layer, 'register_forward_hook')
        mocker.spy(layer, 'register_backward_hook')

    backprop = Backprop(model)

    target_class = 5
    input_ = torch.zeros([1, 3, 224, 224])

    if 'inception' in name:
        input_ = torch.zeros([1, 3, 299, 299])

    make_mock_output(mocker, model, target_class)

    backprop.calculate_gradients(input_, target_class, guided=True)

    for layer in relu_layers:

        layer.register_forward_hook.assert_called_once()
        layer.register_backward_hook.assert_called_once()
Пример #12
0
def get_saliency_map(model, img_p, ypred):
    """
    Return saliency map over the image : shape (255, 255)

    Parameters
    ----------
    model (PyTorch model object)
    img_p (str) path to image
    ypred (int) 0-4. Used to control saliency map

    Returns
    -----------
    img_np : img as (255, 255, 3) numpy array
    saliency_map_np : saliency map as (255, 255) numpy array

    TODO:
    Examine this code more closely. At the moment, saliency maps
    don't change much across classes. I think a different saliency
    mapping technique is needed. The Guided=True flag may be possiblw
    but at the moment causes NaN errors.

    """

    # Load and preprocess image: (1, 3, 255, 255) torch tensor.
    X = load_and_preprocess_img(img_p)

    # Require gradient.
    X.requires_grad_()  # This is critical to actually get gradients.

    # Get gradient using flashtorch.
    with torch.set_grad_enabled(True):
        backprop = Backprop(model)
        gradients = backprop.calculate_gradients(input_=X,
                                                 target_class=ypred,
                                                 take_max=True,
                                                 guided=False)  # (1, 255, 255)

    # Cast image and saliency maps to numpy arrays.
    X = X.detach()
    img_np = X.numpy()[0]  # (3, 255, 255)
    img_np = img_np.transpose(1, 2, 0)  # (255, 255, 3)
    saliency_map_np = gradients[0].numpy()
    #saliency_map_np = np.absolute(saliency_map_np) # absolute value
    # Smooth heatmap.
    saliency_map_np = gaussian_filter(saliency_map_np, sigma=10)

    return img_np, saliency_map_np
Пример #13
0
def test_calculate_gradients_for_all_models(mocker, name, model_module):
    model = model_module()
    backprop = Backprop(model)

    target_class = 5
    input_ = torch.zeros([1, 3, 224, 224])

    if 'inception' in name:
        input_ = torch.zeros([1, 3, 299, 299])

    make_mock_output(mocker, model, target_class)

    gradients = backprop.calculate_gradients(input_,
                                             target_class,
                                             use_gpu=True)

    assert gradients.shape == input_.size()[1:]
Пример #14
0
def test_calculate_gradients_for_all_models(mocker, available_models):
    print()
    for name, model in available_models:
        print(f'Testing model: {name}', end='\r')
        stdout.write('\x1b[2K')

        model = model()
        backprop = Backprop(model)

        target_class = 5
        input_ = torch.zeros([1, 3, 224, 224])

        if 'inception' in name:
            input_ = torch.zeros([1, 3, 299, 299])

        make_mock_output(mocker, model, target_class)

        gradients = backprop.calculate_gradients(input_, target_class)

        assert gradients.shape == input_.size()[1:]
Пример #15
0
plt.axis('off')

# next is needed, otherwise no image pops up
plt.waitforbuttonpress()

model = models.alexnet(pretrained=True)
backprop = Backprop(model)

imagenet = ImageNetIndex()
target_class = imagenet['tabby cat']

input_ = apply_transforms(image)

# Calculate the gradients of each pixel w.r.t. the input image

gradients = backprop.calculate_gradients(input_, target_class)

# Or, take the maximum of the gradients for each pixel across colour channels.

max_gradients = backprop.calculate_gradients(input_,
                                             target_class,
                                             take_max=True)

print('Shape of the gradients:', gradients.shape)
print('Shape of the max gradients:', max_gradients.shape)

visualize(input_, gradients, max_gradients)
plt.waitforbuttonpress()

backprop = Backprop(model)