def test_correct_labels_resampled():
    dataset = datasets.omitted_with_actions(['LT1001_TA05.net'], resample=True)
    labels = [get_label_type(dataset, graph) for graph in dataset]
    unique, counts = np.unique(labels, return_counts=True)
    label_dist = dict(zip(unique, counts))
    for (label, count) in label_dist.items():
        max_count = expected_label_count[label]
        assert count <= max_count, f'Expected <= {max_count} #{label} labels but found {count}'
def test_targets_should_be_different_types():
    dataset = datasets.omitted_with_actions(['LT1001_TA05.net'],
                                            resample=False,
                                            normalize=False)
    for graph in dataset:
        prototype_idx = (graph.x[:, -1] == 1).nonzero()[0]
        valid_target_idx = (graph.y > -1).nonzero()[0]
        for (prototype, target) in zip(prototype_idx, valid_target_idx):
            assert prototype == target
def test_only_one_prototype_per_type():
    dataset = datasets.omitted_with_actions(['LT1001_TA05.net'],
                                            resample=False,
                                            normalize=False)
    for graph in dataset:
        prototype_idx = (graph.x[:, -1] == 1).nonzero()[0]
        prototypes = graph.x[prototype_idx]
        type_counts = np.sum(prototypes, axis=0)[0:-1]
        assert (type_counts > 1).nonzero()[0].size == 0
def test_components_should_be_connected():
    dataset = datasets.omitted_with_actions(['LT1001_TA05.net'],
                                            resample=False,
                                            normalize=False,
                                            shuffle=False)
    for (i, graph) in enumerate(dataset):
        adj = graph.a.toarray()
        edge_counts = np.sum(adj, axis=0)
        for (edge_index, edge_count) in enumerate(edge_counts):
            assert edge_count > 0, f'Found disconnected node: {edge_index}'
def test_no_dupe_types_in_proto():
    dataset = datasets.omitted_with_actions(['LT1001_TA05.net'],
                                            resample=False,
                                            normalize=False,
                                            shuffle=False)
    for (i, graph) in enumerate(dataset):
        targets = graph.y
        prototypes = graph.x[(targets > -1).nonzero()[0]]
        node_types = dataset.get_node_types(prototypes)
        types, counts = np.unique(node_types, return_counts=True)
        type_dict = dict(zip(types, counts))
        for (node_type, count) in type_dict.items():
            assert count == 1, f'Found {count} nodes of type {node_type}'
def test_get_node_types():
    dataset = datasets.omitted_with_actions(['LT1001_TA05.net'],
                                            resample=False)
    graph = dataset[0]
    types = dataset.get_node_types(graph.x)
    unique, counts = np.unique(types, return_counts=True)
    single_graph_count = dict(zip(unique, counts))

    found_omitted_type = False
    for (label, component_count) in expected_label_count.items():
        count = single_graph_count[label]
        if count != (component_count + 1):
            if count == component_count and not found_omitted_type:
                found_omitted_type = True
                continue
            raise Exception(
                f'Expected {expected+1} #{label} labels but found {count}')
def test_to_networkx():
    dataset = datasets.omitted_with_actions([filename], shuffle=False)
    nx_data = datasets.helpers.to_networkx(dataset)

    assert len(dataset) == len(
        nx_data), f'Expected to find 20 graphs. Found {len(nx_data)}'
    for (i, nx_graph) in enumerate(nx_data):
        # Check node counts
        assert len(
            nx_graph.nodes
        ) == 19, f'Expected 19 nodes. Found {nx_graph.node_feature.shape[0]}'

        # Check edge counts
        graph = dataset[i]
        expected_edge_count = graph.n_edges / 2  # Converting to bidirectional graph
        edge_count = len(nx_graph.edges)
        print('----', len(dataset.helpers.component_types))
        assert edge_count == expected_edge_count, f'Expected {expected_edge_count} edges. Found {edge_count}'
def test_load_graphs():
    dataset = datasets.omitted_with_actions(['LT1001_TA05.net'],
                                            resample=False)
    assert len(
        dataset) == 20, f'Expected to find 20 graphs. Found {len(dataset)}'
def test_has_edges():
    dataset = datasets.omitted_with_actions(['LT1001_TA05.net'],
                                            resample=False,
                                            normalize=False)
    for graph in dataset:
        assert graph.n_edges > 0