def test_hetero_in_memory_dataset(): data1 = HeteroData() data1.y = torch.randn(5) data1['paper'].x = torch.randn(10, 16) data1['paper', 'paper'].edge_index = torch.randint(0, 10, (2, 30)).long() data2 = HeteroData() data2.y = torch.randn(5) data2['paper'].x = torch.randn(10, 16) data2['paper', 'paper'].edge_index = torch.randint(0, 10, (2, 30)).long() dataset = MyTestDataset([data1, data2]) assert str(dataset) == 'MyTestDataset(2)' assert len(dataset) == 2 assert len(dataset[0]) == 3 assert dataset[0].y.tolist() == data1.y.tolist() assert dataset[0]['paper'].x.tolist() == data1['paper'].x.tolist() assert (dataset[0]['paper', 'paper'].edge_index.tolist() == data1[ 'paper', 'paper'].edge_index.tolist()) assert len(dataset[1]) == 3 assert dataset[1].y.tolist() == data2.y.tolist() assert dataset[1]['paper'].x.tolist() == data2['paper'].x.tolist() assert (dataset[1]['paper', 'paper'].edge_index.tolist() == data2[ 'paper', 'paper'].edge_index.tolist())
def test_hetero_data_functions(): data = HeteroData() data['paper'].x = x_paper data['author'].x = x_author data['paper', 'paper'].edge_index = edge_index_paper_paper data['paper', 'author'].edge_index = edge_index_paper_author data['author', 'paper'].edge_index = edge_index_author_paper data['paper', 'paper'].edge_attr = edge_attr_paper_paper assert len(data) == 3 assert sorted(data.keys) == ['edge_attr', 'edge_index', 'x'] assert 'x' in data and 'edge_index' in data and 'edge_attr' in data assert data.num_nodes == 15 assert data.num_edges == 110 assert data.num_node_features == {'paper': 16, 'author': 32} assert data.num_edge_features == { ('paper', 'to', 'paper'): 8, ('paper', 'to', 'author'): 0, ('author', 'to', 'paper'): 0, } node_types, edge_types = data.metadata() assert node_types == ['paper', 'author'] assert edge_types == [ ('paper', 'to', 'paper'), ('paper', 'to', 'author'), ('author', 'to', 'paper'), ] x_dict = data.collect('x') assert len(x_dict) == 2 assert x_dict['paper'].tolist() == x_paper.tolist() assert x_dict['author'].tolist() == x_author.tolist() assert x_dict == data.x_dict data.y = 0 assert data['y'] == 0 and data.y == 0 assert len(data) == 4 assert sorted(data.keys) == ['edge_attr', 'edge_index', 'x', 'y'] del data['paper', 'author'] node_types, edge_types = data.metadata() assert node_types == ['paper', 'author'] assert edge_types == [('paper', 'to', 'paper'), ('author', 'to', 'paper')] assert len(data.to_dict()) == 5 assert len(data.to_namedtuple()) == 5 assert data.to_namedtuple().y == 0 assert len(data.to_namedtuple().paper) == 1
def generate_data(self) -> HeteroData: data = HeteroData() iterator = zip(self.node_types, self.num_channels) for i, (node_type, num_channels) in enumerate(iterator): num_nodes = get_num_nodes(self.avg_num_nodes, self.avg_degree) store = data[node_type] if num_channels > 0: store.x = torch.randn(num_nodes, num_channels) else: store.num_nodes = num_nodes if self._num_classes > 0 and self.task == 'node' and i == 0: store.y = torch.randint(self._num_classes, (num_nodes, )) for (src, rel, dst) in self.edge_types: store = data[(src, rel, dst)] store.edge_index = get_edge_index( data[src].num_nodes, data[dst].num_nodes, self.avg_degree, is_undirected=False, remove_loops=False, ) if self.edge_dim > 1: store.edge_attr = torch.rand(store.num_edges, self.edge_dim) elif self.edge_dim == 1: store.edge_weight = torch.rand(store.num_edges) pass if self._num_classes > 0 and self.task == 'graph': data.y = torch.tensor([random.randint(0, self._num_classes - 1)]) for feature_name, feature_shape in self.kwargs.items(): setattr(data, feature_name, torch.randn(feature_shape)) return data