Exemplo n.º 1
0
def test_hetero_in_memory_dataset():
    data1 = HeteroData()
    data1.y = torch.randn(5)
    data1['paper'].x = torch.randn(10, 16)
    data1['paper', 'paper'].edge_index = torch.randint(0, 10, (2, 30)).long()

    data2 = HeteroData()
    data2.y = torch.randn(5)
    data2['paper'].x = torch.randn(10, 16)
    data2['paper', 'paper'].edge_index = torch.randint(0, 10, (2, 30)).long()

    dataset = MyTestDataset([data1, data2])
    assert str(dataset) == 'MyTestDataset(2)'
    assert len(dataset) == 2

    assert len(dataset[0]) == 3
    assert dataset[0].y.tolist() == data1.y.tolist()
    assert dataset[0]['paper'].x.tolist() == data1['paper'].x.tolist()
    assert (dataset[0]['paper', 'paper'].edge_index.tolist() == data1[
        'paper', 'paper'].edge_index.tolist())

    assert len(dataset[1]) == 3
    assert dataset[1].y.tolist() == data2.y.tolist()
    assert dataset[1]['paper'].x.tolist() == data2['paper'].x.tolist()
    assert (dataset[1]['paper', 'paper'].edge_index.tolist() == data2[
        'paper', 'paper'].edge_index.tolist())
Exemplo n.º 2
0
def test_hetero_data_functions():
    data = HeteroData()
    data['paper'].x = x_paper
    data['author'].x = x_author
    data['paper', 'paper'].edge_index = edge_index_paper_paper
    data['paper', 'author'].edge_index = edge_index_paper_author
    data['author', 'paper'].edge_index = edge_index_author_paper
    data['paper', 'paper'].edge_attr = edge_attr_paper_paper
    assert len(data) == 3
    assert sorted(data.keys) == ['edge_attr', 'edge_index', 'x']
    assert 'x' in data and 'edge_index' in data and 'edge_attr' in data
    assert data.num_nodes == 15
    assert data.num_edges == 110

    assert data.num_node_features == {'paper': 16, 'author': 32}
    assert data.num_edge_features == {
        ('paper', 'to', 'paper'): 8,
        ('paper', 'to', 'author'): 0,
        ('author', 'to', 'paper'): 0,
    }

    node_types, edge_types = data.metadata()
    assert node_types == ['paper', 'author']
    assert edge_types == [
        ('paper', 'to', 'paper'),
        ('paper', 'to', 'author'),
        ('author', 'to', 'paper'),
    ]

    x_dict = data.collect('x')
    assert len(x_dict) == 2
    assert x_dict['paper'].tolist() == x_paper.tolist()
    assert x_dict['author'].tolist() == x_author.tolist()
    assert x_dict == data.x_dict

    data.y = 0
    assert data['y'] == 0 and data.y == 0
    assert len(data) == 4
    assert sorted(data.keys) == ['edge_attr', 'edge_index', 'x', 'y']

    del data['paper', 'author']
    node_types, edge_types = data.metadata()
    assert node_types == ['paper', 'author']
    assert edge_types == [('paper', 'to', 'paper'), ('author', 'to', 'paper')]

    assert len(data.to_dict()) == 5
    assert len(data.to_namedtuple()) == 5
    assert data.to_namedtuple().y == 0
    assert len(data.to_namedtuple().paper) == 1
Exemplo n.º 3
0
    def generate_data(self) -> HeteroData:
        data = HeteroData()

        iterator = zip(self.node_types, self.num_channels)
        for i, (node_type, num_channels) in enumerate(iterator):
            num_nodes = get_num_nodes(self.avg_num_nodes, self.avg_degree)

            store = data[node_type]

            if num_channels > 0:
                store.x = torch.randn(num_nodes, num_channels)
            else:
                store.num_nodes = num_nodes

            if self._num_classes > 0 and self.task == 'node' and i == 0:
                store.y = torch.randint(self._num_classes, (num_nodes, ))

        for (src, rel, dst) in self.edge_types:
            store = data[(src, rel, dst)]

            store.edge_index = get_edge_index(
                data[src].num_nodes,
                data[dst].num_nodes,
                self.avg_degree,
                is_undirected=False,
                remove_loops=False,
            )

            if self.edge_dim > 1:
                store.edge_attr = torch.rand(store.num_edges, self.edge_dim)
            elif self.edge_dim == 1:
                store.edge_weight = torch.rand(store.num_edges)

            pass

        if self._num_classes > 0 and self.task == 'graph':
            data.y = torch.tensor([random.randint(0, self._num_classes - 1)])

        for feature_name, feature_shape in self.kwargs.items():
            setattr(data, feature_name, torch.randn(feature_shape))

        return data