コード例 #1
0
    def __init__(self, module: MessagePassing, num_relations: int,
                 num_bases: int):
        super().__init__()

        self.num_relations = num_relations
        self.num_bases = num_bases

        # We make use of a post-message computation hook to inject the
        # basis re-weighting for each individual edge type.
        # This currently requires us to set `conv.fuse = False`, which leads
        # to a materialization of messages.
        def hook(module, inputs, output):
            assert isinstance(module._edge_type, Tensor)
            if module._edge_type.size(0) != output.size(0):
                raise ValueError(
                    f"Number of messages ({output.size(0)}) does not match "
                    f"with the number of original edges "
                    f"({module._edge_type.size(0)}). Does your message "
                    f"passing layer create additional self-loops? Try to "
                    f"remove them via 'add_self_loops=False'")
            weight = module.edge_type_weight.view(-1)[module._edge_type]
            weight = weight.view([-1] + [1] * (output.dim() - 1))
            return weight * output

        params = list(module.parameters())
        device = params[0].device if len(params) > 0 else 'cpu'

        self.convs = torch.nn.ModuleList()
        for _ in range(num_bases):
            conv = copy.deepcopy(module)
            conv.fuse = False  # Disable `message_and_aggregate` functionality.
            # We learn a single scalar weight for each individual edge type,
            # which is used to weight the output message based on edge type:
            conv.edge_type_weight = Parameter(
                torch.Tensor(1, num_relations, device=device))
            conv.register_message_forward_hook(hook)
            self.convs.append(conv)

        if self.num_bases > 1:
            self.reset_parameters()
コード例 #2
0
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GraphConv, SAGEConv, ChebConv, SGConv, RGCNConv
import matplotlib.pyplot as plt
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
import pandas as pd
from numpy import random
from sklearn.preprocessing import StandardScaler
import networkx as nx
from sklearn.metrics import roc_curve, auc, confusion_matrix
from matplotlib.colors import ListedColormap
from torch_geometric.utils import dropout_adj
from torch_geometric.nn.conv import MessagePassing

kwarg = MessagePassing(flow='source_to_target', aggr='add')


class Net_Conv(torch.nn.Module):
    def __init__(self, kwarg, p=0, input_size=1, output_size=2):
        super(Net_Conv, self).__init__()
        self.p = p
        self.drop1 = torch.nn.Dropout(p=p)
        self.conv1 = GCNConv(input_size, 5, kwarg, cached=True)
        #self.conv1      = SAGEConv(input_size,2,normalize=False)
        #self.conv1      = ChebConv(input_size,2,K=2)
        self.linear1 = torch.nn.Linear(int(input_size) * 2, 4, bias=True)
        self.linear2 = torch.nn.Linear(10, 2, bias=True)

    def forward(self, data, flag):
        x, edge_index, edge_orig = data.x, data.edge_index, data.edge_index_orig