Example #1
0
 def forward(self, x, data, has_feature):
     if (has_feature):
         edge_index, edge_weight = data.edge_index, data.edge_weight
         edge_index, norm = GCNConv.norm(edge_index,x.shape[0],edge_weight,dtype=x.dtype)
         x1 = self.conv1(x, edge_index, edge_weight=norm)
         x2 = self.conv1(x1, edge_index, edge_weight=norm)
         return torch.cat([x, x1, x2], dim=1)
     else:
         edge_index, edge_weight = data.edge_index, data.edge_weight
         x1 = self.conv1(x, edge_index, edge_weight=edge_weight)
         x2 = self.conv1(x1, edge_index, edge_weight=edge_weight)
         return torch.cat([x, x1, x2], dim=1)
Example #2
0
    def forward(self, local_preds: torch.FloatTensor, edge_index):
        sz = local_preds.size(0)
        steps = torch.ones(sz).to(local_preds.device)
        sum_h = torch.zeros(sz).to(local_preds.device)
        continue_mask = torch.ones(sz, dtype=torch.bool).to(local_preds.device)
        x = torch.zeros_like(local_preds).to(local_preds.device)

        prop = self.dropout(local_preds)
        for i in range(0, self.niter):

            old_prop = prop
            continue_fmask = continue_mask.type('torch.FloatTensor').to(
                local_preds.device)

            drop_edge_index, _ = dropout_adj(edge_index,
                                             training=self.training)
            drop_edge_index, drop_norm = GCNConv.norm(drop_edge_index, sz)

            prop = self.propagate(drop_edge_index, x=prop, norm=drop_norm)

            h = torch.sigmoid(self.halt(prop)).t().squeeze()
            prob_mask = (((sum_h + h) < 0.99) & continue_mask).squeeze()
            prob_fmask = prob_mask.type('torch.FloatTensor').to(
                local_preds.device)

            steps = steps + prob_fmask
            sum_h = sum_h + prob_fmask * h

            final_iter = steps <= self.niter

            condition = prob_mask & final_iter
            p = torch.where(condition, sum_h, 1 - sum_h)

            to_update = self.dropout(continue_fmask)[:, None]
            x = x + (prop * p[:, None] + old_prop *
                     (1 - p)[:, None]) * to_update

            continue_mask = continue_mask & prob_mask

            if (~continue_mask).all():
                break

        x = x / steps[:, None]

        return x, (steps - 1), (1 - sum_h)
def main():
    parser = argparse.ArgumentParser(description='OGBL-Citation (GraphSAINT)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--num_workers', type=int, default=0)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=16 * 1024)
    parser.add_argument('--walk_length', type=int, default=3)
    parser.add_argument('--sample_coverage', type=int, default=400)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--num_steps', type=int, default=100)
    parser.add_argument('--eval_steps', type=int, default=10)
    parser.add_argument('--runs', type=int, default=10)
    args = parser.parse_args()
    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygLinkPropPredDataset(name='ogbl-citation')
    split_edge = dataset.get_edge_split()
    data = dataset[0]
    data.edge_index = to_undirected(data.edge_index, data.num_nodes)
    data.edge_index, data.edge_attr = GCNConv.norm(data.edge_index,
                                                   data.num_nodes)
    print(data.edge_index)
    print(data.edge_attr)

    loader = GraphSAINTRandomWalkSampler(data,
                                         batch_size=args.batch_size,
                                         walk_length=args.walk_length,
                                         num_steps=args.num_steps,
                                         sample_coverage=args.sample_coverage,
                                         save_dir=dataset.processed_dir,
                                         num_workers=args.num_workers)

    print(loader.adj)
    print(loader.edge_norm)
    print(loader.edge_norm.min(), loader.edge_norm.max())

    # We randomly pick some training samples that we want to evaluate on:
    torch.manual_seed(12345)
    idx = torch.randperm(split_edge['train']['source_node'].numel())[:86596]
    split_edge['eval_train'] = {
        'source_node': split_edge['train']['source_node'][idx],
        'target_node': split_edge['train']['target_node'][idx],
        'target_node_neg': split_edge['valid']['target_node_neg'],
    }

    model = GCN(data.x.size(-1), args.hidden_channels, args.hidden_channels,
                args.num_layers, args.dropout).to(device)
    predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbl-citation')
    logger = Logger(args.runs, args)

    for run in range(args.runs):
        model.reset_parameters()
        predictor.reset_parameters()
        optimizer = torch.optim.Adam(list(model.parameters()) +
                                     list(predictor.parameters()),
                                     lr=args.lr)
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, predictor, loader, optimizer, device)
            print(f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}')

            if epoch % args.eval_steps == 0:
                result = test(model,
                              predictor,
                              data,
                              split_edge,
                              evaluator,
                              batch_size=64 * 1024,
                              device=device)
                logger.add_result(run, result)

                if epoch % args.log_steps == 0:
                    train_mrr, valid_mrr, test_mrr = result
                    print(f'Run: {run + 1:02d}, '
                          f'Epoch: {epoch:02d}, '
                          f'Loss: {loss:.4f}, '
                          f'Train: {train_mrr:.4f}, '
                          f'Valid: {valid_mrr:.4f}, '
                          f'Test: {test_mrr:.4f}')

        logger.print_statistics(run)
    logger.print_statistics()