Esempio n. 1
0
    def __init__(self, in_dim, out_dim, norm='None'):
        super().__init__()
        self.norm = norm

        self.linear = torch.nn.Linear(in_dim, out_dim)

        assert self.norm in ['neighbornorm', 'None']
        if self.norm == 'neighbornorm':
            self.normlayer = NeighborNorm(in_dim)
        else:
            self.normlayer = torch.nn.BatchNorm1d(out_dim)

        self.reset_parameters()
Esempio n. 2
0
    def __init__(self, in_dim, out_dim, norm='None'):
        super().__init__()
        self.norm = norm
        assert self.norm in ['neighbornorm', 'None']

        self.mlp = Sequential(Linear(in_dim, out_dim), ReLU(),
                              Linear(out_dim, out_dim), ReLU(),
                              BatchNorm1d(out_dim))

        if self.norm == 'neighbornorm':
            self.normlayer_l = NeighborNorm(in_dim)
            self.normlayer_r = SingleNorm(in_dim)

        self.reset_parameters()
Esempio n. 3
0
class GCNLayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, norm='None'):
        super().__init__()
        self.norm = norm

        self.linear = torch.nn.Linear(in_dim, out_dim)

        assert self.norm in ['neighbornorm', 'None']
        if self.norm == 'neighbornorm':
            self.normlayer = NeighborNorm(in_dim)
        else:
            self.normlayer = torch.nn.BatchNorm1d(out_dim)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.linear.weight)
        zeros(self.linear.bias)
        if self.norm == 'neighbornorm':
            self.normlayer.reset_parameters()

    def forward(self, x, edge_index):
        edge_index, _ = add_remaining_self_loops(edge_index)

        row, col = edge_index
        deg = degree(row)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        if self.norm == 'neighbornorm':
            x_j = self.normlayer(x, edge_index)
        else:
            x_j = x[col]

        x_j = norm.view(-1, 1) * x_j
        out = scatter_add(src=x_j, index=row, dim=0, dim_size=x.size(0))

        if self.norm == 'neighbornorm':
            out = F.relu(self.linear(out))
        else:
            out = self.normlayer(F.relu(self.linear(out)))

        return out
Esempio n. 4
0
class SAGELayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, norm='None'):
        super().__init__()
        self.norm = norm
        assert self.norm in ['neighbornorm', 'None']
        self.lin_l = torch.nn.Linear(in_dim, out_dim)
        self.lin_r = torch.nn.Linear(in_dim, out_dim)
        if self.norm == 'neighbornorm':
            self.normlayer_l = NeighborNorm(in_dim)
            self.normlayer_r = SingleNorm(in_dim)
        else:
            self.batchnorm = torch.nn.BatchNorm1d(out_dim)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.lin_l.weight)
        glorot(self.lin_r.weight)
        zeros(self.lin_r.bias)
        zeros(self.lin_r.bias)
        if self.norm == 'neighbornorm':
            self.normlayer_l.reset_parameters()
            self.normlayer_r.reset_parameters()

    def forward(self, x, edge_index):
        row, col = edge_index
        if self.norm == 'neighbornorm':
            x_j = self.normlayer_l(x, edge_index)
        else:
            x_j = x[col]

        out = scatter_add(src=x_j, index=row, dim=0, dim_size=x.size(0))

        if self.norm == 'neighbornorm':
            out = self.lin_l(out) + self.lin_r(self.normlayer_r(x))
        else:
            out = self.lin_l(out) + self.lin_r(x)

        if self.norm == 'neighbornorm':
            out = F.relu(out)
        else:
            out = self.batchnorm(F.relu(out))
        return out
Esempio n. 5
0
    def __init__(self,
                 in_dim,
                 out_dim,
                 heads=4,
                 concat=True,
                 dropout=False,
                 norm='None',
                 residual=True,
                 activation=False):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.heads = heads
        self.concat = concat

        self.dropout = dropout
        self.norm = norm
        self.residual = residual
        self.activation = activation

        assert self.norm in [
            'batchnorm', 'layernorm', 'dropedge', 'pairnorm', 'nodenorm',
            'neighbornorm', 'None'
        ]

        if self.norm == 'batchnorm':
            self.normlayer = torch.nn.BatchNorm1d(heads * out_dim)
        elif self.norm == 'layernorm':
            self.normlayer = torch.nn.LayerNorm(heads * out_dim)
        elif self.norm == 'neighbornorm':
            self.normlayer_l = NeighborNorm(heads * out_dim)
            self.normlayer_r = SingleNorm(heads * out_dim)
        elif self.norm == 'pairnorm':
            self.normlayer = PairNorm()
        elif self.norm == 'nodenorm':
            self.normlayer = NodeNorm()
        else:
            self.normlayer = None

        self.lin = torch.nn.Linear(in_dim, heads * out_dim, bias=False)
        if concat:
            self.lin_residual = torch.nn.Linear(in_dim, heads * out_dim)
        else:
            self.lin_residual = torch.nn.Linear(in_dim, out_dim)

        self.att_i = torch.nn.Parameter(torch.Tensor(1, heads, out_dim))
        self.att_j = torch.nn.Parameter(torch.Tensor(1, heads, out_dim))

        if concat:
            self.bias = torch.nn.Parameter(torch.Tensor(heads * out_dim))
        else:
            self.bias = torch.nn.Parameter(torch.Tensor(out_dim))

        self.reset_parameters()
Esempio n. 6
0
class GINLayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, norm='None'):
        super().__init__()
        self.norm = norm
        assert self.norm in ['neighbornorm', 'None']

        self.mlp = Sequential(Linear(in_dim, out_dim), ReLU(),
                              Linear(out_dim, out_dim), ReLU(),
                              BatchNorm1d(out_dim))

        if self.norm == 'neighbornorm':
            self.normlayer_l = NeighborNorm(in_dim)
            self.normlayer_r = SingleNorm(in_dim)

        self.reset_parameters()

    def reset_parameters(self):
        reset(self.mlp)
        if self.norm == 'neighbornorm':
            self.normlayer_l.reset_parameters()
            self.normlayer_r.reset_parameters()

    def forward(self, x, edge_index):
        edge_index, _ = remove_self_loops(edge_index)
        row, col = edge_index

        if self.norm == 'neighbornorm':
            x_j = self.normlayer_l(x, edge_index)
        else:
            x_j = x[col]

        out = scatter_add(src=x_j, index=row, dim=0, dim_size=x.size(0))

        if self.norm == 'neighbornorm':
            out = self.mlp(self.normlayer_r(x) + out)
        else:
            out = self.mlp(x + out)
        return out
Esempio n. 7
0
    def __init__(self,
                 in_dim,
                 out_dim,
                 dropout=False,
                 norm='None',
                 residual=False,
                 activation=True):
        super().__init__()
        self.dropout = dropout
        self.norm = norm
        self.residual = residual
        self.activation = activation

        self.lin_l = torch.nn.Linear(in_dim, out_dim)
        self.lin_r = torch.nn.Linear(in_dim, out_dim)

        assert self.norm in [
            'batchnorm', 'layernorm', 'dropedge', 'pairnorm', 'nodenorm',
            'neighbornorm', 'None'
        ]

        if self.norm == 'batchnorm':
            self.normlayer = torch.nn.BatchNorm1d(out_dim)
        elif self.norm == 'layernorm':
            self.normlayer = torch.nn.LayerNorm(out_dim)
        elif self.norm == 'neighbornorm':
            self.normlayer_l = NeighborNorm(in_dim)
            self.normlayer_r = SingleNorm(in_dim)
        elif self.norm == 'pairnorm':
            self.normlayer = PairNorm()
        elif self.norm == 'nodenorm':
            self.normlayer = NodeNorm()
        else:
            self.normlayer = None

        if in_dim != out_dim:
            self.residual = False

        self.reset_parameters()