예제 #1
0
    def prepare_gnn_training(self):
        if verbose:
            print("\n\n==>> Clustering the graph and preparing dataloader....")
            
        self.data = Data(x=self.x_data.float(), edge_index = self.edge_index_data.long(), edge_attr = self.edge_type_data, y=self.y_data)
        new_num_nodes, _ = self.data.x.shape
        
        self.data.train_mask = torch.FloatTensor(self.split_masks['train_mask'])
        self.data.val_mask = torch.FloatTensor(self.split_masks['val_mask'])
        self.data.representation_mask = torch.FloatTensor(self.split_masks['repr_mask']) 
        self.data.node2id = torch.tensor(list(self.node2id.values()))
        # self.data.node_type = self.node_type
            
        
        if not self.config['full_graph']:
            if self.config['cluster'] :
                cluster_data = ClusterData(self.data, num_parts=self.config['clusters'], recursive=False)
                self.loader = ClusterLoader(cluster_data, batch_size=self.config['batch_size'], shuffle=self.config['shuffle'], num_workers=0)
            elif self.config['saint'] == 'random_walk':
                self.loader = GraphSAINTRandomWalkSampler(self.data, batch_size=6000, walk_length=2, num_steps=5, sample_coverage=100, num_workers=0)
            elif self.config['saint'] == 'node':
                self.loader = GraphSAINTNodeSampler(self.data, batch_size=6000, num_steps=5, sample_coverage=100, num_workers=0)
            elif self.config['saint'] == 'edge':
                self.loader = GraphSAINTEdgeSampler(self.data, batch_size=6000, num_steps=5, sample_coverage=100, num_workers=0)
        else:
            self.loader=None
        

        return self.loader, self.vocab_size, self.data
예제 #2
0
def build_sampler(args, data, save_dir):
    if args.sampler == 'rw-my':
        msg = 'Use GraphSaint randomwalk sampler(mysaint sampler)'
        loader = MySAINTSampler(data, batch_size=args.batch_size, sample_type='random_walk',
                                walk_length=2, sample_coverage=1000, save_dir=save_dir)
    elif args.sampler == 'node-my':
        msg = 'Use random node sampler(mysaint sampler)'
        loader = MySAINTSampler(data, sample_type='node', batch_size=args.batch_size * 3,
                                walk_length=2, sample_coverage=1000, save_dir=save_dir)
    elif args.sampler == 'rw':
        msg = 'Use GraphSaint randomwalk sampler'
        loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch_size, walk_length=2,
                                             num_steps=5, sample_coverage=1000,
                                             save_dir=save_dir)
    elif args.sampler == 'node':
        msg = 'Use GraphSaint node sampler'
        loader = GraphSAINTNodeSampler(data, batch_size=args.batch_size * 3,
                                       num_steps=5, sample_coverage=1000, num_workers=0, save_dir=save_dir)

    elif args.sampler == 'edge':
        msg = 'Use GraphSaint edge sampler'
        loader = GraphSAINTEdgeSampler(data, batch_size=args.batch_size,
                                       num_steps=5, sample_coverage=1000,
                                       save_dir=save_dir, num_workers=0)
    elif args.sampler == 'cluster':
        msg = 'Use cluster sampler'
        cluster_data = ClusterData(data, num_parts=args.num_parts, save_dir=save_dir)
        loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True,
                               num_workers=0)
    else:
        raise KeyError('Sampler type error')

    return loader, msg
def test_graph_saint():
    adj = torch.tensor([
        [1, 1, 1, 0, 1, 0],
        [1, 1, 0, 1, 0, 1],
        [1, 0, 1, 0, 1, 0],
        [0, 1, 0, 1, 0, 1],
        [1, 0, 1, 0, 1, 0],
        [0, 1, 0, 1, 0, 1],
    ])

    edge_index = adj.nonzero().t()
    x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
    data = Data(edge_index=edge_index, x=x, num_nodes=6)

    torch.manual_seed(12345)
    loader = GraphSAINTNodeSampler(data,
                                   batch_size=2,
                                   num_steps=4,
                                   sample_coverage=10,
                                   log=False)

    for sample in loader:
        assert len(sample) == 4
        assert sample.num_nodes <= 2
        assert sample.num_edges <= 3 * 2
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    torch.manual_seed(12345)
    loader = GraphSAINTEdgeSampler(data,
                                   batch_size=2,
                                   num_steps=4,
                                   sample_coverage=10,
                                   log=False)

    for sample in loader:
        assert len(sample) == 4
        assert sample.num_nodes <= 4
        assert sample.num_edges <= 3 * 4
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    torch.manual_seed(12345)
    loader = GraphSAINTRandomWalkSampler(data,
                                         batch_size=2,
                                         walk_length=1,
                                         num_steps=4,
                                         sample_coverage=10,
                                         log=False)

    for sample in loader:
        assert len(sample) == 4
        assert sample.num_nodes <= 4
        assert sample.num_edges <= 3 * 4
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges
예제 #4
0
def build_sampler(args, data, save_dir):
    if args.sampler == 'rw-my':
        msg = 'Use GraphSaint randomwalk sampler(mysaint sampler)'
        loader = MySAINTSampler(data,
                                batch_size=args.batch_size,
                                sample_type='random_walk',
                                walk_length=2,
                                sample_coverage=1000,
                                save_dir=save_dir)
    elif args.sampler == 'node-my':
        msg = 'Use random node sampler(mysaint sampler)'
        loader = MySAINTSampler(data,
                                sample_type='node',
                                batch_size=args.batch_size * 3,
                                walk_length=2,
                                sample_coverage=1000,
                                save_dir=save_dir)
    elif args.sampler == 'rw':
        msg = 'Use GraphSaint randomwalk sampler'
        loader = GraphSAINTRandomWalkSampler(data,
                                             batch_size=args.batch_size,
                                             walk_length=2,
                                             num_steps=5,
                                             sample_coverage=1000,
                                             save_dir=save_dir)
    elif args.sampler == 'node':
        msg = 'Use GraphSaint node sampler'
        loader = GraphSAINTNodeSampler(data,
                                       batch_size=args.batch_size * 3,
                                       num_steps=5,
                                       sample_coverage=1000,
                                       num_workers=0,
                                       save_dir=save_dir)

    elif args.sampler == 'edge':
        msg = 'Use GraphSaint edge sampler'
        loader = GraphSAINTEdgeSampler(data,
                                       batch_size=args.batch_size,
                                       num_steps=5,
                                       sample_coverage=1000,
                                       save_dir=save_dir,
                                       num_workers=0)
    # elif args.sampler == 'cluster':
    #     logger.info('Use cluster sampler')
    #     cluster_data = ClusterData(data, num_parts=args.num_parts, save_dir=dataset.processed_dir)
    #     raise NotImplementedError('Cluster loader not implement yet')
    else:
        raise KeyError('Sampler type error')

    return loader, msg
예제 #5
0
def test_graph_saint():
    adj = torch.tensor([
        [+1, +2, +3, +0, +4, +0],
        [+5, +6, +0, +7, +0, +8],
        [+9, +0, 10, +0, 11, +0],
        [+0, 12, +0, 13, +0, 14],
        [15, +0, 16, +0, 17, +0],
        [+0, 18, +0, 19, +0, 20],
    ])

    edge_index = adj.nonzero(as_tuple=False).t()
    edge_type = adj[edge_index[0], edge_index[1]]
    x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
    data = Data(edge_index=edge_index, x=x, edge_type=edge_type, num_nodes=6)

    torch.manual_seed(12345)
    loader = GraphSAINTNodeSampler(data, batch_size=3, num_steps=4,
                                   sample_coverage=10, log=False)

    sample = next(iter(loader))
    assert sample.x.tolist() == [[2, 2], [4, 4], [5, 5]]
    assert sample.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 1, 2]]
    assert sample.edge_type.tolist() == [10, 11, 16, 17, 20]

    assert len(loader) == 4
    for sample in loader:
        assert len(sample) == 5
        assert sample.num_nodes <= 3
        assert sample.num_edges <= 3 * 4
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    torch.manual_seed(12345)
    loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4,
                                   sample_coverage=10, log=False)

    sample = next(iter(loader))
    assert sample.x.tolist() == [[0, 0], [2, 2], [3, 3]]
    assert sample.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 1, 2]]
    assert sample.edge_type.tolist() == [1, 3, 9, 10, 13]

    assert len(loader) == 4
    for sample in loader:
        assert len(sample) == 5
        assert sample.num_nodes <= 4
        assert sample.num_edges <= 4 * 4
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    torch.manual_seed(12345)
    loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1,
                                         num_steps=4, sample_coverage=10,
                                         log=False)

    sample = next(iter(loader))
    assert sample.x.tolist() == [[1, 1], [2, 2], [4, 4]]
    assert sample.edge_index.tolist() == [[0, 1, 1, 2, 2], [0, 1, 2, 1, 2]]
    assert sample.edge_type.tolist() == [6, 10, 11, 16, 17]

    assert len(loader) == 4
    for sample in loader:
        assert len(sample) == 5
        assert sample.num_nodes <= 4
        assert sample.num_edges <= 4 * 4
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges
예제 #6
0
def test_graph_saint():
    adj = torch.tensor([
        [+1, +2, +3, +0, +4, +0],
        [+5, +6, +0, +7, +0, +8],
        [+9, +0, 10, +0, 11, +0],
        [+0, 12, +0, 13, +0, 14],
        [15, +0, 16, +0, 17, +0],
        [+0, 18, +0, 19, +0, 20],
    ])

    edge_index = adj.nonzero(as_tuple=False).t()
    edge_id = adj[edge_index[0], edge_index[1]]
    x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
    n_id = torch.arange(6)
    data = Data(edge_index=edge_index,
                x=x,
                n_id=n_id,
                edge_id=edge_id,
                num_nodes=6)

    loader = GraphSAINTNodeSampler(data,
                                   batch_size=3,
                                   num_steps=4,
                                   sample_coverage=10,
                                   log=False)

    assert len(loader) == 4
    for sample in loader:
        assert sample.num_nodes <= data.num_nodes
        assert sample.n_id.min() >= 0 and sample.n_id.max() < 6
        assert sample.num_nodes == sample.n_id.numel()
        assert sample.x.tolist() == x[sample.n_id].tolist()
        assert sample.edge_index.min() >= 0
        assert sample.edge_index.max() < sample.num_nodes
        assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21
        assert sample.edge_id.numel() == sample.num_edges
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    loader = GraphSAINTEdgeSampler(data,
                                   batch_size=2,
                                   num_steps=4,
                                   sample_coverage=10,
                                   log=False)

    assert len(loader) == 4
    for sample in loader:
        assert sample.num_nodes <= data.num_nodes
        assert sample.n_id.min() >= 0 and sample.n_id.max() < 6
        assert sample.num_nodes == sample.n_id.numel()
        assert sample.x.tolist() == x[sample.n_id].tolist()
        assert sample.edge_index.min() >= 0
        assert sample.edge_index.max() < sample.num_nodes
        assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21
        assert sample.edge_id.numel() == sample.num_edges
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    loader = GraphSAINTRandomWalkSampler(data,
                                         batch_size=2,
                                         walk_length=1,
                                         num_steps=4,
                                         sample_coverage=10,
                                         log=False)

    assert len(loader) == 4
    for sample in loader:
        assert sample.num_nodes <= data.num_nodes
        assert sample.n_id.min() >= 0 and sample.n_id.max() < 6
        assert sample.num_nodes == sample.n_id.numel()
        assert sample.x.tolist() == x[sample.n_id].tolist()
        assert sample.edge_index.min() >= 0
        assert sample.edge_index.max() < sample.num_nodes
        assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21
        assert sample.edge_id.numel() == sample.num_edges
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges