Ejemplo n.º 1
0
def test_enzymes():
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = TUDataset(root, 'ENZYMES')

    assert len(dataset) == 600
    assert dataset.num_features == 3
    assert dataset.num_classes == 6
    assert dataset.__repr__() == 'ENZYMES(600)'

    assert len(dataset[0]) == 3
    assert len(dataset.shuffle()) == 600
    assert len(dataset.shuffle(return_perm=True)) == 2
    assert len(dataset[:100]) == 100
    assert len(dataset[torch.arange(100, dtype=torch.long)]) == 100
    mask = torch.zeros(600, dtype=torch.bool)
    mask[:100] = 1
    assert len(dataset[mask]) == 100

    loader = DataLoader(dataset, batch_size=len(dataset))
    for data in loader:
        assert data.num_graphs == 600

        avg_num_nodes = data.num_nodes / data.num_graphs
        assert pytest.approx(avg_num_nodes, abs=1e-2) == 32.63

        avg_num_edges = data.num_edges / (2 * data.num_graphs)
        assert pytest.approx(avg_num_edges, abs=1e-2) == 62.14

        assert len(data) == 5
        assert list(data.x.size()) == [data.num_nodes, 3]
        assert list(data.y.size()) == [data.num_graphs]
        assert data.y.max() + 1 == 6
        assert list(data.batch.size()) == [data.num_nodes]
        assert data.ptr.numel() == data.num_graphs + 1

        assert data.has_isolated_nodes()
        assert not data.has_self_loops()
        assert data.is_undirected()

    loader = DataListLoader(dataset, batch_size=len(dataset))
    for data_list in loader:
        assert len(data_list) == 600

    dataset.transform = ToDense(num_nodes=126)
    loader = DenseDataLoader(dataset, batch_size=len(dataset))
    for data in loader:
        assert len(data) == 4
        assert list(data.x.size()) == [600, 126, 3]
        assert list(data.adj.size()) == [600, 126, 126]
        assert list(data.mask.size()) == [600, 126]
        assert list(data.y.size()) == [600, 1]

    dataset = TUDataset(root, 'ENZYMES', use_node_attr=True)
    assert dataset.num_node_features == 21
    assert dataset.num_features == 21
    assert dataset.num_edge_features == 0

    shutil.rmtree(root)
def test_enzymes():
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)), 'test')
    dataset = TUDataset(root, 'ENZYMES')

    assert len(dataset) == 600
    assert dataset.num_features == 21
    assert dataset.num_classes == 6
    assert dataset.__repr__() == 'ENZYMES(600)'

    assert len(dataset.__getitem__(0)) == 3

    assert len(dataset.shuffle()) == 600
    assert len(dataset[:100]) == 100
    assert len(dataset[torch.arange(100, dtype=torch.long)]) == 100
    mask = torch.zeros(600, dtype=torch.uint8)
    mask[:100] = 1
    assert len(dataset[mask]) == 100

    loader = DataLoader(dataset, batch_size=len(dataset))
    for data in loader:
        assert data.num_graphs == 600

        avg_num_nodes = data.num_nodes / data.num_graphs
        assert pytest.approx(avg_num_nodes, abs=1e-2) == 32.63

        avg_num_edges = data.num_edges / (2 * data.num_graphs)
        assert pytest.approx(avg_num_edges, abs=1e-2) == 62.14

        assert len(data) == 4
        assert list(data.x.size()) == [data.num_nodes, 21]
        assert list(data.y.size()) == [data.num_graphs]
        assert data.y.max() + 1 == 6
        assert list(data.batch.size()) == [data.num_nodes]

        assert data.contains_isolated_nodes()
        assert not data.contains_self_loops()
        assert data.is_undirected()

    dataset.transform = ToDense(num_nodes=126)
    loader = DenseDataLoader(dataset, batch_size=len(dataset))
    for data in loader:
        assert len(data) == 4
        assert list(data.x.size()) == [600, 126, 21]
        assert list(data.adj.size()) == [600, 126, 126]
        assert list(data.mask.size()) == [600, 126]
        assert list(data.y.size()) == [600, 1]

    shutil.rmtree(root)
Ejemplo n.º 3
0
def test_mutag():
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = TUDataset(root, 'MUTAG')

    assert len(dataset) == 188
    assert dataset.num_features == 7
    assert dataset.num_classes == 2
    assert dataset.__repr__() == 'MUTAG(188)'

    assert len(dataset[0]) == 4
    assert dataset[0].edge_attr.size(1) == 4

    dataset = TUDataset(root, 'MUTAG', use_node_attr=True)
    assert dataset.num_features == 7

    shutil.rmtree(root)
Ejemplo n.º 4
0
def test_bzr():
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = TUDataset(root, 'BZR')

    assert len(dataset) == 405
    assert dataset.num_features == 53
    assert dataset.num_node_labels == 53
    assert dataset.num_node_attributes == 0
    assert dataset.num_classes == 2
    assert dataset.__repr__() == 'BZR(405)'
    assert len(dataset[0]) == 3

    dataset = TUDataset(root, 'BZR', use_node_attr=True)
    assert dataset.num_features == 56
    assert dataset.num_node_labels == 53
    assert dataset.num_node_attributes == 3

    shutil.rmtree(root)
Ejemplo n.º 5
0
def test_baseDataset_Loader(filted_dataset=None):
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    shutil.copytree('../input/smt', root)
    dataset = TUDataset(root, 'SMT')

    assert len(dataset) == 2688
    assert dataset.num_features == 20
    assert dataset.num_classes == 2
    assert dataset.__repr__() == 'SMT(2688)'
    assert dataset[0].keys == ['x', 'edge_index', 'y']  # ==len(data.keys)
    assert len(dataset.shuffle()) == 2688

    loader = DataLoader(dataset, batch_size=len(dataset))
    assert loader.dataset.__repr__() == 'SMT(2688)'
    for batch in loader:
        assert batch.num_graphs == 2688
        assert batch.num_nodes == sum([data.num_nodes
                                       for data in dataset])  # 2788794
        assert batch.num_edges == sum([data.num_edges
                                       for data in dataset])  # 13347768
        assert batch.keys == ['x', 'edge_index', 'y', 'batch']

    num_nodes = sum(dataset.data.num_nodes)
    max_num_nodes = max(dataset.data.num_nodes)
    num_nodes = min(int(num_nodes / len(dataset) * 5), max_num_nodes)

    assert num_nodes == 5187
    assert max_num_nodes == 34623

    indices = []
    for i, data in enumerate(dataset):
        if data.num_nodes < num_nodes:
            indices.append(i)

    if not filted_dataset:
        filted_dataset = dataset[torch.tensor(indices)]
        filted_dataset.transform = T.ToDense(num_nodes)  # add 'adj' attribute

    assert ('adj' in dataset[0]) is False
    assert ('adj' in filted_dataset[0]) is True