class RGCNSAGPooling(torch.nn.Module): def __init__(self, in_channels, num_relations, ratio=0.5, min_score=None, multiplier=1, nonlinearity=torch.tanh, rgcn_func="FastRGCNConv", **kwargs): super(RGCNSAGPooling, self).__init__() self.in_channels = in_channels self.ratio = ratio self.gnn = FastRGCNConv( in_channels, 1, num_relations, ** kwargs) if rgcn_func == "FastRGCNConv" else RGCNConv( in_channels, 1, num_relations, **kwargs) self.min_score = min_score self.multiplier = multiplier self.nonlinearity = nonlinearity self.reset_parameters() def reset_parameters(self): self.gnn.reset_parameters() def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None): """""" if batch is None: batch = edge_index.new_zeros(x.size(0)) attn = x if attn is None else attn attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn score = self.gnn(attn, edge_index, edge_attr).view(-1) if self.min_score is None: score = self.nonlinearity(score) else: score = softmax(score, batch) perm = topk(score, self.ratio, batch, self.min_score) x = x[perm] * score[perm].view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch, perm, score[perm] def __repr__(self): return '{}({}, {}, {}={}, multiplier={})'.format( self.__class__.__name__, self.gnn.__class__.__name__, self.in_channels, 'ratio' if self.min_score is None else 'min_score', self.ratio if self.min_score is None else self.min_score, self.multiplier)
def __init__(self, in_channels, out_channels, num_relations): super().__init__() self.conv1 = FastRGCNConv(in_channels, 16, num_relations, num_bases=30) self.conv2 = FastRGCNConv(16, out_channels, num_relations, num_bases=30)
def __init__(self, in_channels, num_relations, ratio=0.5, min_score=None, multiplier=1, nonlinearity=torch.tanh, rgcn_func="FastRGCNConv", **kwargs): super(RGCNSAGPooling, self).__init__() self.in_channels = in_channels self.ratio = ratio self.gnn = FastRGCNConv(in_channels, 1, num_relations, **kwargs) if rgcn_func=="FastRGCNConv" else RGCNConv(in_channels, 1, num_relations, **kwargs) self.min_score = min_score self.multiplier = multiplier self.nonlinearity = nonlinearity self.reset_parameters()
def test_rgcn_conv_equality(conf): num_bases, num_blocks = conf x1 = torch.randn(4, 4) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1]]) edge_type = torch.tensor([0, 1, 1, 0, 0, 1]) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1], ]) edge_type = torch.tensor([0, 1, 1, 0, 0, 1, 2, 3, 3, 2, 2, 3]) torch.manual_seed(12345) conv1 = RGCNConv(4, 32, 4, num_bases, num_blocks) torch.manual_seed(12345) conv2 = FastRGCNConv(4, 32, 4, num_bases, num_blocks) out1 = conv1(x1, edge_index, edge_type) out2 = conv2(x1, edge_index, edge_type) assert torch.allclose(out1, out2, atol=1e-6) if num_blocks is None: out1 = conv1(None, edge_index, edge_type) out2 = conv2(None, edge_index, edge_type) assert torch.allclose(out1, out2, atol=1e-6)
def __init__(self, in_channels, number_hidden_layers, aggr, hidden_out_channel, out_channel, pool_layer, k=1, device=None): super(InceptionNet, self).__init__() self.pool_layer = pool_layer # 'add', 'max', 'mean' or 'sort' self.device = device self.k = k self.atom_encoder = AtomEncoder(emb_dim=in_channels) self.batchnorm = BatchNorm(in_channels=2 * hidden_out_channel) self.rgcn_list = torch.nn.ModuleList() self.graphconv_list = torch.nn.ModuleList() self.rgcn_list.append( FastRGCNConv(in_channels=in_channels, out_channels=hidden_out_channel, num_relations=NUM_RELATIONS)) self.graphconv_list.append( GraphConv(in_channels=in_channels, out_channels=hidden_out_channel)) if number_hidden_layers != 0: for i in range(number_hidden_layers): self.rgcn_list.append( FastRGCNConv(in_channels=2 * hidden_out_channel, out_channels=hidden_out_channel, num_relations=NUM_RELATIONS)) self.graphconv_list.append( GraphConv(in_channels=2 * hidden_out_channel, out_channels=hidden_out_channel)) self.rgcn_list.append( FastRGCNConv(in_channels=2 * hidden_out_channel, out_channels=out_channel, num_relations=NUM_RELATIONS)) self.graphconv_list.append( GraphConv(in_channels=2 * hidden_out_channel, out_channels=out_channel)) self.linear1 = nn.Linear(2 * k * out_channel, 16) self.linear2 = nn.Linear(16, 1)
def __init__(self, config): super(Encoder, self).__init__() # self.initializer = Initializer(config) layer = EncoderLayer(config) # self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([layer]) # self.conv = FastRGCNConv(config.hidden_size,config.hidden_size) self.conv3 = FastRGCNConv(config.hidden_size, config.hidden_size, 25, num_bases=128) self.conv2 = torch.nn.ModuleList() self.conv22 = torch.nn.ModuleList() for i in range(3): self.conv2.append(DNAConv(config.hidden_size, 32, 2, 0.1)) self.conv22.append(DNAConv(config.hidden_size, 32, 2, 0.1)) self.hidden_size = config.hidden_size # self.conv2 = DNAConv(config.hidden_size,32,16,0.1) # self.conv2 = AGNNConv(config.hidden_size,config.hidden_size) self.norm = nn.LayerNorm([512, config.hidden_size], 1e-05)
def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios=0.5, sum_res=True, act=F.relu, num_relations=4): super(GraphRUNet, self).__init__() assert depth >= 1 self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.depth = depth self.pool_ratios = repeat(pool_ratios, depth) self.act = act self.sum_res = sum_res channels = hidden_channels self.down_convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.down_convs.append( FastRGCNConv(in_channels, channels, num_relations=num_relations)) for i in range(depth): self.pools.append(TopKPooling(channels, self.pool_ratios[i])) self.down_convs.append(GCNConv(channels, channels)) in_channels = channels if sum_res else 2 * channels self.up_convs = torch.nn.ModuleList() for i in range(depth - 1): self.up_convs.append(GCNConv(channels, channels)) self.up_convs.append(GCNConv(channels, channels)) self.reset_parameters()