def __init__(self, module: MessagePassing, num_relations: int, num_bases: int): super().__init__() self.num_relations = num_relations self.num_bases = num_bases # We make use of a post-message computation hook to inject the # basis re-weighting for each individual edge type. # This currently requires us to set `conv.fuse = False`, which leads # to a materialization of messages. def hook(module, inputs, output): assert isinstance(module._edge_type, Tensor) if module._edge_type.size(0) != output.size(0): raise ValueError( f"Number of messages ({output.size(0)}) does not match " f"with the number of original edges " f"({module._edge_type.size(0)}). Does your message " f"passing layer create additional self-loops? Try to " f"remove them via 'add_self_loops=False'") weight = module.edge_type_weight.view(-1)[module._edge_type] weight = weight.view([-1] + [1] * (output.dim() - 1)) return weight * output params = list(module.parameters()) device = params[0].device if len(params) > 0 else 'cpu' self.convs = torch.nn.ModuleList() for _ in range(num_bases): conv = copy.deepcopy(module) conv.fuse = False # Disable `message_and_aggregate` functionality. # We learn a single scalar weight for each individual edge type, # which is used to weight the output message based on edge type: conv.edge_type_weight = Parameter( torch.Tensor(1, num_relations, device=device)) conv.register_message_forward_hook(hook) self.convs.append(conv) if self.num_bases > 1: self.reset_parameters()
import torch.nn.functional as F from torch_geometric.nn import GCNConv, GraphConv, SAGEConv, ChebConv, SGConv, RGCNConv import matplotlib.pyplot as plt import torch from torch_geometric.data import InMemoryDataset from torch_geometric.data import Data import pandas as pd from numpy import random from sklearn.preprocessing import StandardScaler import networkx as nx from sklearn.metrics import roc_curve, auc, confusion_matrix from matplotlib.colors import ListedColormap from torch_geometric.utils import dropout_adj from torch_geometric.nn.conv import MessagePassing kwarg = MessagePassing(flow='source_to_target', aggr='add') class Net_Conv(torch.nn.Module): def __init__(self, kwarg, p=0, input_size=1, output_size=2): super(Net_Conv, self).__init__() self.p = p self.drop1 = torch.nn.Dropout(p=p) self.conv1 = GCNConv(input_size, 5, kwarg, cached=True) #self.conv1 = SAGEConv(input_size,2,normalize=False) #self.conv1 = ChebConv(input_size,2,K=2) self.linear1 = torch.nn.Linear(int(input_size) * 2, 4, bias=True) self.linear2 = torch.nn.Linear(10, 2, bias=True) def forward(self, data, flag): x, edge_index, edge_orig = data.x, data.edge_index, data.edge_index_orig