class GNNBlock(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.norm = LayerNorm(in_channels, elementwise_affine=True) self.conv = SAGEConv(in_channels, out_channels) def reset_parameters(self): self.norm.reset_parameters() self.conv.reset_parameters() def forward(self, x, edge_index, dropout_mask=None): x = self.norm(x).relu() if self.training and dropout_mask is not None: x = x * dropout_mask return self.conv(x, edge_index)
class RevGNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, num_groups=2): super().__init__() self.dropout = dropout self.lin1 = Linear(in_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) self.norm = LayerNorm(hidden_channels, elementwise_affine=True) assert hidden_channels % num_groups == 0 self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = GNNBlock( hidden_channels // num_groups, hidden_channels // num_groups, ) self.convs.append(GroupAddRev(conv, num_groups=num_groups)) def reset_parameters(self): self.lin1.reset_parameters() self.lin2.reset_parameters() self.norm.reset_parameters() for conv in self.convs: conv.reset_parameters() def forward(self, x, edge_index): x = self.lin1(x) # Generate a dropout mask which will be shared across GNN blocks: mask = None if self.training and self.dropout > 0: mask = torch.zeros_like(x).bernoulli_(1 - self.dropout) mask = mask.requires_grad_(False) mask = mask / (1 - self.dropout) for conv in self.convs: x = conv(x, edge_index, mask) x = self.norm(x).relu() x = F.dropout(x, p=self.dropout, training=self.training) return self.lin2(x)
class MAB(torch.nn.Module): r"""Multihead-Attention Block.""" def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, Conv: Optional[Type] = None, layer_norm: bool = False): super().__init__() self.dim_V = dim_V self.num_heads = num_heads self.layer_norm = layer_norm self.fc_q = Linear(dim_Q, dim_V) if Conv is None: self.layer_k = Linear(dim_K, dim_V) self.layer_v = Linear(dim_K, dim_V) else: self.layer_k = Conv(dim_K, dim_V) self.layer_v = Conv(dim_K, dim_V) if layer_norm: self.ln0 = LayerNorm(dim_V) self.ln1 = LayerNorm(dim_V) self.fc_o = Linear(dim_V, dim_V) def reset_parameters(self): self.fc_q.reset_parameters() self.layer_k.reset_parameters() self.layer_v.reset_parameters() if self.layer_norm: self.ln0.reset_parameters() self.ln1.reset_parameters() self.fc_o.reset_parameters() pass def forward( self, Q: Tensor, K: Tensor, graph: Optional[Tuple[Tensor, Tensor, Tensor]] = None, mask: Optional[Tensor] = None, ) -> Tensor: Q = self.fc_q(Q) if graph is not None: x, edge_index, batch = graph K, V = self.layer_k(x, edge_index), self.layer_v(x, edge_index) K, _ = to_dense_batch(K, batch) V, _ = to_dense_batch(V, batch) else: K, V = self.layer_k(K), self.layer_v(K) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q.split(dim_split, 2), dim=0) K_ = torch.cat(K.split(dim_split, 2), dim=0) V_ = torch.cat(V.split(dim_split, 2), dim=0) if mask is not None: mask = torch.cat([mask for _ in range(self.num_heads)], 0) attention_score = Q_.bmm(K_.transpose(1, 2)) attention_score = attention_score / math.sqrt(self.dim_V) A = torch.softmax(mask + attention_score, 1) else: A = torch.softmax( Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 1) out = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) if self.layer_norm: out = self.ln0(out) out = out + self.fc_o(out).relu() if self.layer_norm: out = self.ln1(out) return out