def test_captum_attribution_methods(mask_type, method): from captum import attr # noqa captum_model = to_captum(GCN, mask_type, 0) input_mask = torch.ones((1, edge_index.shape[1]), dtype=torch.float, requires_grad=True) explainer = getattr(attr, method)(captum_model) if mask_type == 'node': input = x.clone().unsqueeze(0) additional_forward_args = (edge_index, ) sliding_window_shapes = (3, 3) elif mask_type == 'edge': input = input_mask additional_forward_args = (x, edge_index) sliding_window_shapes = (5, ) elif mask_type == 'node_and_edge': input = (x.clone().unsqueeze(0), input_mask) additional_forward_args = (edge_index, ) sliding_window_shapes = ((3, 3), (5, )) if method == 'IntegratedGradients': attributions, delta = explainer.attribute( input, target=0, internal_batch_size=1, additional_forward_args=additional_forward_args, return_convergence_delta=True) elif method == 'GradientShap': attributions, delta = explainer.attribute( input, target=0, return_convergence_delta=True, baselines=input, n_samples=1, additional_forward_args=additional_forward_args) elif method == 'DeepLiftShap' or method == 'DeepLift': attributions, delta = explainer.attribute( input, target=0, return_convergence_delta=True, baselines=input, additional_forward_args=additional_forward_args) elif method == 'Occlusion': attributions = explainer.attribute( input, target=0, sliding_window_shapes=sliding_window_shapes, additional_forward_args=additional_forward_args) else: attributions = explainer.attribute( input, target=0, additional_forward_args=additional_forward_args) if mask_type == 'node': assert attributions.shape == (1, 8, 3) elif mask_type == 'edge': assert attributions.shape == (1, 14) else: assert attributions[0].shape == (1, 8, 3) assert attributions[1].shape == (1, 14)
def test_to_captum(model, mask_type, output_idx): captum_model = to_captum(model, mask_type=mask_type, output_idx=output_idx) pre_out = model(x, edge_index) if mask_type == 'node': mask = x * 0.0 out = captum_model(mask.unsqueeze(0), edge_index) elif mask_type == 'edge': mask = torch.ones(edge_index.shape[1], dtype=torch.float, requires_grad=True) * 0.5 out = captum_model(mask.unsqueeze(0), x, edge_index) elif mask_type == 'node_and_edge': node_mask = x * 0.0 edge_mask = torch.ones(edge_index.shape[1], dtype=torch.float, requires_grad=True) * 0.5 out = captum_model(node_mask.unsqueeze(0), edge_mask.unsqueeze(0), edge_index) if output_idx is not None: assert out.shape == (1, 7) assert torch.any(out != pre_out[[output_idx]]) else: assert out.shape == (8, 7) assert torch.any(out != pre_out)
model.train() optimizer.zero_grad() log_logits = model(data.x, data.edge_index) loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() output_idx = 10 target = int(data.y[output_idx]) # Edge explainability # =================== # Captum assumes that for all given input tensors, dimension 0 is # equal to the number of samples. Therefore, we use unsqueeze(0). captum_model = to_captum(model, mask_type='edge', output_idx=output_idx) edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device) ig = IntegratedGradients(captum_model) ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=target, additional_forward_args=(data.x, data.edge_index), internal_batch_size=1) # Scale attributions to [0, 1]: ig_attr_edge = ig_attr_edge.squeeze(0).abs() ig_attr_edge /= ig_attr_edge.max() # Visualize absolute values of attributions with GNNExplainer visualizer explainer = GNNExplainer(model) # TODO: Change to general Explainer visualizer ax, G = explainer.visualize_subgraph(output_idx, data.edge_index, ig_attr_edge)