Exemplo n.º 1
0
def load_data(args):
    dataset = args.input
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)

    if dataset in ['cora', 'citeseer', 'pubmed']:
        dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
        num_features = dataset.num_features
        num_classes = dataset.num_classes
        data = dataset[0]
        return data, num_features, num_classes
    elif dataset == 'corafull':
        dataset = CoraFull(path)
    elif dataset in ['cs', 'physics']:
        dataset = Coauthor(path, name=dataset)
    elif dataset in ['computers', 'photo']:
        dataset = Amazon(path, name=dataset)
    elif dataset == 'reddit':
        dataset = Reddit(path)
        num_features = dataset.num_features
        num_classes = dataset.num_classes
        data = dataset[0]
        return data, num_features, num_classes
    num_features = dataset.num_features
    num_classes = dataset.num_classes
    data = dataset[0]

    data.train_mask, data.val_mask, data.test_mask = generate_split(
        data, num_classes)

    return data, num_features, num_classes
Exemplo n.º 2
0
def load_dataset(dataset='flickr'):
    """

    Args:
        dataset: str, name of dataset, assuming the raw dataset path is ./data/your_dataset/raw.
                 torch_geometric.dataset will automatically preprocess the raw files and save preprocess dataset into
                 ./data/your_dataset/preprocess

    Returns:
        dataset
    """
    path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', dataset)
    if dataset == 'flickr':
        dataset = Flickr(path)

    elif dataset == 'reddit':
        dataset = Reddit(path)

    elif dataset == 'ppi':
        dataset = PPI(path)

    elif dataset == 'ppi-large':
        dataset = PPI(path)

    elif dataset == 'yelp':
        dataset = Yelp(path)

    else:
        raise KeyError('Dataset name error')

    return dataset
Exemplo n.º 3
0
 def __init__(self, args=None):
     self.url = "https://data.dgl.ai/dataset/reddit.zip"
     dataset = "Reddit"
     path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
     if not osp.exists(path):
         Reddit(path)
     super(RedditDataset, self).__init__(path, transform=T.TargetIndegree())
Exemplo n.º 4
0
def load_reddit(dataset):
    data_name = ['Reddit']
    assert dataset in data_name
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'Datasets',
                    'NodeData', 'Reddit')
    dataset = Reddit(path)
    return dataset
Exemplo n.º 5
0
def _run_trainer():
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
                    'Reddit')
    print("Load Dataset")
    dataset = Reddit(path)
    data = dataset[0]
    print("Load Train Sampler")
    train_loader = NeighborSampler(data.edge_index,
                                   node_idx=data.train_mask,
                                   sizes=[25, 10],
                                   batch_size=1024,
                                   shuffle=True,
                                   num_workers=0)

    print("Creating SAGE model")
    model = SAGE(dataset.num_features, 256, dataset.num_classes)

    optimizer = DistributedOptimizer(torch.optim.Adam,
                                     model.parameter_rrefs(),
                                     lr=0.01)

    print("Start training")
    for epoch in range(1, 11):
        loss, acc = train(model, optimizer, epoch, data, train_loader)
        print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')
Exemplo n.º 6
0
 def __init__(self):
     dataset = "Reddit"
     path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data",
                     dataset)
     if not osp.exists(path):
         Reddit(path)
     super(RedditDataset, self).__init__(path, transform=T.TargetIndegree())
Exemplo n.º 7
0
def load_dataset(dataset):
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)

    if dataset in ['cora', 'citeseer', 'pubmed']:
        dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
        num_features = dataset.num_features
        num_classes = dataset.num_classes
        data = dataset[0]
        data.adj = torch.zeros((data.x.size(0), data.x.size(0)))
        col, row = data.edge_index
        data.adj[col, row] = 1
        return data, num_features, num_classes
    elif dataset == 'reddit':
        dataset = Reddit(path)
    elif dataset == 'corafull':
        dataset = CoraFull(path)
    num_features = dataset.num_features
    num_classes = dataset.num_classes
    data = dataset[0]

    data.train_mask, data.val_mask, data.test_mask = generate_split(
        data, num_classes)
    data.adj = torch.zeros((data.x.size[0], data.x.size(0)))
    col, row = data.edge_index
    data.adj[col, row] = 1
    return data, num_features, num_classes
Exemplo n.º 8
0
def gen_reddit_dataset():
    """Returns a function to be called on each worker that returns Reddit Dataset."""

    # For Reddit dataset, we have to download the data on each node, so we create the
    # dataset on each training worker.
    with FileLock(os.path.expanduser("~/.reddit_dataset_lock")):
        dataset = Reddit("./data/Reddit")
    return dataset
def main(args):
    path = osp.join('..', 'data', 'Reddit')
    dataset = Reddit(path)
    data = dataset[0]

    features = data.x
    labels = data.y
    train_mask = torch.BoolTensor(data.train_mask)
    val_mask = torch.BoolTensor(data.val_mask)
    test_mask = torch.BoolTensor(data.test_mask)

    edge_index = data.edge_index
    edge_index, _ = remove_self_loops(edge_index)
    edge_index, _ = add_self_loops(edge_index, num_nodes=features.size(0))

    model = GAT(num_layers=args.num_layers,
                in_feats=features.size(-1),
                num_hidden=args.num_hidden,
                num_classes=dataset.num_classes,
                heads=[1, 1, 1],
                dropout=args.dropout)

    loss_fcn = nn.CrossEntropyLoss()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    dur = []
    for epoch in range(1, 1 + args.epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        logits = model(features, edge_index)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)
            print('Training time/epoch {}'.format(np.mean(dur)))

        if args.eval:
            acc = evaluate(model, edge_index, features, labels, val_mask)
        else:
            acc = 0
        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} ".format(
            epoch, np.mean(dur), loss.item(), acc))

    if args.eval:
        print()
        acc = evaluate(model, edge_index, features, labels, test_mask)
        print("Test Accuracy {:.4f}".format(acc))
Exemplo n.º 10
0
def readdata():
    from torch_geometric.datasets import Reddit
    reddit = Reddit(root="data/reddit/")
    print(reddit.data.y.max())
    edges = reddit.data.edge_index.transpose(0, 1).numpy().tolist()
    g = nx.Graph()
    g = nx.Graph()
    g.add_edges_from(edges)
    g.add_edges_from(edges)

    g.subgraph([])
Exemplo n.º 11
0
def main(args):
    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    path = osp.join('..', 'data', 'Reddit')
    dataset = Reddit(path)
    data = dataset[0]

    features = data.x.to(device)
    labels = data.y.to(device)
    edge_index = data.edge_index.to(device)
    adj = SparseTensor(row=edge_index[0], col=edge_index[1])
    train_mask = torch.BoolTensor(data.train_mask).to(device)
    val_mask = torch.BoolTensor(data.val_mask).to(device)
    test_mask = torch.BoolTensor(data.test_mask).to(device)

    model = GraphSAGE(dataset.num_features, args.n_hidden, dataset.num_classes,
                      args.aggr, F.relu, args.dropout).to(device)

    loss_fcn = nn.CrossEntropyLoss()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    dur = []
    for epoch in range(1, args.epochs + 1):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        logits = model(features, adj)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        if args.eval:
            acc = evaluate(model, adj, features, labels, val_mask)
        else:
            acc = 0
        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} ".
              format(epoch, np.mean(dur), loss.item(), acc))

    if args.eval:
        print()
        acc = evaluate(model, adj, features, labels, test_mask)
        print("Test Accuracy {:.4f}".format(acc))
def get_dataset(dataset_name):
    """
    Retrieves the dataset corresponding to the given name.
    """
    path = 'dataset'
    if dataset_name == 'reddit':
        dataset = Reddit(path)
    elif dataset_name == 'amazon_comp':
        dataset = Amazon(path, name="Computers")
        data = dataset.data
        idx_train, idx_test = train_test_split(list(range(data.x.shape[0])),
                                               test_size=0.4,
                                               random_state=42)
        idx_val, idx_test = train_test_split(idx_test,
                                             test_size=0.5,
                                             random_state=42)

        train_mask = torch.tensor([False] * data.x.shape[0])
        val_mask = torch.tensor([False] * data.x.shape[0])
        test_mask = torch.tensor([False] * data.x.shape[0])

        train_mask[idx_train] = True
        val_mask[idx_val] = True
        test_mask[idx_test] = True

        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask
        dataset.data = data
    elif dataset_name in ["Cora", "CiteSeer", "PubMed"]:
        dataset = Planetoid(
            path,
            name=dataset_name,
            split="full",
        )
    else:
        raise NotImplementedError

    return dataset
Exemplo n.º 13
0
def get_dataset(dataset_name):
    """
    Retrieves the dataset corresponding to the given name.
    """
    print("Getting dataset...")
    path = join('dataset', dataset_name)
    if dataset_name == 'reddit':
        dataset = Reddit(path)
    elif dataset_name == 'ppi':
        dataset = PPI(path)
    elif dataset_name == 'github':
        dataset = GitHub(path)
        data = dataset.data
        idx_train, idx_test = train_test_split(list(range(data.x.shape[0])),
                                               test_size=0.4,
                                               random_state=42)
        idx_val, idx_test = train_test_split(idx_test,
                                             test_size=0.5,
                                             random_state=42)
        data.train_mask = torch.tensor(idx_train)
        data.val_mask = torch.tensor(idx_val)
        data.test_mask = torch.tensor(idx_test)
        dataset.data = data
    elif dataset_name in ['amazon_comp', 'amazon_photo']:
        dataset = Amazon(path, "Computers", T.NormalizeFeatures()
                         ) if dataset_name == 'amazon_comp' else Amazon(
                             path, "Photo", T.NormalizeFeatures())
        data = dataset.data
        idx_train, idx_test = train_test_split(list(range(data.x.shape[0])),
                                               test_size=0.4,
                                               random_state=42)
        idx_val, idx_test = train_test_split(idx_test,
                                             test_size=0.5,
                                             random_state=42)
        data.train_mask = torch.tensor(idx_train)
        data.val_mask = torch.tensor(idx_val)
        data.test_mask = torch.tensor(idx_test)
        dataset.data = data
    elif dataset_name in ["Cora", "CiteSeer", "PubMed"]:
        dataset = Planetoid(path,
                            name=dataset_name,
                            split="full",
                            transform=T.NormalizeFeatures())
    else:
        raise NotImplementedError

    print("Dataset ready!")
    return dataset
Exemplo n.º 14
0
def prepare_data(dataset, seed):
    """
	:param dataset: name of the dataset used
	:return: data, in the correct format
	"""
    # Retrieve main path of project
    dirname = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

    # Download and store dataset at chosen location
    if dataset == 'Cora' or dataset == 'PubMed' or dataset == 'Citeseer':
        path = os.path.join(dirname, 'data')
        data = Planetoid(path, name=dataset, split='full')[0]
        # data.train_mask, data.val_mask, data.test_mask = split_function(data.y.numpy())
        data.num_classes = (max(data.y) + 1).item()
        # dataset = Planetoid(path, name=dataset, split='public', transform=T.NormalizeFeatures(), num_train_per_class=20, num_val=500, num_test=1000)
        # data = modify_train_mask(data)

    elif dataset == 'Amazon':
        path = os.path.join(dirname, 'data', 'Amazon')
        data = Amazon(path, 'photo')[0]
        data.num_classes = (max(data.y) + 1).item()
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy())
        # Amazon: 4896 train, 1224 val, 1530 test

    elif dataset == 'Reddit':
        path = os.path.join(dirname, 'data', 'Reedit')
        data = Reddit(path)[0]
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy())

    elif dataset == 'PPI':
        path = os.path.join(dirname, 'data', 'PPI')
        data = ppi_prepoc(path, seed)
        data.x = data.graphs[0].x
        data.num_classes = data.graphs[0].y.size(1)
        for df in data.graphs:
            df.num_classes = data.num_classes

    #elif dataset = 'MUTAG'

    # Get it in right format
    if dataset != 'PPI':
        print('Train mask is of size: ',
              data.train_mask[data.train_mask == True].shape)

# data = add_noise_features(data, args.num_noise)

    return data
Exemplo n.º 15
0
    def __init__(self, path: str):
        pyg_dataset = Reddit(os.path.join(path, '_pyg'))
        if hasattr(pyg_dataset, "__data_list__"):
            delattr(pyg_dataset, "__data_list__")
        if hasattr(pyg_dataset, "_data_list"):
            delattr(pyg_dataset, "_data_list")
        pyg_data = pyg_dataset[0]

        static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph(
            {
                'x': pyg_data.x,
                'y': pyg_data.y,
                'train_mask': getattr(pyg_data, 'train_mask'),
                'val_mask': getattr(pyg_data, 'val_mask'),
                'test_mask': getattr(pyg_data, 'test_mask')
            }, pyg_data.edge_index)
        super(RedditDataset, self).__init__([static_graph])
Exemplo n.º 16
0
def get_dataset(name, root, use_sparse_tensor):
    path = osp.join(osp.dirname(osp.realpath(__file__)), root, name)
    transform = T.ToSparseTensor() if use_sparse_tensor else None
    if name == 'ogbn-mag':
        if transform is None:
            transform = T.ToUndirected(merge=True)
        else:
            transform = T.Compose([T.ToUndirected(merge=True), transform])
        dataset = OGB_MAG(root=path,
                          preprocess='metapath2vec',
                          transform=transform)
    elif name == 'ogbn-products':
        dataset = PygNodePropPredDataset('ogbn-products',
                                         root=path,
                                         transform=transform)
    elif name == 'Reddit':
        dataset = Reddit(root=path, transform=transform)

    return dataset[0], dataset.num_classes
Exemplo n.º 17
0
def main():
    seed_everything(42)

    dataset = Reddit(osp.join('data', 'Reddit'))
    data = dataset[0]

    datamodule = LightningNodeData(data, data.train_mask, data.val_mask,
                                   data.test_mask, loader='neighbor',
                                   num_neighbors=[25, 10], batch_size=1024,
                                   num_workers=8)

    model = Model(dataset.num_node_features, dataset.num_classes)

    devices = torch.cuda.device_count()
    strategy = pl.strategies.DDPSpawnStrategy(find_unused_parameters=False)
    checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_acc', save_top_k=1)
    trainer = pl.Trainer(strategy=strategy, accelerator='gpu', devices=devices,
                         max_epochs=20, callbacks=[checkpoint])

    trainer.fit(model, datamodule)
    trainer.test(ckpt_path='best', datamodule=datamodule)
Exemplo n.º 18
0
def load_data(dataset_name):
    """
    Loads required data set and normalizes features.
    Implemented data sets are any of type Planetoid and Reddit.
    :param dataset_name: Name of data set
    :return: Tuple of dataset and extracted graph
    """
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
                    dataset_name)

    if dataset_name == 'cora_full':
        dataset = CoraFull(path, T.NormalizeFeatures())
    elif dataset_name.lower() == 'coauthor':
        dataset = Coauthor(path, 'Physics', T.NormalizeFeatures())
    elif dataset_name.lower() == 'reddit':
        dataset = Reddit(path, T.NormalizeFeatures())
    elif dataset_name.lower() == 'amazon':
        dataset = Amazon(path)
    else:
        dataset = Planetoid(path, dataset_name, T.NormalizeFeatures())

    print(f"Loading data set {dataset_name} from: ", path)
    data = dataset[0]  # Extract graph
    return dataset, data
Exemplo n.º 19
0
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Reddit
from torch_geometric.data import ClusterDataset, ClusterLoader
from torch_geometric.nn import SAGEConv

dataset = Reddit('../data/Reddit')

print('Partioning the graph... (this may take a while)')
cluster_dataset = ClusterDataset(dataset, num_parts=1500, save=True)
train_loader = ClusterLoader(cluster_dataset,
                             batch_size=20,
                             shuffle=True,
                             drop_last=True,
                             num_workers=6)
test_loader = ClusterLoader(cluster_dataset,
                            batch_size=20,
                            shuffle=False,
                            num_workers=6)
print('Done!')


class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.conv1 = SAGEConv(in_channels, 128, normalize=False)
        self.conv2 = SAGEConv(128, out_channels, normalize=False)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.2, training=self.training)
        x = F.relu(self.conv1(x, edge_index))
Exemplo n.º 20
0
import os.path as osp

import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.datasets import Reddit
from torch_geometric.data import NeighborSampler
import time
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import GATConv
from torch.nn import Linear as Lin

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
dataset = Reddit(path)
data = dataset[0]

train_loader = NeighborSampler(data.edge_index, node_idx=data.train_mask,
                               sizes=[25, 10], batch_size=1024, shuffle=True,
                               num_workers=8)
subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1],
                                  batch_size=1024, shuffle=False,
                                  num_workers=8)

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 heads):
        super(GAT, self).__init__()

        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
Exemplo n.º 21
0
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, ChebConv  # noqa
from torch_geometric.datasets import Reddit
parser = argparse.ArgumentParser()
parser.add_argument('--use_gdc',
                    action='store_true',
                    help='Use GDC preprocessing.')
args = parser.parse_args()

dataset = 'Reddit'
path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset', dataset)
dataset = Reddit(path, transform=T.NormalizeFeatures())
data = dataset[0]

if args.use_gdc:
    gdc = T.GDC(self_loop_weight=1,
                normalization_in='sym',
                normalization_out='col',
                diffusion_kwargs=dict(method='ppr', alpha=0.05),
                sparsification_kwargs=dict(method='topk', k=128, dim=0),
                exact=True)
    data = gdc(data)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from tqdm import tqdm
from torch_geometric.datasets import Reddit
from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler
from torch_geometric.nn import SAGEConv

# Some configuration
n_data_workers = 4
do_eval = True
data_dir = os.path.expandvars('$SCRATCH/pytorch-build/data/Reddit')

print('Preparing dataset')
dataset = Reddit(data_dir)
data = dataset[0]

cluster_data = ClusterData(data,
                           num_parts=1500,
                           recursive=False,
                           save_dir=dataset.processed_dir)
train_loader = ClusterLoader(cluster_data,
                             batch_size=20,
                             shuffle=True,
                             num_workers=n_data_workers)

subgraph_loader = NeighborSampler(data.edge_index,
                                  sizes=[-1],
                                  batch_size=1024,
                                  shuffle=False,
Exemplo n.º 23
0
 def setup(self, stage: Optional[str] = None):
     self.data = Reddit(self.data_dir, pre_transform=self.transform)[0]
Exemplo n.º 24
0
 def prepare_data(self):
     Reddit(self.data_dir, pre_transform=self.transform)
Exemplo n.º 25
0
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from tqdm import tqdm
from torch_geometric.datasets import Reddit, Reddit2
from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler
from torch_geometric.nn import SAGEConv

# 需要下载 https://data.dgl.ai/dataset/reddit.zip 到 data/Reddit 文件夹下
dataset = Reddit('data/Reddit')
# dataset = Reddit2('data/Reddit2')
data = dataset[0]

cluster_data = ClusterData(data, num_parts=1500, recursive=False,
                           save_dir=dataset.processed_dir)
train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True,
                             num_workers=12)

subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024,
                                  shuffle=False, num_workers=12)


class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.convs = ModuleList(
            [SAGEConv(in_channels, 128),
             SAGEConv(128, out_channels)])

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
Exemplo n.º 26
0
from torch_geometric.datasets import Reddit
from GPT_GNN.data import *

dataset = Reddit(root='/datadrive/dataset')
graph_reddit = Graph()
el = defaultdict(  #target_id
    lambda: defaultdict(  #source_id(
        lambda: int  # time
    ))
for i, j in tqdm(dataset.data.edge_index.t()):
    el[i.item()][j.item()] = 1

target_type = 'def'
graph_reddit.edge_list['def']['def']['def'] = el
n = list(el.keys())
degree = np.zeros(np.max(n) + 1)
for i in n:
    degree[i] = len(el[i])
x = np.concatenate((dataset.data.x.numpy(), np.log(degree).reshape(-1, 1)),
                   axis=-1)
graph_reddit.node_feature['def'] = pd.DataFrame({'emb': list(x)})

idx = np.arange(len(graph_reddit.node_feature[target_type]))
np.random.seed(43)
np.random.shuffle(idx)

graph_reddit.pre_target_nodes = idx[:int(len(idx) * 0.7)]
graph_reddit.train_target_nodes = idx[int(len(idx) * 0.7):int(len(idx) * 0.8)]
graph_reddit.valid_target_nodes = idx[int(len(idx) * 0.8):int(len(idx) * 0.9)]
graph_reddit.test_target_nodes = idx[int(len(idx) * 0.9):]
Exemplo n.º 27
0
def get_reddit_dataset(name):
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'datasets',
                    'node_data', name)
    dataset = Reddit(path)
    return dataset
Exemplo n.º 28
0
 def __init__(self, path):
     dataset = "Reddit"
     # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
     Reddit(path)
     super(RedditDataset, self).__init__(path)
Exemplo n.º 29
0
def main(args):
    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    path = osp.join('dataset', 'Reddit')
    dataset = Reddit(path)
    data = dataset[0]

    features = data.x.to(device)
    labels = data.y.to(device)
    edge_index = data.edge_index.to(device)
    adj = SparseTensor(row=edge_index[0], col=edge_index[1])
    train_mask = torch.BoolTensor(data.train_mask).to(device)
    val_mask = torch.BoolTensor(data.val_mask).to(device)
    test_mask = torch.BoolTensor(data.test_mask).to(device)

    model = GraphSAGE(dataset.num_features, args.n_hidden, dataset.num_classes,
                      args.aggr, F.relu, args.dropout).to(device)

    loss_fcn = nn.CrossEntropyLoss()

    logger = Logger(args.runs, args)
    dur = []
    for run in range(args.runs):
        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
        for epoch in range(1, args.epochs + 1):
            model.train()
            if epoch >= 3:
                t0 = time.time()
            # forward
            logits = model(features, adj)
            loss = loss_fcn(logits[train_mask], labels[train_mask])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if epoch >= 3:
                dur.append(time.time() - t0)
                print('Training time/epoch {}'.format(np.mean(dur)))

            if not args.eval:
                continue

            train_acc, val_acc, test_acc = evaluate(model, features, adj,
                                                    labels, train_mask,
                                                    val_mask, test_mask)
            logger.add_result(run, (train_acc, val_acc, test_acc))

            print(
                "Run {:02d} | Epoch {:05d} | Loss {:.4f} | Train {:.4f} | Val {:.4f} | Test {:.4f}"
                .format(run, epoch, loss.item(), train_acc, val_acc, test_acc))

        if args.eval:
            logger.print_statistics(run)

    if args.eval:
        logger.print_statistics()
Exemplo n.º 30
0
 def prepare_data(self):
     path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "..",
                     "data", self.NAME)
     self.dataset = Reddit(path)
     self.data = self.dataset[0]