def __init__(self, in_dim, out_dim, norm='None'): super().__init__() self.norm = norm self.linear = torch.nn.Linear(in_dim, out_dim) assert self.norm in ['neighbornorm', 'None'] if self.norm == 'neighbornorm': self.normlayer = NeighborNorm(in_dim) else: self.normlayer = torch.nn.BatchNorm1d(out_dim) self.reset_parameters()
def __init__(self, in_dim, out_dim, norm='None'): super().__init__() self.norm = norm assert self.norm in ['neighbornorm', 'None'] self.mlp = Sequential(Linear(in_dim, out_dim), ReLU(), Linear(out_dim, out_dim), ReLU(), BatchNorm1d(out_dim)) if self.norm == 'neighbornorm': self.normlayer_l = NeighborNorm(in_dim) self.normlayer_r = SingleNorm(in_dim) self.reset_parameters()
class GCNLayer(torch.nn.Module): def __init__(self, in_dim, out_dim, norm='None'): super().__init__() self.norm = norm self.linear = torch.nn.Linear(in_dim, out_dim) assert self.norm in ['neighbornorm', 'None'] if self.norm == 'neighbornorm': self.normlayer = NeighborNorm(in_dim) else: self.normlayer = torch.nn.BatchNorm1d(out_dim) self.reset_parameters() def reset_parameters(self): glorot(self.linear.weight) zeros(self.linear.bias) if self.norm == 'neighbornorm': self.normlayer.reset_parameters() def forward(self, x, edge_index): edge_index, _ = add_remaining_self_loops(edge_index) row, col = edge_index deg = degree(row) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] if self.norm == 'neighbornorm': x_j = self.normlayer(x, edge_index) else: x_j = x[col] x_j = norm.view(-1, 1) * x_j out = scatter_add(src=x_j, index=row, dim=0, dim_size=x.size(0)) if self.norm == 'neighbornorm': out = F.relu(self.linear(out)) else: out = self.normlayer(F.relu(self.linear(out))) return out
class SAGELayer(torch.nn.Module): def __init__(self, in_dim, out_dim, norm='None'): super().__init__() self.norm = norm assert self.norm in ['neighbornorm', 'None'] self.lin_l = torch.nn.Linear(in_dim, out_dim) self.lin_r = torch.nn.Linear(in_dim, out_dim) if self.norm == 'neighbornorm': self.normlayer_l = NeighborNorm(in_dim) self.normlayer_r = SingleNorm(in_dim) else: self.batchnorm = torch.nn.BatchNorm1d(out_dim) self.reset_parameters() def reset_parameters(self): glorot(self.lin_l.weight) glorot(self.lin_r.weight) zeros(self.lin_r.bias) zeros(self.lin_r.bias) if self.norm == 'neighbornorm': self.normlayer_l.reset_parameters() self.normlayer_r.reset_parameters() def forward(self, x, edge_index): row, col = edge_index if self.norm == 'neighbornorm': x_j = self.normlayer_l(x, edge_index) else: x_j = x[col] out = scatter_add(src=x_j, index=row, dim=0, dim_size=x.size(0)) if self.norm == 'neighbornorm': out = self.lin_l(out) + self.lin_r(self.normlayer_r(x)) else: out = self.lin_l(out) + self.lin_r(x) if self.norm == 'neighbornorm': out = F.relu(out) else: out = self.batchnorm(F.relu(out)) return out
def __init__(self, in_dim, out_dim, heads=4, concat=True, dropout=False, norm='None', residual=True, activation=False): super().__init__() self.in_dim = in_dim self.out_dim = out_dim self.heads = heads self.concat = concat self.dropout = dropout self.norm = norm self.residual = residual self.activation = activation assert self.norm in [ 'batchnorm', 'layernorm', 'dropedge', 'pairnorm', 'nodenorm', 'neighbornorm', 'None' ] if self.norm == 'batchnorm': self.normlayer = torch.nn.BatchNorm1d(heads * out_dim) elif self.norm == 'layernorm': self.normlayer = torch.nn.LayerNorm(heads * out_dim) elif self.norm == 'neighbornorm': self.normlayer_l = NeighborNorm(heads * out_dim) self.normlayer_r = SingleNorm(heads * out_dim) elif self.norm == 'pairnorm': self.normlayer = PairNorm() elif self.norm == 'nodenorm': self.normlayer = NodeNorm() else: self.normlayer = None self.lin = torch.nn.Linear(in_dim, heads * out_dim, bias=False) if concat: self.lin_residual = torch.nn.Linear(in_dim, heads * out_dim) else: self.lin_residual = torch.nn.Linear(in_dim, out_dim) self.att_i = torch.nn.Parameter(torch.Tensor(1, heads, out_dim)) self.att_j = torch.nn.Parameter(torch.Tensor(1, heads, out_dim)) if concat: self.bias = torch.nn.Parameter(torch.Tensor(heads * out_dim)) else: self.bias = torch.nn.Parameter(torch.Tensor(out_dim)) self.reset_parameters()
class GINLayer(torch.nn.Module): def __init__(self, in_dim, out_dim, norm='None'): super().__init__() self.norm = norm assert self.norm in ['neighbornorm', 'None'] self.mlp = Sequential(Linear(in_dim, out_dim), ReLU(), Linear(out_dim, out_dim), ReLU(), BatchNorm1d(out_dim)) if self.norm == 'neighbornorm': self.normlayer_l = NeighborNorm(in_dim) self.normlayer_r = SingleNorm(in_dim) self.reset_parameters() def reset_parameters(self): reset(self.mlp) if self.norm == 'neighbornorm': self.normlayer_l.reset_parameters() self.normlayer_r.reset_parameters() def forward(self, x, edge_index): edge_index, _ = remove_self_loops(edge_index) row, col = edge_index if self.norm == 'neighbornorm': x_j = self.normlayer_l(x, edge_index) else: x_j = x[col] out = scatter_add(src=x_j, index=row, dim=0, dim_size=x.size(0)) if self.norm == 'neighbornorm': out = self.mlp(self.normlayer_r(x) + out) else: out = self.mlp(x + out) return out
def __init__(self, in_dim, out_dim, dropout=False, norm='None', residual=False, activation=True): super().__init__() self.dropout = dropout self.norm = norm self.residual = residual self.activation = activation self.lin_l = torch.nn.Linear(in_dim, out_dim) self.lin_r = torch.nn.Linear(in_dim, out_dim) assert self.norm in [ 'batchnorm', 'layernorm', 'dropedge', 'pairnorm', 'nodenorm', 'neighbornorm', 'None' ] if self.norm == 'batchnorm': self.normlayer = torch.nn.BatchNorm1d(out_dim) elif self.norm == 'layernorm': self.normlayer = torch.nn.LayerNorm(out_dim) elif self.norm == 'neighbornorm': self.normlayer_l = NeighborNorm(in_dim) self.normlayer_r = SingleNorm(in_dim) elif self.norm == 'pairnorm': self.normlayer = PairNorm() elif self.norm == 'nodenorm': self.normlayer = NodeNorm() else: self.normlayer = None if in_dim != out_dim: self.residual = False self.reset_parameters()