def prepare_gnn_training(self): if verbose: print("\n\n==>> Clustering the graph and preparing dataloader....") self.data = Data(x=self.x_data.float(), edge_index = self.edge_index_data.long(), edge_attr = self.edge_type_data, y=self.y_data) new_num_nodes, _ = self.data.x.shape self.data.train_mask = torch.FloatTensor(self.split_masks['train_mask']) self.data.val_mask = torch.FloatTensor(self.split_masks['val_mask']) self.data.representation_mask = torch.FloatTensor(self.split_masks['repr_mask']) self.data.node2id = torch.tensor(list(self.node2id.values())) # self.data.node_type = self.node_type if not self.config['full_graph']: if self.config['cluster'] : cluster_data = ClusterData(self.data, num_parts=self.config['clusters'], recursive=False) self.loader = ClusterLoader(cluster_data, batch_size=self.config['batch_size'], shuffle=self.config['shuffle'], num_workers=0) elif self.config['saint'] == 'random_walk': self.loader = GraphSAINTRandomWalkSampler(self.data, batch_size=6000, walk_length=2, num_steps=5, sample_coverage=100, num_workers=0) elif self.config['saint'] == 'node': self.loader = GraphSAINTNodeSampler(self.data, batch_size=6000, num_steps=5, sample_coverage=100, num_workers=0) elif self.config['saint'] == 'edge': self.loader = GraphSAINTEdgeSampler(self.data, batch_size=6000, num_steps=5, sample_coverage=100, num_workers=0) else: self.loader=None return self.loader, self.vocab_size, self.data
def build_sampler(args, data, save_dir): if args.sampler == 'rw-my': msg = 'Use GraphSaint randomwalk sampler(mysaint sampler)' loader = MySAINTSampler(data, batch_size=args.batch_size, sample_type='random_walk', walk_length=2, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'node-my': msg = 'Use random node sampler(mysaint sampler)' loader = MySAINTSampler(data, sample_type='node', batch_size=args.batch_size * 3, walk_length=2, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'rw': msg = 'Use GraphSaint randomwalk sampler' loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch_size, walk_length=2, num_steps=5, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'node': msg = 'Use GraphSaint node sampler' loader = GraphSAINTNodeSampler(data, batch_size=args.batch_size * 3, num_steps=5, sample_coverage=1000, num_workers=0, save_dir=save_dir) elif args.sampler == 'edge': msg = 'Use GraphSaint edge sampler' loader = GraphSAINTEdgeSampler(data, batch_size=args.batch_size, num_steps=5, sample_coverage=1000, save_dir=save_dir, num_workers=0) elif args.sampler == 'cluster': msg = 'Use cluster sampler' cluster_data = ClusterData(data, num_parts=args.num_parts, save_dir=save_dir) loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=0) else: raise KeyError('Sampler type error') return loader, msg
def test_graph_saint(): adj = torch.tensor([ [1, 1, 1, 0, 1, 0], [1, 1, 0, 1, 0, 1], [1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1], [1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1], ]) edge_index = adj.nonzero().t() x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]) data = Data(edge_index=edge_index, x=x, num_nodes=6) torch.manual_seed(12345) loader = GraphSAINTNodeSampler(data, batch_size=2, num_steps=4, sample_coverage=10, log=False) for sample in loader: assert len(sample) == 4 assert sample.num_nodes <= 2 assert sample.num_edges <= 3 * 2 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges torch.manual_seed(12345) loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4, sample_coverage=10, log=False) for sample in loader: assert len(sample) == 4 assert sample.num_nodes <= 4 assert sample.num_edges <= 3 * 4 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges torch.manual_seed(12345) loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1, num_steps=4, sample_coverage=10, log=False) for sample in loader: assert len(sample) == 4 assert sample.num_nodes <= 4 assert sample.num_edges <= 3 * 4 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges
def build_sampler(args, data, save_dir): if args.sampler == 'rw-my': msg = 'Use GraphSaint randomwalk sampler(mysaint sampler)' loader = MySAINTSampler(data, batch_size=args.batch_size, sample_type='random_walk', walk_length=2, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'node-my': msg = 'Use random node sampler(mysaint sampler)' loader = MySAINTSampler(data, sample_type='node', batch_size=args.batch_size * 3, walk_length=2, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'rw': msg = 'Use GraphSaint randomwalk sampler' loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch_size, walk_length=2, num_steps=5, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'node': msg = 'Use GraphSaint node sampler' loader = GraphSAINTNodeSampler(data, batch_size=args.batch_size * 3, num_steps=5, sample_coverage=1000, num_workers=0, save_dir=save_dir) elif args.sampler == 'edge': msg = 'Use GraphSaint edge sampler' loader = GraphSAINTEdgeSampler(data, batch_size=args.batch_size, num_steps=5, sample_coverage=1000, save_dir=save_dir, num_workers=0) # elif args.sampler == 'cluster': # logger.info('Use cluster sampler') # cluster_data = ClusterData(data, num_parts=args.num_parts, save_dir=dataset.processed_dir) # raise NotImplementedError('Cluster loader not implement yet') else: raise KeyError('Sampler type error') return loader, msg
def test_graph_saint(): adj = torch.tensor([ [+1, +2, +3, +0, +4, +0], [+5, +6, +0, +7, +0, +8], [+9, +0, 10, +0, 11, +0], [+0, 12, +0, 13, +0, 14], [15, +0, 16, +0, 17, +0], [+0, 18, +0, 19, +0, 20], ]) edge_index = adj.nonzero(as_tuple=False).t() edge_type = adj[edge_index[0], edge_index[1]] x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]) data = Data(edge_index=edge_index, x=x, edge_type=edge_type, num_nodes=6) torch.manual_seed(12345) loader = GraphSAINTNodeSampler(data, batch_size=3, num_steps=4, sample_coverage=10, log=False) sample = next(iter(loader)) assert sample.x.tolist() == [[2, 2], [4, 4], [5, 5]] assert sample.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 1, 2]] assert sample.edge_type.tolist() == [10, 11, 16, 17, 20] assert len(loader) == 4 for sample in loader: assert len(sample) == 5 assert sample.num_nodes <= 3 assert sample.num_edges <= 3 * 4 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges torch.manual_seed(12345) loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4, sample_coverage=10, log=False) sample = next(iter(loader)) assert sample.x.tolist() == [[0, 0], [2, 2], [3, 3]] assert sample.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 1, 2]] assert sample.edge_type.tolist() == [1, 3, 9, 10, 13] assert len(loader) == 4 for sample in loader: assert len(sample) == 5 assert sample.num_nodes <= 4 assert sample.num_edges <= 4 * 4 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges torch.manual_seed(12345) loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1, num_steps=4, sample_coverage=10, log=False) sample = next(iter(loader)) assert sample.x.tolist() == [[1, 1], [2, 2], [4, 4]] assert sample.edge_index.tolist() == [[0, 1, 1, 2, 2], [0, 1, 2, 1, 2]] assert sample.edge_type.tolist() == [6, 10, 11, 16, 17] assert len(loader) == 4 for sample in loader: assert len(sample) == 5 assert sample.num_nodes <= 4 assert sample.num_edges <= 4 * 4 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges
def test_graph_saint(): adj = torch.tensor([ [+1, +2, +3, +0, +4, +0], [+5, +6, +0, +7, +0, +8], [+9, +0, 10, +0, 11, +0], [+0, 12, +0, 13, +0, 14], [15, +0, 16, +0, 17, +0], [+0, 18, +0, 19, +0, 20], ]) edge_index = adj.nonzero(as_tuple=False).t() edge_id = adj[edge_index[0], edge_index[1]] x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]) n_id = torch.arange(6) data = Data(edge_index=edge_index, x=x, n_id=n_id, edge_id=edge_id, num_nodes=6) loader = GraphSAINTNodeSampler(data, batch_size=3, num_steps=4, sample_coverage=10, log=False) assert len(loader) == 4 for sample in loader: assert sample.num_nodes <= data.num_nodes assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 assert sample.num_nodes == sample.n_id.numel() assert sample.x.tolist() == x[sample.n_id].tolist() assert sample.edge_index.min() >= 0 assert sample.edge_index.max() < sample.num_nodes assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 assert sample.edge_id.numel() == sample.num_edges assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4, sample_coverage=10, log=False) assert len(loader) == 4 for sample in loader: assert sample.num_nodes <= data.num_nodes assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 assert sample.num_nodes == sample.n_id.numel() assert sample.x.tolist() == x[sample.n_id].tolist() assert sample.edge_index.min() >= 0 assert sample.edge_index.max() < sample.num_nodes assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 assert sample.edge_id.numel() == sample.num_edges assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1, num_steps=4, sample_coverage=10, log=False) assert len(loader) == 4 for sample in loader: assert sample.num_nodes <= data.num_nodes assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 assert sample.num_nodes == sample.n_id.numel() assert sample.x.tolist() == x[sample.n_id].tolist() assert sample.edge_index.min() >= 0 assert sample.edge_index.max() < sample.num_nodes assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 assert sample.edge_id.numel() == sample.num_edges assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges