コード例 #1
0
def setup_data_loaders(data_folder,
                       batch_size=32,
                       run_test=False,
                       num_workers=3):
    """
    Sets up the dataloaders.
    """
    train_dataset = ZINC(data_folder, subset=True, split='train')
    val_dataset = ZINC(data_folder, subset=True, split='val')
    if run_test:
        test_dataset = ZINC(data_folder, subset=True, split='test')

    train_loader = DataLoader(train_dataset,
                              batch_size,
                              shuffle=True,
                              num_workers=num_workers)
    val_loader = DataLoader(val_dataset,
                            batch_size,
                            shuffle=False,
                            num_workers=num_workers)
    if run_test:
        test_loader = DataLoader(test_dataset,
                                 batch_size,
                                 shuffle=False,
                                 num_workers=num_workers)
    else:
        test_loader = None
    return train_loader, val_loader, test_loader
コード例 #2
0
ファイル: zinc.py プロジェクト: tchaton/lightning-geometric
    def prepare_data(self):
        path = osp.join(
            osp.dirname(osp.realpath(__file__)), "..", "..", "data", self.NAME
        )

        self.train_dataset = ZINC(path, subset=True, split="train")
        self.val_dataset = ZINC(path, subset=True, split="val")
        self.test_dataset = ZINC(path, subset=True, split="test")
コード例 #3
0
ファイル: train_zinc.py プロジェクト: toenshoff/CRaWl
def load_split_data(config):
    train_data = ZINC(DATA_PATH,
                      subset=True,
                      split='train',
                      transform=feat_transform,
                      pre_transform=preproc)
    val_data = ZINC(DATA_PATH,
                    subset=True,
                    split='val',
                    transform=feat_transform,
                    pre_transform=preproc)

    train_iter = CRaWlLoader(train_data,
                             shuffle=True,
                             batch_size=config['batch_size'],
                             num_workers=4)
    val_iter = CRaWlLoader(val_data, batch_size=100, num_workers=4)
    return train_iter, val_iter
コード例 #4
0
ファイル: configs.py プロジェクト: jingmouren/egc
def zinc_data(root, batch_size=128):
    data = dict()
    for split in ["train", "val", "test"]:
        data[split] = DataLoader(
            ZINC(root, subset=True, split=split),
            batch_size=batch_size,
            shuffle=(split == "train"),
        )

    return data
コード例 #5
0
def get_dataset(dataset_name):
    """
    Retrieves the dataset corresponding to the given name.
    """
    path = join('dataset', dataset_name)
    if dataset_name == 'reddit':
        dataset = Reddit(path)
    elif dataset_name == 'flickr':
        dataset = Flickr(path)
    elif dataset_name == 'zinc':
        dataset = ZINC(root='dataset', subset=True, split='train')
    elif dataset_name == 'QM9':
        dataset = QM9(root='dataset')
    elif dataset_name == 'github':
        dataset = GitHub(path)
    elif dataset_name == 'ppi':
        dataset = PPI(path)
    elif dataset_name in ['amazon_comp', 'amazon_photo']:
        dataset = Amazon(path, "Computers", T.NormalizeFeatures()
                         ) if dataset_name == 'amazon_comp' else Amazon(
                             path, "Photo", T.NormalizeFeatures())
        data = dataset.data
        idx_train, idx_test = train_test_split(list(range(data.x.shape[0])),
                                               test_size=0.4,
                                               random_state=42)
        idx_val, idx_test = train_test_split(idx_test,
                                             test_size=0.5,
                                             random_state=42)
        data.train_mask = torch.tensor(idx_train)
        data.val_mask = torch.tensor(idx_val)
        data.test_mask = torch.tensor(idx_test)
        dataset.data = data
    elif dataset_name in ["Cora", "CiteSeer", "PubMed"]:
        dataset = Planetoid(path,
                            name=dataset_name,
                            split="public",
                            transform=T.NormalizeFeatures())
    else:
        raise NotImplementedError

    return dataset
コード例 #6
0
from torch_geometric.nn.inits import zeros


argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--epochs', type=int, default=300)
argparser.add_argument('--hidden', type=int, default=100)
argparser.add_argument('--emb', type=int, default=100)
argparser.add_argument('--layers', type=int, default=4)
argparser.add_argument('--lr', type=float, default=0.001)
argparser.add_argument('--dropout', type=float, default=0.0)
argparser.add_argument('--rank', type=int, default=100)
argparser.add_argument('--batch', type=int, default=1000)
args = argparser.parse_args()
#args.hidden,args.rank = args.emb,args.emb

train_dataset = ZINC(osp.join('torch_geometric_data','zinc'), subset=True, split='train')
val_dataset = ZINC(osp.join('torch_geometric_data','zinc'), subset=True, split='val')
test_dataset = ZINC(osp.join('torch_geometric_data','zinc'), subset=True, split='test')

train_loader = DataLoader(train_dataset, batch_size=args.batch, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch)
test_loader = DataLoader(test_dataset, batch_size=args.batch)


class graph_cp_pooling(torch.nn.Module):
    def __init__(self, out_feats):
        super(graph_cp_pooling, self).__init__()
        self.w = torch.nn.Linear(out_feats+1, out_feats)
        self.reset_parameters()
        
    def reset_parameters(self):
コード例 #7
0
import os.path as osp

import torch
import torch.nn.functional as F
from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool
from torch_geometric.utils import degree

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC')
train_dataset = ZINC(path, subset=True, split='train')
val_dataset = ZINC(path, subset=True, split='val')
test_dataset = ZINC(path, subset=True, split='test')

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=128)

# Compute in-degree histogram over training data.
deg = torch.zeros(5, dtype=torch.long)
for data in train_dataset:
    d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
コード例 #8
0
ファイル: train_zinc_subset.py プロジェクト: rusty1s/himp-gnn
from model import Net

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--hidden_channels', type=int, default=128)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--no_inter_message_passing', action='store_true')
args = parser.parse_args()
print(args)

root = 'data/ZINC'
transform = JunctionTree()

train_dataset = ZINC(root, subset=True, split='train', pre_transform=transform)
val_dataset = ZINC(root, subset=True, split='val', pre_transform=transform)
test_dataset = ZINC(root, subset=True, split='test', pre_transform=transform)

train_loader = DataLoader(train_dataset, 128, shuffle=True, num_workers=12)
val_loader = DataLoader(val_dataset, 1000, shuffle=False, num_workers=12)
test_loader = DataLoader(test_dataset, 1000, shuffle=False, num_workers=12)

device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
model = Net(hidden_channels=args.hidden_channels,
            out_channels=1,
            num_layers=args.num_layers,
            dropout=args.dropout,
            inter_message_passing=not args.no_inter_message_passing).to(device)

コード例 #9
0
ファイル: zinc_main.py プロジェクト: cvignac/SMP
                                                       verbose=True)
lr_limit = args.lr_limit

if args.load_model:
    model = torch.load(savedir + model_name + '.pkl')

pytorch_total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters", pytorch_total_params)

loss_fct = nn.L1Loss(reduction='sum')

# Load the data
batch_size = args.batch_size
transform = OneHotNodeEdgeFeatures(model_config['num_input_features'] - 1, model_config['num_edge_features'])

train_data = ZINC(rootdir, subset=args.subset, split='train', pre_transform=transform)
val_data = ZINC(rootdir, subset=args.subset, split='val', pre_transform=transform)
test_data = ZINC(rootdir, subset=args.subset, split='test', pre_transform=transform)

train_loader = DataLoader(train_data, batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size, shuffle=False)

print("Starting to train")
for epoch in range(args.epochs):
    if args.load_model:
        break
    epoch_start = time.time()
    tr_loss = train()
    current_lr = optimizer.param_groups[0]["lr"]
    if current_lr < lr_limit:
コード例 #10
0
 def __init__(self, f1_alpha, f2_alpha, f3_alpha):
     dataset = ZINC('data/ZINC')
     super(ZINCSampler, self).__init__(dataset, f1_alpha, f2_alpha,
                                       f3_alpha)
コード例 #11
0
ファイル: test.py プロジェクト: jingmouren/egc
def load_zinc(root):
    dataset = ZINC(root, subset=True)
    batch = Batch.from_data_list([dataset[i] for i in range(10000)])
    return batch
コード例 #12
0
ファイル: main_zinc.py プロジェクト: WillHua127/MFGNN-new
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=512,
        help='dimensionality of hidden units in GNNs (default: 300)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--dataset',
                        type=str,
                        default="molhiv",
                        help='dataset name (default: ogbg-molhiv)')
    parser.add_argument(
        '--rank',
        type=int,
        default=512,
        help='dimensionality of rank units in GNNs (default: 300)')
    parser.add_argument('--filename',
                        type=str,
                        default="",
                        help='filename to output result (default: )')
    parser.add_argument('--lr', type=float, default=0.003)
    parser.add_argument('--wd',
                        type=float,
                        default=5e-5,
                        help='Weight decay (L2 loss on parameters).')
    args = parser.parse_args()

    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

    train_dataset = ZINC(os.path.join('torch_geometric_data', 'zinc'),
                         subset=True,
                         split='train')
    val_dataset = ZINC(os.path.join('torch_geometric_data', 'zinc'),
                       subset=True,
                       split='val')
    test_dataset = ZINC(os.path.join('torch_geometric_data', 'zinc'),
                        subset=True,
                        split='test')
    n_classes = 1
    in_feat = train_dataset[0].x.shape[1]

    train_graphs = [to_dgl(g) for g in train_dataset]
    val_graphs = [to_dgl(g) for g in val_dataset]
    test_graphs = [to_dgl(g) for g in test_dataset]
    ### automatic evaluator. takes dataset name as input
    train_loader = DataLoader(train_graphs,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              collate_fn=collate_dgl)
    valid_loader = DataLoader(val_graphs,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=collate_dgl)
    test_loader = DataLoader(test_graphs,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             collate_fn=collate_dgl)

    model = GNN(num_tasks=n_classes,
                in_dim=in_feat,
                num_layer=args.num_layer,
                emb_dim=args.emb_dim,
                rank=args.rank,
                drop_ratio=args.drop_ratio).to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.wd)

    valid_curve = []
    test_curve = []
    train_curve = []

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train(model, device, train_loader, optimizer)

        print('Evaluating...')
        train_perf = eval(model, device, train_loader)
        valid_perf = eval(model, device, valid_loader)
        test_perf = eval(model, device, test_loader)

        print({
            'Train': train_perf,
            'Validation': valid_perf,
            'Test': test_perf
        })

        train_curve.append(train_perf)
        valid_curve.append(valid_perf)
        test_curve.append(test_perf)

    best_val_epoch = np.argmin(np.array(valid_curve))
    best_train = min(train_curve)

    print('Finished training!')
    print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
    print('Test score: {}'.format(test_curve[best_val_epoch]))

    if not args.filename == '':
        torch.save(
            {
                'Val': valid_curve[best_val_epoch],
                'Test': test_curve[best_val_epoch],
                'Train': train_curve[best_val_epoch],
                'BestTrain': best_train
            }, args.filename)
コード例 #13
0
                        type=str,
                        default='test',
                        choices={'test', 'val'},
                        help="split to evaluate on")
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print('Loading Graphs...')

    data = ZINC(DATA_PATH,
                subset=True,
                split=args.split,
                transform=feat_transform,
                pre_transform=preproc)
    iter = CRaWlLoader(data, batch_size=50, num_workers=4)

    mean_list, std_list = [], []
    model_list = sorted(list(glob(args.model_dir)))
    for model_dir in model_list:
        print(f'Evaluating {model_dir}...')
        model = CRaWl.load(model_dir)
        mean, std = test(model, iter, repeats=args.reps, steps=args.steps)

        mean_list.append(mean)
        std_list.append(std)

    print(
コード例 #14
0
 def _zinc(self):
     dataset = ZINC('data/ZINC', transform=ZINCTransformer())
     mean = dataset.data.y.mean()
     std = dataset.data.y.std()
     dataset.data.y = (dataset.data.y - mean) / std
     return dataset, std.item(), 28, 4
コード例 #15
0
ファイル: train_zinc.py プロジェクト: thegodone/phc-gnn
def main():
    args = get_parser()
    # get some argparse arguments that are parsed a bool string
    naive_encoder = not str2bool(args.full_encoder)
    pin_memory = str2bool(args.pin_memory)
    use_bias = str2bool(args.bias)
    downstream_bn = str(args.d_bn)
    same_dropout = str2bool(args.same_dropout)
    mlp_mp = str2bool(args.mlp_mp)
    subset_data = str2bool(args.subset_data)
    phm_dim = args.phm_dim
    learn_phm = str2bool(args.learn_phm)

    base_dir = "zinc/"
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)

    if base_dir not in args.save_dir:
        args.save_dir = os.path.join(base_dir, args.save_dir)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    set_logging(save_dir=args.save_dir)
    logging.info(f"Creating log directory at {args.save_dir}.")
    with open(os.path.join(args.save_dir, "params.json"), 'w') as fp:
        json.dump(args.__dict__, fp)

    mp_layers = [int(item) for item in args.mp_units.split(',')]
    downstream_layers = [int(item) for item in args.d_units.split(',')]
    mp_dropout = [float(item) for item in args.dropout_mpnn.split(',')]
    dn_dropout = [float(item) for item in args.dropout_dn.split(',')]
    logging.info(f'Initialising model with {mp_layers} hidden units with dropout {mp_dropout} '
                 f'and downstream units: {downstream_layers} with dropout {dn_dropout}.')

    if args.pooling == "globalsum":
        logging.info("Using GlobalSum Pooling")
    else:
        logging.info("Using SoftAttention Pooling")


    logging.info(f"Using Adam optimizer with weight_decay ({args.weightdecay}) and regularization "
                 f"norm ({args.regularization})")
    logging.info(f"Weight init: {args.w_init} \n Contribution init: {args.c_init}")

    # data
    path = osp.join(osp.dirname(osp.realpath(__file__)), 'dataset', 'ZINC')
    train_data = ZINC(path, subset=subset_data, split='train')
    valid_data = ZINC(path, subset=subset_data, split='val')
    test_data = ZINC(path, subset=subset_data, split='test')
    evaluator = Evaluator()

    train_loader = DataLoader(train_data, batch_size=args.batch_size, drop_last=False,
                              shuffle=True, num_workers=args.nworkers, pin_memory=pin_memory)
    valid_loader = DataLoader(valid_data, batch_size=args.batch_size, drop_last=False,
                              shuffle=False, num_workers=args.nworkers, pin_memory=pin_memory)
    test_loader = DataLoader(test_data, batch_size=args.batch_size, drop_last=False,
                             shuffle=False, num_workers=args.nworkers, pin_memory=pin_memory)

    #transform = RemoveIsolatedNodes()
    transform = None

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    #device = "cpu"

    # for hypercomplex model
    unique_phm = str2bool(args.unique_phm)
    if unique_phm:
        phm_rule = get_multiplication_matrices(phm_dim=args.phm_dim, type="phm")
        phm_rule = torch.nn.ParameterList(
            [torch.nn.Parameter(a, requires_grad=learn_phm) for a in phm_rule]
        )
    else:
        phm_rule = None


    #https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/data/molecules/prepare_molecules_ZINC_full.ipynb
    FULL_ATOM_FEATURE_DIMS = [28]
    FULL_BOND_FEATURE_DIMS = [4]

    if args.aggr_msg == "pna" or args.aggr_node == "pna":
        # if PNA is used
        # Compute in-degree histogram over training data.
        deg = torch.zeros(5, dtype=torch.long)
        for data in train_data:
            d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
            deg += torch.bincount(d, minlength=deg.numel())
    else:
        deg = None

    aggr_kwargs = {"aggregators": ['mean', 'min', 'max', 'std'],
                   "scalers": ['identity', 'amplification', 'attenuation'],
                   "deg": deg,
                   "post_layers": 1,
                   "msg_scalers": str2bool(args.msg_scale),  # this key is for directional messagepassing layers.

                   "initial_beta": 1.0,  # Softmax
                   "learn_beta": True
                   }


    if "quaternion" in args.type:
        if args.aggr_msg == "pna" or args.aggr_msg == "pna":
            logging.info("PNA not implemented for quaternion models.")
            raise NotImplementedError

    if args.type == "undirectional-quaternion-sc-add":
        logging.info("Using Quaternion Undirectional MPNN with Skip Connection through Addition")
        model = UQ_SC_ADD(atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                          atom_encoded_dim=args.input_embed_dim,
                          bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder,
                          mp_layers=mp_layers, dropout_mpnn=mp_dropout,
                          init=args.w_init, same_dropout=same_dropout,
                          norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node,
                          mlp=mlp_mp, pooling=args.pooling, activation=args.activation,
                          real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=1,
                          dropout_dn=dn_dropout, norm_dn=downstream_bn,
                          msg_encoder=args.msg_encoder,
                          **aggr_kwargs)
    elif args.type == "undirectional-quaternion-sc-cat":
        logging.info("Using Quaternion Undirectional MPNN with Skip Connection through Concatenation")
        model = UQ_SC_CAT(atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                          atom_encoded_dim=args.input_embed_dim,
                          bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder,
                          mp_layers=mp_layers, dropout_mpnn=mp_dropout,
                          init=args.w_init, same_dropout=same_dropout,
                          norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node,
                          mlp=mlp_mp, pooling=args.pooling, activation=args.activation,
                          real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=1,
                          dropout_dn=dn_dropout, norm_dn=downstream_bn,
                          msg_encoder=args.msg_encoder,
                          **aggr_kwargs)
    elif args.type == "undirectional-phm-sc-add":
        logging.info("Using PHM Undirectional MPNN with Skip Connection through Addition")
        model = UPH_SC_ADD(phm_dim=phm_dim, learn_phm=learn_phm, phm_rule=phm_rule,
                           atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                           atom_encoded_dim=args.input_embed_dim,
                           bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder,
                           mp_layers=mp_layers, dropout_mpnn=mp_dropout,
                           w_init=args.w_init, c_init=args.c_init,
                           same_dropout=same_dropout,
                           norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node,
                           mlp=mlp_mp, pooling=args.pooling, activation=args.activation,
                           real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=1,
                           dropout_dn=dn_dropout, norm_dn=downstream_bn,
                           msg_encoder=args.msg_encoder,
                           sc_type=args.sc_type,
                           **aggr_kwargs)
    elif args.type == "undirectional-phm-sc-cat":
        logging.info("Using PHM Undirectional MPNN with Skip Connection through Concatenation")
        model = UPH_SC_CAT(phm_dim=phm_dim, learn_phm=learn_phm, phm_rule=phm_rule,
                           atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                           atom_encoded_dim=args.input_embed_dim,
                           bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder,
                           mp_layers=mp_layers, dropout_mpnn=mp_dropout,
                           w_init=args.w_init, c_init=args.c_init,
                           same_dropout=same_dropout,
                           norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node,
                           mlp=mlp_mp, pooling=args.pooling, activation=args.activation,
                           real_trafo=args.real_trafo, downstream_layers=downstream_layers,
                           target_dim=1,
                           dropout_dn=dn_dropout, norm_dn=downstream_bn,
                           msg_encoder=args.msg_encoder,
                           **aggr_kwargs)
    else:
        raise ModuleNotFoundError

    logging.info(f"Model consists of {model.get_number_of_params_()} trainable parameters")
    # do runs
    test_best_epoch_metrics_arr = []
    test_last_epoch_metrics_arr = []
    val_metrics_arr = []

    t0 = time.time()
    for i in range(1, args.n_runs + 1):
        ogb_bestEpoch_test_metrics, ogb_lastEpoch_test_metric, ogb_val_metrics = do_run(i, model, args,
                                                                                        transform,
                                                                                        train_loader, valid_loader,
                                                                                        test_loader, device, evaluator,
                                                                                        t0)

        test_best_epoch_metrics_arr.append(ogb_bestEpoch_test_metrics)
        test_last_epoch_metrics_arr.append(ogb_lastEpoch_test_metric)
        val_metrics_arr.append(ogb_val_metrics)


    logging.info(f"Performance of model across {args.n_runs} runs:")
    test_bestEpoch_perf = torch.tensor(test_best_epoch_metrics_arr)
    test_lastEpoch_perf = torch.tensor(test_last_epoch_metrics_arr)
    valid_perf = torch.tensor(val_metrics_arr)
    logging.info('===========================')
    logging.info(f'Final Test (best val-epoch) '
                 f'"{evaluator.eval_metric}": {test_bestEpoch_perf.mean():.4f} ± {test_bestEpoch_perf.std():.4f}')
    logging.info(f'Final Test (last-epoch) '
                 f'"{evaluator.eval_metric}": {test_lastEpoch_perf.mean():.4f} ± {test_lastEpoch_perf.std():.4f}')
    logging.info(f'Final (best) Valid "{evaluator.eval_metric}": {valid_perf.mean():.4f} ± {valid_perf.std():.4f}')
コード例 #16
0
def load_dataset(args):
    # automatic data loading and splitting
    transform = add_zeros if args.dataset == 'ogbg-ppa' else None
    cls_criterion = get_loss_function(args.dataset)
    idx2word_mapper = None

    if args.dataset == 'mnist':
        train_data = MNISTSuperpixels(root='dataset',
                                      train=True,
                                      transform=T.Polar())
        dataset = train_data
        dataset.name = 'mnist'
        dataset.eval_metric = 'acc'
        validation_data = []
        test_data = MNISTSuperpixels(root='dataset',
                                     train=False,
                                     transform=T.Polar())

        train_data = list(train_data)
        test_data = list(test_data)

    elif args.dataset == 'QM9':
        # Contains 19 targets. Use only the first 12 (0-11)
        QM9_VALIDATION_START = 110000
        QM9_VALIDATION_END = 120000
        dataset = QM9(root='dataset',
                      transform=ExtractTargetTransform(args.target)).shuffle()
        dataset.name = 'QM9'
        dataset.eval_metric = 'mae'

        train_data = dataset[:QM9_VALIDATION_START]
        validation_data = dataset[QM9_VALIDATION_START:QM9_VALIDATION_END]
        test_data = dataset[QM9_VALIDATION_END:]

        train_data = list(train_data)
        validation_data = list(validation_data)
        test_data = list(test_data)

    elif args.dataset == 'zinc':
        train_data = ZINC(root='dataset', subset=True, split='train')

        dataset = train_data
        dataset.name = 'zinc'
        validation_data = ZINC(root='dataset', subset=True, split='val')
        test_data = ZINC(root='dataset', subset=True, split='test')
        dataset.eval_metric = 'mae'

        train_data = list(train_data)
        validation_data = list(validation_data)
        test_data = list(test_data)

    elif args.dataset in [
            'ogbg-molhiv', 'ogbg-molpcba', 'ogbg-ppa', 'ogbg-code2'
    ]:
        dataset = PygGraphPropPredDataset(name=args.dataset,
                                          transform=transform)

        if args.dataset == 'obgb-code2':
            seq_len_list = np.array([len(seq) for seq in dataset.data.y])
            max_seq_len = args.max_seq_len
            num_less_or_equal_to_max = np.sum(
                seq_len_list <= args.max_seq_len) / len(seq_len_list)
            print(
                f'Target sequence less or equal to {max_seq_len} is {num_less_or_equal_to_max}%.'
            )

        split_idx = dataset.get_idx_split()
        # The following is only used in the evaluation of the ogbg-code classifier.
        if args.dataset == 'ogbg-code2':
            vocab2idx, idx2vocab = get_vocab_mapping(
                [dataset.data.y[i] for i in split_idx['train']],
                args.num_vocab)
            # specific transformations for the ogbg-code dataset
            dataset.transform = transforms.Compose([
                augment_edge,
                lambda data: encode_y_to_arr(data, vocab2idx, args.max_seq_len)
            ])
            idx2word_mapper = partial(decode_arr_to_seq, idx2vocab=idx2vocab)

        train_data = list(dataset[split_idx["train"]])
        validation_data = list(dataset[split_idx["valid"]])
        test_data = list(dataset[split_idx["test"]])

    return dataset, train_data, validation_data, test_data, cls_criterion, idx2word_mapper