-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
44 lines (33 loc) · 975 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#!/usr/bin/env python
"""
test.py
"""
import torch
import sys
sys.path.append('.')
sys.path.append('build')
import pygunrock as pyg
import numpy as np
from time import time
from tqdm import trange
from scipy.io import mmread
np.set_printoptions(linewidth=240)
# Load graph
csr = mmread('chesapeake.mtx').tocsr()
# csr = mmread('cit-Patents-sub.mtx').tocsr()
n_vertices = csr.shape[0]
n_edges = csr.nnz
# Convert data to torch + move to GPU
indptr = torch.IntTensor(csr.indptr).cuda()
indices = torch.IntTensor(csr.indices).cuda()
data = torch.FloatTensor(csr.data).cuda()
# Allocate memory for output
distances = torch.zeros(csr.shape[0]).float().cuda()
predecessors = torch.zeros(csr.shape[0]).int().cuda()
# Create graph
for single_source in [0, 1, 2]:
distances.zero_()
predecessors.zero_()
G = pyg.from_csr(n_vertices, n_vertices, n_edges, indptr, indices, data)
_ = pyg.sssp(G, single_source, distances, predecessors)
print(distances)