Ejemplo n.º 1
0
def load_pyg(name, dataset_dir):
    '''
    load pyg format dataset
    :param name: dataset name
    :param dataset_dir: data directory
    :return: a list of networkx/deepsnap graphs
    '''
    dataset_dir = '{}/{}'.format(dataset_dir, name)
    if name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset_raw = Planetoid(dataset_dir, name)
    elif name[:3] == 'TU_':
        # TU_IMDB doesn't have node features
        if name[3:] == 'IMDB':
            name = 'IMDB-MULTI'
            dataset_raw = TUDataset(dataset_dir, name, transform=T.Constant())
        else:
            dataset_raw = TUDataset(dataset_dir, name[3:])
        # TU_dataset only has graph-level label
        # The goal is to have synthetic tasks
        # that select smallest 100 graphs that have more than 200 edges
        if cfg.dataset.tu_simple and cfg.dataset.task != 'graph':
            size = []
            for data in dataset_raw:
                edge_num = data.edge_index.shape[1]
                edge_num = 9999 if edge_num < 200 else edge_num
                size.append(edge_num)
            size = torch.tensor(size)
            order = torch.argsort(size)[:100]
            dataset_raw = dataset_raw[order]
    elif name == 'Karate':
        dataset_raw = KarateClub()
    elif 'Coauthor' in name:
        if 'CS' in name:
            dataset_raw = Coauthor(dataset_dir, name='CS')
        else:
            dataset_raw = Coauthor(dataset_dir, name='Physics')
    elif 'Amazon' in name:
        if 'Computers' in name:
            dataset_raw = Amazon(dataset_dir, name='Computers')
        else:
            dataset_raw = Amazon(dataset_dir, name='Photo')
    elif name == 'MNIST':
        dataset_raw = MNISTSuperpixels(dataset_dir)
    elif name == 'PPI':
        dataset_raw = PPI(dataset_dir)
    elif name == 'QM7b':
        dataset_raw = QM7b(dataset_dir)
    else:
        raise ValueError('{} not support'.format(name))
    graphs = GraphDataset.pyg_to_graphs(dataset_raw)
    return graphs
Ejemplo n.º 2
0
def load_pyg(name, dataset_dir):
    """
    Load PyG dataset objects. (More PyG datasets will be supported)

    Args:
        name (string): dataset name
        dataset_dir (string): data directory

    Returns: PyG dataset object

    """
    dataset_dir = '{}/{}'.format(dataset_dir, name)
    if name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset = Planetoid(dataset_dir, name)
    elif name[:3] == 'TU_':
        # TU_IMDB doesn't have node features
        if name[3:] == 'IMDB':
            name = 'IMDB-MULTI'
            dataset = TUDataset(dataset_dir, name, transform=T.Constant())
        else:
            dataset = TUDataset(dataset_dir, name[3:])
    elif name == 'Karate':
        dataset = KarateClub()
    elif 'Coauthor' in name:
        if 'CS' in name:
            dataset = Coauthor(dataset_dir, name='CS')
        else:
            dataset = Coauthor(dataset_dir, name='Physics')
    elif 'Amazon' in name:
        if 'Computers' in name:
            dataset = Amazon(dataset_dir, name='Computers')
        else:
            dataset = Amazon(dataset_dir, name='Photo')
    elif name == 'MNIST':
        dataset = MNISTSuperpixels(dataset_dir)
    elif name == 'PPI':
        dataset = PPI(dataset_dir)
    elif name == 'QM7b':
        dataset = QM7b(dataset_dir)
    else:
        raise ValueError('{} not support'.format(name))

    return dataset
Ejemplo n.º 3
0
def load_pyg(name, dataset_dir):
    '''
    load pyg format dataset
    :param name: dataset name
    :param dataset_dir: data directory
    :return: a list of networkx/deepsnap graphs
    '''
    dataset_dir = '{}/{}'.format(dataset_dir, name)
    if name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset = Planetoid(dataset_dir, name)
    elif name[:3] == 'TU_':
        # TU_IMDB doesn't have node features
        if name[3:] == 'IMDB':
            name = 'IMDB-MULTI'
            dataset = TUDataset(dataset_dir, name, transform=T.Constant())
        else:
            dataset = TUDataset(dataset_dir, name[3:])
    elif name == 'Karate':
        dataset = KarateClub()
    elif 'Coauthor' in name:
        if 'CS' in name:
            dataset = Coauthor(dataset_dir, name='CS')
        else:
            dataset = Coauthor(dataset_dir, name='Physics')
    elif 'Amazon' in name:
        if 'Computers' in name:
            dataset = Amazon(dataset_dir, name='Computers')
        else:
            dataset = Amazon(dataset_dir, name='Photo')
    elif name == 'MNIST':
        dataset = MNISTSuperpixels(dataset_dir)
    elif name == 'PPI':
        dataset = PPI(dataset_dir)
    elif name == 'QM7b':
        dataset = QM7b(dataset_dir)
    else:
        raise ValueError('{} not support'.format(name))

    return dataset
Ejemplo n.º 4
0
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--gpu', type=int)
args = parser.parse_args()

if args.gpu is not None:
    torch.cuda.set_device(args.gpu)

max_degrees = {
    'IMDB-BINARY': 135,
    'IMDB-MULTI': 88,
    'COLLAB': 491,
}

transforms = []
if 'REDDIT' in args.dataset or args.dataset in max_degrees:
    transforms.append(T.Constant(1))
if args.dataset in max_degrees:
    transforms.append(T.OneHotDegree(max_degrees[args.dataset]))
print('transforms:', transforms)

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
                args.dataset)
dataset = TUDataset(path, name=args.dataset, transform=T.Compose(transforms))

# different seeds for different folds so that one particularly good or bad init doesn't affect the results for the whole seed
# multiply folds by 10 so that nets in different seeds are initialised with different seeds
seed = args.seed + 10 * args.fold
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

dataset = dataset.shuffle()
Ejemplo n.º 5
0
        if self.transform is not None:
            data_s = self.transform(data_s)
            data_t = self.transform(data_t)

        data = Data(num_nodes=pos_s.size(0))
        for key in data_s.keys:
            data['{}_s'.format(key)] = data_s[key]
        for key in data_t.keys:
            data['{}_t'.format(key)] = data_t[key]

        return data


transform = T.Compose([
    T.Constant(),
    T.KNNGraph(k=8),
    T.Cartesian(),
])
train_dataset = RandomGraphDataset(30, 60, 0, 20, transform=transform)
train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True,
                          follow_batch=['x_s', 'x_t'])

path = osp.join('..', 'data', 'PascalPF')
test_datasets = [PascalPF(path, cat, transform) for cat in PascalPF.categories]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
psi_1 = SplineCNN(1, args.dim, 2, args.num_layers, cat=False, dropout=0.0)
psi_2 = SplineCNN(args.rnd_dim, args.rnd_dim, 2, args.num_layers, cat=True,
                  dropout=0.0)
model = DGMC_modified_v2(psi_1, psi_2, num_steps=args.num_steps).to(device)
Ejemplo n.º 6
0
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.datasets import FAUST
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SplineConv

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FAUST')
pre_transform = T.Compose([T.FaceToEdge(), T.Constant(value=1)])
train_dataset = FAUST(path, True, T.Cartesian(), pre_transform)
test_dataset = FAUST(path, False, T.Cartesian(), pre_transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)
d = train_dataset[0]


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SplineConv(1, 32, dim=3, kernel_size=5, aggr='add')
        self.conv2 = SplineConv(32, 64, dim=3, kernel_size=5, aggr='add')
        self.conv3 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')
        self.conv4 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')
        self.conv5 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')
        self.conv6 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')
        self.lin1 = torch.nn.Linear(64, 256)
        self.lin2 = torch.nn.Linear(256, d.num_nodes)

    def forward(self, data):