Ejemplo n.º 1
0
def test_from_scipy_sparse_matrix():
    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])
    adj = to_scipy_sparse_matrix(edge_index)

    out = from_scipy_sparse_matrix(adj)
    assert out[0].tolist() == edge_index.tolist()
    assert out[1].tolist() == [1, 1, 1]
Ejemplo n.º 2
0
    def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):
        # Double-radius node labeling (DRNL).
        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

        dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,
                                 indices=src)
        dist2src = np.insert(dist2src, dst, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)

        dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,
                                 indices=dst - 1)
        dist2dst = np.insert(dist2dst, src, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
        dist_over_2, dist_mod_2 = dist // 2, dist % 2

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
        z[src] = 1.
        z[dst] = 1.
        z[torch.isnan(z)] = 0.

        self._max_z = max(int(z.max()), self._max_z)

        return z.to(torch.long)
Ejemplo n.º 3
0
    def __call__(self, data: Data) -> Data:
        from scipy.sparse.linalg import eigs, eigsh
        eig_fn = eigs if not self.is_undirected else eigsh

        num_nodes = data.num_nodes
        edge_index, edge_weight = get_laplacian(
            data.edge_index,
            normalization='sym',
            num_nodes=num_nodes,
        )

        L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes)

        eig_vals, eig_vecs = eig_fn(
            L,
            k=self.k + 1,
            which='SR' if not self.is_undirected else 'SA',
            return_eigenvectors=True,
            **self.kwargs,
        )

        eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()])
        pe = torch.from_numpy(eig_vecs[:, 1:self.k + 1])
        sign = -1 + 2 * torch.randint(0, 2, (self.k, ))
        pe *= sign

        data = add_node_attr(data, pe, attr_name=self.attr_name)
        return data
Ejemplo n.º 4
0
def test_to_scipy_sparse_matrix():
    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])

    adj = to_scipy_sparse_matrix(edge_index)
    assert isinstance(adj, scipy.sparse.coo_matrix) is True
    assert adj.shape == (2, 2)
    assert adj.row.tolist() == edge_index[0].tolist()
    assert adj.col.tolist() == edge_index[1].tolist()
    assert adj.data.tolist() == [1, 1, 1]

    edge_attr = torch.Tensor([1, 2, 3])
    adj = to_scipy_sparse_matrix(edge_index, edge_attr)
    assert isinstance(adj, scipy.sparse.coo_matrix) is True
    assert adj.shape == (2, 2)
    assert adj.row.tolist() == edge_index[0].tolist()
    assert adj.col.tolist() == edge_index[1].tolist()
    assert adj.data.tolist() == edge_attr.tolist()
Ejemplo n.º 5
0
 def load_data(self, data):
     adj = data.edge_index
     adj, _ = remove_self_loops(adj)
     adj = to_scipy_sparse_matrix(adj).asformat('csr')
     features = data.ori_x.numpy()
     features = sp.csr_matrix(features)
     self.adj_orig = adj
     self.features_orig = features
Ejemplo n.º 6
0
def test_to_scipy_sparse_matrix():
    row = torch.tensor([0, 1, 0])
    col = torch.tensor([1, 0, 0])

    adj = to_scipy_sparse_matrix(torch.stack([row, col], dim=0))
    assert isinstance(adj, scipy.sparse.coo_matrix) is True
    assert adj.shape == (2, 2)
    assert adj.row.tolist() == row.tolist()
    assert adj.col.tolist() == col.tolist()
    assert adj.data.tolist() == [1, 1, 1]

    edge_attr = torch.Tensor([1, 2, 3])
    adj = to_scipy_sparse_matrix(torch.stack([row, col], dim=0), edge_attr)
    assert isinstance(adj, scipy.sparse.coo_matrix) is True
    assert adj.shape == (2, 2)
    assert adj.row.tolist() == row.tolist()
    assert adj.col.tolist() == col.tolist()
    assert adj.data.tolist() == edge_attr.tolist()
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(description='OGB (Node2Vec)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--task', type=str, default='ogbn')
    parser.add_argument('--dataset', type=str, default='arxiv')
    parser.add_argument('--embedding_dim', type=int, default=128)
    parser.add_argument('--walk_length', type=int, default=80)
    parser.add_argument('--context_size', type=int, default=20)
    parser.add_argument('--walks_per_node', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--dropedge_rate', type=float, default=0.4)
    parser.add_argument('--dump_adj_only', dest="dump_adj_only", action="store_true", help="dump adj matrix for proX")
    parser.set_defaults(dump_adj_only=False)
    args = parser.parse_args()

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = create_dataset(name=f'{args.task}-{args.dataset}')
    data = dataset[0]
    if args.dataset == 'arxiv':
        data.edge_index = to_undirected(data.edge_index, data.num_nodes)
    elif args.dataset == 'papers100M':
        data.edge_index, _ = dropout_adj(data.edge_index, p = args.dropedge_rate, num_nodes= data.num_nodes)
        data.edge_index = to_undirected(data.edge_index, data.num_nodes)

    if args.dump_adj_only:
        adj = to_scipy_sparse_matrix(data.edge_index)
        sp.save_npz(f'data/{args.name}-adj.npz', adj)
        return

    model = Node2Vec(data.edge_index, args.embedding_dim, args.walk_length,
                     args.context_size, args.walks_per_node,
                     sparse=True).to(device)

    loader = model.loader(batch_size=args.batch_size, shuffle=True,
                          num_workers=4)
    optimizer = torch.optim.SparseAdam(model.parameters(), lr=args.lr)

    model.train()
    for epoch in range(1, args.epochs + 1):
        for i, (pos_rw, neg_rw) in enumerate(loader):
            optimizer.zero_grad()
            loss = model.loss(pos_rw.to(device), neg_rw.to(device))
            loss.backward()
            optimizer.step()

            if (i + 1) % args.log_steps == 0:
                print(f'Epoch: {epoch:02d}, Step: {i+1:03d}/{len(loader)}, '
                      f'Loss: {loss:.4f}')

            if (i + 1) % 100 == 0:  # Save model every 100 steps.
                save_embedding(model, args.embedding_dim, args.dataset, args.context_size)
        save_embedding(model, args.embedding_dim, args.dataset, args.context_size)
Ejemplo n.º 8
0
 def __init__(self, data):
     # Dataオブジェクト→隣接行列、特徴量行列、ラベル
     self.device = torch.device(
         'cuda:0' if torch.cuda.is_available() else 'cpu')
     self.data = data
     self.A = to_scipy_sparse_matrix(data.edge_index,
                                     data.edge_weight).tocsr()
     self.X = data.x.cpu().numpy()
     self.labels = data.y.cpu().numpy()
     self.k = 500
     self.e = 100
Ejemplo n.º 9
0
def get_edge_and_y(dataset):
    edge_list = []
    labels = []
    for i in range(len(dataset)):
        data = dataset.get(i)
        ea, ei = data.edge_attr, data.edge_index
        adj = to_scipy_sparse_matrix(ei, ea).toarray()
        np.fill_diagonal(adj, 1)
        edge_list.append(adj.reshape(-1))
        labels.append(data.y.numpy())
    edge_list = np.stack(edge_list, -1)
    labels = np.stack(labels, -1)
    return edge_list, labels
Ejemplo n.º 10
0
    def __call__(self, data: Data) -> Data:
        import numpy as np
        import scipy.sparse as sp

        adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes)

        num_components, component = sp.csgraph.connected_components(adj)

        if num_components <= self.num_components:
            return data

        _, count = np.unique(component, return_counts=True)
        subset = np.in1d(component, count.argsort()[-self.num_components:])

        return data.subgraph(torch.from_numpy(subset).to(torch.bool))
Ejemplo n.º 11
0
def test_mstgcn():
    """
    Testing MSTGCN block
    """
    node_count = 307
    num_classes = 10
    edge_per_node = 15

    num_of_vertices = node_count  # 307
    num_for_predict = 12
    len_input = 12
    nb_time_strides = 1

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    node_features = 1
    nb_block = 2
    K = 3
    nb_chev_filter = 64
    nb_time_filter = 64
    batch_size = 32

    x, edge_index = create_mock_data(node_count, edge_per_node, node_features)
    adj_mx = to_scipy_sparse_matrix(edge_index)
    model = MSTGCN(DEVICE, nb_block, node_features, K,
                   nb_chev_filter, nb_time_filter, nb_time_strides,
                   adj_mx.toarray(), num_for_predict, len_input)

    T = len_input
    x_seq = torch.zeros([batch_size, node_count, node_features, T]).to(DEVICE)
    target_seq = torch.zeros([batch_size, node_count, T]).to(DEVICE)
    for b in range(batch_size):
        for t in range(T):
            x, edge_index = create_mock_data(node_count, edge_per_node,
                                             node_features)
            x_seq[b, :, :, t] = x
            target = create_mock_target(node_count, num_classes)
            target_seq[b, :, t] = target
    shuffle = True
    train_dataset = torch.utils.data.TensorDataset(x_seq, target_seq)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=shuffle)
    criterion = torch.nn.MSELoss().to(DEVICE)
    for batch_data in train_loader:
        encoder_inputs, labels = batch_data
        outputs = model(encoder_inputs)
    assert outputs.shape == (batch_size, node_count, num_for_predict)
    def __call__(self, data):
        edge_weight = data.edge_attr
        if edge_weight is not None and edge_weight.numel() != data.num_edges:
            edge_weight = None

        edge_index, edge_weight = get_laplacian(data.edge_index, edge_weight,
                                                self.normalization,
                                                num_nodes=data.num_nodes)

        L = to_scipy_sparse_matrix(edge_index, edge_weight, data.num_nodes)

        eig_fn = eigsh if self.normalization == 'sym' else eigs
        lambda_max = eig_fn(L, k=1, which='LM', return_eigenvectors=False)
        data.lambda_max = float(lambda_max.real)

        return data
Ejemplo n.º 13
0
    def eps_rule(layer,input,R, edge_index, edge_weight, index, after_message, before_message, LeakyReLU, message_passing, transpose):
        EPSILON=1e-9
        a=copy_tensor(input)
        a.retain_grad()

        z = layer.forward(a)

        if LeakyReLU:
            w = torch.eye(a.shape[1])

        elif message_passing:
            edge_index_detached=edge_index.detach()
            edge_weight_detached=edge_weight.detach()

            # need to stack/cat to make the Adjacency matrix symmetric (because this is an undirected graph but scipy doesn't know that)
            # TODO: make it mean not sum.. right now w corrsponds to weighted sum of nodes.. i have to divide w by a scale factor
            edge_indices=torch.stack([torch.cat((edge_index_detached[0], edge_index_detached[1]), axis=0),torch.cat((edge_index_detached[1], edge_index_detached[0]), axis=0)])
            edge_weights = torch.cat([edge_weight_detached,edge_weight_detached])

            w = torch.from_numpy(to_scipy_sparse_matrix(edge_indices,edge_weights).toarray())
            w = w - 0.5*torch.diag(torch.diag(w))    # to avoid double counting when i stacked the edge indices

            z = after_message.transpose(0,1)
            a = before_message.transpose(0,1)
            R = R.transpose(0,1)

        else:
            w = layer.weight

        if transpose:
            R = R.transpose(0,1)

        I = torch.ones_like(R)

        Numerator=(a*torch.matmul(R,w))
        Denominator=(a*torch.matmul(I,w)).sum(axis=1)

        Denominator = Denominator.reshape(-1,1).expand(Denominator.size()[0],Numerator.size()[1])

        R = Numerator / (Denominator+EPSILON*torch.sign(Denominator))
        return R
Ejemplo n.º 14
0
def perturb_edges(data,
                  name,
                  remove_pct,
                  add_pct,
                  hidden_channels=16,
                  epochs=400):
    if remove_pct == 0 and add_pct == 0:
        return
    try:
        cached = pickle.load(
            open(f'{ROOT}/cache/edge/{name}_{remove_pct}_{add_pct}.pt', 'rb'))
        print(f'Use cached edge augmentation for dataset {name}')

        if data.setting == 'inductive':
            data.train_edge_index = cached
        else:
            data.edge_index = cached
        return
    except FileNotFoundError:
        try:
            A_pred, adj_orig = pickle.load(
                open(f'{ROOT}/cache/edge/{name}.pt', 'rb'))
            A = sample_graph_det(adj_orig, A_pred, remove_pct, add_pct)
            data.edge_index, _ = from_scipy_sparse_matrix(A)
            pickle.dump(
                data.edge_index,
                open(f'{ROOT}/cache/edge/{name}_{remove_pct}_{add_pct}.pt',
                     'wb'))
            return
        except FileNotFoundError:
            print(
                f'cache/edge/{name}_{remove_pct}_{add_pct}.pt not found! Regenerating it now'
            )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if data.setting == 'inductive':
        train_data = Data(x=data.train_x,
                          ori_x=data.ori_x,
                          edge_index=data.train_edge_index,
                          y=data.train_y)
    else:
        train_data = deepcopy(data)

    edge_index = deepcopy(train_data.edge_index)
    train_data = train_test_split_edges(train_data,
                                        val_ratio=0.1,
                                        test_ratio=0)
    num_features = train_data.ori_x.shape[1]
    model = GAE(GCNEncoder(num_features, hidden_channels))
    model = model.to(device)
    x = train_data.ori_x.to(device)
    train_pos_edge_index = train_data.train_pos_edge_index.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    best_val_auc = 0
    best_z = None
    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()
        z = model.encode(x, train_pos_edge_index)
        loss = model.recon_loss(z, train_pos_edge_index)
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            z = model.encode(x, train_pos_edge_index)

        auc, ap = model.test(z, train_data.val_pos_edge_index,
                             train_data.val_neg_edge_index)
        print('Val | Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(
            epoch, auc, ap))
        if auc > best_val_auc:
            best_val_auc = auc
            best_z = deepcopy(z)

    A_pred = torch.sigmoid(torch.mm(z, z.T)).cpu().numpy()

    adj_orig = to_scipy_sparse_matrix(edge_index).asformat('csr')
    adj_pred = sample_graph_det(adj_orig, A_pred, remove_pct, add_pct)

    if data.setting == 'inductive':
        data.train_edge_index, _ = from_scipy_sparse_matrix(adj_pred)
    else:
        data.edge_index, _ = from_scipy_sparse_matrix(adj_pred)

    pickle.dump((A_pred, adj_orig), open(f'{ROOT}/cache/edge/{name}.pt', 'wb'))

    if data.setting == 'inductive':
        pickle.dump(
            data.train_edge_index,
            open(f'{ROOT}/cache/edge/{name}_{remove_pct}_{add_pct}.pt', 'wb'))
    else:
        pickle.dump(
            data.edge_index,
            open(f'{ROOT}/cache/edge/{name}_{remove_pct}_{add_pct}.pt', 'wb'))
Ejemplo n.º 15
0
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'DBLP')
from torch_geometric.utils import to_scipy_sparse_matrix
from utils import normalize_adjacency_matrix, normalizemx
from DBLP_utils import SCAT_Red
from utils import normalize_adjacency_matrix, sparse_mx_to_torch_sparse_tensor
from layers import GC_withres, GraphConvolution
#from torch_geometric.nn import GATConv
from torch.optim.lr_scheduler import MultiStepLR, StepLR

#dataset = TUDataset(root= path,name='REDDIT-BINARY')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = CitationFull(path, name='dblp', transform=T.TargetIndegree())
data = dataset[0]
# Num of feat:1639
adj = to_scipy_sparse_matrix(edge_index=data.edge_index)
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
A_tilde = sparse_mx_to_torch_sparse_tensor(
    normalize_adjacency_matrix(adj, sp.eye(adj.shape[0]))).to(device)
adj = sparse_mx_to_torch_sparse_tensor(adj).to(device)
#print(dataset)
#print(data.x.shape)
#print(data.y.shape)

#tp = SCAT_Red(in_features=1639,med_f0=10,med_f1=10,med_f2=10,med_f3=10,med_f4=10).to(device)
#tp2 = SCAT_Red(in_features=40,med_f0=30,med_f1=10,med_f2=10,med_f3=10,med_f4=10).to(device)
train_mask = torch.cat((torch.ones(10000), torch.zeros(2000),
                        torch.zeros(2000), torch.zeros(3716)), 0) > 0
val_mask = torch.cat((torch.zeros(10000), torch.ones(2000), torch.zeros(2000),
                      torch.zeros(3716)), 0) > 0
test_mask = torch.cat((torch.zeros(10000), torch.zeros(2000), torch.ones(2000),
Ejemplo n.º 16
0
node_dim = 40
conv_channels = 32
residual_channels = 32
skip_channels = 64
in_dim = 2
seq_in_len = 12
batch_size = 16
propalpha = 0.05
tanhalpha = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(3)
x, edge_index = create_mock_data(number_of_nodes=num_nodes,
                                 edge_per_node=8,
                                 in_channels=in_dim)
mock_adj = to_scipy_sparse_matrix(edge_index)
predefined_A = torch.tensor(mock_adj.toarray()).to(device)
total_size = batch_size
num_batch = int(total_size // batch_size)
x_all = torch.zeros(total_size, seq_in_len, num_nodes, in_dim)
for i in range(total_size):
    for j in range(seq_in_len):
        x, _ = create_mock_data(number_of_nodes=num_nodes,
                                edge_per_node=8,
                                in_channels=in_dim)
        x_all[i, j] = x
# define model and optimizer
start_conv = torch.nn.Conv2d(in_channels=in_dim,
                             out_channels=residual_channels,
                             kernel_size=(1, 1)).to(device)
gc = graph_constructor(num_nodes,
Ejemplo n.º 17
0
weight_decay = 5e-4

num_for_predict = 12
len_input = 12
nb_time_strides = 1

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
node_features = 1
nb_block = 2
K = 3
nb_chev_filter = 64
nb_time_filter = 64
batch_size = 32

x, edge_index = create_mock_data(node_count, edge_per_node, node_features)
adj_mx = to_scipy_sparse_matrix(edge_index)
model = ASTGCN(DEVICE, nb_block, node_features,
               K, nb_chev_filter, nb_time_filter, nb_time_strides,
               adj_mx.toarray(), num_for_predict, len_input, node_count)

optimizer = torch.optim.Adam(model.parameters(),
                             lr=learning_rate,
                             weight_decay=weight_decay)

model.train()
T = len_input
x_seq = torch.zeros([batch_size, node_count, node_features, T]).to(DEVICE)
target_seq = torch.zeros([batch_size, node_count, T]).to(DEVICE)
for b in range(batch_size):
    for t in range(T):
        x, edge_index = create_mock_data(node_count, edge_per_node,
Ejemplo n.º 18
0
def to_sparse_cpu(data):
    return to_scipy_sparse_matrix(data.edge_index).tocsr().astype(np.float32)
Ejemplo n.º 19
0
def label_propagation(data, name, alpha=0.99, max_iter=10):
    '''
    Label propagation algorithm, modified from the NetworkX implementation.
    Label some nodes then add them to the training set.
    Only support undirected graphs.
    '''
    assert data.setting == 'transductive'
    try:
        cached_train_mask, cached_y = pickle.load(
            open(f'{ROOT}/cache/label/{name}.pt', 'rb'))
        print(f'Use cached label augmentation with for dataset {name}')
        data.train_mask = cached_train_mask
        data.y = cached_y
        return
    except FileNotFoundError:
        print(f'cache/label/{name}.pt not found! Regenerating it now')

    if hasattr(data, 'adj_t'):
        X = data.adj_t.to_scipy(layout='csr').astype('long')
    else:
        X = to_scipy_sparse_matrix(
            data.edge_index).asformat('csr').astype('long')
    n_samples = X.shape[0]
    n_classes = int(max(data.y)) + 1
    F = np.zeros((n_samples, n_classes))

    degrees = X.sum(axis=0).A[0]
    degrees[degrees == 0] = 1  # Avoid division by 0
    D2 = np.sqrt(sparse.diags((1.0 / degrees), offsets=0))
    P = alpha * D2.dot(X).dot(D2)

    train_idxs = torch.where(data.train_mask == True)[0].numpy()
    y = data.y.numpy().squeeze()
    labels = []
    label_to_id = {}
    lid = 0

    for node_id in train_idxs:
        label = y[node_id]
        if label not in label_to_id:
            label_to_id[label] = lid
            lid += 1
        labels.append([node_id, label_to_id[label]])
    labels = np.array(labels)
    label_dict = np.array([
        label for label, _ in sorted(label_to_id.items(), key=lambda x: x[1])
    ])
    B = np.zeros((n_samples, n_classes))
    B[labels[:, 0], labels[:, 1]] = 1 - alpha

    remaining_iter = max_iter
    while remaining_iter > 0:
        F = P.dot(F) + B
        remaining_iter -= 1

    predicted_label_ids = np.argmax(F, axis=1)
    predicted = label_dict[predicted_label_ids].tolist()

    all_labeled_mask = data.train_mask + data.val_mask + data.test_mask
    count = 0
    for node_id, row in enumerate(F):
        if row.max() >= (row.sum() -
                         row.max()) * 2 and not all_labeled_mask[node_id]:
            data.train_mask[node_id] = True
            data.y[node_id] = predicted[node_id]
            count += 1

    print('Label propagation: Label additional {} nodes in the training set.'.
          format(count))
    pickle.dump((data.train_mask, data.y),
                open(f'{ROOT}/cache/label/{name}.pt', 'wb'))
Ejemplo n.º 20
0
def test_mtgnn():
    """
    Testing MTGNN block
    """
    gcn_true = True
    buildA_true = True
    cl = True
    dropout = 0.3
    subgraph_size = 20
    gcn_depth = 2
    num_nodes = 207
    node_dim = 40
    dilation_exponential = 1
    conv_channels = 32
    residual_channels = 32
    skip_channels = 64
    end_channels = 128
    in_dim = 2
    seq_in_len = 12
    seq_out_len = 12
    layers = 3
    batch_size = 16
    learning_rate = 0.001
    weight_decay = 0.00001
    clip = 5
    step_size1 = 2500
    step_size2 = 100
    epochs = 3
    seed = 101
    propalpha = 0.05
    tanhalpha = 3
    num_split = 1

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.set_num_threads(3)
    x, edge_index = create_mock_data(number_of_nodes=num_nodes,
                                     edge_per_node=8,
                                     in_channels=in_dim)
    mock_adj = to_scipy_sparse_matrix(edge_index)
    predefined_A = torch.tensor(mock_adj.toarray()).to(device)
    x_all = torch.zeros(batch_size, seq_in_len, num_nodes, in_dim)
    for i in range(batch_size):
        for j in range(seq_in_len):
            x, _ = create_mock_data(number_of_nodes=num_nodes,
                                    edge_per_node=8,
                                    in_channels=in_dim)
            x_all[i, j] = x
    # define model and optimizer
    model = MTGNN(gcn_true,
                  buildA_true,
                  gcn_depth,
                  num_nodes,
                  predefined_A=predefined_A,
                  dropout=dropout,
                  subgraph_size=subgraph_size,
                  node_dim=node_dim,
                  dilation_exponential=dilation_exponential,
                  conv_channels=conv_channels,
                  residual_channels=residual_channels,
                  skip_channels=skip_channels,
                  end_channels=end_channels,
                  seq_length=seq_in_len,
                  in_dim=in_dim,
                  out_dim=seq_out_len,
                  layers=layers,
                  propalpha=propalpha,
                  tanhalpha=tanhalpha,
                  layer_norm_affline=True)
    trainx = torch.Tensor(x_all).to(device)
    trainx = trainx.transpose(1, 3)
    perm = np.random.permutation(range(num_nodes))
    num_sub = int(num_nodes / num_split)  # number of nodes in each sudgraph
    for j in range(num_split):
        if j != num_split - 1:
            id = perm[j * num_sub:(j + 1) * num_sub]
        else:
            id = perm[j * num_sub:]
        id = torch.tensor(id).to(device)  # a permutation of node id
        tx = trainx[:, :, id, :]
        output = model(tx, idx=id)
        output = output.transpose(1, 3)
        assert output.shape == (batch_size, 1, num_nodes, seq_out_len)