Exemplo n.º 1
0
def test_captum_attribution_methods(mask_type, method):
    from captum import attr  # noqa

    captum_model = to_captum(GCN, mask_type, 0)
    input_mask = torch.ones((1, edge_index.shape[1]), dtype=torch.float,
                            requires_grad=True)
    explainer = getattr(attr, method)(captum_model)

    if mask_type == 'node':
        input = x.clone().unsqueeze(0)
        additional_forward_args = (edge_index, )
        sliding_window_shapes = (3, 3)
    elif mask_type == 'edge':
        input = input_mask
        additional_forward_args = (x, edge_index)
        sliding_window_shapes = (5, )
    elif mask_type == 'node_and_edge':
        input = (x.clone().unsqueeze(0), input_mask)
        additional_forward_args = (edge_index, )
        sliding_window_shapes = ((3, 3), (5, ))

    if method == 'IntegratedGradients':
        attributions, delta = explainer.attribute(
            input, target=0, internal_batch_size=1,
            additional_forward_args=additional_forward_args,
            return_convergence_delta=True)
    elif method == 'GradientShap':
        attributions, delta = explainer.attribute(
            input, target=0, return_convergence_delta=True, baselines=input,
            n_samples=1, additional_forward_args=additional_forward_args)
    elif method == 'DeepLiftShap' or method == 'DeepLift':
        attributions, delta = explainer.attribute(
            input, target=0, return_convergence_delta=True, baselines=input,
            additional_forward_args=additional_forward_args)
    elif method == 'Occlusion':
        attributions = explainer.attribute(
            input, target=0, sliding_window_shapes=sliding_window_shapes,
            additional_forward_args=additional_forward_args)
    else:
        attributions = explainer.attribute(
            input, target=0, additional_forward_args=additional_forward_args)
    if mask_type == 'node':
        assert attributions.shape == (1, 8, 3)
    elif mask_type == 'edge':
        assert attributions.shape == (1, 14)
    else:
        assert attributions[0].shape == (1, 8, 3)
        assert attributions[1].shape == (1, 14)
Exemplo n.º 2
0
def test_to_captum(model, mask_type, output_idx):
    captum_model = to_captum(model, mask_type=mask_type, output_idx=output_idx)
    pre_out = model(x, edge_index)
    if mask_type == 'node':
        mask = x * 0.0
        out = captum_model(mask.unsqueeze(0), edge_index)
    elif mask_type == 'edge':
        mask = torch.ones(edge_index.shape[1], dtype=torch.float,
                          requires_grad=True) * 0.5
        out = captum_model(mask.unsqueeze(0), x, edge_index)
    elif mask_type == 'node_and_edge':
        node_mask = x * 0.0
        edge_mask = torch.ones(edge_index.shape[1], dtype=torch.float,
                               requires_grad=True) * 0.5
        out = captum_model(node_mask.unsqueeze(0), edge_mask.unsqueeze(0),
                           edge_index)

    if output_idx is not None:
        assert out.shape == (1, 7)
        assert torch.any(out != pre_out[[output_idx]])
    else:
        assert out.shape == (8, 7)
        assert torch.any(out != pre_out)
    model.train()
    optimizer.zero_grad()
    log_logits = model(data.x, data.edge_index)
    loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

output_idx = 10
target = int(data.y[output_idx])

# Edge explainability
# ===================

# Captum assumes that for all given input tensors, dimension 0 is
# equal to the number of samples. Therefore, we use unsqueeze(0).
captum_model = to_captum(model, mask_type='edge', output_idx=output_idx)
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

ig = IntegratedGradients(captum_model)
ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0),
                            target=target,
                            additional_forward_args=(data.x, data.edge_index),
                            internal_batch_size=1)

# Scale attributions to [0, 1]:
ig_attr_edge = ig_attr_edge.squeeze(0).abs()
ig_attr_edge /= ig_attr_edge.max()

# Visualize absolute values of attributions with GNNExplainer visualizer
explainer = GNNExplainer(model)  # TODO: Change to general Explainer visualizer
ax, G = explainer.visualize_subgraph(output_idx, data.edge_index, ig_attr_edge)