def get_data(): dataset = args.name path = '../data/geometric/Entities-AIFB' trainset = Entities(path, "AIFB") testset = Entities(path, "AIFB") lenTrain = len(trainset) lenTest = len(testset) print("Len Dataset:", lenTrain) trainLoader = DataLoader(trainset[:lenTrain], batch_size=1, shuffle=False) testloader = DataLoader(trainset[:lenTest], batch_size=1, shuffle=False) print("Len TrainLoader:", len(trainLoader)) return trainLoader, testloader
def train_rgcn(): from torch_geometric.datasets import Entities name = 'MUTAG' path = osp.join(osp.dirname(osp.realpath(__file__)), './', 'data', 'Entities', name) dataset = Entities(path, name) data = dataset[0] from rgcn import RGCN x = torch.zeros(data.num_nodes, 16) torch.nn.init.xavier_uniform_(x) model = RGCN(x.size(1), 16, dataset.num_relations, 30, dataset.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) train_idx, test_idx, train_y, test_y = mix_train_test(name, 0.2) print('train_size: {:03d}, test_size: {:03d}'.format( train_idx.size(0), test_idx.size(0))) def train(): model.train() optimizer.zero_grad() out = model.forward_(x, data.edge_index, data.edge_type, None) loss = F.nll_loss(out[train_idx], train_y) loss.backward() optimizer.step() return loss.item() def test(): model.eval() out = model.forward_(x, data.edge_index, data.edge_type, None) pred = out[test_idx].max(1)[1] test_acc = pred.eq(test_y).sum().item() / test_y.size(0) pred = out[train_idx].max(1)[1] train_acc = pred.eq(train_y).sum().item() / train_y.size(0) return train_acc, test_acc for epoch in range(1, 100): loss = train() train_acc, test_acc = test() print( 'Epoch: {:02d}, Loss: {:.4f}, Train Accuracy: {:.4f}, Test Accuracy: {:.4f}' .format(epoch, loss, train_acc, test_acc))
def download_data(): from torch_geometric.datasets import Entities name = 'AM' path = osp.join(osp.dirname(osp.realpath(__file__)), './', 'data', 'Entities', name) dataset = Entities(path, name) data = dataset[0] print(dataset.num_relations) print(dataset.num_classes) print(data.num_nodes) print(data.train_idx.size(0)) print(data.test_idx.size(0)) # write edge_index and edge_types writer = open('data/Entities/AM/am_edge', 'w') src, dst = data.edge_index[0], data.edge_index[1] size = src.size(0) for idx in range(size): writer.write( '%d %d %s\n' % (src[idx].item(), dst[idx].item(), data.edge_type[idx].item())) # write labels writer = open('data/Entities/AM/am_label', 'w') train_idx, train_y = data.train_idx, data.train_y size = train_idx.size(0) for idx in range(size): writer.write('%d %d\n' % (train_idx[idx].item(), train_y[idx].item())) test_idx, test_y = data.test_idx, data.test_y size = test_idx.size(0) for idx in range(size): writer.write('%d %d\n' % (test_idx[idx].item(), test_y[idx].item())) x = torch.zeros(data.num_nodes, 32) torch.nn.init.xavier_uniform_(x) size = data.num_nodes writer = open('data/Entities/AM/am_feature', 'w') for idx in range(size): fs = x[idx].numpy() fs = [str(f) for f in fs] writer.write('%d\t%s\n' % (idx, ' '.join(fs)))
def mix_train_test(name, test_ratio): from torch_geometric.datasets import Entities path = osp.join(osp.dirname(osp.realpath(__file__)), './', 'data', 'Entities', name) dataset = Entities(path, name) data = dataset[0] train_test_idx = torch.cat([data.train_idx, data.test_idx]) train_test_y = torch.cat([data.train_y, data.test_y]) size = train_test_idx.size(0) index = np.array([x for x in range(size)]) random.shuffle(index) index = torch.from_numpy(index) train_size = int(size * (1 - test_ratio)) train_idx, train_y = train_test_idx[0:train_size], train_test_y[ 0:train_size] test_idx, test_y = train_test_idx[train_size:], train_test_y[train_size:] return train_idx, test_idx, train_y, test_y
optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() def test(): with torch.no_grad(): model.eval() _, pred = model(data).max(dim=1) correct = float(pred[data.test_mask].eq( data.y[data.test_mask]).sum().item()) acc = correct / data.test_mask.sum().item() print("Accuracy: {:.4f}".format(acc)) train() test() name = 'MUTAG' path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities', name) dataset = Entities(path, name) print(len(dataset), dataset.num_classes, dataset.num_node_features) data = dataset[0] print("Nodes: {}, Features: {}, Edges: {}".format(data.num_nodes, data.num_node_features, data.num_edges)) print("Directed: {}, Self loops: {}".format(data.is_directed(), data.contains_self_loops()))
from torch_geometric.datasets import Entities from torch_geometric.utils import k_hop_subgraph from torch_geometric.nn import RGCNConv, FastRGCNConv parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, choices=['AIFB', 'MUTAG', 'BGS', 'AM']) args = parser.parse_args() # Trade memory consumption for faster computation. if args.dataset in ['AIFB', 'MUTAG']: RGCNConv = FastRGCNConv path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities') dataset = Entities(path, args.dataset) data = dataset[0] # BGS and AM graphs are too big to process them in a full-batch fashion. # Since our model does only make use of a rather small receptive field, we # filter the graph to only contain the nodes that are at most 2-hop neighbors # away from any training/test node. node_idx = torch.cat([data.train_idx, data.test_idx], dim=0) node_idx, edge_index, mapping, edge_mask = k_hop_subgraph(node_idx, 2, data.edge_index, relabel_nodes=True) data.num_nodes = node_idx.size(0) data.edge_index = edge_index data.edge_type = data.edge_type[edge_mask]
def load_dataset(name): """ Load real-world datasets, available in PyTorch Geometric. Used as a helper for DiskDataSource. """ task = "graph" if name == "enzymes": dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") elif name == "proteins": dataset = TUDataset(root="/tmp/PROTEINS", name="PROTEINS") elif name == "cox2": dataset = TUDataset(root="/tmp/cox2", name="COX2") elif name == "aids": dataset = TUDataset(root="/tmp/AIDS", name="AIDS") elif name == "reddit-binary": dataset = TUDataset(root="/tmp/REDDIT-BINARY", name="REDDIT-BINARY") elif name == "imdb-binary": dataset = TUDataset(root="/tmp/IMDB-BINARY", name="IMDB-BINARY") elif name == "firstmm_db": dataset = TUDataset(root="/tmp/FIRSTMM_DB", name="FIRSTMM_DB") elif name == "dblp": dataset = TUDataset(root="/tmp/DBLP_v1", name="DBLP_v1") elif name == "ppi": dataset = PPI(root="/tmp/PPI") elif name == "qm9": dataset = QM9(root="/tmp/QM9") elif name == "atlas": dataset = [g for g in nx.graph_atlas_g()[1:] if nx.is_connected(g)] elif name == 'aifb': dataset = Entities(root="/tmp/aifb", name='AIFB') # 90 edge types elif name == 'wn18': dataset = WordNet18(root="/tmp/wn18") elif name == 'fb15k237': dataset = [None] if task == "graph": train_len = int(0.8 * len(dataset)) train, test = [], [] if name not in ['aifb', 'wn18', 'fb15k237']: dataset = list(dataset) random.shuffle(dataset) has_name = hasattr(dataset[0], "name") else: has_name = True for i, graph in tqdm(enumerate(dataset)): if not type(graph) == nx.Graph: try: if has_name: del graph.name except: pass if name == 'aifb': graph = pyg_utils.to_networkx(graph, edge_attrs=['edge_type']) elif name == 'wn18': graph = pyg_utils.to_networkx(graph, edge_attrs=['edge_type']) elif name == 'fb15k237': data = FB15k_237() (graph, _, _, _) = data.load() graph = graph.to_networkx() edge_type_dict = [] for j in graph.edges: edge_type_dict.append(graph.edges[j]['label']) edge_type_dict = { i: ind for ind, i in enumerate(sorted(set(edge_type_dict))) } for j in graph.edges: graph.edges[j]['edge_type'] = edge_type_dict[ graph.edges[j]['label']] del graph.edges[j]['label'] del graph.edges[j]['weight'] else: graph = pyg_utils.to_networkx(graph).to_undirected() if name == 'aifb': train.append(graph) test.append(deepcopy(graph)) elif name == 'wn18': train.append(graph) test.append(deepcopy(graph)) elif name == 'fb15k237': train.append(graph) test.append(deepcopy(graph)) else: if i < train_len: train.append(graph) else: test.append(graph) return train, test, task
from itertools import product import torch from runtime.gat import GAT from runtime.gcn import GCN from runtime.rgcn import RGCN from runtime.train import train_runtime from torch_geometric.datasets import Entities, Planetoid device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data') Cora = Planetoid(osp.join(root, 'Cora'), 'Cora') CiteSeer = Planetoid(osp.join(root, 'CiteSeer'), 'CiteSeer') PubMed = Planetoid(osp.join(root, 'PubMed'), 'PubMed') MUTAG = Entities(osp.join(root, 'EntitiesMUTAG'), 'MUTAG') # One training run before we start tracking duration to warm up GPU. model = GCN(Cora.num_features, Cora.num_classes) train_runtime(model, Cora[0], epochs=200, device=device) for d, Net in product([Cora, CiteSeer, PubMed], [GCN, GAT]): model = Net(d.num_features, d.num_classes) t = train_runtime(model, d[0], epochs=200, device=device) print(f'{d.__repr__()[:-2]} - {Net.__name__}: {t:.2f}s') for d, Net in product([MUTAG], [RGCN]): model = Net(d[0].num_nodes, d.num_classes, d.num_relations) t = train_runtime(model, d[0], epochs=200, device=device) print(f'{d.__repr__()[:-2]} - {Net.__name__}: {t:.2f}s')
import os.path as osp import torch import torch.nn.functional as F from torch_geometric.datasets import Entities from torch_geometric.nn import RGATConv path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities') dataset = Entities(path, 'AIFB') data = dataset[0] data.x = torch.randn(data.num_nodes, 16) class RGAT(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_relations): super().__init__() self.conv1 = RGATConv(in_channels, hidden_channels, num_relations) self.conv2 = RGATConv(hidden_channels, hidden_channels, num_relations) self.lin = torch.nn.Linear(hidden_channels, out_channels) def forward(self, x, edge_index, edge_type): x = self.conv1(x, edge_index, edge_type).relu() x = self.conv2(x, edge_index, edge_type).relu() x = self.lin(x) return F.log_softmax(x, dim=-1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = data.to(device)