예제 #1
0
import torch
import torch.nn.functional as F
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.utils import degree
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import PNAConv, BatchNorm, global_add_pool

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC')
train_dataset = ZINC(path, subset=True, split='train')
val_dataset = ZINC(path, subset=True, split='val')
test_dataset = ZINC(path, subset=True, split='test')

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=128)

# Compute in-degree histogram over training data.
deg = torch.zeros(5, dtype=torch.long)
for data in train_dataset:
    d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.node_emb = Embedding(21, 75)
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.datasets import MNISTSuperpixels
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SplineConv, voxel_grid, max_pool, max_pool_x

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST')
transform = T.Cartesian(cat=False)
train_dataset = MNISTSuperpixels(path, True, transform=transform)
test_dataset = MNISTSuperpixels(path, False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)
d = train_dataset


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = SplineConv(d.num_features, 32, dim=2, kernel_size=5)
        self.conv2 = SplineConv(32, 64, dim=2, kernel_size=5)
        self.conv3 = SplineConv(64, 64, dim=2, kernel_size=5)
        self.fc1 = torch.nn.Linear(4 * 64, 128)
        self.fc2 = torch.nn.Linear(128, d.num_classes)

    def forward(self, data):
        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
        cluster = voxel_grid(data.pos, data.batch, size=5, start=0, end=28)
        data.edge_attr = None
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data).max(1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)


if __name__ == '__main__':
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..',
                    'data/ModelNet10')
    pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024)
    train_dataset = ModelNet(path, '10', True, transform, pre_transform)
    test_dataset = ModelNet(path, '10', False, transform, pre_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=32,
                              shuffle=True,
                              num_workers=6)
    test_loader = DataLoader(test_dataset,
                             batch_size=32,
                             shuffle=False,
                             num_workers=6)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Net().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(1, 201):
        train(epoch)
        test_acc = test(test_loader)
        print(f'Epoch: {epoch:03d}, Test: {test_acc:.4f}')
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GraphConv, TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS')
dataset = TUDataset(path, name='PROTEINS')
dataset = dataset.shuffle()
n = len(dataset) // 10
test_dataset = dataset[:n]
train_dataset = dataset[n:]
test_loader = DataLoader(test_dataset, batch_size=60)
train_loader = DataLoader(train_dataset, batch_size=60)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = GraphConv(dataset.num_features, 128)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = GraphConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = GraphConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)

        self.lin1 = torch.nn.Linear(256, 128)
예제 #5
0
def cross_validation_with_val_set(dataset,
                                  model,
                                  folds,
                                  epochs,
                                  batch_size,
                                  lr,
                                  lr_decay_factor,
                                  lr_decay_step_size,
                                  weight_decay,
                                  logger=None):

    val_losses, accs, durations = [], [], []
    for fold, (train_idx, test_idx,
               val_idx) in enumerate(zip(*k_fold(dataset, folds))):

        train_dataset = dataset[train_idx]
        test_dataset = dataset[test_idx]
        val_dataset = dataset[val_idx]

        if 'adj' in train_dataset[0]:
            train_loader = DenseLoader(train_dataset, batch_size, shuffle=True)
            val_loader = DenseLoader(val_dataset, batch_size, shuffle=False)
            test_loader = DenseLoader(test_dataset, batch_size, shuffle=False)
        else:
            train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size, shuffle=False)
            test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

        model.to(device).reset_parameters()
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_start = time.perf_counter()

        for epoch in range(1, epochs + 1):
            train_loss = train(model, optimizer, train_loader)
            val_losses.append(eval_loss(model, val_loader))
            accs.append(eval_acc(model, test_loader))
            eval_info = {
                'fold': fold,
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_losses[-1],
                'test_acc': accs[-1],
            }

            if logger is not None:
                logger(eval_info)

            if epoch % lr_decay_step_size == 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_decay_factor * param_group['lr']

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_end = time.perf_counter()
        durations.append(t_end - t_start)

    loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations)
    loss, acc = loss.view(folds, epochs), acc.view(folds, epochs)
    loss, argmin = loss.min(dim=1)
    acc = acc[torch.arange(folds, dtype=torch.long), argmin]

    loss_mean = loss.mean().item()
    acc_mean = acc.mean().item()
    acc_std = acc.std().item()
    duration_mean = duration.mean().item()
    print(f'Val Loss: {loss_mean:.4f}, Test Accuracy: {acc_mean:.3f} '
          f'± {acc_std:.3f}, Duration: {duration_mean:.3f}')

    return loss_mean, acc_mean, acc_std
예제 #6
0
 def train_dataloader(self):
     return DataLoader(self.train_dataset, batch_size=64, shuffle=True)
예제 #7
0
class HandleNodeAttention(object):
    def __call__(self, data):
        data.attn = torch.softmax(data.x, dim=0).flatten()
        data.x = None
        return data


transform = T.Compose([HandleNodeAttention(), T.OneHotDegree(max_degree=14)])
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TRIANGLES')
dataset = TUDataset(path,
                    name='TRIANGLES',
                    use_node_attr=True,
                    transform=transform)

train_loader = DataLoader(dataset[:30000], batch_size=60, shuffle=True)
val_loader = DataLoader(dataset[30000:35000], batch_size=60)
test_loader = DataLoader(dataset[35000:], batch_size=60)


class Net(torch.nn.Module):
    def __init__(self, in_channels):
        super(Net, self).__init__()

        self.conv1 = GINConv(Seq(Lin(in_channels, 64), ReLU(), Lin(64, 64)))
        self.pool1 = SAGPooling(64, min_score=0.001, GNN=GCNConv)
        self.conv2 = GINConv(Seq(Lin(64, 64), ReLU(), Lin(64, 64)))
        self.pool2 = SAGPooling(64, min_score=0.001, GNN=GCNConv)
        self.conv3 = GINConv(Seq(Lin(64, 64), ReLU(), Lin(64, 64)))

        self.lin = torch.nn.Linear(64, 1)
예제 #8
0
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
init_wandb(name=f'GIN-{args.dataset}',
           batch_size=args.batch_size,
           lr=args.lr,
           epochs=args.epochs,
           hidden_channels=args.hidden_channels,
           num_layers=args.num_layers,
           device=device)

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TU')
dataset = TUDataset(path, name=args.dataset).shuffle()

train_dataset = dataset[len(dataset) // 10:]
train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True)

test_dataset = dataset[:len(dataset) // 10]
test_loader = DataLoader(test_dataset, args.batch_size)


class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            mlp = MLP([in_channels, hidden_channels, hidden_channels])
            self.convs.append(GINConv(nn=mlp, train_eps=False))
            in_channels = hidden_channels
예제 #9
0
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models.re_net import RENet

seq_len = 10

# Load the dataset and precompute history objects.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ICEWS18')
train_dataset = ICEWS18(path, pre_transform=RENet.pre_transform(seq_len))
test_dataset = ICEWS18(path, split='test')

# path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'GDELT')
# train_dataset = GDELT(path, pre_transform=RENet.pre_transform(seq_len))
# test_dataset = ICEWS18(path, split='test')

# Create dataloader for training and test dataset.
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True,
                          follow_batch=['h_sub', 'h_obj'], num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False,
                         follow_batch=['h_sub', 'h_obj'], num_workers=6)

# Initialize model and optimizer.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RENet(
    train_dataset.num_nodes,
    train_dataset.num_rels,
    hidden_channels=200,
    seq_len=seq_len,
    dropout=0.5,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
                             weight_decay=0.00001)
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SchNet

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9')
dataset = QM9(path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for target in range(12):
    model, datasets = SchNet.from_qm9_pretrained(path, dataset, target)
    train_dataset, val_dataset, test_dataset = datasets

    model = model.to(device)
    loader = DataLoader(test_dataset, batch_size=256)

    maes = []
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data.z, data.pos, data.batch)
        mae = (pred.view(-1) - data.y[:, target]).abs()
        maes.append(mae)

    mae = torch.cat(maes, dim=0)

    # Report meV instead of eV.
    mae = 1000 * mae if target in [2, 3, 4, 6, 7, 8, 9, 10] else mae

    print(f'Target: {target:02d}, MAE: {mae.mean():.5f} ± {mae.std():.5f}')
예제 #11
0
def test_to_superpixels():
    import torchvision.transforms as T
    from torchvision.datasets.mnist import (
        MNIST,
        read_image_file,
        read_label_file,
    )

    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))

    raw_folder = osp.join(root, 'MNIST', 'raw')
    processed_folder = osp.join(root, 'MNIST', 'processed')

    makedirs(raw_folder)
    makedirs(processed_folder)
    for resource in resources:
        path = download_url(resource, raw_folder)
        extract_gz(path, osp.join(root, raw_folder))

    test_set = (
        read_image_file(osp.join(raw_folder, 't10k-images-idx3-ubyte')),
        read_label_file(osp.join(raw_folder, 't10k-labels-idx1-ubyte')),
    )

    torch.save(test_set, osp.join(processed_folder, 'training.pt'))
    torch.save(test_set, osp.join(processed_folder, 'test.pt'))

    dataset = MNIST(root, download=False)

    dataset.transform = T.Compose([T.ToTensor(), ToSLIC()])

    data, y = dataset[0]
    assert len(data) == 2
    assert data.pos.dim() == 2 and data.pos.size(1) == 2
    assert data.x.dim() == 2 and data.x.size(1) == 1
    assert data.pos.size(0) == data.x.size(0)
    assert y == 7

    loader = DataLoader(dataset, batch_size=2, shuffle=False)
    for data, y in loader:
        assert len(data) == 4
        assert data.pos.dim() == 2 and data.pos.size(1) == 2
        assert data.x.dim() == 2 and data.x.size(1) == 1
        assert data.batch.dim() == 1
        assert data.ptr.dim() == 1
        assert data.pos.size(0) == data.x.size(0) == data.batch.size(0)
        assert y.tolist() == [7, 2]
        break

    dataset.transform = T.Compose(
        [T.ToTensor(), ToSLIC(add_seg=True, add_img=True)])

    data, y = dataset[0]
    assert len(data) == 4
    assert data.pos.dim() == 2 and data.pos.size(1) == 2
    assert data.x.dim() == 2 and data.x.size(1) == 1
    assert data.pos.size(0) == data.x.size(0)
    assert data.seg.size() == (1, 28, 28)
    assert data.img.size() == (1, 1, 28, 28)
    assert data.seg.max().item() + 1 == data.x.size(0)
    assert y == 7

    loader = DataLoader(dataset, batch_size=2, shuffle=False)
    for data, y in loader:
        assert len(data) == 6
        assert data.pos.dim() == 2 and data.pos.size(1) == 2
        assert data.x.dim() == 2 and data.x.size(1) == 1
        assert data.batch.dim() == 1
        assert data.ptr.dim() == 1
        assert data.pos.size(0) == data.x.size(0) == data.batch.size(0)
        assert data.seg.size() == (2, 28, 28)
        assert data.img.size() == (2, 1, 28, 28)
        assert y.tolist() == [7, 2]
        break

    shutil.rmtree(root)
예제 #12
0
parser = argparse.ArgumentParser()
parser.add_argument('--use_multi_aggregators', action='store_true',
                    help='Switch between EGC-S and EGC-M')
args = parser.parse_args()

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB')
dataset = OGBG('ogbg-molhiv', path, pre_transform=T.ToSparseTensor())
evaluator = Evaluator('ogbg-molhiv')

split_idx = dataset.get_idx_split()
train_dataset = dataset[split_idx['train']]
val_dataset = dataset[split_idx['valid']]
test_dataset = dataset[split_idx['test']]

train_loader = DataLoader(train_dataset, batch_size=32, num_workers=4,
                          shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256)
test_loader = DataLoader(test_dataset, batch_size=256)


class Net(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, num_heads, num_bases):
        super().__init__()
        if args.use_multi_aggregators:
            aggregators = ['sum', 'mean', 'max']
        else:
            aggregators = ['symnorm']

        self.encoder = AtomEncoder(hidden_channels)

        self.convs = torch.nn.ModuleList()
예제 #13
0
 def test_dataloader(self):
     return DataLoader(self.test_dataset, batch_size=128)
예제 #14
0
 def val_dataloader(self):
     return DataLoader(self.val_dataset, batch_size=128)
예제 #15
0
train_dataset = PPI(path, split='train')
val_dataset = PPI(path, split='val')
test_dataset = PPI(path, split='test')

# Group all training graphs into a single graph to perform sampling:
train_data = Batch.from_data_list(train_dataset)
loader = LinkNeighborLoader(train_data,
                            batch_size=2048,
                            shuffle=True,
                            neg_sampling_ratio=0.5,
                            num_neighbors=[10, 10],
                            num_workers=6,
                            persistent_workers=True)

# Evaluation loaders (one datapoint corresponds to a graph)
train_loader = DataLoader(train_dataset, batch_size=2)
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphSAGE(
    in_channels=train_dataset.num_features,
    hidden_channels=64,
    num_layers=2,
    out_channels=64,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)


def train():
    model.train()
예제 #16
0
파일: main_pyg.py 프로젝트: rpatil524/ogb
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=300,
        help='dimensionality of hidden units in GNNs (default: 300)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--dataset',
                        type=str,
                        default="ogbg-molhiv",
                        help='dataset name (default: ogbg-molhiv)')

    parser.add_argument('--feature',
                        type=str,
                        default="full",
                        help='full feature or simple feature')
    parser.add_argument('--filename',
                        type=str,
                        default="",
                        help='filename to output result (default: )')
    args = parser.parse_args()

    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

    ### automatic dataloading and splitting
    dataset = PygGraphPropPredDataset(name=args.dataset)

    if args.feature == 'full':
        pass
    elif args.feature == 'simple':
        print('using simple feature')
        # only retain the top two node/edge features
        dataset.data.x = dataset.data.x[:, :2]
        dataset.data.edge_attr = dataset.data.edge_attr[:, :2]

    split_idx = dataset.get_idx_split()

    ### automatic evaluator. takes dataset name as input
    evaluator = Evaluator(args.dataset)

    train_loader = DataLoader(dataset[split_idx["train"]],
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    valid_loader = DataLoader(dataset[split_idx["valid"]],
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers)
    test_loader = DataLoader(dataset[split_idx["test"]],
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin',
                    num_tasks=dataset.num_tasks,
                    num_layer=args.num_layer,
                    emb_dim=args.emb_dim,
                    drop_ratio=args.drop_ratio,
                    virtual_node=False).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin',
                    num_tasks=dataset.num_tasks,
                    num_layer=args.num_layer,
                    emb_dim=args.emb_dim,
                    drop_ratio=args.drop_ratio,
                    virtual_node=True).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn',
                    num_tasks=dataset.num_tasks,
                    num_layer=args.num_layer,
                    emb_dim=args.emb_dim,
                    drop_ratio=args.drop_ratio,
                    virtual_node=False).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn',
                    num_tasks=dataset.num_tasks,
                    num_layer=args.num_layer,
                    emb_dim=args.emb_dim,
                    drop_ratio=args.drop_ratio,
                    virtual_node=True).to(device)
    else:
        raise ValueError('Invalid GNN type')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    valid_curve = []
    test_curve = []
    train_curve = []

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train(model, device, train_loader, optimizer, dataset.task_type)

        print('Evaluating...')
        train_perf = eval(model, device, train_loader, evaluator)
        valid_perf = eval(model, device, valid_loader, evaluator)
        test_perf = eval(model, device, test_loader, evaluator)

        print({
            'Train': train_perf,
            'Validation': valid_perf,
            'Test': test_perf
        })

        train_curve.append(train_perf[dataset.eval_metric])
        valid_curve.append(valid_perf[dataset.eval_metric])
        test_curve.append(test_perf[dataset.eval_metric])

    if 'classification' in dataset.task_type:
        best_val_epoch = np.argmax(np.array(valid_curve))
        best_train = max(train_curve)
    else:
        best_val_epoch = np.argmin(np.array(valid_curve))
        best_train = min(train_curve)

    print('Finished training!')
    print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
    print('Test score: {}'.format(test_curve[best_val_epoch]))

    if not args.filename == '':
        torch.save(
            {
                'Val': valid_curve[best_val_epoch],
                'Test': test_curve[best_val_epoch],
                'Train': train_curve[best_val_epoch],
                'BestTrain': best_train
            }, args.filename)
예제 #17
0
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv
from sklearn.metrics import f1_score

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI')
train_dataset = PPI(path, split='train')
val_dataset = PPI(path, split='val')
test_dataset = PPI(path, split='test')
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GATConv(train_dataset.num_features, 256, heads=4)
        self.lin1 = torch.nn.Linear(train_dataset.num_features, 4 * 256)
        self.conv2 = GATConv(4 * 256, 256, heads=4)
        self.lin2 = torch.nn.Linear(4 * 256, 4 * 256)
        self.conv3 = GATConv(4 * 256,
                             train_dataset.num_classes,
                             heads=6,
                             concat=False)
        self.lin3 = torch.nn.Linear(4 * 256, train_dataset.num_classes)
예제 #18
0
파일: main_gnn.py 프로젝트: rpatil524/ogb
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with Pytorch Geometrics')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--train_subset', action='store_true')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    random.seed(42)

    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

    ### automatic dataloading and splitting
    dataset = PygPCQM4Mv2Dataset(root='dataset/')

    split_idx = dataset.get_idx_split()

    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4Mv2Evaluator()

    if args.train_subset:
        subset_ratio = 0.1
        subset_idx = torch.randperm(len(
            split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
        train_loader = DataLoader(dataset[split_idx["train"][subset_idx]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)
    else:
        train_loader = DataLoader(dataset[split_idx["train"]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

    valid_loader = DataLoader(dataset[split_idx["valid"]],
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers)

    if args.save_test_dir != '':
        testdev_loader = DataLoader(dataset[split_idx["test-dev"]],
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=args.num_workers)
        testchallenge_loader = DataLoader(dataset[split_idx["test-challenge"]],
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.num_workers)

    if args.checkpoint_dir != '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True,
                    **shared_params).to(device)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if args.log_dir != '':
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    if args.train_subset:
        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
        args.epochs = 1000
    else:
        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train_mae = train(model, device, train_loader, optimizer)

        print('Evaluating...')
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({'Train': train_mae, 'Validation': valid_mae})

        if args.log_dir != '':
            writer.add_scalar('valid/mae', valid_mae, epoch)
            writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir != '':
                print('Saving checkpoint...')
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_mae': best_valid_mae,
                    'num_params': num_params
                }
                torch.save(checkpoint,
                           os.path.join(args.checkpoint_dir, 'checkpoint.pt'))

            if args.save_test_dir != '':
                testdev_pred = test(model, device, testdev_loader)
                testdev_pred = testdev_pred.cpu().detach().numpy()

                testchallenge_pred = test(model, device, testchallenge_loader)
                testchallenge_pred = testchallenge_pred.cpu().detach().numpy()

                print('Saving test submission file...')
                evaluator.save_test_submission({'y_pred': testdev_pred},
                                               args.save_test_dir,
                                               mode='test-dev')
                evaluator.save_test_submission({'y_pred': testchallenge_pred},
                                               args.save_test_dir,
                                               mode='test-challenge')

        scheduler.step()

        print(f'Best validation MAE so far: {best_valid_mae}')

    if args.log_dir != '':
        writer.close()