Пример #1
0
    def prepare_gnn_training(self):
        if verbose:
            print("\n\n==>> Clustering the graph and preparing dataloader....")
            
        self.data = Data(x=self.x_data.float(), edge_index = self.edge_index_data.long(), edge_attr = self.edge_type_data, y=self.y_data)
        new_num_nodes, _ = self.data.x.shape
        
        self.data.train_mask = torch.FloatTensor(self.split_masks['train_mask'])
        self.data.val_mask = torch.FloatTensor(self.split_masks['val_mask'])
        self.data.representation_mask = torch.FloatTensor(self.split_masks['repr_mask']) 
        self.data.node2id = torch.tensor(list(self.node2id.values()))
        # self.data.node_type = self.node_type
            
        
        if not self.config['full_graph']:
            if self.config['cluster'] :
                cluster_data = ClusterData(self.data, num_parts=self.config['clusters'], recursive=False)
                self.loader = ClusterLoader(cluster_data, batch_size=self.config['batch_size'], shuffle=self.config['shuffle'], num_workers=0)
            elif self.config['saint'] == 'random_walk':
                self.loader = GraphSAINTRandomWalkSampler(self.data, batch_size=6000, walk_length=2, num_steps=5, sample_coverage=100, num_workers=0)
            elif self.config['saint'] == 'node':
                self.loader = GraphSAINTNodeSampler(self.data, batch_size=6000, num_steps=5, sample_coverage=100, num_workers=0)
            elif self.config['saint'] == 'edge':
                self.loader = GraphSAINTEdgeSampler(self.data, batch_size=6000, num_steps=5, sample_coverage=100, num_workers=0)
        else:
            self.loader=None
        

        return self.loader, self.vocab_size, self.data
    def __init__(self, graph, model, args, criterion=None):

        self.args = args
        self.graph = graph
        self.model = model.to(self.args.device)

        #build data loader on cpu
        self.loader =  GraphSAINTRandomWalkSampler(graph,
                                                   batch_size=self.args.GraphSAINT['batch_size'],
                                                   walk_length=self.args.GraphSAINT['walk_length'],
                                                   num_steps=self.args.GraphSAINT['num_steps'])
        print ('Data Loader created for learner . . .')

        #Set loss function
        if not criterion: 
            #BCEWithLogitsLoss is a class to use for individual or multiple binary prediction tasks
            criterion = nn.BCEWithLogitsLoss()
        self.criterion = criterion

        #Set optimizer
        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.w_decay)

        self.train_loss = []
        self.val_loss = []
        self.train_complete = False 
Пример #3
0
def build_sampler(args, data, save_dir):
    if args.sampler == 'rw-my':
        msg = 'Use GraphSaint randomwalk sampler(mysaint sampler)'
        loader = MySAINTSampler(data, batch_size=args.batch_size, sample_type='random_walk',
                                walk_length=2, sample_coverage=1000, save_dir=save_dir)
    elif args.sampler == 'node-my':
        msg = 'Use random node sampler(mysaint sampler)'
        loader = MySAINTSampler(data, sample_type='node', batch_size=args.batch_size * 3,
                                walk_length=2, sample_coverage=1000, save_dir=save_dir)
    elif args.sampler == 'rw':
        msg = 'Use GraphSaint randomwalk sampler'
        loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch_size, walk_length=2,
                                             num_steps=5, sample_coverage=1000,
                                             save_dir=save_dir)
    elif args.sampler == 'node':
        msg = 'Use GraphSaint node sampler'
        loader = GraphSAINTNodeSampler(data, batch_size=args.batch_size * 3,
                                       num_steps=5, sample_coverage=1000, num_workers=0, save_dir=save_dir)

    elif args.sampler == 'edge':
        msg = 'Use GraphSaint edge sampler'
        loader = GraphSAINTEdgeSampler(data, batch_size=args.batch_size,
                                       num_steps=5, sample_coverage=1000,
                                       save_dir=save_dir, num_workers=0)
    elif args.sampler == 'cluster':
        msg = 'Use cluster sampler'
        cluster_data = ClusterData(data, num_parts=args.num_parts, save_dir=save_dir)
        loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True,
                               num_workers=0)
    else:
        raise KeyError('Sampler type error')

    return loader, msg
def test_graph_saint():
    adj = torch.tensor([
        [1, 1, 1, 0, 1, 0],
        [1, 1, 0, 1, 0, 1],
        [1, 0, 1, 0, 1, 0],
        [0, 1, 0, 1, 0, 1],
        [1, 0, 1, 0, 1, 0],
        [0, 1, 0, 1, 0, 1],
    ])

    edge_index = adj.nonzero().t()
    x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
    data = Data(edge_index=edge_index, x=x, num_nodes=6)

    torch.manual_seed(12345)
    loader = GraphSAINTNodeSampler(data,
                                   batch_size=2,
                                   num_steps=4,
                                   sample_coverage=10,
                                   log=False)

    for sample in loader:
        assert len(sample) == 4
        assert sample.num_nodes <= 2
        assert sample.num_edges <= 3 * 2
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    torch.manual_seed(12345)
    loader = GraphSAINTEdgeSampler(data,
                                   batch_size=2,
                                   num_steps=4,
                                   sample_coverage=10,
                                   log=False)

    for sample in loader:
        assert len(sample) == 4
        assert sample.num_nodes <= 4
        assert sample.num_edges <= 3 * 4
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    torch.manual_seed(12345)
    loader = GraphSAINTRandomWalkSampler(data,
                                         batch_size=2,
                                         walk_length=1,
                                         num_steps=4,
                                         sample_coverage=10,
                                         log=False)

    for sample in loader:
        assert len(sample) == 4
        assert sample.num_nodes <= 4
        assert sample.num_edges <= 3 * 4
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges
Пример #5
0
 def sample_subgraph(whole_data):
     data = whole_data.data
     loader = GraphSAINTRandomWalkSampler(
         data,
         batch_size=self.sample_batch_size,
         walk_length=self.sample_walk_length,
         num_steps=self.subgraphs,
         save_dir=whole_data.processed_dir,
     )
     results = []
     for data in loader:
         results.append(data)
     return results
Пример #6
0
def build_sampler(args, data, save_dir):
    if args.sampler == 'rw-my':
        msg = 'Use GraphSaint randomwalk sampler(mysaint sampler)'
        loader = MySAINTSampler(data,
                                batch_size=args.batch_size,
                                sample_type='random_walk',
                                walk_length=2,
                                sample_coverage=1000,
                                save_dir=save_dir)
    elif args.sampler == 'node-my':
        msg = 'Use random node sampler(mysaint sampler)'
        loader = MySAINTSampler(data,
                                sample_type='node',
                                batch_size=args.batch_size * 3,
                                walk_length=2,
                                sample_coverage=1000,
                                save_dir=save_dir)
    elif args.sampler == 'rw':
        msg = 'Use GraphSaint randomwalk sampler'
        loader = GraphSAINTRandomWalkSampler(data,
                                             batch_size=args.batch_size,
                                             walk_length=2,
                                             num_steps=5,
                                             sample_coverage=1000,
                                             save_dir=save_dir)
    elif args.sampler == 'node':
        msg = 'Use GraphSaint node sampler'
        loader = GraphSAINTNodeSampler(data,
                                       batch_size=args.batch_size * 3,
                                       num_steps=5,
                                       sample_coverage=1000,
                                       num_workers=0,
                                       save_dir=save_dir)

    elif args.sampler == 'edge':
        msg = 'Use GraphSaint edge sampler'
        loader = GraphSAINTEdgeSampler(data,
                                       batch_size=args.batch_size,
                                       num_steps=5,
                                       sample_coverage=1000,
                                       save_dir=save_dir,
                                       num_workers=0)
    # elif args.sampler == 'cluster':
    #     logger.info('Use cluster sampler')
    #     cluster_data = ClusterData(data, num_parts=args.num_parts, save_dir=dataset.processed_dir)
    #     raise NotImplementedError('Cluster loader not implement yet')
    else:
        raise KeyError('Sampler type error')

    return loader, msg
Пример #7
0
                 edge_attr=edge_type,
                 node_type=node_type,
                 local_node_idx=local_node_idx,
                 num_nodes=node_type.size(0))

homo_data.y = node_type.new_full((node_type.size(0), 1), -1)
homo_data.y[local2global['paper']] = data.y_dict['paper']

homo_data.train_mask = torch.zeros((node_type.size(0)), dtype=torch.bool)
homo_data.train_mask[local2global['paper'][split_idx['train']['paper']]] = True

print(homo_data)

train_loader = GraphSAINTRandomWalkSampler(homo_data,
                                           batch_size=args.batch_size,
                                           walk_length=args.num_layers,
                                           num_steps=args.num_steps,
                                           sample_coverage=0,
                                           save_dir=dataset.processed_dir)

# Map informations to their canonical type.
x_dict = {}
for key, x in data.x_dict.items():
    x_dict[key2int[key]] = x

num_nodes_dict = {}
for key, N in data.num_nodes_dict.items():
    num_nodes_dict[key2int[key]] = N


class RGCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, num_node_types,
Пример #8
0
def main():
    parser = argparse.ArgumentParser(description='OGBN-Products (GraphSAINT)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--inductive', action='store_true')
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=20000)
    parser.add_argument('--walk_length', type=int, default=3)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--num_steps', type=int, default=30)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--eval_steps', type=int, default=2)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygNodePropPredDataset(
        name='ogbn-products', root='/srv/scratch/ogb/datasets/nodeproppred')
    split_idx = dataset.get_idx_split()
    data = dataset[0]

    # Convert split indices to boolean masks and add them to `data`.
    for key, idx in split_idx.items():
        mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        mask[idx] = True
        data[f'{key}_mask'] = mask

    # We omit normalization factors here since those are only defined for the
    # inductive learning setup.
    sampler_data = data
    if args.inductive:
        sampler_data = to_inductive(data)

    loader = GraphSAINTRandomWalkSampler(sampler_data,
                                         batch_size=args.batch_size,
                                         walk_length=args.walk_length,
                                         num_steps=args.num_steps,
                                         sample_coverage=0,
                                         save_dir=dataset.processed_dir)

    model = SAGE(data.x.size(-1), args.hidden_channels, dataset.num_classes,
                 args.num_layers, args.dropout).to(device)

    subgraph_loader = NeighborSampler(data.edge_index,
                                      sizes=[-1],
                                      batch_size=4096,
                                      shuffle=False,
                                      num_workers=12)

    evaluator = Evaluator(name='ogbn-products')
    logger = Logger(args.runs, args)

    for run in range(args.runs):
        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, loader, optimizer, device)
            if epoch % args.log_steps == 0:
                print(f'Run: {run + 1:02d}, '
                      f'Epoch: {epoch:02d}, '
                      f'Loss: {loss:.4f}')

            if epoch > 9 and epoch % args.eval_steps == 0:
                result = test(model, data, evaluator, subgraph_loader, device)
                logger.add_result(run, result)
                train_acc, valid_acc, test_acc = result
                print(f'Run: {run + 1:02d}, '
                      f'Epoch: {epoch:02d}, '
                      f'Train: {100 * train_acc:.2f}%, '
                      f'Valid: {100 * valid_acc:.2f}% '
                      f'Test: {100 * test_acc:.2f}%')

        logger.add_result(run, result)
        logger.print_statistics(run)
    logger.print_statistics()
Пример #9
0
def main():
    parser = argparse.ArgumentParser(description='OGBL-Citation2 (GraphSAINT)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=16 * 1024)
    parser.add_argument('--walk_length', type=int, default=3)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--num_steps', type=int, default=100)
    parser.add_argument('--eval_steps', type=int, default=10)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-citation2')
    split_edge = dataset.get_edge_split()
    data = dataset[0]
    data.edge_index = to_undirected(data.edge_index, data.num_nodes)

    loader = GraphSAINTRandomWalkSampler(data,
                                         batch_size=args.batch_size,
                                         walk_length=args.walk_length,
                                         num_steps=args.num_steps,
                                         sample_coverage=0,
                                         save_dir=dataset.processed_dir)

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['source_node'].numel())[:86596]
    split_edge['eval_train'] = {
        'source_node': split_edge['train']['source_node'][idx],
        'target_node': split_edge['train']['target_node'][idx],
        'target_node_neg': split_edge['valid']['target_node_neg'],
    }

    model = GCN(data.x.size(-1), args.hidden_channels, args.hidden_channels,
                args.num_layers, args.dropout).to(device)
    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-citation2')
    logger = Logger(args.runs, args)

    run_idx = 0

    while run_idx < args.runs:
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(list(model.parameters()) +
                                     list(predictor.parameters()),
                                     lr=args.lr)

        run_success = True
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, loader, optimizer, device)
            print(
                f'Run: {run_idx + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}'
            )
            if loss > 2.:
                run_success = False
                logger.reset(run_idx)
                print('Learning failed. Rerun...')
                break

            if epoch > 49 and epoch % args.eval_steps == 0:
                result = test(model,
                              predictor,
                              data,
                              split_edge,
                              evaluator,
                              batch_size=64 * 1024,
                              device=device)
                logger.add_result(run_idx, result)

                train_mrr, valid_mrr, test_mrr = result
                print(f'Run: {run_idx + 1:02d}, '
                      f'Epoch: {epoch:02d}, '
                      f'Loss: {loss:.4f}, '
                      f'Train: {train_mrr:.4f}, '
                      f'Valid: {valid_mrr:.4f}, '
                      f'Test: {test_mrr:.4f}')

        print('GraphSAINT')
        if run_success:
            logger.print_statistics(run_idx)
            run_idx += 1

    print('GraphSAINT')
    logger.print_statistics()
Пример #10
0
def train(epoch, model, optimizer):

    global all_data, best_val_acc, best_embeddings, best_model, curr_hyperparameters, best_hyperparameters

    # Save predictions
    total_loss = 0
    roc_val = []
    ap_val = []
    f1_val = []
    acc_val = []

    # Minibatches
    if config.MINIBATCH == "NeighborSampler":
        loader = NeighborSampler(all_data.edge_index,
                                 sizes=[curr_hyperparameters['nb_size']],
                                 batch_size=curr_hyperparameters['batch_size'],
                                 shuffle=True)
    elif config.MINIBATCH == "GraphSaint":
        all_data.num_classes = torch.tensor([2])
        loader = GraphSAINTRandomWalkSampler(
            all_data,
            batch_size=curr_hyperparameters['batch_size'],
            walk_length=curr_hyperparameters['walk_length'],
            num_steps=curr_hyperparameters['num_steps'])

    # Iterate through minibatches
    for data in loader:

        if config.MINIBATCH == "NeighborSampler":
            data = preprocess.set_data(data, all_data, config.MINIBATCH)
        curr_train_pos = data.edge_index[:, data.train_mask]
        curr_train_neg = negative_sampling(
            curr_train_pos, num_neg_samples=curr_train_pos.size(1) // 4)
        curr_train_total = torch.cat([curr_train_pos, curr_train_neg], dim=-1)
        data.y = torch.zeros(curr_train_total.size(1)).float()
        data.y[:curr_train_pos.size(1)] = 1.

        # Perform training
        data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        curr_dot_embed = utils.el_dot(out, curr_train_total)
        loss = utils.calc_loss_both(data, curr_dot_embed)
        if torch.isnan(loss) == False:
            total_loss += loss
            loss.backward()
        optimizer.step()
        curr_train_pos_mask = torch.zeros(curr_train_total.size(1)).bool()
        curr_train_pos_mask[:curr_train_pos.size(1)] = 1
        curr_train_neg_mask = (curr_train_pos_mask == 0)
        roc_score, ap_score, train_acc, train_f1 = utils.calc_roc_score(
            pred_all=curr_dot_embed.T[1],
            pos_edges=curr_train_pos_mask,
            neg_edges=curr_train_neg_mask)
        print(">>>>>>Train: (ROC) ", roc_score, " (AP) ", ap_score, " (ACC) ",
              train_acc, " (F1) ", train_f1)

        curr_val_pos = data.edge_index[:, data.val_mask]
        curr_val_neg = negative_sampling(
            curr_val_pos, num_neg_samples=curr_val_pos.size(1) // 4)
        curr_val_total = torch.cat([curr_val_pos, curr_val_neg], dim=-1)
        curr_val_pos_mask = torch.zeros(curr_val_total.size(1)).bool()
        curr_val_pos_mask[:curr_val_pos.size(1)] = 1
        curr_val_neg_mask = (curr_val_pos_mask == 0)
        val_dot_embed = utils.el_dot(out, curr_val_total)
        data.y = torch.zeros(curr_val_total.size(1)).float()
        data.y[:curr_val_pos.size(1)] = 1.
        roc_score, ap_score, val_acc, val_f1 = utils.calc_roc_score(
            pred_all=val_dot_embed.T[1],
            pos_edges=curr_val_pos_mask,
            neg_edges=curr_val_neg_mask)
        roc_val.append(roc_score)
        ap_val.append(ap_score)
        acc_val.append(val_acc)
        f1_val.append(val_f1)
    res = "\t".join([
        "Epoch: %04d" % (epoch + 1), "train_loss = {:.5f}".format(total_loss),
        "val_roc = {:.5f}".format(np.mean(roc_val)),
        "val_ap = {:.5f}".format(np.mean(ap_val)),
        "val_f1 = {:.5f}".format(np.mean(f1_val)),
        "val_acc = {:.5f}".format(np.mean(acc_val))
    ])
    print(res)
    log_f.write(res + "\n")

    # Save best model and parameters
    if best_val_acc <= np.mean(acc_val) + eps:
        best_val_acc = np.mean(acc_val)
        with open(str(config.DATASET_DIR / "best_model.pth"), 'wb') as f:
            torch.save(model.state_dict(), f)
        best_hyperparameters = curr_hyperparameters
        best_model = model

    return total_loss
Пример #11
0
def main():
    parser = argparse.ArgumentParser(description='OGBN-Products (GraphSAINT)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=5)
    parser.add_argument('--use_sage', action='store_true')
    parser.add_argument('--num_workers', type=int, default=0)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=8000)
    parser.add_argument('--walk_length', type=int, default=3)
    parser.add_argument('--sample_coverage', type=int, default=1000)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--num_steps', type=int, default=20)
    parser.add_argument('--epochs', type=int, default=150)
    parser.add_argument('--eval_steps', type=int, default=25)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygNodePropPredDataset(name='ogbn-products')
    split_idx = dataset.get_idx_split()
    data = dataset[0]

    # Convert split indices to boolean masks and add them to `data`.
    for key, idx in split_idx.items():
        mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        mask[idx] = True
        data[f'{key}_mask'] = mask

    # Create "inductive" subgraph containing only train and validation nodes.
    ind_data = to_inductive(copy.copy(data))
    row, col = ind_data.edge_index
    ind_data.edge_attr = 1. / degree(col, ind_data.num_nodes)[col]

    loader = GraphSAINTRandomWalkSampler(ind_data,
                                         batch_size=args.batch_size,
                                         walk_length=args.walk_length,
                                         num_steps=args.num_steps,
                                         sample_coverage=args.sample_coverage,
                                         save_dir=dataset.processed_dir,
                                         num_workers=args.num_workers)

    model = SAGE(ind_data.x.size(-1), args.hidden_channels,
                 dataset.num_classes, args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbn-products')
    logger = Logger(args.runs, args)

    for run in range(args.runs):
        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, loader, optimizer, device)
            if epoch % args.log_steps == 0:
                print(f'Run: {run + 1:02d}, '
                      f'Epoch: {epoch:02d}, '
                      f'Loss: {loss:.4f}')

            if epoch % args.eval_steps == 0:
                result = test(model, data, evaluator)
                logger.add_result(run, result)
                train_acc, valid_acc, test_acc = result
                print(f'Run: {run + 1:02d}, '
                      f'Epoch: {epoch:02d}, '
                      f'Train: {100 * train_acc:.2f}%, '
                      f'Valid: {100 * valid_acc:.2f}% '
                      f'Test: {100 * test_acc:.2f}%')

        logger.add_result(run, result)
        logger.print_statistics(run)
    logger.print_statistics()
Пример #12
0
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Flickr
from torch_geometric.data import GraphSAINTRandomWalkSampler
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import degree

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Flickr')
dataset = Flickr(path)
data = dataset[0]
row, col = data.edge_index
data.edge_attr = 1. / degree(col, data.num_nodes)[col]  # Norm by in-degree.

loader = GraphSAINTRandomWalkSampler(data, batch_size=6000, walk_length=2,
                                     num_steps=5, sample_coverage=1000,
                                     save_dir=dataset.processed_dir,
                                     num_workers=4)


class Net(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(Net, self).__init__()
        in_channels = dataset.num_node_features
        out_channels = dataset.num_classes
        self.conv1 = SAGEConv(in_channels, hidden_channels, concat=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, concat=True)
        self.conv3 = SAGEConv(hidden_channels, hidden_channels, concat=True)
        self.lin = torch.nn.Linear(3 * hidden_channels, out_channels)

    def set_aggr(self, aggr):
        self.conv1.aggr = aggr
Пример #13
0
def main():
    parser = argparse.ArgumentParser(description='OGBN-Products (GraphSAINT)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--inductive', action='store_true')
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=20000)
    parser.add_argument('--walk_length', type=int, default=3)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--num_steps', type=int, default=30)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--runs', type=int, default=10)

    parser.add_argument('--step-size', type=float, default=8e-3)
    parser.add_argument('-m', type=int, default=3)
    parser.add_argument('--test-freq', type=int, default=2)
    parser.add_argument('--attack', type=str, default='flag')
    parser.add_argument('--amp', type=float, default=2)

    args = parser.parse_args()
    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygNodePropPredDataset(name='ogbn-products')
    split_idx = dataset.get_idx_split()
    data = dataset[0]

    # Convert split indices to boolean masks and add them to `data`.
    for key, idx in split_idx.items():
        mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        mask[idx] = True
        data[f'{key}_mask'] = mask

    # We omit normalization factors here since those are only defined for the
    # inductive learning setup.
    sampler_data = data
    if args.inductive:
        sampler_data = to_inductive(data)

    loader = GraphSAINTRandomWalkSampler(sampler_data,
                                         batch_size=args.batch_size,
                                         walk_length=args.walk_length,
                                         num_steps=args.num_steps,
                                         sample_coverage=0,
                                         save_dir=dataset.processed_dir)

    model = SAGE(data.x.size(-1), args.hidden_channels, dataset.num_classes,
                 args.num_layers, args.dropout).to(device)

    subgraph_loader = NeighborSampler(data.edge_index,
                                      sizes=[-1],
                                      batch_size=4096,
                                      shuffle=False,
                                      num_workers=12)

    evaluator = Evaluator(name='ogbn-products')

    vals, tests = [], []
    for run in range(args.runs):
        best_val, final_test = 0, 0

        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

        for epoch in range(1, args.epochs + 1):
            loss = train_flag(model, loader, optimizer, device, args)
            if epoch > args.epochs / 2 and epoch % args.test_freq == 0 or epoch == args.epochs:
                result = test(model, data, evaluator, subgraph_loader, device)
                train, val, tst = result
                if val > best_val:
                    best_val = val
                    final_test = tst

        print(f'Run{run} val:{best_val}, test:{final_test}')
        vals.append(best_val)
        tests.append(final_test)

    print('')
    print(f"Average val accuracy: {np.mean(vals)} ± {np.std(vals)}")
    print(f"Average test accuracy: {np.mean(tests)} ± {np.std(tests)}")
Пример #14
0
def test_graph_saint():
    adj = torch.tensor([
        [+1, +2, +3, +0, +4, +0],
        [+5, +6, +0, +7, +0, +8],
        [+9, +0, 10, +0, 11, +0],
        [+0, 12, +0, 13, +0, 14],
        [15, +0, 16, +0, 17, +0],
        [+0, 18, +0, 19, +0, 20],
    ])

    edge_index = adj.nonzero(as_tuple=False).t()
    edge_type = adj[edge_index[0], edge_index[1]]
    x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
    data = Data(edge_index=edge_index, x=x, edge_type=edge_type, num_nodes=6)

    torch.manual_seed(12345)
    loader = GraphSAINTNodeSampler(data, batch_size=3, num_steps=4,
                                   sample_coverage=10, log=False)

    sample = next(iter(loader))
    assert sample.x.tolist() == [[2, 2], [4, 4], [5, 5]]
    assert sample.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 1, 2]]
    assert sample.edge_type.tolist() == [10, 11, 16, 17, 20]

    assert len(loader) == 4
    for sample in loader:
        assert len(sample) == 5
        assert sample.num_nodes <= 3
        assert sample.num_edges <= 3 * 4
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    torch.manual_seed(12345)
    loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4,
                                   sample_coverage=10, log=False)

    sample = next(iter(loader))
    assert sample.x.tolist() == [[0, 0], [2, 2], [3, 3]]
    assert sample.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 1, 2]]
    assert sample.edge_type.tolist() == [1, 3, 9, 10, 13]

    assert len(loader) == 4
    for sample in loader:
        assert len(sample) == 5
        assert sample.num_nodes <= 4
        assert sample.num_edges <= 4 * 4
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    torch.manual_seed(12345)
    loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1,
                                         num_steps=4, sample_coverage=10,
                                         log=False)

    sample = next(iter(loader))
    assert sample.x.tolist() == [[1, 1], [2, 2], [4, 4]]
    assert sample.edge_index.tolist() == [[0, 1, 1, 2, 2], [0, 1, 2, 1, 2]]
    assert sample.edge_type.tolist() == [6, 10, 11, 16, 17]

    assert len(loader) == 4
    for sample in loader:
        assert len(sample) == 5
        assert sample.num_nodes <= 4
        assert sample.num_edges <= 4 * 4
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges
Пример #15
0
splitted_idx = dataset.get_idx_split()
data = dataset[0]
data.n_id = torch.arange(data.num_nodes)
data.node_species = None
data.y = data.y.float()
# Initialize features of nodes by aggregating edge features.
row, col = data.edge_index
#Set split indices to masks.
for split in ['train', 'valid', 'test']:
    mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    mask[splitted_idx[split]] = True
    data[f'{split}_mask'] = mask

train_loader = GraphSAINTRandomWalkSampler(data,
                                           batch_size=args.batch_size,
                                           walk_length=args.walk_length,
                                           num_steps=args.num_steps,
                                           sample_coverage=0,
                                           save_dir=dataset.processed_dir)
test_loader = GraphSAINTRandomWalkSampler(data,
                                          batch_size=args.batch_size,
                                          walk_length=args.walk_length,
                                          num_steps=args.num_steps,
                                          sample_coverage=0,
                                          save_dir=dataset.processed_dir)

p_train_loader = GraphSAINTRandomWalkSampler(data,
                                             batch_size=args.batch_size * 2,
                                             walk_length=args.walk_length,
                                             num_steps=args.num_steps,
                                             sample_coverage=0,
                                             save_dir=dataset.processed_dir)
Пример #16
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default='0')
    parser.add_argument('--model', type=str, default='GraphSAINT')
    parser.add_argument('--dataset', type=str,
                        default='Reddit')  # Reddit or Flickr
    parser.add_argument('--batch', type=int,
                        default=2000)  # Reddit:2000, Flickr:6000
    parser.add_argument('--walk_length', type=int,
                        default=4)  # Reddit:4, Flickr:2
    parser.add_argument('--sample_coverage', type=int,
                        default=50)  # Reddit:50, Flickr:100
    parser.add_argument('--runs', type=int, default=10)
    parser.add_argument('--epochs', type=int, default=100)  # 100, 50
    parser.add_argument('--lr', type=float, default=0.01)  # 0.01, 0.001
    parser.add_argument('--weight_decay', type=float, default=0.0005)
    parser.add_argument('--hidden', type=int, default=256)  # 128, 256
    parser.add_argument('--dropout', type=float, default=0.1)  # 0.1, 0.2
    parser.add_argument('--use_normalization', action='store_true')
    parser.add_argument('--binarize', action='store_true')
    args = parser.parse_args()

    assert args.model in ['GraphSAINT']
    assert args.dataset in ['Flickr', 'Reddit']
    path = '/home/wangjunfu/dataset/graph/' + str(args.dataset)
    if args.dataset == 'Flickr':
        dataset = Flickr(path)
    else:
        dataset = Reddit(path)
    data = dataset[0]
    row, col = data.edge_index
    data.edge_weight = 1. / degree(col,
                                   data.num_nodes)[col]  # Norm by in-degree.
    loader = GraphSAINTRandomWalkSampler(data,
                                         batch_size=args.batch,
                                         walk_length=args.walk_length,
                                         num_steps=5,
                                         sample_coverage=args.sample_coverage,
                                         save_dir=dataset.processed_dir,
                                         num_workers=0)

    device = torch.device(
        f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
    model = SAINT(data.num_node_features, args.hidden, dataset.num_classes,
                  args.dropout, args.binarize).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    val_f1s, test_f1s = [], []
    for run in range(1, args.runs + 1):
        best_val, best_test = 0, 0
        model.reset_parameters()
        start_time = time.time()
        for epoch in range(1, args.epochs + 1):
            loss = train(model, loader, optimizer, device,
                         args.use_normalization)
            accs = test(model, data, device, args.use_normalization)
            if accs[1] > best_val:
                best_val = accs[1]
                best_test = accs[2]
            if args.runs == 1:
                print(
                    f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {accs[0]:.4f}, '
                    f'Val: {accs[1]:.4f}, Test: {accs[2]:.4f}')
        test_f1s.append(best_test)
        print(
            "Run: {:d}, best val: {:.4f}, best test: {:.4f}, time cost: {:d}s".
            format(run, best_val, best_test, int(time.time() - start_time)))

    test_f1s = torch.tensor(test_f1s)
    print("{:.4f} ± {:.4f}".format(test_f1s.mean(), test_f1s.std()))
Пример #17
0
def test_graph_saint():
    adj = torch.tensor([
        [+1, +2, +3, +0, +4, +0],
        [+5, +6, +0, +7, +0, +8],
        [+9, +0, 10, +0, 11, +0],
        [+0, 12, +0, 13, +0, 14],
        [15, +0, 16, +0, 17, +0],
        [+0, 18, +0, 19, +0, 20],
    ])

    edge_index = adj.nonzero(as_tuple=False).t()
    edge_id = adj[edge_index[0], edge_index[1]]
    x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
    n_id = torch.arange(6)
    data = Data(edge_index=edge_index,
                x=x,
                n_id=n_id,
                edge_id=edge_id,
                num_nodes=6)

    loader = GraphSAINTNodeSampler(data,
                                   batch_size=3,
                                   num_steps=4,
                                   sample_coverage=10,
                                   log=False)

    assert len(loader) == 4
    for sample in loader:
        assert sample.num_nodes <= data.num_nodes
        assert sample.n_id.min() >= 0 and sample.n_id.max() < 6
        assert sample.num_nodes == sample.n_id.numel()
        assert sample.x.tolist() == x[sample.n_id].tolist()
        assert sample.edge_index.min() >= 0
        assert sample.edge_index.max() < sample.num_nodes
        assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21
        assert sample.edge_id.numel() == sample.num_edges
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    loader = GraphSAINTEdgeSampler(data,
                                   batch_size=2,
                                   num_steps=4,
                                   sample_coverage=10,
                                   log=False)

    assert len(loader) == 4
    for sample in loader:
        assert sample.num_nodes <= data.num_nodes
        assert sample.n_id.min() >= 0 and sample.n_id.max() < 6
        assert sample.num_nodes == sample.n_id.numel()
        assert sample.x.tolist() == x[sample.n_id].tolist()
        assert sample.edge_index.min() >= 0
        assert sample.edge_index.max() < sample.num_nodes
        assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21
        assert sample.edge_id.numel() == sample.num_edges
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges

    loader = GraphSAINTRandomWalkSampler(data,
                                         batch_size=2,
                                         walk_length=1,
                                         num_steps=4,
                                         sample_coverage=10,
                                         log=False)

    assert len(loader) == 4
    for sample in loader:
        assert sample.num_nodes <= data.num_nodes
        assert sample.n_id.min() >= 0 and sample.n_id.max() < 6
        assert sample.num_nodes == sample.n_id.numel()
        assert sample.x.tolist() == x[sample.n_id].tolist()
        assert sample.edge_index.min() >= 0
        assert sample.edge_index.max() < sample.num_nodes
        assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21
        assert sample.edge_id.numel() == sample.num_edges
        assert sample.node_norm.numel() == sample.num_nodes
        assert sample.edge_norm.numel() == sample.num_edges