Ejemplo n.º 1
0
def load_training_data() -> Tuple[pyg.data.Data, pyg.data.Data, Dict[
    str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
    '''
    Returns Tuple
        train_graph Graph containing a subset of the training edges
        valid_graph Graph containing all training edges
        train_edges Dict of positive edges across entire train split
        eval_edges Dict of positive edges from the training edges set that aren't in eval_graph
        valid_edges Dict of positive and negative edges not in train_graph.
    '''
    dataset = PygLinkPropPredDataset(name='ogbl-ddi')
    transform = T.ToSparseTensor(False)
    edge_split = dataset.get_edge_split()
    train_edges = edge_split['train']
    valid_edges = edge_split['valid']
    train_graph = dataset[0]
    valid_graph = train_graph.clone()

    # Partition training edges
    torch.manual_seed(12345)
    perm = torch.randperm(train_edges['edge'].shape[0])
    eval_idxs, train_idxs = perm[:valid_edges['edge'].
                                 shape[0]], perm[valid_edges['edge'].shape[0]:]
    eval_edges = {'edge': train_edges['edge'][eval_idxs]}
    train_edges = {'edge': train_edges['edge'][train_idxs]}

    # Update graph object to have subset of edges and adj_t matrix
    train_edge_index = torch.cat(
        [train_edges['edge'].T, train_edges['edge'][:, [1, 0]].T], dim=1)
    train_graph.edge_index = train_edge_index
    train_graph = transform(train_graph)
    valid_graph = transform(valid_graph)

    return train_graph, valid_graph, edge_split[
        'train'], eval_edges, valid_edges
Ejemplo n.º 2
0
def load_test_data(
) -> Tuple[pyg.torch_geometric.data.Data, pyg.torch_geometric.data.Data, Dict[
    str, torch.Tensor], Dict[str, torch.Tensor]]:
    '''
    Returns Tuple
        valid_graph Graph containing all training edges
        test_graph Graph containing all training edges, plus validation edges
        valid_edges Dict of positive and negative edges from validation edge split (not in train_graph)
        test_edges Dict of positive and negative edges from test edge split (not in valid_graph)
    '''
    dataset = PygLinkPropPredDataset(name='ogbl-ddi')
    transform = T.ToSparseTensor(False)
    edge_split = dataset.get_edge_split()
    valid_edges = edge_split['valid']
    test_edges = edge_split['test']
    valid_graph = dataset[0]
    test_graph = valid_graph.clone()

    # Add validation edges to valid_graph for test inference
    valid_edge_index = torch.cat([
        test_graph.edge_index, valid_edges['edge'].T,
        valid_edges['edge'][:, [1, 0]].T
    ],
                                 dim=1)
    test_graph.edge_index = valid_edge_index

    valid_graph = transform(valid_graph)
    test_graph = transform(test_graph)
    return valid_graph, test_graph, valid_edges, test_edges
Ejemplo n.º 3
0
 def __init__(self, path):
     dataset = "ogbl-biokg"
     # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
     PygLinkPropPredDataset(name=dataset, root=path)
     super(OGBLbiokgDataset, self).__init__(dataset, path)
     setattr(OGBLbiokgDataset, "metric", "MRR")
     setattr(OGBLbiokgDataset, "loss", "pos_neg_loss")
Ejemplo n.º 4
0
def main():

    args = ArgsInit().args

    if args.use_gpu:
        device = torch.device("cuda:" +
                              str(args.device)) if torch.cuda.is_available(
                              ) else torch.device("cpu")
    else:
        device = torch.device('cpu')

    dataset = PygLinkPropPredDataset(name=args.dataset)
    data = dataset[0]
    # Data(edge_index=[2, 2358104], edge_weight=[2358104, 1], edge_year=[2358104, 1], x=[235868, 128])
    split_edge = dataset.get_edge_split()
    evaluator = Evaluator(args.dataset)

    x = data.x.to(device)

    edge_index = data.edge_index.to(device)

    args.in_channels = data.x.size(-1)
    args.num_tasks = 1

    print(args)

    model = DeeperGCN(args).to(device)
    predictor = LinkPredictor(args).to(device)

    model.load_state_dict(torch.load(args.model_load_path)['model_state_dict'])
    model.to(device)

    predictor.load_state_dict(
        torch.load(args.predictor_load_path)['model_state_dict'])
    predictor.to(device)

    hits = ['Hits@10', 'Hits@50', 'Hits@100']

    result = test(model, predictor, x, edge_index, split_edge, evaluator,
                  args.batch_size)

    for k in hits:
        train_result, valid_result, test_result = result[k]
        print('{}--Train: {}, Validation: {}, Test: {}'.format(
            k, train_result, valid_result, test_result))
Ejemplo n.º 5
0
 def __init__(self, path):
     dataset = "ogbl-citation"
     # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
     PygLinkPropPredDataset(name=dataset,
                            root=path,
                            transform=T.ToSparseTensor())
     super(OGBLcitationDataset, self).__init__(dataset, path)
     setattr(OGBLcitationDataset, "metric", "MRR")
     setattr(OGBLcitationDataset, "loss", "pos_neg_loss")
Ejemplo n.º 6
0
def load_ogb(name, dataset_dir):
    if name[:4] == 'ogbn':
        dataset = PygNodePropPredDataset(name=name, root=dataset_dir)
        splits = dataset.get_idx_split()
        split_names = ['train_mask', 'val_mask', 'test_mask']
        for i, key in enumerate(splits.keys()):
            mask = index2mask(splits[key], size=dataset.data.y.shape[0])
            set_dataset_attr(dataset, split_names[i], mask, len(mask))
        edge_index = to_undirected(dataset.data.edge_index)
        set_dataset_attr(dataset, 'edge_index', edge_index,
                         edge_index.shape[1])

    elif name[:4] == 'ogbg':
        dataset = PygGraphPropPredDataset(name=name, root=dataset_dir)
        splits = dataset.get_idx_split()
        split_names = [
            'train_graph_index', 'val_graph_index', 'test_graph_index'
        ]
        for i, key in enumerate(splits.keys()):
            id = splits[key]
            set_dataset_attr(dataset, split_names[i], id, len(id))

    elif name[:4] == "ogbl":
        dataset = PygLinkPropPredDataset(name=name, root=dataset_dir)
        splits = dataset.get_edge_split()

        id = splits['train']['edge'].T
        if cfg.dataset.resample_negative:
            set_dataset_attr(dataset, 'train_pos_edge_index', id, id.shape[1])
            # todo: applying transform for negative sampling is very slow
            dataset.transform = neg_sampling_transform
        else:
            id_neg = negative_sampling(edge_index=id,
                                       num_nodes=dataset.data.num_nodes[0],
                                       num_neg_samples=id.shape[1])
            id_all = torch.cat([id, id_neg], dim=-1)
            label = get_link_label(id, id_neg)
            set_dataset_attr(dataset, 'train_edge_index', id_all,
                             id_all.shape[1])
            set_dataset_attr(dataset, 'train_edge_label', label, len(label))

        id, id_neg = splits['valid']['edge'].T, splits['valid']['edge_neg'].T
        id_all = torch.cat([id, id_neg], dim=-1)
        label = get_link_label(id, id_neg)
        set_dataset_attr(dataset, 'val_edge_index', id_all, id_all.shape[1])
        set_dataset_attr(dataset, 'val_edge_label', label, len(label))

        id, id_neg = splits['test']['edge'].T, splits['test']['edge_neg'].T
        id_all = torch.cat([id, id_neg], dim=-1)
        label = get_link_label(id, id_neg)
        set_dataset_attr(dataset, 'test_edge_index', id_all, id_all.shape[1])
        set_dataset_attr(dataset, 'test_edge_label', label, len(label))

    else:
        raise ValueError('OGB dataset: {} non-exist')

    return dataset
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(description="OGBL-Citation2 (Node2Vec)")
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--embedding_dim", type=int, default=128)
    parser.add_argument("--walk_length", type=int, default=40)
    parser.add_argument("--context_size", type=int, default=20)
    parser.add_argument("--walks_per_node", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--log_steps", type=int, default=1)
    args = parser.parse_args()

    device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name="ogbl-citation2")
    data = dataset[0]
    data.edge_index = to_undirected(data.edge_index, data.num_nodes)

    model = Node2Vec(
        data.edge_index,
        args.embedding_dim,
        args.walk_length,
        args.context_size,
        args.walks_per_node,
        sparse=True,
    ).to(device)

    loader = model.loader(batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=4)
    optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=args.lr)

    model.train()
    for epoch in range(1, args.epochs + 1):
        for i, (pos_rw, neg_rw) in enumerate(loader):
            optimizer.zero_grad()
            loss = model.loss(pos_rw.to(device), neg_rw.to(device))
            loss.backward()
            optimizer.step()

            if (i + 1) % args.log_steps == 0:
                print(f"Epoch: {epoch:02d}, Step: {i+1:03d}/{len(loader)}, "
                      f"Loss: {loss:.4f}")

            if (i + 1) % 100 == 0:  # Save model every 100 steps.
                save_embedding(model)
        save_embedding(model)
Ejemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-COLLAB (Node2Vec)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--embedding_dim', type=int, default=128)
    parser.add_argument('--walk_length', type=int, default=80)
    parser.add_argument('--context_size', type=int, default=20)
    parser.add_argument('--walks_per_node', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=2)
    parser.add_argument('--log_steps', type=int, default=1)
    args = parser.parse_args()

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-collab')
    data = dataset[0]

    edge_index = data.edge_index.to(device)
    perm = torch.argsort(edge_index[0] * data.num_nodes + edge_index[1])
    edge_index = edge_index[:, perm]

    model = Node2Vec(data.num_nodes, args.embedding_dim, args.walk_length,
                     args.context_size, args.walks_per_node).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    loader = DataLoader(torch.arange(data.num_nodes),
                        batch_size=args.batch_size,
                        shuffle=True)

    model.train()
    for epoch in range(1, args.epochs + 1):
        for i, subset in enumerate(loader):
            optimizer.zero_grad()
            loss = model.loss(edge_index, subset.to(edge_index.device))
            loss.backward()
            optimizer.step()

            if (i + 1) % args.log_steps == 0:
                print(f'Epoch: {epoch:02d}, Step: {i+1:03d}/{len(loader)}, '
                      f'Loss: {loss:.4f}')

            if (i + 1) % 100 == 0:  # Save model every 100 steps.
                save_embedding(model)
        save_embedding(model)
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-PPA (Node2Vec)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--embedding_dim', type=int, default=128)
    parser.add_argument('--walk_length', type=int, default=40)
    parser.add_argument('--context_size', type=int, default=20)
    parser.add_argument('--walks_per_node', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=2)
    parser.add_argument('--log_steps', type=int, default=1)
    args = parser.parse_args()

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-ppa')
    data = dataset[0]

    model = Node2Vec(data.edge_index,
                     args.embedding_dim,
                     args.walk_length,
                     args.context_size,
                     args.walks_per_node,
                     sparse=True).to(device)

    loader = model.loader(batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=4)
    optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=args.lr)

    model.train()
    for epoch in range(1, args.epochs + 1):
        for i, (pos_rw, neg_rw) in enumerate(loader):
            optimizer.zero_grad()
            loss = model.loss(pos_rw.to(device), neg_rw.to(device))
            loss.backward()
            optimizer.step()

            if (i + 1) % args.log_steps == 0:
                print(f'Epoch: {epoch:02d}, Step: {i+1:03d}/{len(loader)}, '
                      f'Loss: {loss:.4f}')

            if (i + 1) % 100 == 0:  # Save model every 100 steps.
                save_embedding(model)
        save_embedding(model)
Ejemplo n.º 10
0
def load_link_dataset(name,
                      hparams,
                      path="~/Bioinformatics_ExternalData/OGB/"):
    if "ogbl" in name:
        ogbl = PygLinkPropPredDataset(name=name, root=path)

        if isinstance(ogbl, PygLinkPropPredDataset) and not hasattr(ogbl[0], "edge_index_dict") \
                and not hasattr(ogbl[0], "edge_reltype"):
            dataset = EdgeSampler(ogbl,
                                  directed=True,
                                  add_reverse_metapaths=hparams.use_reverse)
            print(dataset.node_types, dataset.metapaths)
        else:
            dataset = TripletSampler(ogbl,
                                     directed=True,
                                     head_node_type=None,
                                     add_reverse_metapaths=hparams.use_reverse)
            print(dataset.node_types, dataset.metapaths)
    else:
        raise Exception(f"dataset {name} not found")

    return dataset
Ejemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(description="OGBL-COLLAB (GNN)")
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--log_steps", type=int, default=1)
    parser.add_argument("--use_sage", action="store_true")
    parser.add_argument("--use_valedges_as_input", action="store_true")
    parser.add_argument("--num_layers", type=int, default=3)
    parser.add_argument("--hidden_channels", type=int, default=256)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--batch_size", type=int, default=64 * 1024)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--epochs", type=int, default=400)
    parser.add_argument("--eval_steps", type=int, default=1)
    parser.add_argument("--runs", type=int, default=1)
    parser.add_argument("--seed",type=int,default=1)
    args = parser.parse_args()
    print(args)
    
    device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name="ogbl-collab")
    data = dataset[0]
    edge_index = data.edge_index
    data.edge_weight = data.edge_weight.view(-1).to(torch.float)
    data = T.ToSparseTensor()(data)

    split_edge = dataset.get_edge_split()

    
    # Use training + validation edges for inference on test set.
    if args.use_valedges_as_input:
        val_edge_index = split_edge["valid"]["edge"].t()
        full_edge_index = torch.cat([edge_index, val_edge_index], dim=-1)
        data.full_adj_t = SparseTensor.from_edge_index(full_edge_index).t()
        data.full_adj_t = data.full_adj_t.to_symmetric()
    else:
        data.full_adj_t = data.adj_t

    data = data.to(device)

    if args.use_sage:
        model = SAGE(
            data.num_features,
            args.hidden_channels,
            args.hidden_channels,
            args.num_layers,
            args.dropout,
        ).to(device)
    else:
        model = GCN(
            data.num_features,
            args.hidden_channels,
            args.hidden_channels,
            args.num_layers,
            args.dropout,
        ).to(device)

    predictor = LinkPredictor(
        args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout
    ).to(device)

    evaluator = Evaluator(name="ogbl-collab")
    loggers = {
        "Hits@10": Logger(args.runs, args),
        "Hits@50": Logger(args.runs, args),
        "Hits@100": Logger(args.runs, args),
    }

    for run in tqdm(range(args.runs)):
        torch.manual_seed(args.seed + run)
        np.random.seed(args.seed+run)
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(
            list(model.parameters()) + list(predictor.parameters()), lr=args.lr
        )

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, data, split_edge, optimizer, args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(
                    model, predictor, data, split_edge, evaluator, args.batch_size
                )
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(
                            f"Run: {run + 1:02d}, "
                            f"Epoch: {epoch:02d}, "
                            f"Loss: {loss:.4f}, "
                            f"Train: {100 * train_hits:.2f}%, "
                            f"Valid: {100 * valid_hits:.2f}%, "
                            f"Test: {100 * test_hits:.2f}%"
                        )
                    print("---")

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()
Ejemplo n.º 12
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-Citation2 (GraphSAINT)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=16 * 1024)
    parser.add_argument('--walk_length', type=int, default=3)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--num_steps', type=int, default=100)
    parser.add_argument('--eval_steps', type=int, default=10)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-citation2')
    split_edge = dataset.get_edge_split()
    data = dataset[0]
    data.edge_index = to_undirected(data.edge_index, data.num_nodes)

    loader = GraphSAINTRandomWalkSampler(data,
                                         batch_size=args.batch_size,
                                         walk_length=args.walk_length,
                                         num_steps=args.num_steps,
                                         sample_coverage=0,
                                         save_dir=dataset.processed_dir)

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['source_node'].numel())[:86596]
    split_edge['eval_train'] = {
        'source_node': split_edge['train']['source_node'][idx],
        'target_node': split_edge['train']['target_node'][idx],
        'target_node_neg': split_edge['valid']['target_node_neg'],
    }

    model = GCN(data.x.size(-1), args.hidden_channels, args.hidden_channels,
                args.num_layers, args.dropout).to(device)
    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-citation2')
    logger = Logger(args.runs, args)

    run_idx = 0

    while run_idx < args.runs:
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(list(model.parameters()) +
                                     list(predictor.parameters()),
                                     lr=args.lr)

        run_success = True
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, loader, optimizer, device)
            print(
                f'Run: {run_idx + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}'
            )
            if loss > 2.:
                run_success = False
                logger.reset(run_idx)
                print('Learning failed. Rerun...')
                break

            if epoch > 49 and epoch % args.eval_steps == 0:
                result = test(model,
                              predictor,
                              data,
                              split_edge,
                              evaluator,
                              batch_size=64 * 1024,
                              device=device)
                logger.add_result(run_idx, result)

                train_mrr, valid_mrr, test_mrr = result
                print(f'Run: {run_idx + 1:02d}, '
                      f'Epoch: {epoch:02d}, '
                      f'Loss: {loss:.4f}, '
                      f'Train: {train_mrr:.4f}, '
                      f'Valid: {valid_mrr:.4f}, '
                      f'Test: {test_mrr:.4f}')

        print('GraphSAINT')
        if run_success:
            logger.print_statistics(run_idx)
            run_idx += 1

    print('GraphSAINT')
    logger.print_statistics()
Ejemplo n.º 13
0
    data_list = []
    for src, dst in tqdm(link_index.t().tolist()):
        tmp = k_hop_subgraph(src, dst, num_hops, A, ratio_per_hop, 
                             max_nodes_per_hop, node_features=x, y=y, 
                             directed=directed, A_csc=A_csc)
        data = construct_pyg_graph(*tmp, node_label)
        data_list.append(data)

    return data_list

#######################
# MAIN
#######################
path = 'dataset/'
use_coalesce = False
dataset = PygLinkPropPredDataset(name='ogbl-ddi')
split_edge = dataset.get_edge_split()
data = dataset[0]

train_dataset = eval('SEALDataset')(
    path, 
    data, 
    split_edge, 
    num_hops=p_num_hops, 
    percent=p_train_percent, 
    split='train', 
    use_coalesce=use_coalesce, 
    node_label=p_node_label, 
    ratio_per_hop=p_ratio_per_hop, 
    max_nodes_per_hop=p_max_nodes_per_hop, 
    directed=directed, 
Ejemplo n.º 14
0
def main_get_mask(args, imp_num, rewind_weight_mask=None, rewind_predict_weight=None, resume_train_ckpt=None):

    device = torch.device("cuda:" + str(args.device))
    dataset = PygLinkPropPredDataset(name=args.dataset)
    data = dataset[0]
    
    # Data(edge_index=[2, 2358104], edge_weight=[2358104, 1], edge_year=[2358104, 1], x=[235868, 128])
    split_edge = dataset.get_edge_split()
    evaluator = Evaluator(args.dataset)

    x = data.x.to(device)

    edge_index = data.edge_index.to(device)

    args.in_channels = data.x.size(-1)
    args.num_tasks = 1

    model = DeeperGCN(args).to(device)
    pruning.add_mask(model, args)
    predictor = LinkPredictor(args).to(device)
    pruning.add_trainable_mask_noise(model, args, c=1e-4)
    optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr)

    results = {'epoch': 0 }
    keys = ['highest_valid', 'final_train', 'final_test', 'highest_train']
    hits = ['Hits@10', 'Hits@50', 'Hits@100']

    for key in keys:
        results[key] = {k: 0 for k in hits}

    start_epoch = 1
    if resume_train_ckpt:
        
        start_epoch = resume_train_ckpt['epoch']
        rewind_weight_mask = resume_train_ckpt['rewind_weight_mask']
        ori_model_dict = model.state_dict()
        over_lap = {k : v for k, v in resume_train_ckpt['model_state_dict'].items() if k in ori_model_dict.keys()}
        ori_model_dict.update(over_lap)
        model.load_state_dict(ori_model_dict)
        print("Resume at IMP:[{}] epoch:[{}] len:[{}/{}]!".format(imp_num, resume_train_ckpt['epoch'], len(over_lap.keys()), len(ori_model_dict.keys())))
        optimizer.load_state_dict(resume_train_ckpt['optimizer_state_dict'])
        adj_spar, wei_spar = pruning.print_sparsity(model, args)
    else:
        rewind_weight_mask = copy.deepcopy(model.state_dict())
        rewind_predict_weight = copy.deepcopy(predictor.state_dict())

    for epoch in range(start_epoch, args.mask_epochs + 1):

        t0 = time.time()     
        
        epoch_loss, prune_info_dict = train.train_mask(model, predictor, 
                                                                 x, 
                                                                 edge_index, 
                                                                 split_edge, 
                                                                 optimizer, args)

        result = train.test(model, predictor, 
                                   x, 
                                   edge_index, 
                                   split_edge, 
                                   evaluator, 
                                   args.batch_size, args)

        k = 'Hits@50'
        train_result, valid_result, test_result = result[k]
        if train_result > results['highest_train'][k]:
            results['highest_train'][k] = train_result

        if valid_result > results['highest_valid'][k]:
            results['highest_valid'][k] = valid_result
            results['final_train'][k] = train_result
            results['final_test'][k] = test_result
            results['epoch'] = epoch
            pruning.save_all(model, predictor, 
                                    rewind_weight_mask, 
                                    optimizer, 
                                    imp_num, 
                                    epoch, 
                                    args.model_save_path, 
                                    'IMP{}_train_ckpt'.format(imp_num))

        epoch_time = (time.time() - t0) / 60
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' +
              'IMP:[{}] (GET Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}] | Adj[{:.3f}%] Wei[{:.3f}%] Time:[{:.2f}min]'
              .format(imp_num, epoch, args.mask_epochs, 
                                      epoch_loss, 
                                      train_result * 100,
                                      valid_result * 100,
                                      test_result * 100,
                                      results['final_test'][k] * 100,
                                      results['epoch'],
                                      prune_info_dict['adj_spar'], 
                                      prune_info_dict['wei_spar'],
                                      epoch_time))

    rewind_weight_mask, adj_spar, wei_spar = pruning.change(rewind_weight_mask, model, args)
    print('-' * 100)
    print("INFO : IMP:[{}] (GET MASK) Final Result Train:[{:.2f}]  Valid:[{:.2f}]  Test:[{:.2f}] | Adj:[{:.3f}%] Wei:[{:.3f}%]"
        .format(imp_num, results['final_train'][k] * 100,
                         results['highest_valid'][k] * 100,
                         results['final_test'][k] * 100,
                         adj_spar, 
                         wei_spar))
    print('-' * 100)
    return rewind_weight_mask, rewind_predict_weight
Ejemplo n.º 15
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-PPA (MLP)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--eval_steps', type=int, default=1)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-ppa')
    splitted_edge = dataset.get_edge_split()
    data = dataset[0]

    x = data.x.to(torch.float)
    embedding = torch.load('embedding.pt', map_location='cpu')
    x = torch.cat([x, embedding], dim=-1)
    x = x.to(device)

    predictor = LinkPredictor(x.size(-1), args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    optimizer = torch.optim.Adam(predictor.parameters(), lr=args.lr)

    evaluator = Evaluator(name='ogbl-ppa')
    loggers = {
        'Hits@10': Logger(args.runs, args),
        'Hits@50': Logger(args.runs, args),
        'Hits@100': Logger(args.runs, args),
    }

    for run in range(args.runs):
        predictor.reset_parameters()

        for epoch in range(1, 1 + args.epochs):
            loss = train(predictor, x, splitted_edge, optimizer,
                         args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(predictor, x, splitted_edge, evaluator,
                               args.batch_size)
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()
Ejemplo n.º 16
0
Archivo: mlp.py Proyecto: rryoung98/ogb
def main():
    parser = argparse.ArgumentParser(description="OGBL-Citation2 (MLP)")
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--log_steps", type=int, default=1)
    parser.add_argument("--use_node_embedding", action="store_true")
    parser.add_argument("--num_layers", type=int, default=3)
    parser.add_argument("--hidden_channels", type=int, default=256)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--batch_size", type=int, default=64 * 1024)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--eval_steps", type=int, default=10)
    parser.add_argument("--runs", type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name="ogbl-citation2")
    split_edge = dataset.get_edge_split()
    data = dataset[0]

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge["train"]["source_node"].numel())[:86596]
    split_edge["eval_train"] = {
        "source_node": split_edge["train"]["source_node"][idx],
        "target_node": split_edge["train"]["target_node"][idx],
        "target_node_neg": split_edge["valid"]["target_node_neg"],
    }

    x = data.x
    if args.use_node_embedding:
        embedding = torch.load("embedding.pt", map_location="cpu")
        x = torch.cat([x, embedding], dim=-1)
    x = x.to(device)

    predictor = LinkPredictor(x.size(-1), args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name="ogbl-citation2")
    logger = Logger(args.runs, args)

    for run in range(args.runs):
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(predictor.parameters(), lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(predictor, x, split_edge, optimizer, args.batch_size)
            print(f"Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}")

            if epoch % args.eval_steps == 0:
                result = test(predictor, x, split_edge, evaluator,
                              args.batch_size)
                logger.add_result(run, result)

                if epoch % args.log_steps == 0:
                    train_mrr, valid_mrr, test_mrr = result
                    print(f"Run: {run + 1:02d}, "
                          f"Epoch: {epoch:02d}, "
                          f"Loss: {loss:.4f}, "
                          f"Train: {train_mrr:.4f}, "
                          f"Valid: {valid_mrr:.4f}, "
                          f"Test: {test_mrr:.4f}")

        print("Node2vec" if args.use_node_embedding else "MLP")
        logger.print_statistics(run)
    print("Node2vec" if args.use_node_embedding else "MLP")
    logger.print_statistics()
Ejemplo n.º 17
0
D = np.zeros((theta.size()[0], theta.size()[1]))
theta_numpy = theta.numpy()
theta_sum = np.sum(theta_numpy, axis=1).tolist()
for idx, val in enumerate(theta_sum):
  D[idx, idx] = val

A_h = torch.mm(theta, theta.T)
A_h = A_h -D
theta = A_h

# Load dataset

device = 'cpu' #f'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

dataset = PygLinkPropPredDataset(name='ogbl-ddi',
                                     transform=T.ToSparseTensor())
data = dataset[0]
adj_t = data.adj_t.to(device)

split_edge = dataset.get_edge_split()


# Randomly pick some training samples to perform evaluation on

torch.manual_seed(12345)
idx = torch.randperm(split_edge['train']['edge'].size(0))
idx = idx[:split_edge['valid']['edge'].size(0)]
split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}


# Load model, initial embeddings, the link predictor, and the evaluator
Ejemplo n.º 18
0
def main():
    parser = argparse.ArgumentParser(description="OGBL-PPA (MF)")
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--log_steps", type=int, default=1)
    parser.add_argument("--num_layers", type=int, default=3)
    parser.add_argument("--hidden_channels", type=int, default=256)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--batch_size", type=int, default=64 * 1024)
    parser.add_argument("--lr", type=float, default=0.005)
    parser.add_argument("--epochs", type=int, default=1000)
    parser.add_argument("--eval_steps", type=int, default=1)
    parser.add_argument("--runs", type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name="ogbl-ppa")
    split_edge = dataset.get_edge_split()
    data = dataset[0]

    emb = torch.nn.Embedding(data.num_nodes, args.hidden_channels).to(device)

    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name="ogbl-ppa")
    loggers = {
        "Hits@10": Logger(args.runs, args),
        "Hits@50": Logger(args.runs, args),
        "Hits@100": Logger(args.runs, args),
    }

    for run in range(args.runs):
        emb.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(list(emb.parameters()) +
                                     list(predictor.parameters()),
                                     lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(emb.weight, predictor, split_edge, optimizer,
                         args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(emb.weight, predictor, split_edge, evaluator,
                               args.batch_size)
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f"Run: {run + 1:02d}, "
                              f"Epoch: {epoch:02d}, "
                              f"Loss: {loss:.4f}, "
                              f"Train: {100 * train_hits:.2f}%, "
                              f"Valid: {100 * valid_hits:.2f}%, "
                              f"Test: {100 * test_hits:.2f}%")

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()
Ejemplo n.º 19
0
def load_ogb(name, dataset_dir):
    r"""

    Load OGB dataset objects.


    Args:
        name (string): dataset name
        dataset_dir (string): data directory

    Returns: PyG dataset object

    """
    from ogb.graphproppred import PygGraphPropPredDataset
    from ogb.linkproppred import PygLinkPropPredDataset
    from ogb.nodeproppred import PygNodePropPredDataset

    if name[:4] == 'ogbn':
        dataset = PygNodePropPredDataset(name=name, root=dataset_dir)
        splits = dataset.get_idx_split()
        split_names = ['train_mask', 'val_mask', 'test_mask']
        for i, key in enumerate(splits.keys()):
            mask = index_to_mask(splits[key], size=dataset.data.y.shape[0])
            set_dataset_attr(dataset, split_names[i], mask, len(mask))
        edge_index = to_undirected(dataset.data.edge_index)
        set_dataset_attr(dataset, 'edge_index', edge_index,
                         edge_index.shape[1])

    elif name[:4] == 'ogbg':
        dataset = PygGraphPropPredDataset(name=name, root=dataset_dir)
        splits = dataset.get_idx_split()
        split_names = [
            'train_graph_index', 'val_graph_index', 'test_graph_index'
        ]
        for i, key in enumerate(splits.keys()):
            id = splits[key]
            set_dataset_attr(dataset, split_names[i], id, len(id))

    elif name[:4] == "ogbl":
        dataset = PygLinkPropPredDataset(name=name, root=dataset_dir)
        splits = dataset.get_edge_split()
        id = splits['train']['edge'].T
        if cfg.dataset.resample_negative:
            set_dataset_attr(dataset, 'train_pos_edge_index', id, id.shape[1])
            dataset.transform = neg_sampling_transform
        else:
            id_neg = negative_sampling(edge_index=id,
                                       num_nodes=dataset.data.num_nodes,
                                       num_neg_samples=id.shape[1])
            id_all = torch.cat([id, id_neg], dim=-1)
            label = create_link_label(id, id_neg)
            set_dataset_attr(dataset, 'train_edge_index', id_all,
                             id_all.shape[1])
            set_dataset_attr(dataset, 'train_edge_label', label, len(label))

        id, id_neg = splits['valid']['edge'].T, splits['valid']['edge_neg'].T
        id_all = torch.cat([id, id_neg], dim=-1)
        label = create_link_label(id, id_neg)
        set_dataset_attr(dataset, 'val_edge_index', id_all, id_all.shape[1])
        set_dataset_attr(dataset, 'val_edge_label', label, len(label))

        id, id_neg = splits['test']['edge'].T, splits['test']['edge_neg'].T
        id_all = torch.cat([id, id_neg], dim=-1)
        label = create_link_label(id, id_neg)
        set_dataset_attr(dataset, 'test_edge_index', id_all, id_all.shape[1])
        set_dataset_attr(dataset, 'test_edge_label', label, len(label))

    else:
        raise ValueError('OGB dataset: {} non-exist')
    return dataset
Ejemplo n.º 20
0
    os.makedirs(args.res_dir) 
if not args.keep_old:
    # Backup python files.
    copy('seal_link_pred.py', args.res_dir)
    copy('utils.py', args.res_dir)
log_file = os.path.join(args.res_dir, 'log.txt')
# Save command line input.
cmd_input = 'python ' + ' '.join(sys.argv) + '\n'
with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'a') as f:
    f.write(cmd_input)
print('Command line input: ' + cmd_input + ' is saved.')
with open(log_file, 'a') as f:
    f.write('\n' + cmd_input)

if args.dataset.startswith('ogbl'):
    dataset = PygLinkPropPredDataset(name=args.dataset)
    split_edge = dataset.get_edge_split()
    data = dataset[0]

else:
    path = osp.join('dataset', args.dataset)
    dataset = Planetoid(path, args.dataset)
    split_edge = do_edge_split(dataset)
    print(split_edge)
    data = dataset[0]
    #data.edge_index = split_edge['train']['edge'].t()

if args.dataset.startswith('ogbl-citation'):
    args.eval_metric = 'mrr'
    directed = True
elif args.dataset.startswith('ogbl-bvg'):
Ejemplo n.º 21
0
Archivo: gnn.py Proyecto: tjufan/LRGA
def main():
    parser = argparse.ArgumentParser(description='OGBL-COLLAB (GNN)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--eval_steps', type=int, default=1)
    parser.add_argument('--runs', type=int, default=10)
    parser.add_argument('--k', type=int, default=100)
    parser.add_argument('--gpu_id', type=int, default=0)
    args = parser.parse_args()
    print(args)
    device = gpu_setup(args.gpu_id)

    dataset = PygLinkPropPredDataset(name='ogbl-collab')
    data = dataset[0]
    data.edge_weight = data.edge_weight.view(-1).to(torch.float)
    data = T.ToSparseTensor()(data)
    data = data.to(device)

    split_edge = dataset.get_edge_split()

    model = GCNWithAttention(data.num_features, args.hidden_channels,
                             args.hidden_channels, args.num_layers,
                             args.dropout, args.k).to(device)

    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)
    print("model parameters {}".format(
        sum(p.numel() for p in model.parameters())))
    print("predictor parameters {}".format(
        sum(p.numel() for p in predictor.parameters())))
    print("total parameters {}".format(
        sum(p.numel() for p in model.parameters()) +
        sum(p.numel() for p in predictor.parameters())))
    evaluator = Evaluator(name='ogbl-collab')
    loggers = {
        'Hits@10': Logger(args.runs, args),
        'Hits@50': Logger(args.runs, args),
        'Hits@100': Logger(args.runs, args),
    }

    for run in range(args.runs):
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(list(model.parameters()) +
                                     list(predictor.parameters()),
                                     lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, data, split_edge, optimizer,
                         args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(model, predictor, data, split_edge, evaluator,
                               args.batch_size)
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')
                    print('---')

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()
Ejemplo n.º 22
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-PPA (Full-Batch)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--use_node_embedding', action='store_true')
    parser.add_argument('--use_sage', action='store_true')
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--eval_steps', type=int, default=1)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-ppa')
    data = dataset[0]
    splitted_edge = dataset.get_edge_split()

    if args.use_node_embedding:
        x = data.x.to(torch.float)
        x = torch.cat([x, torch.load('embedding.pt')], dim=-1)
        x = x.to(device)
    else:
        x = data.x.to(torch.float).to(device)

    edge_index = data.edge_index.to(device)
    adj = SparseTensor(row=edge_index[0], col=edge_index[1])

    if args.use_sage:
        model = SAGE(x.size(-1), args.hidden_channels, args.hidden_channels,
                     args.num_layers, args.dropout).to(device)
    else:
        model = GCN(x.size(-1), args.hidden_channels, args.hidden_channels,
                    args.num_layers, args.dropout).to(device)

        # Pre-compute GCN normalization.
        adj = adj.set_diag()
        deg = adj.sum(dim=1).to(torch.float)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1)

    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-ppa')
    loggers = {
        'Hits@10': Logger(args.runs, args),
        'Hits@50': Logger(args.runs, args),
        'Hits@100': Logger(args.runs, args),
    }

    for run in range(args.runs):
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(
            list(model.parameters()) + list(predictor.parameters()),
            lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, x, adj, splitted_edge, optimizer,
                         args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(model, predictor, x, adj, splitted_edge,
                               evaluator, args.batch_size)
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()
Ejemplo n.º 23
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-Citation (Cluster-GCN)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_partitions', type=int, default=15000)
    parser.add_argument('--num_workers', type=int, default=12)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--eval_steps', type=int, default=10)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-citation')
    split_edge = dataset.get_edge_split()
    data = dataset[0]
    data.edge_index = to_undirected(data.edge_index, data.num_nodes)

    cluster_data = ClusterData(data, num_parts=args.num_partitions,
                               recursive=False, save_dir=dataset.processed_dir)

    loader = ClusterLoader(cluster_data, batch_size=args.batch_size,
                           shuffle=True, num_workers=args.num_workers)

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['source_node'].numel())[:86596]
    split_edge['eval_train'] = {
        'source_node': split_edge['train']['source_node'][idx],
        'target_node': split_edge['train']['target_node'][idx],
        'target_node_neg': split_edge['valid']['target_node_neg'],
    }

    model = GCN(data.x.size(-1), args.hidden_channels, args.hidden_channels,
                args.num_layers, args.dropout).to(device)
    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-citation')
    logger = Logger(args.runs, args)

    for run in range(args.runs):
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(
            list(model.parameters()) + list(predictor.parameters()),
            lr=args.lr)
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, loader, optimizer, device)
            print(f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}')

            if epoch > 49 and epoch % args.eval_steps == 0:
                result = test(model, predictor, data, split_edge, evaluator,
                              batch_size=64 * 1024, device=device)
                logger.add_result(run, result)

                train_mrr, valid_mrr, test_mrr = result
                print(f'Run: {run + 1:02d}, '
                      f'Epoch: {epoch:02d}, '
                      f'Loss: {loss:.4f}, '
                      f'Train: {train_mrr:.4f}, '
                      f'Valid: {valid_mrr:.4f}, '
                      f'Test: {test_mrr:.4f}')

        logger.print_statistics(run)
    logger.print_statistics()
Ejemplo n.º 24
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-Citation2 (MLP)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--use_node_embedding', action='store_true')
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--eval_steps', type=int, default=10)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-citation2')
    split_edge = dataset.get_edge_split()
    data = dataset[0]

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['source_node'].numel())[:86596]
    split_edge['eval_train'] = {
        'source_node': split_edge['train']['source_node'][idx],
        'target_node': split_edge['train']['target_node'][idx],
        'target_node_neg': split_edge['valid']['target_node_neg'],
    }

    x = data.x
    if args.use_node_embedding:
        embedding = torch.load('embedding.pt', map_location='cpu')
        x = torch.cat([x, embedding], dim=-1)
    x = x.to(device)

    predictor = LinkPredictor(x.size(-1), args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-citation2')
    logger = Logger(args.runs, args)

    for run in range(args.runs):
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(predictor.parameters(), lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(predictor, x, split_edge, optimizer, args.batch_size)
            print(f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}')

            if epoch % args.eval_steps == 0:
                result = test(predictor, x, split_edge, evaluator,
                              args.batch_size)
                logger.add_result(run, result)

                if epoch % args.log_steps == 0:
                    train_mrr, valid_mrr, test_mrr = result
                    print(f'Run: {run + 1:02d}, '
                          f'Epoch: {epoch:02d}, '
                          f'Loss: {loss:.4f}, '
                          f'Train: {train_mrr:.4f}, '
                          f'Valid: {valid_mrr:.4f}, '
                          f'Test: {test_mrr:.4f}')

        print('Node2vec' if args.use_node_embedding else 'MLP')
        logger.print_statistics(run)
    print('Node2vec' if args.use_node_embedding else 'MLP')
    logger.print_statistics()
Ejemplo n.º 25
0
Archivo: mlp.py Proyecto: rryoung98/ogb
def main():
    parser = argparse.ArgumentParser(description="OGBL-COLLAB (MLP)")
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--log_steps", type=int, default=1)
    parser.add_argument("--use_node_embedding", action="store_true")
    parser.add_argument("--num_layers", type=int, default=3)
    parser.add_argument("--hidden_channels", type=int, default=256)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--batch_size", type=int, default=64 * 1024)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--eval_steps", type=int, default=1)
    parser.add_argument("--runs", type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name="ogbl-collab")
    split_edge = dataset.get_edge_split()
    data = dataset[0]

    x = data.x
    if args.use_node_embedding:
        embedding = torch.load("embedding.pt", map_location="cpu")
        x = torch.cat([x, embedding], dim=-1)
    x = x.to(device)

    predictor = LinkPredictor(x.size(-1), args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name="ogbl-collab")
    loggers = {
        "Hits@10": Logger(args.runs, args),
        "Hits@50": Logger(args.runs, args),
        "Hits@100": Logger(args.runs, args),
    }

    for run in range(args.runs):
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(predictor.parameters(), lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(predictor, x, split_edge, optimizer, args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(predictor, x, split_edge, evaluator,
                               args.batch_size)
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f"Run: {run + 1:02d}, "
                              f"Epoch: {epoch:02d}, "
                              f"Loss: {loss:.4f}, "
                              f"Train: {100 * train_hits:.2f}%, "
                              f"Valid: {100 * valid_hits:.2f}%, "
                              f"Test: {100 * test_hits:.2f}%")
                    print("---")

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()
Ejemplo n.º 26
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-DDI (Full-Batch)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--use_sage', action='store_true')
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-ddi')
    split_edge = dataset.get_edge_split()
    data = dataset[0]

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['edge'].size(0))
    idx = idx[:split_edge['valid']['edge'].size(0)]
    split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}

    edge_index = data.edge_index
    adj = SparseTensor(row=edge_index[0], col=edge_index[1]).to(device)

    if args.use_sage:
        model = SAGE(args.hidden_channels, args.hidden_channels,
                     args.hidden_channels, args.num_layers,
                     args.dropout).to(device)
    else:
        model = GCN(args.hidden_channels, args.hidden_channels,
                    args.hidden_channels, args.num_layers,
                    args.dropout).to(device)

        # Pre-compute GCN normalization.
        adj = adj.set_diag()
        deg = adj.sum(dim=1)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1)

    emb = torch.nn.Embedding(data.num_nodes, args.hidden_channels).to(device)
    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-ddi')
    loggers = {
        'Hits@10': Logger(args.runs, args),
        'Hits@20': Logger(args.runs, args),
        'Hits@30': Logger(args.runs, args),
    }

    for run in range(args.runs):
        torch.nn.init.xavier_uniform_(emb.weight)
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(list(model.parameters()) +
                                     list(emb.parameters()) +
                                     list(predictor.parameters()),
                                     lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, emb.weight, adj, data.edge_index,
                         split_edge, optimizer, args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(model, predictor, emb.weight, adj, split_edge,
                               evaluator, args.batch_size)
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')
                    print('---')

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()
Ejemplo n.º 27
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-DDI (GNN)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--eval_steps', type=int, default=1)
    parser.add_argument('--runs', type=int, default=10)
    parser.add_argument('--k', type=int, default=50)
    parser.add_argument('--gpu_id', type=int, default=0)
    args = parser.parse_args()
    print(args)
    device = gpu_setup(args.gpu_id)

    dataset = PygLinkPropPredDataset(name='ogbl-ddi',
                                     transform=T.ToSparseTensor())
    data = dataset[0]
    adj_t = data.adj_t.to(device)

    split_edge = dataset.get_edge_split()

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['edge'].size(0))
    idx = idx[:split_edge['valid']['edge'].size(0)]
    split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}

    model = GCNWithAttention(args.hidden_channels, args.hidden_channels,
                args.hidden_channels, args.num_layers,
                args.dropout, args.k).to(device)

    emb = torch.nn.Embedding(data.num_nodes, args.hidden_channels).to(device)
    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)
    print("model parameters {}".format(sum(p.numel() for p in model.parameters())))
    print("predictor parameters {}".format(sum(p.numel() for p in predictor.parameters())))
    print("total parameters {}".format(data.num_nodes*args.hidden_channels + 
    sum(p.numel() for p in model.parameters())+sum(p.numel() for p in predictor.parameters())))
    evaluator = Evaluator(name='ogbl-ddi')
    loggers = {
        'Hits@10': Logger(args.runs, args),
        'Hits@20': Logger(args.runs, args),
        'Hits@30': Logger(args.runs, args),
    }

    for run in range(args.runs):
        torch.nn.init.xavier_uniform_(emb.weight)
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(
            list(model.parameters()) + list(emb.parameters()) +
            list(predictor.parameters()), lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, emb.weight, adj_t, split_edge,
                         optimizer, args.batch_size)

            if epoch % args.eval_steps == 0:
                results = test(model, predictor, emb.weight, adj_t, split_edge,
                               evaluator, args.batch_size)
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')
                    print('---')

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()
Ejemplo n.º 28
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-Citation (GNN)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--use_sage', action='store_true')
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0)
    parser.add_argument('--batch_size', type=int, default=64 * 1024)
    parser.add_argument('--lr', type=float, default=0.0005)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--eval_steps', type=int, default=1)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-citation')
    split_edge = dataset.get_edge_split()
    data = dataset[0]

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['source_node'].numel())[:86596]
    split_edge['eval_train'] = {
        'source_node': split_edge['train']['source_node'][idx],
        'target_node': split_edge['train']['target_node'][idx],
        'target_node_neg': split_edge['valid']['target_node_neg'],
    }

    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    edge_index = to_undirected(edge_index, data.num_nodes)
    adj = SparseTensor(row=edge_index[0], col=edge_index[1])

    if args.use_sage:
        model = SAGE(x.size(-1), args.hidden_channels, args.hidden_channels,
                     args.num_layers, args.dropout).to(device)
    else:
        model = GCN(x.size(-1), args.hidden_channels, args.hidden_channels,
                    args.num_layers, args.dropout).to(device)

        # Pre-compute GCN normalization.
        adj = adj.set_value(None)
        adj = adj.set_diag()
        deg = adj.sum(dim=1)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1)

    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-citation')
    logger = Logger(args.runs, args)

    for run in range(args.runs):
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(list(model.parameters()) +
                                     list(predictor.parameters()),
                                     lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, x, adj, split_edge, optimizer,
                         args.batch_size)
            print(f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}')

            if epoch % args.eval_steps == 0:
                result = test(model, predictor, x, adj, split_edge, evaluator,
                              args.batch_size)
                logger.add_result(run, result)

                if epoch % args.log_steps == 0:
                    train_mrr, valid_mrr, test_mrr = result
                    print(f'Run: {run + 1:02d}, '
                          f'Epoch: {epoch:02d}, '
                          f'Loss: {loss:.4f}, '
                          f'Train: {train_mrr:.4f}, '
                          f'Valid: {valid_mrr:.4f}, '
                          f'Test: {test_mrr:.4f}')

        logger.print_statistics(run)

    logger.print_statistics()
Ejemplo n.º 29
0
def main_fixed_mask(args, imp_num, resume_train_ckpt=None):

    device = torch.device("cuda:" + str(args.device))
    dataset = PygLinkPropPredDataset(name=args.dataset)
    data = dataset[0]
    # Data(edge_index=[2, 2358104], edge_weight=[2358104, 1], edge_year=[2358104, 1], x=[235868, 128])
    split_edge = dataset.get_edge_split()
    evaluator = Evaluator(args.dataset)

    x = data.x.to(device)

    edge_index = data.edge_index.to(device)

    args.in_channels = data.x.size(-1)
    args.num_tasks = 1

    model = DeeperGCN(args).to(device)
    pruning.add_mask(model, args)
    predictor = LinkPredictor(args).to(device)
    
    rewind_weight_mask, adj_spar, wei_spar = pruning.resume_change(resume_train_ckpt, model, args)
    model.load_state_dict(rewind_weight_mask)
    predictor.load_state_dict(resume_train_ckpt['predictor_state_dict'])

    # model.load_state_dict(rewind_weight_mask)
    # predictor.load_state_dict(rewind_predict_weight)
    adj_spar, wei_spar = pruning.print_sparsity(model, args)

    for name, param in model.named_parameters():
        if 'mask' in name:
            param.requires_grad = False

    optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr)
    #results = {}
    results = {'epoch': 0 }
    keys = ['highest_valid', 'final_train', 'final_test', 'highest_train']
    hits = ['Hits@10', 'Hits@50', 'Hits@100']
    
    for key in keys:
        results[key] = {k: 0 for k in hits}
    results['adj_spar'] = adj_spar
    results['wei_spar'] = wei_spar
    
    start_epoch = 1
    
    for epoch in range(start_epoch, args.fix_epochs + 1):

        t0 = time.time()
        epoch_loss = train.train_fixed(model, predictor, x, edge_index, split_edge, optimizer, args.batch_size, args)
        result = train.test(model, predictor, x, edge_index, split_edge, evaluator, args.batch_size, args)
        # return a tuple
        k = 'Hits@50'
        train_result, valid_result, test_result = result[k]

        if train_result > results['highest_train'][k]:
            results['highest_train'][k] = train_result

        if valid_result > results['highest_valid'][k]:
            results['highest_valid'][k] = valid_result
            results['final_train'][k] = train_result
            results['final_test'][k] = test_result
            results['epoch'] = epoch
            pruning.save_all(model, predictor, 
                                    rewind_weight_mask, 
                                    optimizer, 
                                    imp_num, 
                                    epoch, 
                                    args.model_save_path, 
                                    'IMP{}_fixed_ckpt'.format(imp_num))

        epoch_time = (time.time() - t0) / 60
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' +
              'IMP:[{}] (FIX Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}] Time:[{:.2f}min]'
              .format(imp_num, epoch, args.fix_epochs, epoch_loss, train_result * 100,
                                                               valid_result * 100,
                                                               test_result * 100, 
                                                               results['final_test'][k] * 100,
                                                               results['epoch'],
                                                               epoch_time))
    print("=" * 120)
    print("syd final: IMP:[{}], Train:[{:.2f}]  Best Val:[{:.2f}] at epoch:[{}] | Final Test Acc:[{:.2f}] Adj:[{:.2f}%] Wei:[{:.2f}%]"
        .format(imp_num,    results['final_train'][k] * 100,
                            results['highest_valid'][k] * 100,
                            results['epoch'],
                            results['final_test'][k] * 100,
                            results['adj_spar'],
                            results['wei_spar']))
    print("=" * 120)
Ejemplo n.º 30
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-DDI')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--model',
                        type=str,
                        default='MAD_GCN',
                        choices=[
                            'GCN_Linear', 'SAGE_Linear', 'MAD_GCN', 'MAD_SAGE',
                            'MAD_Model'
                        ])
    parser.add_argument('--train_batch_size', type=int, default=4096)
    parser.add_argument('--test_batch_size', type=int, default=1024)
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=5)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-ddi',
                                     transform=T.ToSparseTensor())
    data = dataset[0]
    adj_t = data.adj_t.to(device)

    split_edge = dataset.get_edge_split()

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['edge'].size(0))
    idx = idx[:split_edge['valid']['edge'].size(0)]
    split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]}

    model = models.get_model(args.model)(data.num_nodes, adj_t).to(device)
    print(f"Parameters: {count_parameters(model)}")

    evaluator = Evaluator(name='ogbl-ddi')
    loggers = {
        'Hits@10': Logger(args.runs, args),
        'Hits@20': Logger(args.runs, args),
        'Hits@30': Logger(args.runs, args),
    }

    for run in range(args.runs):
        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

        for epoch in range(1, 1 + args.epochs):
            loss = train(model, adj_t, split_edge, optimizer,
                         args.train_batch_size)

            if epoch % args.eval_steps == 0:
                results = test(model, adj_t, split_edge, evaluator,
                               args.test_batch_size)
                for key, result in results.items():
                    loggers[key].add_result(run, result)

                if epoch % args.log_steps == 0:
                    for key, result in results.items():
                        train_hits, valid_hits, test_hits = result
                        print(key)
                        print(f'Run: {run + 1:02d}, '
                              f'Epoch: {epoch:02d}, '
                              f'Loss: {loss:.4f}, '
                              f'Train: {100 * train_hits:.2f}%, '
                              f'Valid: {100 * valid_hits:.2f}%, '
                              f'Test: {100 * test_hits:.2f}%')
                    print('---')
            print(f'Finished epoch {epoch}')

        for key in loggers.keys():
            print(key)
            loggers[key].print_statistics(run)

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics()