def test_gnn_explainer_explain_node(model, return_type, allow_edge_mask,
                                    feat_mask_type):
    explainer = GNNExplainer(model, log=False, return_type=return_type,
                             allow_edge_mask=allow_edge_mask,
                             feat_mask_type=feat_mask_type)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    y = torch.tensor([0, 1, 1, 0, 1, 0, 1, 0])
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    if feat_mask_type == 'scalar':
        _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y,
                                            threshold=0.8,
                                            node_alpha=node_feat_mask)
    else:
        edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), ))
        _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y,
                                            edge_y=edge_y, threshold=0.8)
    if feat_mask_type == 'individual_feature':
        assert node_feat_mask.size() == x.size()
    elif feat_mask_type == 'scalar':
        assert node_feat_mask.size() == (x.size(0), )
    else:
        assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1
    if not allow_edge_mask:
        assert edge_mask[:8].tolist() == [1.] * 8
        assert edge_mask[8:].tolist() == [0.] * 6
def test_gnn_explainer(model):
    explainer = GNNExplainer(model, log=False)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1
def test_gnn_explainer(model):
    explainer = GNNExplainer(model, log=False)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])
    y = torch.randint(0, 6, (8, ), dtype=torch.long)

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1

    explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=None)
    explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=0.5)
    explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=None)
    explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=0.5)